rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 314            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 320            .add_request_handler(
 321                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 322            )
 323            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 328            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 333            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 339            .add_request_handler(
 340                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 341            )
 342            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 343            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 345            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 346            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 348            .add_message_handler(create_buffer_for_peer)
 349            .add_request_handler(update_buffer)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 356            .add_request_handler(get_users)
 357            .add_request_handler(fuzzy_search_users)
 358            .add_request_handler(request_contact)
 359            .add_request_handler(remove_contact)
 360            .add_request_handler(respond_to_contact_request)
 361            .add_message_handler(subscribe_to_channels)
 362            .add_request_handler(create_channel)
 363            .add_request_handler(delete_channel)
 364            .add_request_handler(invite_channel_member)
 365            .add_request_handler(remove_channel_member)
 366            .add_request_handler(set_channel_member_role)
 367            .add_request_handler(set_channel_visibility)
 368            .add_request_handler(rename_channel)
 369            .add_request_handler(join_channel_buffer)
 370            .add_request_handler(leave_channel_buffer)
 371            .add_message_handler(update_channel_buffer)
 372            .add_request_handler(rejoin_channel_buffers)
 373            .add_request_handler(get_channel_members)
 374            .add_request_handler(respond_to_channel_invite)
 375            .add_request_handler(join_channel)
 376            .add_request_handler(join_channel_chat)
 377            .add_message_handler(leave_channel_chat)
 378            .add_request_handler(send_channel_message)
 379            .add_request_handler(remove_channel_message)
 380            .add_request_handler(update_channel_message)
 381            .add_request_handler(get_channel_messages)
 382            .add_request_handler(get_channel_messages_by_id)
 383            .add_request_handler(get_notifications)
 384            .add_request_handler(mark_notification_as_read)
 385            .add_request_handler(move_channel)
 386            .add_request_handler(follow)
 387            .add_message_handler(unfollow)
 388            .add_message_handler(update_followers)
 389            .add_request_handler(get_private_user_info)
 390            .add_request_handler(get_llm_api_token)
 391            .add_request_handler(accept_terms_of_service)
 392            .add_message_handler(acknowledge_channel_message)
 393            .add_message_handler(acknowledge_buffer_version)
 394            .add_request_handler(get_supermaven_api_key)
 395            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 396            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 397            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 398            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 399            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 401            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 402            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 404            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 405            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 406            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 407            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 408            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 409            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 410            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 411            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 412            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 413            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 414            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 415            .add_message_handler(update_context)
 416            .add_request_handler({
 417                let app_state = app_state.clone();
 418                move |request, response, session| {
 419                    let app_state = app_state.clone();
 420                    async move {
 421                        count_language_model_tokens(request, response, session, &app_state.config)
 422                            .await
 423                    }
 424                }
 425            })
 426            .add_request_handler(get_cached_embeddings)
 427            .add_request_handler({
 428                let app_state = app_state.clone();
 429                move |request, response, session| {
 430                    compute_embeddings(
 431                        request,
 432                        response,
 433                        session,
 434                        app_state.config.openai_api_key.clone(),
 435                    )
 436                }
 437            });
 438
 439        Arc::new(server)
 440    }
 441
 442    pub async fn start(&self) -> Result<()> {
 443        let server_id = *self.id.lock();
 444        let app_state = self.app_state.clone();
 445        let peer = self.peer.clone();
 446        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 447        let pool = self.connection_pool.clone();
 448        let livekit_client = self.app_state.livekit_client.clone();
 449
 450        let span = info_span!("start server");
 451        self.app_state.executor.spawn_detached(
 452            async move {
 453                tracing::info!("waiting for cleanup timeout");
 454                timeout.await;
 455                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 456                if let Some((room_ids, channel_ids)) = app_state
 457                    .db
 458                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 459                    .await
 460                    .trace_err()
 461                {
 462                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 463                    tracing::info!(
 464                        stale_channel_buffer_count = channel_ids.len(),
 465                        "retrieved stale channel buffers"
 466                    );
 467
 468                    for channel_id in channel_ids {
 469                        if let Some(refreshed_channel_buffer) = app_state
 470                            .db
 471                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 472                            .await
 473                            .trace_err()
 474                        {
 475                            for connection_id in refreshed_channel_buffer.connection_ids {
 476                                peer.send(
 477                                    connection_id,
 478                                    proto::UpdateChannelBufferCollaborators {
 479                                        channel_id: channel_id.to_proto(),
 480                                        collaborators: refreshed_channel_buffer
 481                                            .collaborators
 482                                            .clone(),
 483                                    },
 484                                )
 485                                .trace_err();
 486                            }
 487                        }
 488                    }
 489
 490                    for room_id in room_ids {
 491                        let mut contacts_to_update = HashSet::default();
 492                        let mut canceled_calls_to_user_ids = Vec::new();
 493                        let mut livekit_room = String::new();
 494                        let mut delete_livekit_room = false;
 495
 496                        if let Some(mut refreshed_room) = app_state
 497                            .db
 498                            .clear_stale_room_participants(room_id, server_id)
 499                            .await
 500                            .trace_err()
 501                        {
 502                            tracing::info!(
 503                                room_id = room_id.0,
 504                                new_participant_count = refreshed_room.room.participants.len(),
 505                                "refreshed room"
 506                            );
 507                            room_updated(&refreshed_room.room, &peer);
 508                            if let Some(channel) = refreshed_room.channel.as_ref() {
 509                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 510                            }
 511                            contacts_to_update
 512                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 513                            contacts_to_update
 514                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 515                            canceled_calls_to_user_ids =
 516                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 517                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 518                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 519                        }
 520
 521                        {
 522                            let pool = pool.lock();
 523                            for canceled_user_id in canceled_calls_to_user_ids {
 524                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 525                                    peer.send(
 526                                        connection_id,
 527                                        proto::CallCanceled {
 528                                            room_id: room_id.to_proto(),
 529                                        },
 530                                    )
 531                                    .trace_err();
 532                                }
 533                            }
 534                        }
 535
 536                        for user_id in contacts_to_update {
 537                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 538                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 539                            if let Some((busy, contacts)) = busy.zip(contacts) {
 540                                let pool = pool.lock();
 541                                let updated_contact = contact_for_user(user_id, busy, &pool);
 542                                for contact in contacts {
 543                                    if let db::Contact::Accepted {
 544                                        user_id: contact_user_id,
 545                                        ..
 546                                    } = contact
 547                                    {
 548                                        for contact_conn_id in
 549                                            pool.user_connection_ids(contact_user_id)
 550                                        {
 551                                            peer.send(
 552                                                contact_conn_id,
 553                                                proto::UpdateContacts {
 554                                                    contacts: vec![updated_contact.clone()],
 555                                                    remove_contacts: Default::default(),
 556                                                    incoming_requests: Default::default(),
 557                                                    remove_incoming_requests: Default::default(),
 558                                                    outgoing_requests: Default::default(),
 559                                                    remove_outgoing_requests: Default::default(),
 560                                                },
 561                                            )
 562                                            .trace_err();
 563                                        }
 564                                    }
 565                                }
 566                            }
 567                        }
 568
 569                        if let Some(live_kit) = livekit_client.as_ref() {
 570                            if delete_livekit_room {
 571                                live_kit.delete_room(livekit_room).await.trace_err();
 572                            }
 573                        }
 574                    }
 575                }
 576
 577                app_state
 578                    .db
 579                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 580                    .await
 581                    .trace_err();
 582            }
 583            .instrument(span),
 584        );
 585        Ok(())
 586    }
 587
 588    pub fn teardown(&self) {
 589        self.peer.teardown();
 590        self.connection_pool.lock().reset();
 591        let _ = self.teardown.send(true);
 592    }
 593
 594    #[cfg(test)]
 595    pub fn reset(&self, id: ServerId) {
 596        self.teardown();
 597        *self.id.lock() = id;
 598        self.peer.reset(id.0 as u32);
 599        let _ = self.teardown.send(false);
 600    }
 601
 602    #[cfg(test)]
 603    pub fn id(&self) -> ServerId {
 604        *self.id.lock()
 605    }
 606
 607    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 608    where
 609        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 610        Fut: 'static + Send + Future<Output = Result<()>>,
 611        M: EnvelopedMessage,
 612    {
 613        let prev_handler = self.handlers.insert(
 614            TypeId::of::<M>(),
 615            Box::new(move |envelope, session| {
 616                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 617                let received_at = envelope.received_at;
 618                tracing::info!("message received");
 619                let start_time = Instant::now();
 620                let future = (handler)(*envelope, session);
 621                async move {
 622                    let result = future.await;
 623                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 624                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 625                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 626                    let payload_type = M::NAME;
 627
 628                    match result {
 629                        Err(error) => {
 630                            tracing::error!(
 631                                ?error,
 632                                total_duration_ms,
 633                                processing_duration_ms,
 634                                queue_duration_ms,
 635                                payload_type,
 636                                "error handling message"
 637                            )
 638                        }
 639                        Ok(()) => tracing::info!(
 640                            total_duration_ms,
 641                            processing_duration_ms,
 642                            queue_duration_ms,
 643                            "finished handling message"
 644                        ),
 645                    }
 646                }
 647                .boxed()
 648            }),
 649        );
 650        if prev_handler.is_some() {
 651            panic!("registered a handler for the same message twice");
 652        }
 653        self
 654    }
 655
 656    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 657    where
 658        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 659        Fut: 'static + Send + Future<Output = Result<()>>,
 660        M: EnvelopedMessage,
 661    {
 662        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 663        self
 664    }
 665
 666    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 667    where
 668        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 669        Fut: Send + Future<Output = Result<()>>,
 670        M: RequestMessage,
 671    {
 672        let handler = Arc::new(handler);
 673        self.add_handler(move |envelope, session| {
 674            let receipt = envelope.receipt();
 675            let handler = handler.clone();
 676            async move {
 677                let peer = session.peer.clone();
 678                let responded = Arc::new(AtomicBool::default());
 679                let response = Response {
 680                    peer: peer.clone(),
 681                    responded: responded.clone(),
 682                    receipt,
 683                };
 684                match (handler)(envelope.payload, response, session).await {
 685                    Ok(()) => {
 686                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 687                            Ok(())
 688                        } else {
 689                            Err(anyhow!("handler did not send a response"))?
 690                        }
 691                    }
 692                    Err(error) => {
 693                        let proto_err = match &error {
 694                            Error::Internal(err) => err.to_proto(),
 695                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 696                        };
 697                        peer.respond_with_error(receipt, proto_err)?;
 698                        Err(error)
 699                    }
 700                }
 701            }
 702        })
 703    }
 704
 705    pub fn handle_connection(
 706        self: &Arc<Self>,
 707        connection: Connection,
 708        address: String,
 709        principal: Principal,
 710        zed_version: ZedVersion,
 711        geoip_country_code: Option<String>,
 712        system_id: Option<String>,
 713        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 714        executor: Executor,
 715    ) -> impl Future<Output = ()> {
 716        let this = self.clone();
 717        let span = info_span!("handle connection", %address,
 718            connection_id=field::Empty,
 719            user_id=field::Empty,
 720            login=field::Empty,
 721            impersonator=field::Empty,
 722            geoip_country_code=field::Empty
 723        );
 724        principal.update_span(&span);
 725        if let Some(country_code) = geoip_country_code.as_ref() {
 726            span.record("geoip_country_code", country_code);
 727        }
 728
 729        let mut teardown = self.teardown.subscribe();
 730        async move {
 731            if *teardown.borrow() {
 732                tracing::error!("server is tearing down");
 733                return
 734            }
 735            let (connection_id, handle_io, mut incoming_rx) = this
 736                .peer
 737                .add_connection(connection, {
 738                    let executor = executor.clone();
 739                    move |duration| executor.sleep(duration)
 740                });
 741            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 742
 743            tracing::info!("connection opened");
 744
 745            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 746            let http_client = match ReqwestClient::user_agent(&user_agent) {
 747                Ok(http_client) => Arc::new(http_client),
 748                Err(error) => {
 749                    tracing::error!(?error, "failed to create HTTP client");
 750                    return;
 751                }
 752            };
 753
 754            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 755                    supermaven_admin_api_key.to_string(),
 756                    http_client.clone(),
 757                )));
 758
 759            let session = Session {
 760                principal: principal.clone(),
 761                connection_id,
 762                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 763                peer: this.peer.clone(),
 764                connection_pool: this.connection_pool.clone(),
 765                app_state: this.app_state.clone(),
 766                http_client,
 767                geoip_country_code,
 768                system_id,
 769                _executor: executor.clone(),
 770                supermaven_client,
 771            };
 772
 773            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 774                tracing::error!(?error, "failed to send initial client update");
 775                return;
 776            }
 777
 778            let handle_io = handle_io.fuse();
 779            futures::pin_mut!(handle_io);
 780
 781            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 782            // This prevents deadlocks when e.g., client A performs a request to client B and
 783            // client B performs a request to client A. If both clients stop processing further
 784            // messages until their respective request completes, they won't have a chance to
 785            // respond to the other client's request and cause a deadlock.
 786            //
 787            // This arrangement ensures we will attempt to process earlier messages first, but fall
 788            // back to processing messages arrived later in the spirit of making progress.
 789            let mut foreground_message_handlers = FuturesUnordered::new();
 790            let concurrent_handlers = Arc::new(Semaphore::new(256));
 791            loop {
 792                let next_message = async {
 793                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 794                    let message = incoming_rx.next().await;
 795                    (permit, message)
 796                }.fuse();
 797                futures::pin_mut!(next_message);
 798                futures::select_biased! {
 799                    _ = teardown.changed().fuse() => return,
 800                    result = handle_io => {
 801                        if let Err(error) = result {
 802                            tracing::error!(?error, "error handling I/O");
 803                        }
 804                        break;
 805                    }
 806                    _ = foreground_message_handlers.next() => {}
 807                    next_message = next_message => {
 808                        let (permit, message) = next_message;
 809                        if let Some(message) = message {
 810                            let type_name = message.payload_type_name();
 811                            // note: we copy all the fields from the parent span so we can query them in the logs.
 812                            // (https://github.com/tokio-rs/tracing/issues/2670).
 813                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 814                                user_id=field::Empty,
 815                                login=field::Empty,
 816                                impersonator=field::Empty,
 817                            );
 818                            principal.update_span(&span);
 819                            let span_enter = span.enter();
 820                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 821                                let is_background = message.is_background();
 822                                let handle_message = (handler)(message, session.clone());
 823                                drop(span_enter);
 824
 825                                let handle_message = async move {
 826                                    handle_message.await;
 827                                    drop(permit);
 828                                }.instrument(span);
 829                                if is_background {
 830                                    executor.spawn_detached(handle_message);
 831                                } else {
 832                                    foreground_message_handlers.push(handle_message);
 833                                }
 834                            } else {
 835                                tracing::error!("no message handler");
 836                            }
 837                        } else {
 838                            tracing::info!("connection closed");
 839                            break;
 840                        }
 841                    }
 842                }
 843            }
 844
 845            drop(foreground_message_handlers);
 846            tracing::info!("signing out");
 847            if let Err(error) = connection_lost(session, teardown, executor).await {
 848                tracing::error!(?error, "error signing out");
 849            }
 850
 851        }.instrument(span)
 852    }
 853
 854    async fn send_initial_client_update(
 855        &self,
 856        connection_id: ConnectionId,
 857        principal: &Principal,
 858        zed_version: ZedVersion,
 859        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 860        session: &Session,
 861    ) -> Result<()> {
 862        self.peer.send(
 863            connection_id,
 864            proto::Hello {
 865                peer_id: Some(connection_id.into()),
 866            },
 867        )?;
 868        tracing::info!("sent hello message");
 869        if let Some(send_connection_id) = send_connection_id.take() {
 870            let _ = send_connection_id.send(connection_id);
 871        }
 872
 873        match principal {
 874            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 875                if !user.connected_once {
 876                    self.peer.send(connection_id, proto::ShowContacts {})?;
 877                    self.app_state
 878                        .db
 879                        .set_user_connected_once(user.id, true)
 880                        .await?;
 881                }
 882
 883                update_user_plan(user.id, session).await?;
 884
 885                let contacts = self.app_state.db.get_contacts(user.id).await?;
 886
 887                {
 888                    let mut pool = self.connection_pool.lock();
 889                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 890                    self.peer.send(
 891                        connection_id,
 892                        build_initial_contacts_update(contacts, &pool),
 893                    )?;
 894                }
 895
 896                if should_auto_subscribe_to_channels(zed_version) {
 897                    subscribe_user_to_channels(user.id, session).await?;
 898                }
 899
 900                if let Some(incoming_call) =
 901                    self.app_state.db.incoming_call_for_user(user.id).await?
 902                {
 903                    self.peer.send(connection_id, incoming_call)?;
 904                }
 905
 906                update_user_contacts(user.id, session).await?;
 907            }
 908        }
 909
 910        Ok(())
 911    }
 912
 913    pub async fn invite_code_redeemed(
 914        self: &Arc<Self>,
 915        inviter_id: UserId,
 916        invitee_id: UserId,
 917    ) -> Result<()> {
 918        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 919            if let Some(code) = &user.invite_code {
 920                let pool = self.connection_pool.lock();
 921                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 922                for connection_id in pool.user_connection_ids(inviter_id) {
 923                    self.peer.send(
 924                        connection_id,
 925                        proto::UpdateContacts {
 926                            contacts: vec![invitee_contact.clone()],
 927                            ..Default::default()
 928                        },
 929                    )?;
 930                    self.peer.send(
 931                        connection_id,
 932                        proto::UpdateInviteInfo {
 933                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 934                            count: user.invite_count as u32,
 935                        },
 936                    )?;
 937                }
 938            }
 939        }
 940        Ok(())
 941    }
 942
 943    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 944        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 945            if let Some(invite_code) = &user.invite_code {
 946                let pool = self.connection_pool.lock();
 947                for connection_id in pool.user_connection_ids(user_id) {
 948                    self.peer.send(
 949                        connection_id,
 950                        proto::UpdateInviteInfo {
 951                            url: format!(
 952                                "{}{}",
 953                                self.app_state.config.invite_link_prefix, invite_code
 954                            ),
 955                            count: user.invite_count as u32,
 956                        },
 957                    )?;
 958                }
 959            }
 960        }
 961        Ok(())
 962    }
 963
 964    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 965        let pool = self.connection_pool.lock();
 966        for connection_id in pool.user_connection_ids(user_id) {
 967            self.peer
 968                .send(connection_id, proto::RefreshLlmToken {})
 969                .trace_err();
 970        }
 971    }
 972
 973    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 974        ServerSnapshot {
 975            connection_pool: ConnectionPoolGuard {
 976                guard: self.connection_pool.lock(),
 977                _not_send: PhantomData,
 978            },
 979            peer: &self.peer,
 980        }
 981    }
 982}
 983
 984impl Deref for ConnectionPoolGuard<'_> {
 985    type Target = ConnectionPool;
 986
 987    fn deref(&self) -> &Self::Target {
 988        &self.guard
 989    }
 990}
 991
 992impl DerefMut for ConnectionPoolGuard<'_> {
 993    fn deref_mut(&mut self) -> &mut Self::Target {
 994        &mut self.guard
 995    }
 996}
 997
 998impl Drop for ConnectionPoolGuard<'_> {
 999    fn drop(&mut self) {
1000        #[cfg(test)]
1001        self.check_invariants();
1002    }
1003}
1004
1005fn broadcast<F>(
1006    sender_id: Option<ConnectionId>,
1007    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1008    mut f: F,
1009) where
1010    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1011{
1012    for receiver_id in receiver_ids {
1013        if Some(receiver_id) != sender_id {
1014            if let Err(error) = f(receiver_id) {
1015                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1016            }
1017        }
1018    }
1019}
1020
1021pub struct ProtocolVersion(u32);
1022
1023impl Header for ProtocolVersion {
1024    fn name() -> &'static HeaderName {
1025        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1026        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1027    }
1028
1029    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1030    where
1031        Self: Sized,
1032        I: Iterator<Item = &'i axum::http::HeaderValue>,
1033    {
1034        let version = values
1035            .next()
1036            .ok_or_else(axum::headers::Error::invalid)?
1037            .to_str()
1038            .map_err(|_| axum::headers::Error::invalid())?
1039            .parse()
1040            .map_err(|_| axum::headers::Error::invalid())?;
1041        Ok(Self(version))
1042    }
1043
1044    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1045        values.extend([self.0.to_string().parse().unwrap()]);
1046    }
1047}
1048
1049pub struct AppVersionHeader(SemanticVersion);
1050impl Header for AppVersionHeader {
1051    fn name() -> &'static HeaderName {
1052        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1053        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1054    }
1055
1056    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1057    where
1058        Self: Sized,
1059        I: Iterator<Item = &'i axum::http::HeaderValue>,
1060    {
1061        let version = values
1062            .next()
1063            .ok_or_else(axum::headers::Error::invalid)?
1064            .to_str()
1065            .map_err(|_| axum::headers::Error::invalid())?
1066            .parse()
1067            .map_err(|_| axum::headers::Error::invalid())?;
1068        Ok(Self(version))
1069    }
1070
1071    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1072        values.extend([self.0.to_string().parse().unwrap()]);
1073    }
1074}
1075
1076pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1077    Router::new()
1078        .route("/rpc", get(handle_websocket_request))
1079        .layer(
1080            ServiceBuilder::new()
1081                .layer(Extension(server.app_state.clone()))
1082                .layer(middleware::from_fn(auth::validate_header)),
1083        )
1084        .route("/metrics", get(handle_metrics))
1085        .layer(Extension(server))
1086}
1087
1088pub async fn handle_websocket_request(
1089    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1090    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1091    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1092    Extension(server): Extension<Arc<Server>>,
1093    Extension(principal): Extension<Principal>,
1094    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1095    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1096    ws: WebSocketUpgrade,
1097) -> axum::response::Response {
1098    if protocol_version != rpc::PROTOCOL_VERSION {
1099        return (
1100            StatusCode::UPGRADE_REQUIRED,
1101            "client must be upgraded".to_string(),
1102        )
1103            .into_response();
1104    }
1105
1106    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1107        return (
1108            StatusCode::UPGRADE_REQUIRED,
1109            "no version header found".to_string(),
1110        )
1111            .into_response();
1112    };
1113
1114    if !version.can_collaborate() {
1115        return (
1116            StatusCode::UPGRADE_REQUIRED,
1117            "client must be upgraded".to_string(),
1118        )
1119            .into_response();
1120    }
1121
1122    let socket_address = socket_address.to_string();
1123    ws.on_upgrade(move |socket| {
1124        let socket = socket
1125            .map_ok(to_tungstenite_message)
1126            .err_into()
1127            .with(|message| async move { to_axum_message(message) });
1128        let connection = Connection::new(Box::pin(socket));
1129        async move {
1130            server
1131                .handle_connection(
1132                    connection,
1133                    socket_address,
1134                    principal,
1135                    version,
1136                    country_code_header.map(|header| header.to_string()),
1137                    system_id_header.map(|header| header.to_string()),
1138                    None,
1139                    Executor::Production,
1140                )
1141                .await;
1142        }
1143    })
1144}
1145
1146pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1147    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1148    let connections_metric = CONNECTIONS_METRIC
1149        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1150
1151    let connections = server
1152        .connection_pool
1153        .lock()
1154        .connections()
1155        .filter(|connection| !connection.admin)
1156        .count();
1157    connections_metric.set(connections as _);
1158
1159    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1160    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1161        register_int_gauge!(
1162            "shared_projects",
1163            "number of open projects with one or more guests"
1164        )
1165        .unwrap()
1166    });
1167
1168    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1169    shared_projects_metric.set(shared_projects as _);
1170
1171    let encoder = prometheus::TextEncoder::new();
1172    let metric_families = prometheus::gather();
1173    let encoded_metrics = encoder
1174        .encode_to_string(&metric_families)
1175        .map_err(|err| anyhow!("{}", err))?;
1176    Ok(encoded_metrics)
1177}
1178
1179#[instrument(err, skip(executor))]
1180async fn connection_lost(
1181    session: Session,
1182    mut teardown: watch::Receiver<bool>,
1183    executor: Executor,
1184) -> Result<()> {
1185    session.peer.disconnect(session.connection_id);
1186    session
1187        .connection_pool()
1188        .await
1189        .remove_connection(session.connection_id)?;
1190
1191    session
1192        .db()
1193        .await
1194        .connection_lost(session.connection_id)
1195        .await
1196        .trace_err();
1197
1198    futures::select_biased! {
1199        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1200
1201            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1202            leave_room_for_session(&session, session.connection_id).await.trace_err();
1203            leave_channel_buffers_for_session(&session)
1204                .await
1205                .trace_err();
1206
1207            if !session
1208                .connection_pool()
1209                .await
1210                .is_user_online(session.user_id())
1211            {
1212                let db = session.db().await;
1213                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1214                    room_updated(&room, &session.peer);
1215                }
1216            }
1217
1218            update_user_contacts(session.user_id(), &session).await?;
1219        },
1220        _ = teardown.changed().fuse() => {}
1221    }
1222
1223    Ok(())
1224}
1225
1226/// Acknowledges a ping from a client, used to keep the connection alive.
1227async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1228    response.send(proto::Ack {})?;
1229    Ok(())
1230}
1231
1232/// Creates a new room for calling (outside of channels)
1233async fn create_room(
1234    _request: proto::CreateRoom,
1235    response: Response<proto::CreateRoom>,
1236    session: Session,
1237) -> Result<()> {
1238    let livekit_room = nanoid::nanoid!(30);
1239
1240    let live_kit_connection_info = util::maybe!(async {
1241        let live_kit = session.app_state.livekit_client.as_ref();
1242        let live_kit = live_kit?;
1243        let user_id = session.user_id().to_string();
1244
1245        let token = live_kit
1246            .room_token(&livekit_room, &user_id.to_string())
1247            .trace_err()?;
1248
1249        Some(proto::LiveKitConnectionInfo {
1250            server_url: live_kit.url().into(),
1251            token,
1252            can_publish: true,
1253        })
1254    })
1255    .await;
1256
1257    let room = session
1258        .db()
1259        .await
1260        .create_room(session.user_id(), session.connection_id, &livekit_room)
1261        .await?;
1262
1263    response.send(proto::CreateRoomResponse {
1264        room: Some(room.clone()),
1265        live_kit_connection_info,
1266    })?;
1267
1268    update_user_contacts(session.user_id(), &session).await?;
1269    Ok(())
1270}
1271
1272/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1273async fn join_room(
1274    request: proto::JoinRoom,
1275    response: Response<proto::JoinRoom>,
1276    session: Session,
1277) -> Result<()> {
1278    let room_id = RoomId::from_proto(request.id);
1279
1280    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1281
1282    if let Some(channel_id) = channel_id {
1283        return join_channel_internal(channel_id, Box::new(response), session).await;
1284    }
1285
1286    let joined_room = {
1287        let room = session
1288            .db()
1289            .await
1290            .join_room(room_id, session.user_id(), session.connection_id)
1291            .await?;
1292        room_updated(&room.room, &session.peer);
1293        room.into_inner()
1294    };
1295
1296    for connection_id in session
1297        .connection_pool()
1298        .await
1299        .user_connection_ids(session.user_id())
1300    {
1301        session
1302            .peer
1303            .send(
1304                connection_id,
1305                proto::CallCanceled {
1306                    room_id: room_id.to_proto(),
1307                },
1308            )
1309            .trace_err();
1310    }
1311
1312    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1313    {
1314        live_kit
1315            .room_token(
1316                &joined_room.room.livekit_room,
1317                &session.user_id().to_string(),
1318            )
1319            .trace_err()
1320            .map(|token| proto::LiveKitConnectionInfo {
1321                server_url: live_kit.url().into(),
1322                token,
1323                can_publish: true,
1324            })
1325    } else {
1326        None
1327    };
1328
1329    response.send(proto::JoinRoomResponse {
1330        room: Some(joined_room.room),
1331        channel_id: None,
1332        live_kit_connection_info,
1333    })?;
1334
1335    update_user_contacts(session.user_id(), &session).await?;
1336    Ok(())
1337}
1338
1339/// Rejoin room is used to reconnect to a room after connection errors.
1340async fn rejoin_room(
1341    request: proto::RejoinRoom,
1342    response: Response<proto::RejoinRoom>,
1343    session: Session,
1344) -> Result<()> {
1345    let room;
1346    let channel;
1347    {
1348        let mut rejoined_room = session
1349            .db()
1350            .await
1351            .rejoin_room(request, session.user_id(), session.connection_id)
1352            .await?;
1353
1354        response.send(proto::RejoinRoomResponse {
1355            room: Some(rejoined_room.room.clone()),
1356            reshared_projects: rejoined_room
1357                .reshared_projects
1358                .iter()
1359                .map(|project| proto::ResharedProject {
1360                    id: project.id.to_proto(),
1361                    collaborators: project
1362                        .collaborators
1363                        .iter()
1364                        .map(|collaborator| collaborator.to_proto())
1365                        .collect(),
1366                })
1367                .collect(),
1368            rejoined_projects: rejoined_room
1369                .rejoined_projects
1370                .iter()
1371                .map(|rejoined_project| rejoined_project.to_proto())
1372                .collect(),
1373        })?;
1374        room_updated(&rejoined_room.room, &session.peer);
1375
1376        for project in &rejoined_room.reshared_projects {
1377            for collaborator in &project.collaborators {
1378                session
1379                    .peer
1380                    .send(
1381                        collaborator.connection_id,
1382                        proto::UpdateProjectCollaborator {
1383                            project_id: project.id.to_proto(),
1384                            old_peer_id: Some(project.old_connection_id.into()),
1385                            new_peer_id: Some(session.connection_id.into()),
1386                        },
1387                    )
1388                    .trace_err();
1389            }
1390
1391            broadcast(
1392                Some(session.connection_id),
1393                project
1394                    .collaborators
1395                    .iter()
1396                    .map(|collaborator| collaborator.connection_id),
1397                |connection_id| {
1398                    session.peer.forward_send(
1399                        session.connection_id,
1400                        connection_id,
1401                        proto::UpdateProject {
1402                            project_id: project.id.to_proto(),
1403                            worktrees: project.worktrees.clone(),
1404                        },
1405                    )
1406                },
1407            );
1408        }
1409
1410        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1411
1412        let rejoined_room = rejoined_room.into_inner();
1413
1414        room = rejoined_room.room;
1415        channel = rejoined_room.channel;
1416    }
1417
1418    if let Some(channel) = channel {
1419        channel_updated(
1420            &channel,
1421            &room,
1422            &session.peer,
1423            &*session.connection_pool().await,
1424        );
1425    }
1426
1427    update_user_contacts(session.user_id(), &session).await?;
1428    Ok(())
1429}
1430
1431fn notify_rejoined_projects(
1432    rejoined_projects: &mut Vec<RejoinedProject>,
1433    session: &Session,
1434) -> Result<()> {
1435    for project in rejoined_projects.iter() {
1436        for collaborator in &project.collaborators {
1437            session
1438                .peer
1439                .send(
1440                    collaborator.connection_id,
1441                    proto::UpdateProjectCollaborator {
1442                        project_id: project.id.to_proto(),
1443                        old_peer_id: Some(project.old_connection_id.into()),
1444                        new_peer_id: Some(session.connection_id.into()),
1445                    },
1446                )
1447                .trace_err();
1448        }
1449    }
1450
1451    for project in rejoined_projects {
1452        for worktree in mem::take(&mut project.worktrees) {
1453            // Stream this worktree's entries.
1454            let message = proto::UpdateWorktree {
1455                project_id: project.id.to_proto(),
1456                worktree_id: worktree.id,
1457                abs_path: worktree.abs_path.clone(),
1458                root_name: worktree.root_name,
1459                updated_entries: worktree.updated_entries,
1460                removed_entries: worktree.removed_entries,
1461                scan_id: worktree.scan_id,
1462                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1463                updated_repositories: worktree.updated_repositories,
1464                removed_repositories: worktree.removed_repositories,
1465            };
1466            for update in proto::split_worktree_update(message) {
1467                session.peer.send(session.connection_id, update.clone())?;
1468            }
1469
1470            // Stream this worktree's diagnostics.
1471            for summary in worktree.diagnostic_summaries {
1472                session.peer.send(
1473                    session.connection_id,
1474                    proto::UpdateDiagnosticSummary {
1475                        project_id: project.id.to_proto(),
1476                        worktree_id: worktree.id,
1477                        summary: Some(summary),
1478                    },
1479                )?;
1480            }
1481
1482            for settings_file in worktree.settings_files {
1483                session.peer.send(
1484                    session.connection_id,
1485                    proto::UpdateWorktreeSettings {
1486                        project_id: project.id.to_proto(),
1487                        worktree_id: worktree.id,
1488                        path: settings_file.path,
1489                        content: Some(settings_file.content),
1490                        kind: Some(settings_file.kind.to_proto().into()),
1491                    },
1492                )?;
1493            }
1494        }
1495
1496        for language_server in &project.language_servers {
1497            session.peer.send(
1498                session.connection_id,
1499                proto::UpdateLanguageServer {
1500                    project_id: project.id.to_proto(),
1501                    language_server_id: language_server.id,
1502                    variant: Some(
1503                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1504                            proto::LspDiskBasedDiagnosticsUpdated {},
1505                        ),
1506                    ),
1507                },
1508            )?;
1509        }
1510    }
1511    Ok(())
1512}
1513
1514/// leave room disconnects from the room.
1515async fn leave_room(
1516    _: proto::LeaveRoom,
1517    response: Response<proto::LeaveRoom>,
1518    session: Session,
1519) -> Result<()> {
1520    leave_room_for_session(&session, session.connection_id).await?;
1521    response.send(proto::Ack {})?;
1522    Ok(())
1523}
1524
1525/// Updates the permissions of someone else in the room.
1526async fn set_room_participant_role(
1527    request: proto::SetRoomParticipantRole,
1528    response: Response<proto::SetRoomParticipantRole>,
1529    session: Session,
1530) -> Result<()> {
1531    let user_id = UserId::from_proto(request.user_id);
1532    let role = ChannelRole::from(request.role());
1533
1534    let (livekit_room, can_publish) = {
1535        let room = session
1536            .db()
1537            .await
1538            .set_room_participant_role(
1539                session.user_id(),
1540                RoomId::from_proto(request.room_id),
1541                user_id,
1542                role,
1543            )
1544            .await?;
1545
1546        let livekit_room = room.livekit_room.clone();
1547        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1548        room_updated(&room, &session.peer);
1549        (livekit_room, can_publish)
1550    };
1551
1552    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1553        live_kit
1554            .update_participant(
1555                livekit_room.clone(),
1556                request.user_id.to_string(),
1557                livekit_api::proto::ParticipantPermission {
1558                    can_subscribe: true,
1559                    can_publish,
1560                    can_publish_data: can_publish,
1561                    hidden: false,
1562                    recorder: false,
1563                },
1564            )
1565            .await
1566            .trace_err();
1567    }
1568
1569    response.send(proto::Ack {})?;
1570    Ok(())
1571}
1572
1573/// Call someone else into the current room
1574async fn call(
1575    request: proto::Call,
1576    response: Response<proto::Call>,
1577    session: Session,
1578) -> Result<()> {
1579    let room_id = RoomId::from_proto(request.room_id);
1580    let calling_user_id = session.user_id();
1581    let calling_connection_id = session.connection_id;
1582    let called_user_id = UserId::from_proto(request.called_user_id);
1583    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1584    if !session
1585        .db()
1586        .await
1587        .has_contact(calling_user_id, called_user_id)
1588        .await?
1589    {
1590        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1591    }
1592
1593    let incoming_call = {
1594        let (room, incoming_call) = &mut *session
1595            .db()
1596            .await
1597            .call(
1598                room_id,
1599                calling_user_id,
1600                calling_connection_id,
1601                called_user_id,
1602                initial_project_id,
1603            )
1604            .await?;
1605        room_updated(room, &session.peer);
1606        mem::take(incoming_call)
1607    };
1608    update_user_contacts(called_user_id, &session).await?;
1609
1610    let mut calls = session
1611        .connection_pool()
1612        .await
1613        .user_connection_ids(called_user_id)
1614        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1615        .collect::<FuturesUnordered<_>>();
1616
1617    while let Some(call_response) = calls.next().await {
1618        match call_response.as_ref() {
1619            Ok(_) => {
1620                response.send(proto::Ack {})?;
1621                return Ok(());
1622            }
1623            Err(_) => {
1624                call_response.trace_err();
1625            }
1626        }
1627    }
1628
1629    {
1630        let room = session
1631            .db()
1632            .await
1633            .call_failed(room_id, called_user_id)
1634            .await?;
1635        room_updated(&room, &session.peer);
1636    }
1637    update_user_contacts(called_user_id, &session).await?;
1638
1639    Err(anyhow!("failed to ring user"))?
1640}
1641
1642/// Cancel an outgoing call.
1643async fn cancel_call(
1644    request: proto::CancelCall,
1645    response: Response<proto::CancelCall>,
1646    session: Session,
1647) -> Result<()> {
1648    let called_user_id = UserId::from_proto(request.called_user_id);
1649    let room_id = RoomId::from_proto(request.room_id);
1650    {
1651        let room = session
1652            .db()
1653            .await
1654            .cancel_call(room_id, session.connection_id, called_user_id)
1655            .await?;
1656        room_updated(&room, &session.peer);
1657    }
1658
1659    for connection_id in session
1660        .connection_pool()
1661        .await
1662        .user_connection_ids(called_user_id)
1663    {
1664        session
1665            .peer
1666            .send(
1667                connection_id,
1668                proto::CallCanceled {
1669                    room_id: room_id.to_proto(),
1670                },
1671            )
1672            .trace_err();
1673    }
1674    response.send(proto::Ack {})?;
1675
1676    update_user_contacts(called_user_id, &session).await?;
1677    Ok(())
1678}
1679
1680/// Decline an incoming call.
1681async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1682    let room_id = RoomId::from_proto(message.room_id);
1683    {
1684        let room = session
1685            .db()
1686            .await
1687            .decline_call(Some(room_id), session.user_id())
1688            .await?
1689            .ok_or_else(|| anyhow!("failed to decline call"))?;
1690        room_updated(&room, &session.peer);
1691    }
1692
1693    for connection_id in session
1694        .connection_pool()
1695        .await
1696        .user_connection_ids(session.user_id())
1697    {
1698        session
1699            .peer
1700            .send(
1701                connection_id,
1702                proto::CallCanceled {
1703                    room_id: room_id.to_proto(),
1704                },
1705            )
1706            .trace_err();
1707    }
1708    update_user_contacts(session.user_id(), &session).await?;
1709    Ok(())
1710}
1711
1712/// Updates other participants in the room with your current location.
1713async fn update_participant_location(
1714    request: proto::UpdateParticipantLocation,
1715    response: Response<proto::UpdateParticipantLocation>,
1716    session: Session,
1717) -> Result<()> {
1718    let room_id = RoomId::from_proto(request.room_id);
1719    let location = request
1720        .location
1721        .ok_or_else(|| anyhow!("invalid location"))?;
1722
1723    let db = session.db().await;
1724    let room = db
1725        .update_room_participant_location(room_id, session.connection_id, location)
1726        .await?;
1727
1728    room_updated(&room, &session.peer);
1729    response.send(proto::Ack {})?;
1730    Ok(())
1731}
1732
1733/// Share a project into the room.
1734async fn share_project(
1735    request: proto::ShareProject,
1736    response: Response<proto::ShareProject>,
1737    session: Session,
1738) -> Result<()> {
1739    let (project_id, room) = &*session
1740        .db()
1741        .await
1742        .share_project(
1743            RoomId::from_proto(request.room_id),
1744            session.connection_id,
1745            &request.worktrees,
1746            request.is_ssh_project,
1747        )
1748        .await?;
1749    response.send(proto::ShareProjectResponse {
1750        project_id: project_id.to_proto(),
1751    })?;
1752    room_updated(room, &session.peer);
1753
1754    Ok(())
1755}
1756
1757/// Unshare a project from the room.
1758async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1759    let project_id = ProjectId::from_proto(message.project_id);
1760    unshare_project_internal(project_id, session.connection_id, &session).await
1761}
1762
1763async fn unshare_project_internal(
1764    project_id: ProjectId,
1765    connection_id: ConnectionId,
1766    session: &Session,
1767) -> Result<()> {
1768    let delete = {
1769        let room_guard = session
1770            .db()
1771            .await
1772            .unshare_project(project_id, connection_id)
1773            .await?;
1774
1775        let (delete, room, guest_connection_ids) = &*room_guard;
1776
1777        let message = proto::UnshareProject {
1778            project_id: project_id.to_proto(),
1779        };
1780
1781        broadcast(
1782            Some(connection_id),
1783            guest_connection_ids.iter().copied(),
1784            |conn_id| session.peer.send(conn_id, message.clone()),
1785        );
1786        if let Some(room) = room {
1787            room_updated(room, &session.peer);
1788        }
1789
1790        *delete
1791    };
1792
1793    if delete {
1794        let db = session.db().await;
1795        db.delete_project(project_id).await?;
1796    }
1797
1798    Ok(())
1799}
1800
1801/// Join someone elses shared project.
1802async fn join_project(
1803    request: proto::JoinProject,
1804    response: Response<proto::JoinProject>,
1805    session: Session,
1806) -> Result<()> {
1807    let project_id = ProjectId::from_proto(request.project_id);
1808
1809    tracing::info!(%project_id, "join project");
1810
1811    let db = session.db().await;
1812    let (project, replica_id) = &mut *db
1813        .join_project(project_id, session.connection_id, session.user_id())
1814        .await?;
1815    drop(db);
1816    tracing::info!(%project_id, "join remote project");
1817    join_project_internal(response, session, project, replica_id)
1818}
1819
1820trait JoinProjectInternalResponse {
1821    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1822}
1823impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1824    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1825        Response::<proto::JoinProject>::send(self, result)
1826    }
1827}
1828
1829fn join_project_internal(
1830    response: impl JoinProjectInternalResponse,
1831    session: Session,
1832    project: &mut Project,
1833    replica_id: &ReplicaId,
1834) -> Result<()> {
1835    let collaborators = project
1836        .collaborators
1837        .iter()
1838        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1839        .map(|collaborator| collaborator.to_proto())
1840        .collect::<Vec<_>>();
1841    let project_id = project.id;
1842    let guest_user_id = session.user_id();
1843
1844    let worktrees = project
1845        .worktrees
1846        .iter()
1847        .map(|(id, worktree)| proto::WorktreeMetadata {
1848            id: *id,
1849            root_name: worktree.root_name.clone(),
1850            visible: worktree.visible,
1851            abs_path: worktree.abs_path.clone(),
1852        })
1853        .collect::<Vec<_>>();
1854
1855    let add_project_collaborator = proto::AddProjectCollaborator {
1856        project_id: project_id.to_proto(),
1857        collaborator: Some(proto::Collaborator {
1858            peer_id: Some(session.connection_id.into()),
1859            replica_id: replica_id.0 as u32,
1860            user_id: guest_user_id.to_proto(),
1861            is_host: false,
1862        }),
1863    };
1864
1865    for collaborator in &collaborators {
1866        session
1867            .peer
1868            .send(
1869                collaborator.peer_id.unwrap().into(),
1870                add_project_collaborator.clone(),
1871            )
1872            .trace_err();
1873    }
1874
1875    // First, we send the metadata associated with each worktree.
1876    response.send(proto::JoinProjectResponse {
1877        project_id: project.id.0 as u64,
1878        worktrees: worktrees.clone(),
1879        replica_id: replica_id.0 as u32,
1880        collaborators: collaborators.clone(),
1881        language_servers: project.language_servers.clone(),
1882        role: project.role.into(),
1883    })?;
1884
1885    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1886        // Stream this worktree's entries.
1887        let message = proto::UpdateWorktree {
1888            project_id: project_id.to_proto(),
1889            worktree_id,
1890            abs_path: worktree.abs_path.clone(),
1891            root_name: worktree.root_name,
1892            updated_entries: worktree.entries,
1893            removed_entries: Default::default(),
1894            scan_id: worktree.scan_id,
1895            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1896            updated_repositories: worktree.repository_entries.into_values().collect(),
1897            removed_repositories: Default::default(),
1898        };
1899        for update in proto::split_worktree_update(message) {
1900            session.peer.send(session.connection_id, update.clone())?;
1901        }
1902
1903        // Stream this worktree's diagnostics.
1904        for summary in worktree.diagnostic_summaries {
1905            session.peer.send(
1906                session.connection_id,
1907                proto::UpdateDiagnosticSummary {
1908                    project_id: project_id.to_proto(),
1909                    worktree_id: worktree.id,
1910                    summary: Some(summary),
1911                },
1912            )?;
1913        }
1914
1915        for settings_file in worktree.settings_files {
1916            session.peer.send(
1917                session.connection_id,
1918                proto::UpdateWorktreeSettings {
1919                    project_id: project_id.to_proto(),
1920                    worktree_id: worktree.id,
1921                    path: settings_file.path,
1922                    content: Some(settings_file.content),
1923                    kind: Some(settings_file.kind.to_proto() as i32),
1924                },
1925            )?;
1926        }
1927    }
1928
1929    for language_server in &project.language_servers {
1930        session.peer.send(
1931            session.connection_id,
1932            proto::UpdateLanguageServer {
1933                project_id: project_id.to_proto(),
1934                language_server_id: language_server.id,
1935                variant: Some(
1936                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1937                        proto::LspDiskBasedDiagnosticsUpdated {},
1938                    ),
1939                ),
1940            },
1941        )?;
1942    }
1943
1944    Ok(())
1945}
1946
1947/// Leave someone elses shared project.
1948async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1949    let sender_id = session.connection_id;
1950    let project_id = ProjectId::from_proto(request.project_id);
1951    let db = session.db().await;
1952
1953    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1954    tracing::info!(
1955        %project_id,
1956        "leave project"
1957    );
1958
1959    project_left(project, &session);
1960    if let Some(room) = room {
1961        room_updated(room, &session.peer);
1962    }
1963
1964    Ok(())
1965}
1966
1967/// Updates other participants with changes to the project
1968async fn update_project(
1969    request: proto::UpdateProject,
1970    response: Response<proto::UpdateProject>,
1971    session: Session,
1972) -> Result<()> {
1973    let project_id = ProjectId::from_proto(request.project_id);
1974    let (room, guest_connection_ids) = &*session
1975        .db()
1976        .await
1977        .update_project(project_id, session.connection_id, &request.worktrees)
1978        .await?;
1979    broadcast(
1980        Some(session.connection_id),
1981        guest_connection_ids.iter().copied(),
1982        |connection_id| {
1983            session
1984                .peer
1985                .forward_send(session.connection_id, connection_id, request.clone())
1986        },
1987    );
1988    if let Some(room) = room {
1989        room_updated(room, &session.peer);
1990    }
1991    response.send(proto::Ack {})?;
1992
1993    Ok(())
1994}
1995
1996/// Updates other participants with changes to the worktree
1997async fn update_worktree(
1998    request: proto::UpdateWorktree,
1999    response: Response<proto::UpdateWorktree>,
2000    session: Session,
2001) -> Result<()> {
2002    let guest_connection_ids = session
2003        .db()
2004        .await
2005        .update_worktree(&request, session.connection_id)
2006        .await?;
2007
2008    broadcast(
2009        Some(session.connection_id),
2010        guest_connection_ids.iter().copied(),
2011        |connection_id| {
2012            session
2013                .peer
2014                .forward_send(session.connection_id, connection_id, request.clone())
2015        },
2016    );
2017    response.send(proto::Ack {})?;
2018    Ok(())
2019}
2020
2021/// Updates other participants with changes to the diagnostics
2022async fn update_diagnostic_summary(
2023    message: proto::UpdateDiagnosticSummary,
2024    session: Session,
2025) -> Result<()> {
2026    let guest_connection_ids = session
2027        .db()
2028        .await
2029        .update_diagnostic_summary(&message, session.connection_id)
2030        .await?;
2031
2032    broadcast(
2033        Some(session.connection_id),
2034        guest_connection_ids.iter().copied(),
2035        |connection_id| {
2036            session
2037                .peer
2038                .forward_send(session.connection_id, connection_id, message.clone())
2039        },
2040    );
2041
2042    Ok(())
2043}
2044
2045/// Updates other participants with changes to the worktree settings
2046async fn update_worktree_settings(
2047    message: proto::UpdateWorktreeSettings,
2048    session: Session,
2049) -> Result<()> {
2050    let guest_connection_ids = session
2051        .db()
2052        .await
2053        .update_worktree_settings(&message, session.connection_id)
2054        .await?;
2055
2056    broadcast(
2057        Some(session.connection_id),
2058        guest_connection_ids.iter().copied(),
2059        |connection_id| {
2060            session
2061                .peer
2062                .forward_send(session.connection_id, connection_id, message.clone())
2063        },
2064    );
2065
2066    Ok(())
2067}
2068
2069/// Notify other participants that a language server has started.
2070async fn start_language_server(
2071    request: proto::StartLanguageServer,
2072    session: Session,
2073) -> Result<()> {
2074    let guest_connection_ids = session
2075        .db()
2076        .await
2077        .start_language_server(&request, session.connection_id)
2078        .await?;
2079
2080    broadcast(
2081        Some(session.connection_id),
2082        guest_connection_ids.iter().copied(),
2083        |connection_id| {
2084            session
2085                .peer
2086                .forward_send(session.connection_id, connection_id, request.clone())
2087        },
2088    );
2089    Ok(())
2090}
2091
2092/// Notify other participants that a language server has changed.
2093async fn update_language_server(
2094    request: proto::UpdateLanguageServer,
2095    session: Session,
2096) -> Result<()> {
2097    let project_id = ProjectId::from_proto(request.project_id);
2098    let project_connection_ids = session
2099        .db()
2100        .await
2101        .project_connection_ids(project_id, session.connection_id, true)
2102        .await?;
2103    broadcast(
2104        Some(session.connection_id),
2105        project_connection_ids.iter().copied(),
2106        |connection_id| {
2107            session
2108                .peer
2109                .forward_send(session.connection_id, connection_id, request.clone())
2110        },
2111    );
2112    Ok(())
2113}
2114
2115/// forward a project request to the host. These requests should be read only
2116/// as guests are allowed to send them.
2117async fn forward_read_only_project_request<T>(
2118    request: T,
2119    response: Response<T>,
2120    session: Session,
2121) -> Result<()>
2122where
2123    T: EntityMessage + RequestMessage,
2124{
2125    let project_id = ProjectId::from_proto(request.remote_entity_id());
2126    let host_connection_id = session
2127        .db()
2128        .await
2129        .host_for_read_only_project_request(project_id, session.connection_id)
2130        .await?;
2131    let payload = session
2132        .peer
2133        .forward_request(session.connection_id, host_connection_id, request)
2134        .await?;
2135    response.send(payload)?;
2136    Ok(())
2137}
2138
2139async fn forward_find_search_candidates_request(
2140    request: proto::FindSearchCandidates,
2141    response: Response<proto::FindSearchCandidates>,
2142    session: Session,
2143) -> Result<()> {
2144    let project_id = ProjectId::from_proto(request.remote_entity_id());
2145    let host_connection_id = session
2146        .db()
2147        .await
2148        .host_for_read_only_project_request(project_id, session.connection_id)
2149        .await?;
2150    let payload = session
2151        .peer
2152        .forward_request(session.connection_id, host_connection_id, request)
2153        .await?;
2154    response.send(payload)?;
2155    Ok(())
2156}
2157
2158/// forward a project request to the host. These requests are disallowed
2159/// for guests.
2160async fn forward_mutating_project_request<T>(
2161    request: T,
2162    response: Response<T>,
2163    session: Session,
2164) -> Result<()>
2165where
2166    T: EntityMessage + RequestMessage,
2167{
2168    let project_id = ProjectId::from_proto(request.remote_entity_id());
2169
2170    let host_connection_id = session
2171        .db()
2172        .await
2173        .host_for_mutating_project_request(project_id, session.connection_id)
2174        .await?;
2175    let payload = session
2176        .peer
2177        .forward_request(session.connection_id, host_connection_id, request)
2178        .await?;
2179    response.send(payload)?;
2180    Ok(())
2181}
2182
2183/// Notify other participants that a new buffer has been created
2184async fn create_buffer_for_peer(
2185    request: proto::CreateBufferForPeer,
2186    session: Session,
2187) -> Result<()> {
2188    session
2189        .db()
2190        .await
2191        .check_user_is_project_host(
2192            ProjectId::from_proto(request.project_id),
2193            session.connection_id,
2194        )
2195        .await?;
2196    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2197    session
2198        .peer
2199        .forward_send(session.connection_id, peer_id.into(), request)?;
2200    Ok(())
2201}
2202
2203/// Notify other participants that a buffer has been updated. This is
2204/// allowed for guests as long as the update is limited to selections.
2205async fn update_buffer(
2206    request: proto::UpdateBuffer,
2207    response: Response<proto::UpdateBuffer>,
2208    session: Session,
2209) -> Result<()> {
2210    let project_id = ProjectId::from_proto(request.project_id);
2211    let mut capability = Capability::ReadOnly;
2212
2213    for op in request.operations.iter() {
2214        match op.variant {
2215            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2216            Some(_) => capability = Capability::ReadWrite,
2217        }
2218    }
2219
2220    let host = {
2221        let guard = session
2222            .db()
2223            .await
2224            .connections_for_buffer_update(project_id, session.connection_id, capability)
2225            .await?;
2226
2227        let (host, guests) = &*guard;
2228
2229        broadcast(
2230            Some(session.connection_id),
2231            guests.clone(),
2232            |connection_id| {
2233                session
2234                    .peer
2235                    .forward_send(session.connection_id, connection_id, request.clone())
2236            },
2237        );
2238
2239        *host
2240    };
2241
2242    if host != session.connection_id {
2243        session
2244            .peer
2245            .forward_request(session.connection_id, host, request.clone())
2246            .await?;
2247    }
2248
2249    response.send(proto::Ack {})?;
2250    Ok(())
2251}
2252
2253async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2254    let project_id = ProjectId::from_proto(message.project_id);
2255
2256    let operation = message.operation.as_ref().context("invalid operation")?;
2257    let capability = match operation.variant.as_ref() {
2258        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2259            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2260                match buffer_op.variant {
2261                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2262                        Capability::ReadOnly
2263                    }
2264                    _ => Capability::ReadWrite,
2265                }
2266            } else {
2267                Capability::ReadWrite
2268            }
2269        }
2270        Some(_) => Capability::ReadWrite,
2271        None => Capability::ReadOnly,
2272    };
2273
2274    let guard = session
2275        .db()
2276        .await
2277        .connections_for_buffer_update(project_id, session.connection_id, capability)
2278        .await?;
2279
2280    let (host, guests) = &*guard;
2281
2282    broadcast(
2283        Some(session.connection_id),
2284        guests.iter().chain([host]).copied(),
2285        |connection_id| {
2286            session
2287                .peer
2288                .forward_send(session.connection_id, connection_id, message.clone())
2289        },
2290    );
2291
2292    Ok(())
2293}
2294
2295/// Notify other participants that a project has been updated.
2296async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2297    request: T,
2298    session: Session,
2299) -> Result<()> {
2300    let project_id = ProjectId::from_proto(request.remote_entity_id());
2301    let project_connection_ids = session
2302        .db()
2303        .await
2304        .project_connection_ids(project_id, session.connection_id, false)
2305        .await?;
2306
2307    broadcast(
2308        Some(session.connection_id),
2309        project_connection_ids.iter().copied(),
2310        |connection_id| {
2311            session
2312                .peer
2313                .forward_send(session.connection_id, connection_id, request.clone())
2314        },
2315    );
2316    Ok(())
2317}
2318
2319/// Start following another user in a call.
2320async fn follow(
2321    request: proto::Follow,
2322    response: Response<proto::Follow>,
2323    session: Session,
2324) -> Result<()> {
2325    let room_id = RoomId::from_proto(request.room_id);
2326    let project_id = request.project_id.map(ProjectId::from_proto);
2327    let leader_id = request
2328        .leader_id
2329        .ok_or_else(|| anyhow!("invalid leader id"))?
2330        .into();
2331    let follower_id = session.connection_id;
2332
2333    session
2334        .db()
2335        .await
2336        .check_room_participants(room_id, leader_id, session.connection_id)
2337        .await?;
2338
2339    let response_payload = session
2340        .peer
2341        .forward_request(session.connection_id, leader_id, request)
2342        .await?;
2343    response.send(response_payload)?;
2344
2345    if let Some(project_id) = project_id {
2346        let room = session
2347            .db()
2348            .await
2349            .follow(room_id, project_id, leader_id, follower_id)
2350            .await?;
2351        room_updated(&room, &session.peer);
2352    }
2353
2354    Ok(())
2355}
2356
2357/// Stop following another user in a call.
2358async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2359    let room_id = RoomId::from_proto(request.room_id);
2360    let project_id = request.project_id.map(ProjectId::from_proto);
2361    let leader_id = request
2362        .leader_id
2363        .ok_or_else(|| anyhow!("invalid leader id"))?
2364        .into();
2365    let follower_id = session.connection_id;
2366
2367    session
2368        .db()
2369        .await
2370        .check_room_participants(room_id, leader_id, session.connection_id)
2371        .await?;
2372
2373    session
2374        .peer
2375        .forward_send(session.connection_id, leader_id, request)?;
2376
2377    if let Some(project_id) = project_id {
2378        let room = session
2379            .db()
2380            .await
2381            .unfollow(room_id, project_id, leader_id, follower_id)
2382            .await?;
2383        room_updated(&room, &session.peer);
2384    }
2385
2386    Ok(())
2387}
2388
2389/// Notify everyone following you of your current location.
2390async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2391    let room_id = RoomId::from_proto(request.room_id);
2392    let database = session.db.lock().await;
2393
2394    let connection_ids = if let Some(project_id) = request.project_id {
2395        let project_id = ProjectId::from_proto(project_id);
2396        database
2397            .project_connection_ids(project_id, session.connection_id, true)
2398            .await?
2399    } else {
2400        database
2401            .room_connection_ids(room_id, session.connection_id)
2402            .await?
2403    };
2404
2405    // For now, don't send view update messages back to that view's current leader.
2406    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2407        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2408        _ => None,
2409    });
2410
2411    for connection_id in connection_ids.iter().cloned() {
2412        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2413            session
2414                .peer
2415                .forward_send(session.connection_id, connection_id, request.clone())?;
2416        }
2417    }
2418    Ok(())
2419}
2420
2421/// Get public data about users.
2422async fn get_users(
2423    request: proto::GetUsers,
2424    response: Response<proto::GetUsers>,
2425    session: Session,
2426) -> Result<()> {
2427    let user_ids = request
2428        .user_ids
2429        .into_iter()
2430        .map(UserId::from_proto)
2431        .collect();
2432    let users = session
2433        .db()
2434        .await
2435        .get_users_by_ids(user_ids)
2436        .await?
2437        .into_iter()
2438        .map(|user| proto::User {
2439            id: user.id.to_proto(),
2440            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2441            github_login: user.github_login,
2442            email: user.email_address,
2443            name: user.name,
2444        })
2445        .collect();
2446    response.send(proto::UsersResponse { users })?;
2447    Ok(())
2448}
2449
2450/// Search for users (to invite) buy Github login
2451async fn fuzzy_search_users(
2452    request: proto::FuzzySearchUsers,
2453    response: Response<proto::FuzzySearchUsers>,
2454    session: Session,
2455) -> Result<()> {
2456    let query = request.query;
2457    let users = match query.len() {
2458        0 => vec![],
2459        1 | 2 => session
2460            .db()
2461            .await
2462            .get_user_by_github_login(&query)
2463            .await?
2464            .into_iter()
2465            .collect(),
2466        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2467    };
2468    let users = users
2469        .into_iter()
2470        .filter(|user| user.id != session.user_id())
2471        .map(|user| proto::User {
2472            id: user.id.to_proto(),
2473            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2474            github_login: user.github_login,
2475            name: user.name,
2476            email: user.email_address,
2477        })
2478        .collect();
2479    response.send(proto::UsersResponse { users })?;
2480    Ok(())
2481}
2482
2483/// Send a contact request to another user.
2484async fn request_contact(
2485    request: proto::RequestContact,
2486    response: Response<proto::RequestContact>,
2487    session: Session,
2488) -> Result<()> {
2489    let requester_id = session.user_id();
2490    let responder_id = UserId::from_proto(request.responder_id);
2491    if requester_id == responder_id {
2492        return Err(anyhow!("cannot add yourself as a contact"))?;
2493    }
2494
2495    let notifications = session
2496        .db()
2497        .await
2498        .send_contact_request(requester_id, responder_id)
2499        .await?;
2500
2501    // Update outgoing contact requests of requester
2502    let mut update = proto::UpdateContacts::default();
2503    update.outgoing_requests.push(responder_id.to_proto());
2504    for connection_id in session
2505        .connection_pool()
2506        .await
2507        .user_connection_ids(requester_id)
2508    {
2509        session.peer.send(connection_id, update.clone())?;
2510    }
2511
2512    // Update incoming contact requests of responder
2513    let mut update = proto::UpdateContacts::default();
2514    update
2515        .incoming_requests
2516        .push(proto::IncomingContactRequest {
2517            requester_id: requester_id.to_proto(),
2518        });
2519    let connection_pool = session.connection_pool().await;
2520    for connection_id in connection_pool.user_connection_ids(responder_id) {
2521        session.peer.send(connection_id, update.clone())?;
2522    }
2523
2524    send_notifications(&connection_pool, &session.peer, notifications);
2525
2526    response.send(proto::Ack {})?;
2527    Ok(())
2528}
2529
2530/// Accept or decline a contact request
2531async fn respond_to_contact_request(
2532    request: proto::RespondToContactRequest,
2533    response: Response<proto::RespondToContactRequest>,
2534    session: Session,
2535) -> Result<()> {
2536    let responder_id = session.user_id();
2537    let requester_id = UserId::from_proto(request.requester_id);
2538    let db = session.db().await;
2539    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2540        db.dismiss_contact_notification(responder_id, requester_id)
2541            .await?;
2542    } else {
2543        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2544
2545        let notifications = db
2546            .respond_to_contact_request(responder_id, requester_id, accept)
2547            .await?;
2548        let requester_busy = db.is_user_busy(requester_id).await?;
2549        let responder_busy = db.is_user_busy(responder_id).await?;
2550
2551        let pool = session.connection_pool().await;
2552        // Update responder with new contact
2553        let mut update = proto::UpdateContacts::default();
2554        if accept {
2555            update
2556                .contacts
2557                .push(contact_for_user(requester_id, requester_busy, &pool));
2558        }
2559        update
2560            .remove_incoming_requests
2561            .push(requester_id.to_proto());
2562        for connection_id in pool.user_connection_ids(responder_id) {
2563            session.peer.send(connection_id, update.clone())?;
2564        }
2565
2566        // Update requester with new contact
2567        let mut update = proto::UpdateContacts::default();
2568        if accept {
2569            update
2570                .contacts
2571                .push(contact_for_user(responder_id, responder_busy, &pool));
2572        }
2573        update
2574            .remove_outgoing_requests
2575            .push(responder_id.to_proto());
2576
2577        for connection_id in pool.user_connection_ids(requester_id) {
2578            session.peer.send(connection_id, update.clone())?;
2579        }
2580
2581        send_notifications(&pool, &session.peer, notifications);
2582    }
2583
2584    response.send(proto::Ack {})?;
2585    Ok(())
2586}
2587
2588/// Remove a contact.
2589async fn remove_contact(
2590    request: proto::RemoveContact,
2591    response: Response<proto::RemoveContact>,
2592    session: Session,
2593) -> Result<()> {
2594    let requester_id = session.user_id();
2595    let responder_id = UserId::from_proto(request.user_id);
2596    let db = session.db().await;
2597    let (contact_accepted, deleted_notification_id) =
2598        db.remove_contact(requester_id, responder_id).await?;
2599
2600    let pool = session.connection_pool().await;
2601    // Update outgoing contact requests of requester
2602    let mut update = proto::UpdateContacts::default();
2603    if contact_accepted {
2604        update.remove_contacts.push(responder_id.to_proto());
2605    } else {
2606        update
2607            .remove_outgoing_requests
2608            .push(responder_id.to_proto());
2609    }
2610    for connection_id in pool.user_connection_ids(requester_id) {
2611        session.peer.send(connection_id, update.clone())?;
2612    }
2613
2614    // Update incoming contact requests of responder
2615    let mut update = proto::UpdateContacts::default();
2616    if contact_accepted {
2617        update.remove_contacts.push(requester_id.to_proto());
2618    } else {
2619        update
2620            .remove_incoming_requests
2621            .push(requester_id.to_proto());
2622    }
2623    for connection_id in pool.user_connection_ids(responder_id) {
2624        session.peer.send(connection_id, update.clone())?;
2625        if let Some(notification_id) = deleted_notification_id {
2626            session.peer.send(
2627                connection_id,
2628                proto::DeleteNotification {
2629                    notification_id: notification_id.to_proto(),
2630                },
2631            )?;
2632        }
2633    }
2634
2635    response.send(proto::Ack {})?;
2636    Ok(())
2637}
2638
2639fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2640    version.0.minor() < 139
2641}
2642
2643async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2644    let plan = session.current_plan(&session.db().await).await?;
2645
2646    session
2647        .peer
2648        .send(
2649            session.connection_id,
2650            proto::UpdateUserPlan { plan: plan.into() },
2651        )
2652        .trace_err();
2653
2654    Ok(())
2655}
2656
2657async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2658    subscribe_user_to_channels(session.user_id(), &session).await?;
2659    Ok(())
2660}
2661
2662async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2663    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2664    let mut pool = session.connection_pool().await;
2665    for membership in &channels_for_user.channel_memberships {
2666        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2667    }
2668    session.peer.send(
2669        session.connection_id,
2670        build_update_user_channels(&channels_for_user),
2671    )?;
2672    session.peer.send(
2673        session.connection_id,
2674        build_channels_update(channels_for_user),
2675    )?;
2676    Ok(())
2677}
2678
2679/// Creates a new channel.
2680async fn create_channel(
2681    request: proto::CreateChannel,
2682    response: Response<proto::CreateChannel>,
2683    session: Session,
2684) -> Result<()> {
2685    let db = session.db().await;
2686
2687    let parent_id = request.parent_id.map(ChannelId::from_proto);
2688    let (channel, membership) = db
2689        .create_channel(&request.name, parent_id, session.user_id())
2690        .await?;
2691
2692    let root_id = channel.root_id();
2693    let channel = Channel::from_model(channel);
2694
2695    response.send(proto::CreateChannelResponse {
2696        channel: Some(channel.to_proto()),
2697        parent_id: request.parent_id,
2698    })?;
2699
2700    let mut connection_pool = session.connection_pool().await;
2701    if let Some(membership) = membership {
2702        connection_pool.subscribe_to_channel(
2703            membership.user_id,
2704            membership.channel_id,
2705            membership.role,
2706        );
2707        let update = proto::UpdateUserChannels {
2708            channel_memberships: vec![proto::ChannelMembership {
2709                channel_id: membership.channel_id.to_proto(),
2710                role: membership.role.into(),
2711            }],
2712            ..Default::default()
2713        };
2714        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2715            session.peer.send(connection_id, update.clone())?;
2716        }
2717    }
2718
2719    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2720        if !role.can_see_channel(channel.visibility) {
2721            continue;
2722        }
2723
2724        let update = proto::UpdateChannels {
2725            channels: vec![channel.to_proto()],
2726            ..Default::default()
2727        };
2728        session.peer.send(connection_id, update.clone())?;
2729    }
2730
2731    Ok(())
2732}
2733
2734/// Delete a channel
2735async fn delete_channel(
2736    request: proto::DeleteChannel,
2737    response: Response<proto::DeleteChannel>,
2738    session: Session,
2739) -> Result<()> {
2740    let db = session.db().await;
2741
2742    let channel_id = request.channel_id;
2743    let (root_channel, removed_channels) = db
2744        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2745        .await?;
2746    response.send(proto::Ack {})?;
2747
2748    // Notify members of removed channels
2749    let mut update = proto::UpdateChannels::default();
2750    update
2751        .delete_channels
2752        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2753
2754    let connection_pool = session.connection_pool().await;
2755    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2756        session.peer.send(connection_id, update.clone())?;
2757    }
2758
2759    Ok(())
2760}
2761
2762/// Invite someone to join a channel.
2763async fn invite_channel_member(
2764    request: proto::InviteChannelMember,
2765    response: Response<proto::InviteChannelMember>,
2766    session: Session,
2767) -> Result<()> {
2768    let db = session.db().await;
2769    let channel_id = ChannelId::from_proto(request.channel_id);
2770    let invitee_id = UserId::from_proto(request.user_id);
2771    let InviteMemberResult {
2772        channel,
2773        notifications,
2774    } = db
2775        .invite_channel_member(
2776            channel_id,
2777            invitee_id,
2778            session.user_id(),
2779            request.role().into(),
2780        )
2781        .await?;
2782
2783    let update = proto::UpdateChannels {
2784        channel_invitations: vec![channel.to_proto()],
2785        ..Default::default()
2786    };
2787
2788    let connection_pool = session.connection_pool().await;
2789    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2790        session.peer.send(connection_id, update.clone())?;
2791    }
2792
2793    send_notifications(&connection_pool, &session.peer, notifications);
2794
2795    response.send(proto::Ack {})?;
2796    Ok(())
2797}
2798
2799/// remove someone from a channel
2800async fn remove_channel_member(
2801    request: proto::RemoveChannelMember,
2802    response: Response<proto::RemoveChannelMember>,
2803    session: Session,
2804) -> Result<()> {
2805    let db = session.db().await;
2806    let channel_id = ChannelId::from_proto(request.channel_id);
2807    let member_id = UserId::from_proto(request.user_id);
2808
2809    let RemoveChannelMemberResult {
2810        membership_update,
2811        notification_id,
2812    } = db
2813        .remove_channel_member(channel_id, member_id, session.user_id())
2814        .await?;
2815
2816    let mut connection_pool = session.connection_pool().await;
2817    notify_membership_updated(
2818        &mut connection_pool,
2819        membership_update,
2820        member_id,
2821        &session.peer,
2822    );
2823    for connection_id in connection_pool.user_connection_ids(member_id) {
2824        if let Some(notification_id) = notification_id {
2825            session
2826                .peer
2827                .send(
2828                    connection_id,
2829                    proto::DeleteNotification {
2830                        notification_id: notification_id.to_proto(),
2831                    },
2832                )
2833                .trace_err();
2834        }
2835    }
2836
2837    response.send(proto::Ack {})?;
2838    Ok(())
2839}
2840
2841/// Toggle the channel between public and private.
2842/// Care is taken to maintain the invariant that public channels only descend from public channels,
2843/// (though members-only channels can appear at any point in the hierarchy).
2844async fn set_channel_visibility(
2845    request: proto::SetChannelVisibility,
2846    response: Response<proto::SetChannelVisibility>,
2847    session: Session,
2848) -> Result<()> {
2849    let db = session.db().await;
2850    let channel_id = ChannelId::from_proto(request.channel_id);
2851    let visibility = request.visibility().into();
2852
2853    let channel_model = db
2854        .set_channel_visibility(channel_id, visibility, session.user_id())
2855        .await?;
2856    let root_id = channel_model.root_id();
2857    let channel = Channel::from_model(channel_model);
2858
2859    let mut connection_pool = session.connection_pool().await;
2860    for (user_id, role) in connection_pool
2861        .channel_user_ids(root_id)
2862        .collect::<Vec<_>>()
2863        .into_iter()
2864    {
2865        let update = if role.can_see_channel(channel.visibility) {
2866            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2867            proto::UpdateChannels {
2868                channels: vec![channel.to_proto()],
2869                ..Default::default()
2870            }
2871        } else {
2872            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2873            proto::UpdateChannels {
2874                delete_channels: vec![channel.id.to_proto()],
2875                ..Default::default()
2876            }
2877        };
2878
2879        for connection_id in connection_pool.user_connection_ids(user_id) {
2880            session.peer.send(connection_id, update.clone())?;
2881        }
2882    }
2883
2884    response.send(proto::Ack {})?;
2885    Ok(())
2886}
2887
2888/// Alter the role for a user in the channel.
2889async fn set_channel_member_role(
2890    request: proto::SetChannelMemberRole,
2891    response: Response<proto::SetChannelMemberRole>,
2892    session: Session,
2893) -> Result<()> {
2894    let db = session.db().await;
2895    let channel_id = ChannelId::from_proto(request.channel_id);
2896    let member_id = UserId::from_proto(request.user_id);
2897    let result = db
2898        .set_channel_member_role(
2899            channel_id,
2900            session.user_id(),
2901            member_id,
2902            request.role().into(),
2903        )
2904        .await?;
2905
2906    match result {
2907        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2908            let mut connection_pool = session.connection_pool().await;
2909            notify_membership_updated(
2910                &mut connection_pool,
2911                membership_update,
2912                member_id,
2913                &session.peer,
2914            )
2915        }
2916        db::SetMemberRoleResult::InviteUpdated(channel) => {
2917            let update = proto::UpdateChannels {
2918                channel_invitations: vec![channel.to_proto()],
2919                ..Default::default()
2920            };
2921
2922            for connection_id in session
2923                .connection_pool()
2924                .await
2925                .user_connection_ids(member_id)
2926            {
2927                session.peer.send(connection_id, update.clone())?;
2928            }
2929        }
2930    }
2931
2932    response.send(proto::Ack {})?;
2933    Ok(())
2934}
2935
2936/// Change the name of a channel
2937async fn rename_channel(
2938    request: proto::RenameChannel,
2939    response: Response<proto::RenameChannel>,
2940    session: Session,
2941) -> Result<()> {
2942    let db = session.db().await;
2943    let channel_id = ChannelId::from_proto(request.channel_id);
2944    let channel_model = db
2945        .rename_channel(channel_id, session.user_id(), &request.name)
2946        .await?;
2947    let root_id = channel_model.root_id();
2948    let channel = Channel::from_model(channel_model);
2949
2950    response.send(proto::RenameChannelResponse {
2951        channel: Some(channel.to_proto()),
2952    })?;
2953
2954    let connection_pool = session.connection_pool().await;
2955    let update = proto::UpdateChannels {
2956        channels: vec![channel.to_proto()],
2957        ..Default::default()
2958    };
2959    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2960        if role.can_see_channel(channel.visibility) {
2961            session.peer.send(connection_id, update.clone())?;
2962        }
2963    }
2964
2965    Ok(())
2966}
2967
2968/// Move a channel to a new parent.
2969async fn move_channel(
2970    request: proto::MoveChannel,
2971    response: Response<proto::MoveChannel>,
2972    session: Session,
2973) -> Result<()> {
2974    let channel_id = ChannelId::from_proto(request.channel_id);
2975    let to = ChannelId::from_proto(request.to);
2976
2977    let (root_id, channels) = session
2978        .db()
2979        .await
2980        .move_channel(channel_id, to, session.user_id())
2981        .await?;
2982
2983    let connection_pool = session.connection_pool().await;
2984    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2985        let channels = channels
2986            .iter()
2987            .filter_map(|channel| {
2988                if role.can_see_channel(channel.visibility) {
2989                    Some(channel.to_proto())
2990                } else {
2991                    None
2992                }
2993            })
2994            .collect::<Vec<_>>();
2995        if channels.is_empty() {
2996            continue;
2997        }
2998
2999        let update = proto::UpdateChannels {
3000            channels,
3001            ..Default::default()
3002        };
3003
3004        session.peer.send(connection_id, update.clone())?;
3005    }
3006
3007    response.send(Ack {})?;
3008    Ok(())
3009}
3010
3011/// Get the list of channel members
3012async fn get_channel_members(
3013    request: proto::GetChannelMembers,
3014    response: Response<proto::GetChannelMembers>,
3015    session: Session,
3016) -> Result<()> {
3017    let db = session.db().await;
3018    let channel_id = ChannelId::from_proto(request.channel_id);
3019    let limit = if request.limit == 0 {
3020        u16::MAX as u64
3021    } else {
3022        request.limit
3023    };
3024    let (members, users) = db
3025        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3026        .await?;
3027    response.send(proto::GetChannelMembersResponse { members, users })?;
3028    Ok(())
3029}
3030
3031/// Accept or decline a channel invitation.
3032async fn respond_to_channel_invite(
3033    request: proto::RespondToChannelInvite,
3034    response: Response<proto::RespondToChannelInvite>,
3035    session: Session,
3036) -> Result<()> {
3037    let db = session.db().await;
3038    let channel_id = ChannelId::from_proto(request.channel_id);
3039    let RespondToChannelInvite {
3040        membership_update,
3041        notifications,
3042    } = db
3043        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3044        .await?;
3045
3046    let mut connection_pool = session.connection_pool().await;
3047    if let Some(membership_update) = membership_update {
3048        notify_membership_updated(
3049            &mut connection_pool,
3050            membership_update,
3051            session.user_id(),
3052            &session.peer,
3053        );
3054    } else {
3055        let update = proto::UpdateChannels {
3056            remove_channel_invitations: vec![channel_id.to_proto()],
3057            ..Default::default()
3058        };
3059
3060        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3061            session.peer.send(connection_id, update.clone())?;
3062        }
3063    };
3064
3065    send_notifications(&connection_pool, &session.peer, notifications);
3066
3067    response.send(proto::Ack {})?;
3068
3069    Ok(())
3070}
3071
3072/// Join the channels' room
3073async fn join_channel(
3074    request: proto::JoinChannel,
3075    response: Response<proto::JoinChannel>,
3076    session: Session,
3077) -> Result<()> {
3078    let channel_id = ChannelId::from_proto(request.channel_id);
3079    join_channel_internal(channel_id, Box::new(response), session).await
3080}
3081
3082trait JoinChannelInternalResponse {
3083    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3084}
3085impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3086    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3087        Response::<proto::JoinChannel>::send(self, result)
3088    }
3089}
3090impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3091    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3092        Response::<proto::JoinRoom>::send(self, result)
3093    }
3094}
3095
3096async fn join_channel_internal(
3097    channel_id: ChannelId,
3098    response: Box<impl JoinChannelInternalResponse>,
3099    session: Session,
3100) -> Result<()> {
3101    let joined_room = {
3102        let mut db = session.db().await;
3103        // If zed quits without leaving the room, and the user re-opens zed before the
3104        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3105        // room they were in.
3106        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3107            tracing::info!(
3108                stale_connection_id = %connection,
3109                "cleaning up stale connection",
3110            );
3111            drop(db);
3112            leave_room_for_session(&session, connection).await?;
3113            db = session.db().await;
3114        }
3115
3116        let (joined_room, membership_updated, role) = db
3117            .join_channel(channel_id, session.user_id(), session.connection_id)
3118            .await?;
3119
3120        let live_kit_connection_info =
3121            session
3122                .app_state
3123                .livekit_client
3124                .as_ref()
3125                .and_then(|live_kit| {
3126                    let (can_publish, token) = if role == ChannelRole::Guest {
3127                        (
3128                            false,
3129                            live_kit
3130                                .guest_token(
3131                                    &joined_room.room.livekit_room,
3132                                    &session.user_id().to_string(),
3133                                )
3134                                .trace_err()?,
3135                        )
3136                    } else {
3137                        (
3138                            true,
3139                            live_kit
3140                                .room_token(
3141                                    &joined_room.room.livekit_room,
3142                                    &session.user_id().to_string(),
3143                                )
3144                                .trace_err()?,
3145                        )
3146                    };
3147
3148                    Some(LiveKitConnectionInfo {
3149                        server_url: live_kit.url().into(),
3150                        token,
3151                        can_publish,
3152                    })
3153                });
3154
3155        response.send(proto::JoinRoomResponse {
3156            room: Some(joined_room.room.clone()),
3157            channel_id: joined_room
3158                .channel
3159                .as_ref()
3160                .map(|channel| channel.id.to_proto()),
3161            live_kit_connection_info,
3162        })?;
3163
3164        let mut connection_pool = session.connection_pool().await;
3165        if let Some(membership_updated) = membership_updated {
3166            notify_membership_updated(
3167                &mut connection_pool,
3168                membership_updated,
3169                session.user_id(),
3170                &session.peer,
3171            );
3172        }
3173
3174        room_updated(&joined_room.room, &session.peer);
3175
3176        joined_room
3177    };
3178
3179    channel_updated(
3180        &joined_room
3181            .channel
3182            .ok_or_else(|| anyhow!("channel not returned"))?,
3183        &joined_room.room,
3184        &session.peer,
3185        &*session.connection_pool().await,
3186    );
3187
3188    update_user_contacts(session.user_id(), &session).await?;
3189    Ok(())
3190}
3191
3192/// Start editing the channel notes
3193async fn join_channel_buffer(
3194    request: proto::JoinChannelBuffer,
3195    response: Response<proto::JoinChannelBuffer>,
3196    session: Session,
3197) -> Result<()> {
3198    let db = session.db().await;
3199    let channel_id = ChannelId::from_proto(request.channel_id);
3200
3201    let open_response = db
3202        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3203        .await?;
3204
3205    let collaborators = open_response.collaborators.clone();
3206    response.send(open_response)?;
3207
3208    let update = UpdateChannelBufferCollaborators {
3209        channel_id: channel_id.to_proto(),
3210        collaborators: collaborators.clone(),
3211    };
3212    channel_buffer_updated(
3213        session.connection_id,
3214        collaborators
3215            .iter()
3216            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3217        &update,
3218        &session.peer,
3219    );
3220
3221    Ok(())
3222}
3223
3224/// Edit the channel notes
3225async fn update_channel_buffer(
3226    request: proto::UpdateChannelBuffer,
3227    session: Session,
3228) -> Result<()> {
3229    let db = session.db().await;
3230    let channel_id = ChannelId::from_proto(request.channel_id);
3231
3232    let (collaborators, epoch, version) = db
3233        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3234        .await?;
3235
3236    channel_buffer_updated(
3237        session.connection_id,
3238        collaborators.clone(),
3239        &proto::UpdateChannelBuffer {
3240            channel_id: channel_id.to_proto(),
3241            operations: request.operations,
3242        },
3243        &session.peer,
3244    );
3245
3246    let pool = &*session.connection_pool().await;
3247
3248    let non_collaborators =
3249        pool.channel_connection_ids(channel_id)
3250            .filter_map(|(connection_id, _)| {
3251                if collaborators.contains(&connection_id) {
3252                    None
3253                } else {
3254                    Some(connection_id)
3255                }
3256            });
3257
3258    broadcast(None, non_collaborators, |peer_id| {
3259        session.peer.send(
3260            peer_id,
3261            proto::UpdateChannels {
3262                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3263                    channel_id: channel_id.to_proto(),
3264                    epoch: epoch as u64,
3265                    version: version.clone(),
3266                }],
3267                ..Default::default()
3268            },
3269        )
3270    });
3271
3272    Ok(())
3273}
3274
3275/// Rejoin the channel notes after a connection blip
3276async fn rejoin_channel_buffers(
3277    request: proto::RejoinChannelBuffers,
3278    response: Response<proto::RejoinChannelBuffers>,
3279    session: Session,
3280) -> Result<()> {
3281    let db = session.db().await;
3282    let buffers = db
3283        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3284        .await?;
3285
3286    for rejoined_buffer in &buffers {
3287        let collaborators_to_notify = rejoined_buffer
3288            .buffer
3289            .collaborators
3290            .iter()
3291            .filter_map(|c| Some(c.peer_id?.into()));
3292        channel_buffer_updated(
3293            session.connection_id,
3294            collaborators_to_notify,
3295            &proto::UpdateChannelBufferCollaborators {
3296                channel_id: rejoined_buffer.buffer.channel_id,
3297                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3298            },
3299            &session.peer,
3300        );
3301    }
3302
3303    response.send(proto::RejoinChannelBuffersResponse {
3304        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3305    })?;
3306
3307    Ok(())
3308}
3309
3310/// Stop editing the channel notes
3311async fn leave_channel_buffer(
3312    request: proto::LeaveChannelBuffer,
3313    response: Response<proto::LeaveChannelBuffer>,
3314    session: Session,
3315) -> Result<()> {
3316    let db = session.db().await;
3317    let channel_id = ChannelId::from_proto(request.channel_id);
3318
3319    let left_buffer = db
3320        .leave_channel_buffer(channel_id, session.connection_id)
3321        .await?;
3322
3323    response.send(Ack {})?;
3324
3325    channel_buffer_updated(
3326        session.connection_id,
3327        left_buffer.connections,
3328        &proto::UpdateChannelBufferCollaborators {
3329            channel_id: channel_id.to_proto(),
3330            collaborators: left_buffer.collaborators,
3331        },
3332        &session.peer,
3333    );
3334
3335    Ok(())
3336}
3337
3338fn channel_buffer_updated<T: EnvelopedMessage>(
3339    sender_id: ConnectionId,
3340    collaborators: impl IntoIterator<Item = ConnectionId>,
3341    message: &T,
3342    peer: &Peer,
3343) {
3344    broadcast(Some(sender_id), collaborators, |peer_id| {
3345        peer.send(peer_id, message.clone())
3346    });
3347}
3348
3349fn send_notifications(
3350    connection_pool: &ConnectionPool,
3351    peer: &Peer,
3352    notifications: db::NotificationBatch,
3353) {
3354    for (user_id, notification) in notifications {
3355        for connection_id in connection_pool.user_connection_ids(user_id) {
3356            if let Err(error) = peer.send(
3357                connection_id,
3358                proto::AddNotification {
3359                    notification: Some(notification.clone()),
3360                },
3361            ) {
3362                tracing::error!(
3363                    "failed to send notification to {:?} {}",
3364                    connection_id,
3365                    error
3366                );
3367            }
3368        }
3369    }
3370}
3371
3372/// Send a message to the channel
3373async fn send_channel_message(
3374    request: proto::SendChannelMessage,
3375    response: Response<proto::SendChannelMessage>,
3376    session: Session,
3377) -> Result<()> {
3378    // Validate the message body.
3379    let body = request.body.trim().to_string();
3380    if body.len() > MAX_MESSAGE_LEN {
3381        return Err(anyhow!("message is too long"))?;
3382    }
3383    if body.is_empty() {
3384        return Err(anyhow!("message can't be blank"))?;
3385    }
3386
3387    // TODO: adjust mentions if body is trimmed
3388
3389    let timestamp = OffsetDateTime::now_utc();
3390    let nonce = request
3391        .nonce
3392        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3393
3394    let channel_id = ChannelId::from_proto(request.channel_id);
3395    let CreatedChannelMessage {
3396        message_id,
3397        participant_connection_ids,
3398        notifications,
3399    } = session
3400        .db()
3401        .await
3402        .create_channel_message(
3403            channel_id,
3404            session.user_id(),
3405            &body,
3406            &request.mentions,
3407            timestamp,
3408            nonce.clone().into(),
3409            request.reply_to_message_id.map(MessageId::from_proto),
3410        )
3411        .await?;
3412
3413    let message = proto::ChannelMessage {
3414        sender_id: session.user_id().to_proto(),
3415        id: message_id.to_proto(),
3416        body,
3417        mentions: request.mentions,
3418        timestamp: timestamp.unix_timestamp() as u64,
3419        nonce: Some(nonce),
3420        reply_to_message_id: request.reply_to_message_id,
3421        edited_at: None,
3422    };
3423    broadcast(
3424        Some(session.connection_id),
3425        participant_connection_ids.clone(),
3426        |connection| {
3427            session.peer.send(
3428                connection,
3429                proto::ChannelMessageSent {
3430                    channel_id: channel_id.to_proto(),
3431                    message: Some(message.clone()),
3432                },
3433            )
3434        },
3435    );
3436    response.send(proto::SendChannelMessageResponse {
3437        message: Some(message),
3438    })?;
3439
3440    let pool = &*session.connection_pool().await;
3441    let non_participants =
3442        pool.channel_connection_ids(channel_id)
3443            .filter_map(|(connection_id, _)| {
3444                if participant_connection_ids.contains(&connection_id) {
3445                    None
3446                } else {
3447                    Some(connection_id)
3448                }
3449            });
3450    broadcast(None, non_participants, |peer_id| {
3451        session.peer.send(
3452            peer_id,
3453            proto::UpdateChannels {
3454                latest_channel_message_ids: vec![proto::ChannelMessageId {
3455                    channel_id: channel_id.to_proto(),
3456                    message_id: message_id.to_proto(),
3457                }],
3458                ..Default::default()
3459            },
3460        )
3461    });
3462    send_notifications(pool, &session.peer, notifications);
3463
3464    Ok(())
3465}
3466
3467/// Delete a channel message
3468async fn remove_channel_message(
3469    request: proto::RemoveChannelMessage,
3470    response: Response<proto::RemoveChannelMessage>,
3471    session: Session,
3472) -> Result<()> {
3473    let channel_id = ChannelId::from_proto(request.channel_id);
3474    let message_id = MessageId::from_proto(request.message_id);
3475    let (connection_ids, existing_notification_ids) = session
3476        .db()
3477        .await
3478        .remove_channel_message(channel_id, message_id, session.user_id())
3479        .await?;
3480
3481    broadcast(
3482        Some(session.connection_id),
3483        connection_ids,
3484        move |connection| {
3485            session.peer.send(connection, request.clone())?;
3486
3487            for notification_id in &existing_notification_ids {
3488                session.peer.send(
3489                    connection,
3490                    proto::DeleteNotification {
3491                        notification_id: (*notification_id).to_proto(),
3492                    },
3493                )?;
3494            }
3495
3496            Ok(())
3497        },
3498    );
3499    response.send(proto::Ack {})?;
3500    Ok(())
3501}
3502
3503async fn update_channel_message(
3504    request: proto::UpdateChannelMessage,
3505    response: Response<proto::UpdateChannelMessage>,
3506    session: Session,
3507) -> Result<()> {
3508    let channel_id = ChannelId::from_proto(request.channel_id);
3509    let message_id = MessageId::from_proto(request.message_id);
3510    let updated_at = OffsetDateTime::now_utc();
3511    let UpdatedChannelMessage {
3512        message_id,
3513        participant_connection_ids,
3514        notifications,
3515        reply_to_message_id,
3516        timestamp,
3517        deleted_mention_notification_ids,
3518        updated_mention_notifications,
3519    } = session
3520        .db()
3521        .await
3522        .update_channel_message(
3523            channel_id,
3524            message_id,
3525            session.user_id(),
3526            request.body.as_str(),
3527            &request.mentions,
3528            updated_at,
3529        )
3530        .await?;
3531
3532    let nonce = request
3533        .nonce
3534        .clone()
3535        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3536
3537    let message = proto::ChannelMessage {
3538        sender_id: session.user_id().to_proto(),
3539        id: message_id.to_proto(),
3540        body: request.body.clone(),
3541        mentions: request.mentions.clone(),
3542        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3543        nonce: Some(nonce),
3544        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3545        edited_at: Some(updated_at.unix_timestamp() as u64),
3546    };
3547
3548    response.send(proto::Ack {})?;
3549
3550    let pool = &*session.connection_pool().await;
3551    broadcast(
3552        Some(session.connection_id),
3553        participant_connection_ids,
3554        |connection| {
3555            session.peer.send(
3556                connection,
3557                proto::ChannelMessageUpdate {
3558                    channel_id: channel_id.to_proto(),
3559                    message: Some(message.clone()),
3560                },
3561            )?;
3562
3563            for notification_id in &deleted_mention_notification_ids {
3564                session.peer.send(
3565                    connection,
3566                    proto::DeleteNotification {
3567                        notification_id: (*notification_id).to_proto(),
3568                    },
3569                )?;
3570            }
3571
3572            for notification in &updated_mention_notifications {
3573                session.peer.send(
3574                    connection,
3575                    proto::UpdateNotification {
3576                        notification: Some(notification.clone()),
3577                    },
3578                )?;
3579            }
3580
3581            Ok(())
3582        },
3583    );
3584
3585    send_notifications(pool, &session.peer, notifications);
3586
3587    Ok(())
3588}
3589
3590/// Mark a channel message as read
3591async fn acknowledge_channel_message(
3592    request: proto::AckChannelMessage,
3593    session: Session,
3594) -> Result<()> {
3595    let channel_id = ChannelId::from_proto(request.channel_id);
3596    let message_id = MessageId::from_proto(request.message_id);
3597    let notifications = session
3598        .db()
3599        .await
3600        .observe_channel_message(channel_id, session.user_id(), message_id)
3601        .await?;
3602    send_notifications(
3603        &*session.connection_pool().await,
3604        &session.peer,
3605        notifications,
3606    );
3607    Ok(())
3608}
3609
3610/// Mark a buffer version as synced
3611async fn acknowledge_buffer_version(
3612    request: proto::AckBufferOperation,
3613    session: Session,
3614) -> Result<()> {
3615    let buffer_id = BufferId::from_proto(request.buffer_id);
3616    session
3617        .db()
3618        .await
3619        .observe_buffer_version(
3620            buffer_id,
3621            session.user_id(),
3622            request.epoch as i32,
3623            &request.version,
3624        )
3625        .await?;
3626    Ok(())
3627}
3628
3629async fn count_language_model_tokens(
3630    request: proto::CountLanguageModelTokens,
3631    response: Response<proto::CountLanguageModelTokens>,
3632    session: Session,
3633    config: &Config,
3634) -> Result<()> {
3635    authorize_access_to_legacy_llm_endpoints(&session).await?;
3636
3637    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3638        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3639        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3640    };
3641
3642    session
3643        .app_state
3644        .rate_limiter
3645        .check(&*rate_limit, session.user_id())
3646        .await?;
3647
3648    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3649        Some(proto::LanguageModelProvider::Google) => {
3650            let api_key = config
3651                .google_ai_api_key
3652                .as_ref()
3653                .context("no Google AI API key configured on the server")?;
3654            google_ai::count_tokens(
3655                session.http_client.as_ref(),
3656                google_ai::API_URL,
3657                api_key,
3658                serde_json::from_str(&request.request)?,
3659            )
3660            .await?
3661        }
3662        _ => return Err(anyhow!("unsupported provider"))?,
3663    };
3664
3665    response.send(proto::CountLanguageModelTokensResponse {
3666        token_count: result.total_tokens as u32,
3667    })?;
3668
3669    Ok(())
3670}
3671
3672struct ZedProCountLanguageModelTokensRateLimit;
3673
3674impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3675    fn capacity(&self) -> usize {
3676        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3677            .ok()
3678            .and_then(|v| v.parse().ok())
3679            .unwrap_or(600) // Picked arbitrarily
3680    }
3681
3682    fn refill_duration(&self) -> chrono::Duration {
3683        chrono::Duration::hours(1)
3684    }
3685
3686    fn db_name(&self) -> &'static str {
3687        "zed-pro:count-language-model-tokens"
3688    }
3689}
3690
3691struct FreeCountLanguageModelTokensRateLimit;
3692
3693impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3694    fn capacity(&self) -> usize {
3695        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3696            .ok()
3697            .and_then(|v| v.parse().ok())
3698            .unwrap_or(600 / 10) // Picked arbitrarily
3699    }
3700
3701    fn refill_duration(&self) -> chrono::Duration {
3702        chrono::Duration::hours(1)
3703    }
3704
3705    fn db_name(&self) -> &'static str {
3706        "free:count-language-model-tokens"
3707    }
3708}
3709
3710struct ZedProComputeEmbeddingsRateLimit;
3711
3712impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3713    fn capacity(&self) -> usize {
3714        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3715            .ok()
3716            .and_then(|v| v.parse().ok())
3717            .unwrap_or(5000) // Picked arbitrarily
3718    }
3719
3720    fn refill_duration(&self) -> chrono::Duration {
3721        chrono::Duration::hours(1)
3722    }
3723
3724    fn db_name(&self) -> &'static str {
3725        "zed-pro:compute-embeddings"
3726    }
3727}
3728
3729struct FreeComputeEmbeddingsRateLimit;
3730
3731impl RateLimit for FreeComputeEmbeddingsRateLimit {
3732    fn capacity(&self) -> usize {
3733        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3734            .ok()
3735            .and_then(|v| v.parse().ok())
3736            .unwrap_or(5000 / 10) // Picked arbitrarily
3737    }
3738
3739    fn refill_duration(&self) -> chrono::Duration {
3740        chrono::Duration::hours(1)
3741    }
3742
3743    fn db_name(&self) -> &'static str {
3744        "free:compute-embeddings"
3745    }
3746}
3747
3748async fn compute_embeddings(
3749    request: proto::ComputeEmbeddings,
3750    response: Response<proto::ComputeEmbeddings>,
3751    session: Session,
3752    api_key: Option<Arc<str>>,
3753) -> Result<()> {
3754    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3755    authorize_access_to_legacy_llm_endpoints(&session).await?;
3756
3757    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3758        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3759        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3760    };
3761
3762    session
3763        .app_state
3764        .rate_limiter
3765        .check(&*rate_limit, session.user_id())
3766        .await?;
3767
3768    let embeddings = match request.model.as_str() {
3769        "openai/text-embedding-3-small" => {
3770            open_ai::embed(
3771                session.http_client.as_ref(),
3772                OPEN_AI_API_URL,
3773                &api_key,
3774                OpenAiEmbeddingModel::TextEmbedding3Small,
3775                request.texts.iter().map(|text| text.as_str()),
3776            )
3777            .await?
3778        }
3779        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3780    };
3781
3782    let embeddings = request
3783        .texts
3784        .iter()
3785        .map(|text| {
3786            let mut hasher = sha2::Sha256::new();
3787            hasher.update(text.as_bytes());
3788            let result = hasher.finalize();
3789            result.to_vec()
3790        })
3791        .zip(
3792            embeddings
3793                .data
3794                .into_iter()
3795                .map(|embedding| embedding.embedding),
3796        )
3797        .collect::<HashMap<_, _>>();
3798
3799    let db = session.db().await;
3800    db.save_embeddings(&request.model, &embeddings)
3801        .await
3802        .context("failed to save embeddings")
3803        .trace_err();
3804
3805    response.send(proto::ComputeEmbeddingsResponse {
3806        embeddings: embeddings
3807            .into_iter()
3808            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3809            .collect(),
3810    })?;
3811    Ok(())
3812}
3813
3814async fn get_cached_embeddings(
3815    request: proto::GetCachedEmbeddings,
3816    response: Response<proto::GetCachedEmbeddings>,
3817    session: Session,
3818) -> Result<()> {
3819    authorize_access_to_legacy_llm_endpoints(&session).await?;
3820
3821    let db = session.db().await;
3822    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3823
3824    response.send(proto::GetCachedEmbeddingsResponse {
3825        embeddings: embeddings
3826            .into_iter()
3827            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3828            .collect(),
3829    })?;
3830    Ok(())
3831}
3832
3833/// This is leftover from before the LLM service.
3834///
3835/// The endpoints protected by this check will be moved there eventually.
3836async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3837    if session.is_staff() {
3838        Ok(())
3839    } else {
3840        Err(anyhow!("permission denied"))?
3841    }
3842}
3843
3844/// Get a Supermaven API key for the user
3845async fn get_supermaven_api_key(
3846    _request: proto::GetSupermavenApiKey,
3847    response: Response<proto::GetSupermavenApiKey>,
3848    session: Session,
3849) -> Result<()> {
3850    let user_id: String = session.user_id().to_string();
3851    if !session.is_staff() {
3852        return Err(anyhow!("supermaven not enabled for this account"))?;
3853    }
3854
3855    let email = session
3856        .email()
3857        .ok_or_else(|| anyhow!("user must have an email"))?;
3858
3859    let supermaven_admin_api = session
3860        .supermaven_client
3861        .as_ref()
3862        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3863
3864    let result = supermaven_admin_api
3865        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3866        .await?;
3867
3868    response.send(proto::GetSupermavenApiKeyResponse {
3869        api_key: result.api_key,
3870    })?;
3871
3872    Ok(())
3873}
3874
3875/// Start receiving chat updates for a channel
3876async fn join_channel_chat(
3877    request: proto::JoinChannelChat,
3878    response: Response<proto::JoinChannelChat>,
3879    session: Session,
3880) -> Result<()> {
3881    let channel_id = ChannelId::from_proto(request.channel_id);
3882
3883    let db = session.db().await;
3884    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3885        .await?;
3886    let messages = db
3887        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3888        .await?;
3889    response.send(proto::JoinChannelChatResponse {
3890        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3891        messages,
3892    })?;
3893    Ok(())
3894}
3895
3896/// Stop receiving chat updates for a channel
3897async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3898    let channel_id = ChannelId::from_proto(request.channel_id);
3899    session
3900        .db()
3901        .await
3902        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3903        .await?;
3904    Ok(())
3905}
3906
3907/// Retrieve the chat history for a channel
3908async fn get_channel_messages(
3909    request: proto::GetChannelMessages,
3910    response: Response<proto::GetChannelMessages>,
3911    session: Session,
3912) -> Result<()> {
3913    let channel_id = ChannelId::from_proto(request.channel_id);
3914    let messages = session
3915        .db()
3916        .await
3917        .get_channel_messages(
3918            channel_id,
3919            session.user_id(),
3920            MESSAGE_COUNT_PER_PAGE,
3921            Some(MessageId::from_proto(request.before_message_id)),
3922        )
3923        .await?;
3924    response.send(proto::GetChannelMessagesResponse {
3925        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3926        messages,
3927    })?;
3928    Ok(())
3929}
3930
3931/// Retrieve specific chat messages
3932async fn get_channel_messages_by_id(
3933    request: proto::GetChannelMessagesById,
3934    response: Response<proto::GetChannelMessagesById>,
3935    session: Session,
3936) -> Result<()> {
3937    let message_ids = request
3938        .message_ids
3939        .iter()
3940        .map(|id| MessageId::from_proto(*id))
3941        .collect::<Vec<_>>();
3942    let messages = session
3943        .db()
3944        .await
3945        .get_channel_messages_by_id(session.user_id(), &message_ids)
3946        .await?;
3947    response.send(proto::GetChannelMessagesResponse {
3948        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3949        messages,
3950    })?;
3951    Ok(())
3952}
3953
3954/// Retrieve the current users notifications
3955async fn get_notifications(
3956    request: proto::GetNotifications,
3957    response: Response<proto::GetNotifications>,
3958    session: Session,
3959) -> Result<()> {
3960    let notifications = session
3961        .db()
3962        .await
3963        .get_notifications(
3964            session.user_id(),
3965            NOTIFICATION_COUNT_PER_PAGE,
3966            request.before_id.map(db::NotificationId::from_proto),
3967        )
3968        .await?;
3969    response.send(proto::GetNotificationsResponse {
3970        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3971        notifications,
3972    })?;
3973    Ok(())
3974}
3975
3976/// Mark notifications as read
3977async fn mark_notification_as_read(
3978    request: proto::MarkNotificationRead,
3979    response: Response<proto::MarkNotificationRead>,
3980    session: Session,
3981) -> Result<()> {
3982    let database = &session.db().await;
3983    let notifications = database
3984        .mark_notification_as_read_by_id(
3985            session.user_id(),
3986            NotificationId::from_proto(request.notification_id),
3987        )
3988        .await?;
3989    send_notifications(
3990        &*session.connection_pool().await,
3991        &session.peer,
3992        notifications,
3993    );
3994    response.send(proto::Ack {})?;
3995    Ok(())
3996}
3997
3998/// Get the current users information
3999async fn get_private_user_info(
4000    _request: proto::GetPrivateUserInfo,
4001    response: Response<proto::GetPrivateUserInfo>,
4002    session: Session,
4003) -> Result<()> {
4004    let db = session.db().await;
4005
4006    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4007    let user = db
4008        .get_user_by_id(session.user_id())
4009        .await?
4010        .ok_or_else(|| anyhow!("user not found"))?;
4011    let flags = db.get_user_flags(session.user_id()).await?;
4012
4013    response.send(proto::GetPrivateUserInfoResponse {
4014        metrics_id,
4015        staff: user.admin,
4016        flags,
4017        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4018    })?;
4019    Ok(())
4020}
4021
4022/// Accept the terms of service (tos) on behalf of the current user
4023async fn accept_terms_of_service(
4024    _request: proto::AcceptTermsOfService,
4025    response: Response<proto::AcceptTermsOfService>,
4026    session: Session,
4027) -> Result<()> {
4028    let db = session.db().await;
4029
4030    let accepted_tos_at = Utc::now();
4031    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4032        .await?;
4033
4034    response.send(proto::AcceptTermsOfServiceResponse {
4035        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4036    })?;
4037    Ok(())
4038}
4039
4040/// The minimum account age an account must have in order to use the LLM service.
4041pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4042
4043async fn get_llm_api_token(
4044    _request: proto::GetLlmToken,
4045    response: Response<proto::GetLlmToken>,
4046    session: Session,
4047) -> Result<()> {
4048    let db = session.db().await;
4049
4050    let flags = db.get_user_flags(session.user_id()).await?;
4051    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4052
4053    if !session.is_staff() && !has_language_models_feature_flag {
4054        Err(anyhow!("permission denied"))?
4055    }
4056
4057    let user_id = session.user_id();
4058    let user = db
4059        .get_user_by_id(user_id)
4060        .await?
4061        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4062
4063    if user.accepted_tos_at.is_none() {
4064        Err(anyhow!("terms of service not accepted"))?
4065    }
4066
4067    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4068    let billing_preferences = db.get_billing_preferences(user.id).await?;
4069
4070    let token = LlmTokenClaims::create(
4071        &user,
4072        session.is_staff(),
4073        billing_preferences,
4074        &flags,
4075        has_llm_subscription,
4076        session.current_plan(&db).await?,
4077        session.system_id.clone(),
4078        &session.app_state.config,
4079    )?;
4080    response.send(proto::GetLlmTokenResponse { token })?;
4081    Ok(())
4082}
4083
4084fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4085    let message = match message {
4086        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4087        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4088        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4089        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4090        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4091            code: frame.code.into(),
4092            reason: frame.reason,
4093        })),
4094        // We should never receive a frame while reading the message, according
4095        // to the `tungstenite` maintainers:
4096        //
4097        // > It cannot occur when you read messages from the WebSocket, but it
4098        // > can be used when you want to send the raw frames (e.g. you want to
4099        // > send the frames to the WebSocket without composing the full message first).
4100        // >
4101        // > — https://github.com/snapview/tungstenite-rs/issues/268
4102        TungsteniteMessage::Frame(_) => {
4103            bail!("received an unexpected frame while reading the message")
4104        }
4105    };
4106
4107    Ok(message)
4108}
4109
4110fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4111    match message {
4112        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4113        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4114        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4115        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4116        AxumMessage::Close(frame) => {
4117            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4118                code: frame.code.into(),
4119                reason: frame.reason,
4120            }))
4121        }
4122    }
4123}
4124
4125fn notify_membership_updated(
4126    connection_pool: &mut ConnectionPool,
4127    result: MembershipUpdated,
4128    user_id: UserId,
4129    peer: &Peer,
4130) {
4131    for membership in &result.new_channels.channel_memberships {
4132        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4133    }
4134    for channel_id in &result.removed_channels {
4135        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4136    }
4137
4138    let user_channels_update = proto::UpdateUserChannels {
4139        channel_memberships: result
4140            .new_channels
4141            .channel_memberships
4142            .iter()
4143            .map(|cm| proto::ChannelMembership {
4144                channel_id: cm.channel_id.to_proto(),
4145                role: cm.role.into(),
4146            })
4147            .collect(),
4148        ..Default::default()
4149    };
4150
4151    let mut update = build_channels_update(result.new_channels);
4152    update.delete_channels = result
4153        .removed_channels
4154        .into_iter()
4155        .map(|id| id.to_proto())
4156        .collect();
4157    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4158
4159    for connection_id in connection_pool.user_connection_ids(user_id) {
4160        peer.send(connection_id, user_channels_update.clone())
4161            .trace_err();
4162        peer.send(connection_id, update.clone()).trace_err();
4163    }
4164}
4165
4166fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4167    proto::UpdateUserChannels {
4168        channel_memberships: channels
4169            .channel_memberships
4170            .iter()
4171            .map(|m| proto::ChannelMembership {
4172                channel_id: m.channel_id.to_proto(),
4173                role: m.role.into(),
4174            })
4175            .collect(),
4176        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4177        observed_channel_message_id: channels.observed_channel_messages.clone(),
4178    }
4179}
4180
4181fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4182    let mut update = proto::UpdateChannels::default();
4183
4184    for channel in channels.channels {
4185        update.channels.push(channel.to_proto());
4186    }
4187
4188    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4189    update.latest_channel_message_ids = channels.latest_channel_messages;
4190
4191    for (channel_id, participants) in channels.channel_participants {
4192        update
4193            .channel_participants
4194            .push(proto::ChannelParticipants {
4195                channel_id: channel_id.to_proto(),
4196                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4197            });
4198    }
4199
4200    for channel in channels.invited_channels {
4201        update.channel_invitations.push(channel.to_proto());
4202    }
4203
4204    update
4205}
4206
4207fn build_initial_contacts_update(
4208    contacts: Vec<db::Contact>,
4209    pool: &ConnectionPool,
4210) -> proto::UpdateContacts {
4211    let mut update = proto::UpdateContacts::default();
4212
4213    for contact in contacts {
4214        match contact {
4215            db::Contact::Accepted { user_id, busy } => {
4216                update.contacts.push(contact_for_user(user_id, busy, pool));
4217            }
4218            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4219            db::Contact::Incoming { user_id } => {
4220                update
4221                    .incoming_requests
4222                    .push(proto::IncomingContactRequest {
4223                        requester_id: user_id.to_proto(),
4224                    })
4225            }
4226        }
4227    }
4228
4229    update
4230}
4231
4232fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4233    proto::Contact {
4234        user_id: user_id.to_proto(),
4235        online: pool.is_user_online(user_id),
4236        busy,
4237    }
4238}
4239
4240fn room_updated(room: &proto::Room, peer: &Peer) {
4241    broadcast(
4242        None,
4243        room.participants
4244            .iter()
4245            .filter_map(|participant| Some(participant.peer_id?.into())),
4246        |peer_id| {
4247            peer.send(
4248                peer_id,
4249                proto::RoomUpdated {
4250                    room: Some(room.clone()),
4251                },
4252            )
4253        },
4254    );
4255}
4256
4257fn channel_updated(
4258    channel: &db::channel::Model,
4259    room: &proto::Room,
4260    peer: &Peer,
4261    pool: &ConnectionPool,
4262) {
4263    let participants = room
4264        .participants
4265        .iter()
4266        .map(|p| p.user_id)
4267        .collect::<Vec<_>>();
4268
4269    broadcast(
4270        None,
4271        pool.channel_connection_ids(channel.root_id())
4272            .filter_map(|(channel_id, role)| {
4273                role.can_see_channel(channel.visibility)
4274                    .then_some(channel_id)
4275            }),
4276        |peer_id| {
4277            peer.send(
4278                peer_id,
4279                proto::UpdateChannels {
4280                    channel_participants: vec![proto::ChannelParticipants {
4281                        channel_id: channel.id.to_proto(),
4282                        participant_user_ids: participants.clone(),
4283                    }],
4284                    ..Default::default()
4285                },
4286            )
4287        },
4288    );
4289}
4290
4291async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4292    let db = session.db().await;
4293
4294    let contacts = db.get_contacts(user_id).await?;
4295    let busy = db.is_user_busy(user_id).await?;
4296
4297    let pool = session.connection_pool().await;
4298    let updated_contact = contact_for_user(user_id, busy, &pool);
4299    for contact in contacts {
4300        if let db::Contact::Accepted {
4301            user_id: contact_user_id,
4302            ..
4303        } = contact
4304        {
4305            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4306                session
4307                    .peer
4308                    .send(
4309                        contact_conn_id,
4310                        proto::UpdateContacts {
4311                            contacts: vec![updated_contact.clone()],
4312                            remove_contacts: Default::default(),
4313                            incoming_requests: Default::default(),
4314                            remove_incoming_requests: Default::default(),
4315                            outgoing_requests: Default::default(),
4316                            remove_outgoing_requests: Default::default(),
4317                        },
4318                    )
4319                    .trace_err();
4320            }
4321        }
4322    }
4323    Ok(())
4324}
4325
4326async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4327    let mut contacts_to_update = HashSet::default();
4328
4329    let room_id;
4330    let canceled_calls_to_user_ids;
4331    let livekit_room;
4332    let delete_livekit_room;
4333    let room;
4334    let channel;
4335
4336    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4337        contacts_to_update.insert(session.user_id());
4338
4339        for project in left_room.left_projects.values() {
4340            project_left(project, session);
4341        }
4342
4343        room_id = RoomId::from_proto(left_room.room.id);
4344        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4345        livekit_room = mem::take(&mut left_room.room.livekit_room);
4346        delete_livekit_room = left_room.deleted;
4347        room = mem::take(&mut left_room.room);
4348        channel = mem::take(&mut left_room.channel);
4349
4350        room_updated(&room, &session.peer);
4351    } else {
4352        return Ok(());
4353    }
4354
4355    if let Some(channel) = channel {
4356        channel_updated(
4357            &channel,
4358            &room,
4359            &session.peer,
4360            &*session.connection_pool().await,
4361        );
4362    }
4363
4364    {
4365        let pool = session.connection_pool().await;
4366        for canceled_user_id in canceled_calls_to_user_ids {
4367            for connection_id in pool.user_connection_ids(canceled_user_id) {
4368                session
4369                    .peer
4370                    .send(
4371                        connection_id,
4372                        proto::CallCanceled {
4373                            room_id: room_id.to_proto(),
4374                        },
4375                    )
4376                    .trace_err();
4377            }
4378            contacts_to_update.insert(canceled_user_id);
4379        }
4380    }
4381
4382    for contact_user_id in contacts_to_update {
4383        update_user_contacts(contact_user_id, session).await?;
4384    }
4385
4386    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4387        live_kit
4388            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4389            .await
4390            .trace_err();
4391
4392        if delete_livekit_room {
4393            live_kit.delete_room(livekit_room).await.trace_err();
4394        }
4395    }
4396
4397    Ok(())
4398}
4399
4400async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4401    let left_channel_buffers = session
4402        .db()
4403        .await
4404        .leave_channel_buffers(session.connection_id)
4405        .await?;
4406
4407    for left_buffer in left_channel_buffers {
4408        channel_buffer_updated(
4409            session.connection_id,
4410            left_buffer.connections,
4411            &proto::UpdateChannelBufferCollaborators {
4412                channel_id: left_buffer.channel_id.to_proto(),
4413                collaborators: left_buffer.collaborators,
4414            },
4415            &session.peer,
4416        );
4417    }
4418
4419    Ok(())
4420}
4421
4422fn project_left(project: &db::LeftProject, session: &Session) {
4423    for connection_id in &project.connection_ids {
4424        if project.should_unshare {
4425            session
4426                .peer
4427                .send(
4428                    *connection_id,
4429                    proto::UnshareProject {
4430                        project_id: project.id.to_proto(),
4431                    },
4432                )
4433                .trace_err();
4434        } else {
4435            session
4436                .peer
4437                .send(
4438                    *connection_id,
4439                    proto::RemoveProjectCollaborator {
4440                        project_id: project.id.to_proto(),
4441                        peer_id: Some(session.connection_id.into()),
4442                    },
4443                )
4444                .trace_err();
4445        }
4446    }
4447}
4448
4449pub trait ResultExt {
4450    type Ok;
4451
4452    fn trace_err(self) -> Option<Self::Ok>;
4453}
4454
4455impl<T, E> ResultExt for Result<T, E>
4456where
4457    E: std::fmt::Debug,
4458{
4459    type Ok = T;
4460
4461    #[track_caller]
4462    fn trace_err(self) -> Option<T> {
4463        match self {
4464            Ok(value) => Some(value),
4465            Err(error) => {
4466                tracing::error!("{:?}", error);
4467                None
4468            }
4469        }
4470    }
4471}