rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 332            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 338            .add_request_handler(
 339                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 340            )
 341            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 342            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 345            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 347            .add_message_handler(create_buffer_for_peer)
 348            .add_request_handler(update_buffer)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 354            .add_request_handler(get_users)
 355            .add_request_handler(fuzzy_search_users)
 356            .add_request_handler(request_contact)
 357            .add_request_handler(remove_contact)
 358            .add_request_handler(respond_to_contact_request)
 359            .add_message_handler(subscribe_to_channels)
 360            .add_request_handler(create_channel)
 361            .add_request_handler(delete_channel)
 362            .add_request_handler(invite_channel_member)
 363            .add_request_handler(remove_channel_member)
 364            .add_request_handler(set_channel_member_role)
 365            .add_request_handler(set_channel_visibility)
 366            .add_request_handler(rename_channel)
 367            .add_request_handler(join_channel_buffer)
 368            .add_request_handler(leave_channel_buffer)
 369            .add_message_handler(update_channel_buffer)
 370            .add_request_handler(rejoin_channel_buffers)
 371            .add_request_handler(get_channel_members)
 372            .add_request_handler(respond_to_channel_invite)
 373            .add_request_handler(join_channel)
 374            .add_request_handler(join_channel_chat)
 375            .add_message_handler(leave_channel_chat)
 376            .add_request_handler(send_channel_message)
 377            .add_request_handler(remove_channel_message)
 378            .add_request_handler(update_channel_message)
 379            .add_request_handler(get_channel_messages)
 380            .add_request_handler(get_channel_messages_by_id)
 381            .add_request_handler(get_notifications)
 382            .add_request_handler(mark_notification_as_read)
 383            .add_request_handler(move_channel)
 384            .add_request_handler(follow)
 385            .add_message_handler(unfollow)
 386            .add_message_handler(update_followers)
 387            .add_request_handler(get_private_user_info)
 388            .add_request_handler(get_llm_api_token)
 389            .add_request_handler(accept_terms_of_service)
 390            .add_message_handler(acknowledge_channel_message)
 391            .add_message_handler(acknowledge_buffer_version)
 392            .add_request_handler(get_supermaven_api_key)
 393            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 395            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 398            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 399            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 400            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 401            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 402            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 403            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 404            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 405            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 406            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 407            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 410            .add_message_handler(update_context)
 411            .add_request_handler({
 412                let app_state = app_state.clone();
 413                move |request, response, session| {
 414                    let app_state = app_state.clone();
 415                    async move {
 416                        count_language_model_tokens(request, response, session, &app_state.config)
 417                            .await
 418                    }
 419                }
 420            })
 421            .add_request_handler(get_cached_embeddings)
 422            .add_request_handler({
 423                let app_state = app_state.clone();
 424                move |request, response, session| {
 425                    compute_embeddings(
 426                        request,
 427                        response,
 428                        session,
 429                        app_state.config.openai_api_key.clone(),
 430                    )
 431                }
 432            });
 433
 434        Arc::new(server)
 435    }
 436
 437    pub async fn start(&self) -> Result<()> {
 438        let server_id = *self.id.lock();
 439        let app_state = self.app_state.clone();
 440        let peer = self.peer.clone();
 441        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 442        let pool = self.connection_pool.clone();
 443        let livekit_client = self.app_state.livekit_client.clone();
 444
 445        let span = info_span!("start server");
 446        self.app_state.executor.spawn_detached(
 447            async move {
 448                tracing::info!("waiting for cleanup timeout");
 449                timeout.await;
 450                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 451                if let Some((room_ids, channel_ids)) = app_state
 452                    .db
 453                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 454                    .await
 455                    .trace_err()
 456                {
 457                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 458                    tracing::info!(
 459                        stale_channel_buffer_count = channel_ids.len(),
 460                        "retrieved stale channel buffers"
 461                    );
 462
 463                    for channel_id in channel_ids {
 464                        if let Some(refreshed_channel_buffer) = app_state
 465                            .db
 466                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 467                            .await
 468                            .trace_err()
 469                        {
 470                            for connection_id in refreshed_channel_buffer.connection_ids {
 471                                peer.send(
 472                                    connection_id,
 473                                    proto::UpdateChannelBufferCollaborators {
 474                                        channel_id: channel_id.to_proto(),
 475                                        collaborators: refreshed_channel_buffer
 476                                            .collaborators
 477                                            .clone(),
 478                                    },
 479                                )
 480                                .trace_err();
 481                            }
 482                        }
 483                    }
 484
 485                    for room_id in room_ids {
 486                        let mut contacts_to_update = HashSet::default();
 487                        let mut canceled_calls_to_user_ids = Vec::new();
 488                        let mut livekit_room = String::new();
 489                        let mut delete_livekit_room = false;
 490
 491                        if let Some(mut refreshed_room) = app_state
 492                            .db
 493                            .clear_stale_room_participants(room_id, server_id)
 494                            .await
 495                            .trace_err()
 496                        {
 497                            tracing::info!(
 498                                room_id = room_id.0,
 499                                new_participant_count = refreshed_room.room.participants.len(),
 500                                "refreshed room"
 501                            );
 502                            room_updated(&refreshed_room.room, &peer);
 503                            if let Some(channel) = refreshed_room.channel.as_ref() {
 504                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 505                            }
 506                            contacts_to_update
 507                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 508                            contacts_to_update
 509                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 510                            canceled_calls_to_user_ids =
 511                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 512                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 513                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 514                        }
 515
 516                        {
 517                            let pool = pool.lock();
 518                            for canceled_user_id in canceled_calls_to_user_ids {
 519                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 520                                    peer.send(
 521                                        connection_id,
 522                                        proto::CallCanceled {
 523                                            room_id: room_id.to_proto(),
 524                                        },
 525                                    )
 526                                    .trace_err();
 527                                }
 528                            }
 529                        }
 530
 531                        for user_id in contacts_to_update {
 532                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 533                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 534                            if let Some((busy, contacts)) = busy.zip(contacts) {
 535                                let pool = pool.lock();
 536                                let updated_contact = contact_for_user(user_id, busy, &pool);
 537                                for contact in contacts {
 538                                    if let db::Contact::Accepted {
 539                                        user_id: contact_user_id,
 540                                        ..
 541                                    } = contact
 542                                    {
 543                                        for contact_conn_id in
 544                                            pool.user_connection_ids(contact_user_id)
 545                                        {
 546                                            peer.send(
 547                                                contact_conn_id,
 548                                                proto::UpdateContacts {
 549                                                    contacts: vec![updated_contact.clone()],
 550                                                    remove_contacts: Default::default(),
 551                                                    incoming_requests: Default::default(),
 552                                                    remove_incoming_requests: Default::default(),
 553                                                    outgoing_requests: Default::default(),
 554                                                    remove_outgoing_requests: Default::default(),
 555                                                },
 556                                            )
 557                                            .trace_err();
 558                                        }
 559                                    }
 560                                }
 561                            }
 562                        }
 563
 564                        if let Some(live_kit) = livekit_client.as_ref() {
 565                            if delete_livekit_room {
 566                                live_kit.delete_room(livekit_room).await.trace_err();
 567                            }
 568                        }
 569                    }
 570                }
 571
 572                app_state
 573                    .db
 574                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 575                    .await
 576                    .trace_err();
 577            }
 578            .instrument(span),
 579        );
 580        Ok(())
 581    }
 582
 583    pub fn teardown(&self) {
 584        self.peer.teardown();
 585        self.connection_pool.lock().reset();
 586        let _ = self.teardown.send(true);
 587    }
 588
 589    #[cfg(test)]
 590    pub fn reset(&self, id: ServerId) {
 591        self.teardown();
 592        *self.id.lock() = id;
 593        self.peer.reset(id.0 as u32);
 594        let _ = self.teardown.send(false);
 595    }
 596
 597    #[cfg(test)]
 598    pub fn id(&self) -> ServerId {
 599        *self.id.lock()
 600    }
 601
 602    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 603    where
 604        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 605        Fut: 'static + Send + Future<Output = Result<()>>,
 606        M: EnvelopedMessage,
 607    {
 608        let prev_handler = self.handlers.insert(
 609            TypeId::of::<M>(),
 610            Box::new(move |envelope, session| {
 611                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 612                let received_at = envelope.received_at;
 613                tracing::info!("message received");
 614                let start_time = Instant::now();
 615                let future = (handler)(*envelope, session);
 616                async move {
 617                    let result = future.await;
 618                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 619                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 620                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 621                    let payload_type = M::NAME;
 622
 623                    match result {
 624                        Err(error) => {
 625                            tracing::error!(
 626                                ?error,
 627                                total_duration_ms,
 628                                processing_duration_ms,
 629                                queue_duration_ms,
 630                                payload_type,
 631                                "error handling message"
 632                            )
 633                        }
 634                        Ok(()) => tracing::info!(
 635                            total_duration_ms,
 636                            processing_duration_ms,
 637                            queue_duration_ms,
 638                            "finished handling message"
 639                        ),
 640                    }
 641                }
 642                .boxed()
 643            }),
 644        );
 645        if prev_handler.is_some() {
 646            panic!("registered a handler for the same message twice");
 647        }
 648        self
 649    }
 650
 651    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 652    where
 653        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 654        Fut: 'static + Send + Future<Output = Result<()>>,
 655        M: EnvelopedMessage,
 656    {
 657        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 658        self
 659    }
 660
 661    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 662    where
 663        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 664        Fut: Send + Future<Output = Result<()>>,
 665        M: RequestMessage,
 666    {
 667        let handler = Arc::new(handler);
 668        self.add_handler(move |envelope, session| {
 669            let receipt = envelope.receipt();
 670            let handler = handler.clone();
 671            async move {
 672                let peer = session.peer.clone();
 673                let responded = Arc::new(AtomicBool::default());
 674                let response = Response {
 675                    peer: peer.clone(),
 676                    responded: responded.clone(),
 677                    receipt,
 678                };
 679                match (handler)(envelope.payload, response, session).await {
 680                    Ok(()) => {
 681                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 682                            Ok(())
 683                        } else {
 684                            Err(anyhow!("handler did not send a response"))?
 685                        }
 686                    }
 687                    Err(error) => {
 688                        let proto_err = match &error {
 689                            Error::Internal(err) => err.to_proto(),
 690                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 691                        };
 692                        peer.respond_with_error(receipt, proto_err)?;
 693                        Err(error)
 694                    }
 695                }
 696            }
 697        })
 698    }
 699
 700    pub fn handle_connection(
 701        self: &Arc<Self>,
 702        connection: Connection,
 703        address: String,
 704        principal: Principal,
 705        zed_version: ZedVersion,
 706        geoip_country_code: Option<String>,
 707        system_id: Option<String>,
 708        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 709        executor: Executor,
 710    ) -> impl Future<Output = ()> {
 711        let this = self.clone();
 712        let span = info_span!("handle connection", %address,
 713            connection_id=field::Empty,
 714            user_id=field::Empty,
 715            login=field::Empty,
 716            impersonator=field::Empty,
 717            geoip_country_code=field::Empty
 718        );
 719        principal.update_span(&span);
 720        if let Some(country_code) = geoip_country_code.as_ref() {
 721            span.record("geoip_country_code", country_code);
 722        }
 723
 724        let mut teardown = self.teardown.subscribe();
 725        async move {
 726            if *teardown.borrow() {
 727                tracing::error!("server is tearing down");
 728                return
 729            }
 730            let (connection_id, handle_io, mut incoming_rx) = this
 731                .peer
 732                .add_connection(connection, {
 733                    let executor = executor.clone();
 734                    move |duration| executor.sleep(duration)
 735                });
 736            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 737
 738            tracing::info!("connection opened");
 739
 740            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 741            let http_client = match ReqwestClient::user_agent(&user_agent) {
 742                Ok(http_client) => Arc::new(http_client),
 743                Err(error) => {
 744                    tracing::error!(?error, "failed to create HTTP client");
 745                    return;
 746                }
 747            };
 748
 749            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 750                    supermaven_admin_api_key.to_string(),
 751                    http_client.clone(),
 752                )));
 753
 754            let session = Session {
 755                principal: principal.clone(),
 756                connection_id,
 757                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 758                peer: this.peer.clone(),
 759                connection_pool: this.connection_pool.clone(),
 760                app_state: this.app_state.clone(),
 761                http_client,
 762                geoip_country_code,
 763                system_id,
 764                _executor: executor.clone(),
 765                supermaven_client,
 766            };
 767
 768            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 769                tracing::error!(?error, "failed to send initial client update");
 770                return;
 771            }
 772
 773            let handle_io = handle_io.fuse();
 774            futures::pin_mut!(handle_io);
 775
 776            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 777            // This prevents deadlocks when e.g., client A performs a request to client B and
 778            // client B performs a request to client A. If both clients stop processing further
 779            // messages until their respective request completes, they won't have a chance to
 780            // respond to the other client's request and cause a deadlock.
 781            //
 782            // This arrangement ensures we will attempt to process earlier messages first, but fall
 783            // back to processing messages arrived later in the spirit of making progress.
 784            let mut foreground_message_handlers = FuturesUnordered::new();
 785            let concurrent_handlers = Arc::new(Semaphore::new(256));
 786            loop {
 787                let next_message = async {
 788                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 789                    let message = incoming_rx.next().await;
 790                    (permit, message)
 791                }.fuse();
 792                futures::pin_mut!(next_message);
 793                futures::select_biased! {
 794                    _ = teardown.changed().fuse() => return,
 795                    result = handle_io => {
 796                        if let Err(error) = result {
 797                            tracing::error!(?error, "error handling I/O");
 798                        }
 799                        break;
 800                    }
 801                    _ = foreground_message_handlers.next() => {}
 802                    next_message = next_message => {
 803                        let (permit, message) = next_message;
 804                        if let Some(message) = message {
 805                            let type_name = message.payload_type_name();
 806                            // note: we copy all the fields from the parent span so we can query them in the logs.
 807                            // (https://github.com/tokio-rs/tracing/issues/2670).
 808                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 809                                user_id=field::Empty,
 810                                login=field::Empty,
 811                                impersonator=field::Empty,
 812                            );
 813                            principal.update_span(&span);
 814                            let span_enter = span.enter();
 815                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 816                                let is_background = message.is_background();
 817                                let handle_message = (handler)(message, session.clone());
 818                                drop(span_enter);
 819
 820                                let handle_message = async move {
 821                                    handle_message.await;
 822                                    drop(permit);
 823                                }.instrument(span);
 824                                if is_background {
 825                                    executor.spawn_detached(handle_message);
 826                                } else {
 827                                    foreground_message_handlers.push(handle_message);
 828                                }
 829                            } else {
 830                                tracing::error!("no message handler");
 831                            }
 832                        } else {
 833                            tracing::info!("connection closed");
 834                            break;
 835                        }
 836                    }
 837                }
 838            }
 839
 840            drop(foreground_message_handlers);
 841            tracing::info!("signing out");
 842            if let Err(error) = connection_lost(session, teardown, executor).await {
 843                tracing::error!(?error, "error signing out");
 844            }
 845
 846        }.instrument(span)
 847    }
 848
 849    async fn send_initial_client_update(
 850        &self,
 851        connection_id: ConnectionId,
 852        principal: &Principal,
 853        zed_version: ZedVersion,
 854        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 855        session: &Session,
 856    ) -> Result<()> {
 857        self.peer.send(
 858            connection_id,
 859            proto::Hello {
 860                peer_id: Some(connection_id.into()),
 861            },
 862        )?;
 863        tracing::info!("sent hello message");
 864        if let Some(send_connection_id) = send_connection_id.take() {
 865            let _ = send_connection_id.send(connection_id);
 866        }
 867
 868        match principal {
 869            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 870                if !user.connected_once {
 871                    self.peer.send(connection_id, proto::ShowContacts {})?;
 872                    self.app_state
 873                        .db
 874                        .set_user_connected_once(user.id, true)
 875                        .await?;
 876                }
 877
 878                update_user_plan(user.id, session).await?;
 879
 880                let contacts = self.app_state.db.get_contacts(user.id).await?;
 881
 882                {
 883                    let mut pool = self.connection_pool.lock();
 884                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 885                    self.peer.send(
 886                        connection_id,
 887                        build_initial_contacts_update(contacts, &pool),
 888                    )?;
 889                }
 890
 891                if should_auto_subscribe_to_channels(zed_version) {
 892                    subscribe_user_to_channels(user.id, session).await?;
 893                }
 894
 895                if let Some(incoming_call) =
 896                    self.app_state.db.incoming_call_for_user(user.id).await?
 897                {
 898                    self.peer.send(connection_id, incoming_call)?;
 899                }
 900
 901                update_user_contacts(user.id, session).await?;
 902            }
 903        }
 904
 905        Ok(())
 906    }
 907
 908    pub async fn invite_code_redeemed(
 909        self: &Arc<Self>,
 910        inviter_id: UserId,
 911        invitee_id: UserId,
 912    ) -> Result<()> {
 913        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 914            if let Some(code) = &user.invite_code {
 915                let pool = self.connection_pool.lock();
 916                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 917                for connection_id in pool.user_connection_ids(inviter_id) {
 918                    self.peer.send(
 919                        connection_id,
 920                        proto::UpdateContacts {
 921                            contacts: vec![invitee_contact.clone()],
 922                            ..Default::default()
 923                        },
 924                    )?;
 925                    self.peer.send(
 926                        connection_id,
 927                        proto::UpdateInviteInfo {
 928                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 929                            count: user.invite_count as u32,
 930                        },
 931                    )?;
 932                }
 933            }
 934        }
 935        Ok(())
 936    }
 937
 938    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 939        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 940            if let Some(invite_code) = &user.invite_code {
 941                let pool = self.connection_pool.lock();
 942                for connection_id in pool.user_connection_ids(user_id) {
 943                    self.peer.send(
 944                        connection_id,
 945                        proto::UpdateInviteInfo {
 946                            url: format!(
 947                                "{}{}",
 948                                self.app_state.config.invite_link_prefix, invite_code
 949                            ),
 950                            count: user.invite_count as u32,
 951                        },
 952                    )?;
 953                }
 954            }
 955        }
 956        Ok(())
 957    }
 958
 959    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 960        let pool = self.connection_pool.lock();
 961        for connection_id in pool.user_connection_ids(user_id) {
 962            self.peer
 963                .send(connection_id, proto::RefreshLlmToken {})
 964                .trace_err();
 965        }
 966    }
 967
 968    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 969        ServerSnapshot {
 970            connection_pool: ConnectionPoolGuard {
 971                guard: self.connection_pool.lock(),
 972                _not_send: PhantomData,
 973            },
 974            peer: &self.peer,
 975        }
 976    }
 977}
 978
 979impl Deref for ConnectionPoolGuard<'_> {
 980    type Target = ConnectionPool;
 981
 982    fn deref(&self) -> &Self::Target {
 983        &self.guard
 984    }
 985}
 986
 987impl DerefMut for ConnectionPoolGuard<'_> {
 988    fn deref_mut(&mut self) -> &mut Self::Target {
 989        &mut self.guard
 990    }
 991}
 992
 993impl Drop for ConnectionPoolGuard<'_> {
 994    fn drop(&mut self) {
 995        #[cfg(test)]
 996        self.check_invariants();
 997    }
 998}
 999
1000fn broadcast<F>(
1001    sender_id: Option<ConnectionId>,
1002    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1003    mut f: F,
1004) where
1005    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1006{
1007    for receiver_id in receiver_ids {
1008        if Some(receiver_id) != sender_id {
1009            if let Err(error) = f(receiver_id) {
1010                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1011            }
1012        }
1013    }
1014}
1015
1016pub struct ProtocolVersion(u32);
1017
1018impl Header for ProtocolVersion {
1019    fn name() -> &'static HeaderName {
1020        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1021        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1022    }
1023
1024    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1025    where
1026        Self: Sized,
1027        I: Iterator<Item = &'i axum::http::HeaderValue>,
1028    {
1029        let version = values
1030            .next()
1031            .ok_or_else(axum::headers::Error::invalid)?
1032            .to_str()
1033            .map_err(|_| axum::headers::Error::invalid())?
1034            .parse()
1035            .map_err(|_| axum::headers::Error::invalid())?;
1036        Ok(Self(version))
1037    }
1038
1039    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1040        values.extend([self.0.to_string().parse().unwrap()]);
1041    }
1042}
1043
1044pub struct AppVersionHeader(SemanticVersion);
1045impl Header for AppVersionHeader {
1046    fn name() -> &'static HeaderName {
1047        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1048        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1049    }
1050
1051    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1052    where
1053        Self: Sized,
1054        I: Iterator<Item = &'i axum::http::HeaderValue>,
1055    {
1056        let version = values
1057            .next()
1058            .ok_or_else(axum::headers::Error::invalid)?
1059            .to_str()
1060            .map_err(|_| axum::headers::Error::invalid())?
1061            .parse()
1062            .map_err(|_| axum::headers::Error::invalid())?;
1063        Ok(Self(version))
1064    }
1065
1066    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1067        values.extend([self.0.to_string().parse().unwrap()]);
1068    }
1069}
1070
1071pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1072    Router::new()
1073        .route("/rpc", get(handle_websocket_request))
1074        .layer(
1075            ServiceBuilder::new()
1076                .layer(Extension(server.app_state.clone()))
1077                .layer(middleware::from_fn(auth::validate_header)),
1078        )
1079        .route("/metrics", get(handle_metrics))
1080        .layer(Extension(server))
1081}
1082
1083pub async fn handle_websocket_request(
1084    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1085    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1086    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1087    Extension(server): Extension<Arc<Server>>,
1088    Extension(principal): Extension<Principal>,
1089    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1090    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1091    ws: WebSocketUpgrade,
1092) -> axum::response::Response {
1093    if protocol_version != rpc::PROTOCOL_VERSION {
1094        return (
1095            StatusCode::UPGRADE_REQUIRED,
1096            "client must be upgraded".to_string(),
1097        )
1098            .into_response();
1099    }
1100
1101    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1102        return (
1103            StatusCode::UPGRADE_REQUIRED,
1104            "no version header found".to_string(),
1105        )
1106            .into_response();
1107    };
1108
1109    if !version.can_collaborate() {
1110        return (
1111            StatusCode::UPGRADE_REQUIRED,
1112            "client must be upgraded".to_string(),
1113        )
1114            .into_response();
1115    }
1116
1117    let socket_address = socket_address.to_string();
1118    ws.on_upgrade(move |socket| {
1119        let socket = socket
1120            .map_ok(to_tungstenite_message)
1121            .err_into()
1122            .with(|message| async move { to_axum_message(message) });
1123        let connection = Connection::new(Box::pin(socket));
1124        async move {
1125            server
1126                .handle_connection(
1127                    connection,
1128                    socket_address,
1129                    principal,
1130                    version,
1131                    country_code_header.map(|header| header.to_string()),
1132                    system_id_header.map(|header| header.to_string()),
1133                    None,
1134                    Executor::Production,
1135                )
1136                .await;
1137        }
1138    })
1139}
1140
1141pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1142    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1143    let connections_metric = CONNECTIONS_METRIC
1144        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1145
1146    let connections = server
1147        .connection_pool
1148        .lock()
1149        .connections()
1150        .filter(|connection| !connection.admin)
1151        .count();
1152    connections_metric.set(connections as _);
1153
1154    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1155    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1156        register_int_gauge!(
1157            "shared_projects",
1158            "number of open projects with one or more guests"
1159        )
1160        .unwrap()
1161    });
1162
1163    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1164    shared_projects_metric.set(shared_projects as _);
1165
1166    let encoder = prometheus::TextEncoder::new();
1167    let metric_families = prometheus::gather();
1168    let encoded_metrics = encoder
1169        .encode_to_string(&metric_families)
1170        .map_err(|err| anyhow!("{}", err))?;
1171    Ok(encoded_metrics)
1172}
1173
1174#[instrument(err, skip(executor))]
1175async fn connection_lost(
1176    session: Session,
1177    mut teardown: watch::Receiver<bool>,
1178    executor: Executor,
1179) -> Result<()> {
1180    session.peer.disconnect(session.connection_id);
1181    session
1182        .connection_pool()
1183        .await
1184        .remove_connection(session.connection_id)?;
1185
1186    session
1187        .db()
1188        .await
1189        .connection_lost(session.connection_id)
1190        .await
1191        .trace_err();
1192
1193    futures::select_biased! {
1194        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1195
1196            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1197            leave_room_for_session(&session, session.connection_id).await.trace_err();
1198            leave_channel_buffers_for_session(&session)
1199                .await
1200                .trace_err();
1201
1202            if !session
1203                .connection_pool()
1204                .await
1205                .is_user_online(session.user_id())
1206            {
1207                let db = session.db().await;
1208                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1209                    room_updated(&room, &session.peer);
1210                }
1211            }
1212
1213            update_user_contacts(session.user_id(), &session).await?;
1214        },
1215        _ = teardown.changed().fuse() => {}
1216    }
1217
1218    Ok(())
1219}
1220
1221/// Acknowledges a ping from a client, used to keep the connection alive.
1222async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1223    response.send(proto::Ack {})?;
1224    Ok(())
1225}
1226
1227/// Creates a new room for calling (outside of channels)
1228async fn create_room(
1229    _request: proto::CreateRoom,
1230    response: Response<proto::CreateRoom>,
1231    session: Session,
1232) -> Result<()> {
1233    let livekit_room = nanoid::nanoid!(30);
1234
1235    let live_kit_connection_info = util::maybe!(async {
1236        let live_kit = session.app_state.livekit_client.as_ref();
1237        let live_kit = live_kit?;
1238        let user_id = session.user_id().to_string();
1239
1240        let token = live_kit
1241            .room_token(&livekit_room, &user_id.to_string())
1242            .trace_err()?;
1243
1244        Some(proto::LiveKitConnectionInfo {
1245            server_url: live_kit.url().into(),
1246            token,
1247            can_publish: true,
1248        })
1249    })
1250    .await;
1251
1252    let room = session
1253        .db()
1254        .await
1255        .create_room(session.user_id(), session.connection_id, &livekit_room)
1256        .await?;
1257
1258    response.send(proto::CreateRoomResponse {
1259        room: Some(room.clone()),
1260        live_kit_connection_info,
1261    })?;
1262
1263    update_user_contacts(session.user_id(), &session).await?;
1264    Ok(())
1265}
1266
1267/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1268async fn join_room(
1269    request: proto::JoinRoom,
1270    response: Response<proto::JoinRoom>,
1271    session: Session,
1272) -> Result<()> {
1273    let room_id = RoomId::from_proto(request.id);
1274
1275    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1276
1277    if let Some(channel_id) = channel_id {
1278        return join_channel_internal(channel_id, Box::new(response), session).await;
1279    }
1280
1281    let joined_room = {
1282        let room = session
1283            .db()
1284            .await
1285            .join_room(room_id, session.user_id(), session.connection_id)
1286            .await?;
1287        room_updated(&room.room, &session.peer);
1288        room.into_inner()
1289    };
1290
1291    for connection_id in session
1292        .connection_pool()
1293        .await
1294        .user_connection_ids(session.user_id())
1295    {
1296        session
1297            .peer
1298            .send(
1299                connection_id,
1300                proto::CallCanceled {
1301                    room_id: room_id.to_proto(),
1302                },
1303            )
1304            .trace_err();
1305    }
1306
1307    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1308    {
1309        live_kit
1310            .room_token(
1311                &joined_room.room.livekit_room,
1312                &session.user_id().to_string(),
1313            )
1314            .trace_err()
1315            .map(|token| proto::LiveKitConnectionInfo {
1316                server_url: live_kit.url().into(),
1317                token,
1318                can_publish: true,
1319            })
1320    } else {
1321        None
1322    };
1323
1324    response.send(proto::JoinRoomResponse {
1325        room: Some(joined_room.room),
1326        channel_id: None,
1327        live_kit_connection_info,
1328    })?;
1329
1330    update_user_contacts(session.user_id(), &session).await?;
1331    Ok(())
1332}
1333
1334/// Rejoin room is used to reconnect to a room after connection errors.
1335async fn rejoin_room(
1336    request: proto::RejoinRoom,
1337    response: Response<proto::RejoinRoom>,
1338    session: Session,
1339) -> Result<()> {
1340    let room;
1341    let channel;
1342    {
1343        let mut rejoined_room = session
1344            .db()
1345            .await
1346            .rejoin_room(request, session.user_id(), session.connection_id)
1347            .await?;
1348
1349        response.send(proto::RejoinRoomResponse {
1350            room: Some(rejoined_room.room.clone()),
1351            reshared_projects: rejoined_room
1352                .reshared_projects
1353                .iter()
1354                .map(|project| proto::ResharedProject {
1355                    id: project.id.to_proto(),
1356                    collaborators: project
1357                        .collaborators
1358                        .iter()
1359                        .map(|collaborator| collaborator.to_proto())
1360                        .collect(),
1361                })
1362                .collect(),
1363            rejoined_projects: rejoined_room
1364                .rejoined_projects
1365                .iter()
1366                .map(|rejoined_project| rejoined_project.to_proto())
1367                .collect(),
1368        })?;
1369        room_updated(&rejoined_room.room, &session.peer);
1370
1371        for project in &rejoined_room.reshared_projects {
1372            for collaborator in &project.collaborators {
1373                session
1374                    .peer
1375                    .send(
1376                        collaborator.connection_id,
1377                        proto::UpdateProjectCollaborator {
1378                            project_id: project.id.to_proto(),
1379                            old_peer_id: Some(project.old_connection_id.into()),
1380                            new_peer_id: Some(session.connection_id.into()),
1381                        },
1382                    )
1383                    .trace_err();
1384            }
1385
1386            broadcast(
1387                Some(session.connection_id),
1388                project
1389                    .collaborators
1390                    .iter()
1391                    .map(|collaborator| collaborator.connection_id),
1392                |connection_id| {
1393                    session.peer.forward_send(
1394                        session.connection_id,
1395                        connection_id,
1396                        proto::UpdateProject {
1397                            project_id: project.id.to_proto(),
1398                            worktrees: project.worktrees.clone(),
1399                        },
1400                    )
1401                },
1402            );
1403        }
1404
1405        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1406
1407        let rejoined_room = rejoined_room.into_inner();
1408
1409        room = rejoined_room.room;
1410        channel = rejoined_room.channel;
1411    }
1412
1413    if let Some(channel) = channel {
1414        channel_updated(
1415            &channel,
1416            &room,
1417            &session.peer,
1418            &*session.connection_pool().await,
1419        );
1420    }
1421
1422    update_user_contacts(session.user_id(), &session).await?;
1423    Ok(())
1424}
1425
1426fn notify_rejoined_projects(
1427    rejoined_projects: &mut Vec<RejoinedProject>,
1428    session: &Session,
1429) -> Result<()> {
1430    for project in rejoined_projects.iter() {
1431        for collaborator in &project.collaborators {
1432            session
1433                .peer
1434                .send(
1435                    collaborator.connection_id,
1436                    proto::UpdateProjectCollaborator {
1437                        project_id: project.id.to_proto(),
1438                        old_peer_id: Some(project.old_connection_id.into()),
1439                        new_peer_id: Some(session.connection_id.into()),
1440                    },
1441                )
1442                .trace_err();
1443        }
1444    }
1445
1446    for project in rejoined_projects {
1447        for worktree in mem::take(&mut project.worktrees) {
1448            // Stream this worktree's entries.
1449            let message = proto::UpdateWorktree {
1450                project_id: project.id.to_proto(),
1451                worktree_id: worktree.id,
1452                abs_path: worktree.abs_path.clone(),
1453                root_name: worktree.root_name,
1454                updated_entries: worktree.updated_entries,
1455                removed_entries: worktree.removed_entries,
1456                scan_id: worktree.scan_id,
1457                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1458                updated_repositories: worktree.updated_repositories,
1459                removed_repositories: worktree.removed_repositories,
1460            };
1461            for update in proto::split_worktree_update(message) {
1462                session.peer.send(session.connection_id, update.clone())?;
1463            }
1464
1465            // Stream this worktree's diagnostics.
1466            for summary in worktree.diagnostic_summaries {
1467                session.peer.send(
1468                    session.connection_id,
1469                    proto::UpdateDiagnosticSummary {
1470                        project_id: project.id.to_proto(),
1471                        worktree_id: worktree.id,
1472                        summary: Some(summary),
1473                    },
1474                )?;
1475            }
1476
1477            for settings_file in worktree.settings_files {
1478                session.peer.send(
1479                    session.connection_id,
1480                    proto::UpdateWorktreeSettings {
1481                        project_id: project.id.to_proto(),
1482                        worktree_id: worktree.id,
1483                        path: settings_file.path,
1484                        content: Some(settings_file.content),
1485                        kind: Some(settings_file.kind.to_proto().into()),
1486                    },
1487                )?;
1488            }
1489        }
1490
1491        for language_server in &project.language_servers {
1492            session.peer.send(
1493                session.connection_id,
1494                proto::UpdateLanguageServer {
1495                    project_id: project.id.to_proto(),
1496                    language_server_id: language_server.id,
1497                    variant: Some(
1498                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1499                            proto::LspDiskBasedDiagnosticsUpdated {},
1500                        ),
1501                    ),
1502                },
1503            )?;
1504        }
1505    }
1506    Ok(())
1507}
1508
1509/// leave room disconnects from the room.
1510async fn leave_room(
1511    _: proto::LeaveRoom,
1512    response: Response<proto::LeaveRoom>,
1513    session: Session,
1514) -> Result<()> {
1515    leave_room_for_session(&session, session.connection_id).await?;
1516    response.send(proto::Ack {})?;
1517    Ok(())
1518}
1519
1520/// Updates the permissions of someone else in the room.
1521async fn set_room_participant_role(
1522    request: proto::SetRoomParticipantRole,
1523    response: Response<proto::SetRoomParticipantRole>,
1524    session: Session,
1525) -> Result<()> {
1526    let user_id = UserId::from_proto(request.user_id);
1527    let role = ChannelRole::from(request.role());
1528
1529    let (livekit_room, can_publish) = {
1530        let room = session
1531            .db()
1532            .await
1533            .set_room_participant_role(
1534                session.user_id(),
1535                RoomId::from_proto(request.room_id),
1536                user_id,
1537                role,
1538            )
1539            .await?;
1540
1541        let livekit_room = room.livekit_room.clone();
1542        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1543        room_updated(&room, &session.peer);
1544        (livekit_room, can_publish)
1545    };
1546
1547    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1548        live_kit
1549            .update_participant(
1550                livekit_room.clone(),
1551                request.user_id.to_string(),
1552                livekit_api::proto::ParticipantPermission {
1553                    can_subscribe: true,
1554                    can_publish,
1555                    can_publish_data: can_publish,
1556                    hidden: false,
1557                    recorder: false,
1558                },
1559            )
1560            .await
1561            .trace_err();
1562    }
1563
1564    response.send(proto::Ack {})?;
1565    Ok(())
1566}
1567
1568/// Call someone else into the current room
1569async fn call(
1570    request: proto::Call,
1571    response: Response<proto::Call>,
1572    session: Session,
1573) -> Result<()> {
1574    let room_id = RoomId::from_proto(request.room_id);
1575    let calling_user_id = session.user_id();
1576    let calling_connection_id = session.connection_id;
1577    let called_user_id = UserId::from_proto(request.called_user_id);
1578    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1579    if !session
1580        .db()
1581        .await
1582        .has_contact(calling_user_id, called_user_id)
1583        .await?
1584    {
1585        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1586    }
1587
1588    let incoming_call = {
1589        let (room, incoming_call) = &mut *session
1590            .db()
1591            .await
1592            .call(
1593                room_id,
1594                calling_user_id,
1595                calling_connection_id,
1596                called_user_id,
1597                initial_project_id,
1598            )
1599            .await?;
1600        room_updated(room, &session.peer);
1601        mem::take(incoming_call)
1602    };
1603    update_user_contacts(called_user_id, &session).await?;
1604
1605    let mut calls = session
1606        .connection_pool()
1607        .await
1608        .user_connection_ids(called_user_id)
1609        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1610        .collect::<FuturesUnordered<_>>();
1611
1612    while let Some(call_response) = calls.next().await {
1613        match call_response.as_ref() {
1614            Ok(_) => {
1615                response.send(proto::Ack {})?;
1616                return Ok(());
1617            }
1618            Err(_) => {
1619                call_response.trace_err();
1620            }
1621        }
1622    }
1623
1624    {
1625        let room = session
1626            .db()
1627            .await
1628            .call_failed(room_id, called_user_id)
1629            .await?;
1630        room_updated(&room, &session.peer);
1631    }
1632    update_user_contacts(called_user_id, &session).await?;
1633
1634    Err(anyhow!("failed to ring user"))?
1635}
1636
1637/// Cancel an outgoing call.
1638async fn cancel_call(
1639    request: proto::CancelCall,
1640    response: Response<proto::CancelCall>,
1641    session: Session,
1642) -> Result<()> {
1643    let called_user_id = UserId::from_proto(request.called_user_id);
1644    let room_id = RoomId::from_proto(request.room_id);
1645    {
1646        let room = session
1647            .db()
1648            .await
1649            .cancel_call(room_id, session.connection_id, called_user_id)
1650            .await?;
1651        room_updated(&room, &session.peer);
1652    }
1653
1654    for connection_id in session
1655        .connection_pool()
1656        .await
1657        .user_connection_ids(called_user_id)
1658    {
1659        session
1660            .peer
1661            .send(
1662                connection_id,
1663                proto::CallCanceled {
1664                    room_id: room_id.to_proto(),
1665                },
1666            )
1667            .trace_err();
1668    }
1669    response.send(proto::Ack {})?;
1670
1671    update_user_contacts(called_user_id, &session).await?;
1672    Ok(())
1673}
1674
1675/// Decline an incoming call.
1676async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1677    let room_id = RoomId::from_proto(message.room_id);
1678    {
1679        let room = session
1680            .db()
1681            .await
1682            .decline_call(Some(room_id), session.user_id())
1683            .await?
1684            .ok_or_else(|| anyhow!("failed to decline call"))?;
1685        room_updated(&room, &session.peer);
1686    }
1687
1688    for connection_id in session
1689        .connection_pool()
1690        .await
1691        .user_connection_ids(session.user_id())
1692    {
1693        session
1694            .peer
1695            .send(
1696                connection_id,
1697                proto::CallCanceled {
1698                    room_id: room_id.to_proto(),
1699                },
1700            )
1701            .trace_err();
1702    }
1703    update_user_contacts(session.user_id(), &session).await?;
1704    Ok(())
1705}
1706
1707/// Updates other participants in the room with your current location.
1708async fn update_participant_location(
1709    request: proto::UpdateParticipantLocation,
1710    response: Response<proto::UpdateParticipantLocation>,
1711    session: Session,
1712) -> Result<()> {
1713    let room_id = RoomId::from_proto(request.room_id);
1714    let location = request
1715        .location
1716        .ok_or_else(|| anyhow!("invalid location"))?;
1717
1718    let db = session.db().await;
1719    let room = db
1720        .update_room_participant_location(room_id, session.connection_id, location)
1721        .await?;
1722
1723    room_updated(&room, &session.peer);
1724    response.send(proto::Ack {})?;
1725    Ok(())
1726}
1727
1728/// Share a project into the room.
1729async fn share_project(
1730    request: proto::ShareProject,
1731    response: Response<proto::ShareProject>,
1732    session: Session,
1733) -> Result<()> {
1734    let (project_id, room) = &*session
1735        .db()
1736        .await
1737        .share_project(
1738            RoomId::from_proto(request.room_id),
1739            session.connection_id,
1740            &request.worktrees,
1741            request.is_ssh_project,
1742        )
1743        .await?;
1744    response.send(proto::ShareProjectResponse {
1745        project_id: project_id.to_proto(),
1746    })?;
1747    room_updated(room, &session.peer);
1748
1749    Ok(())
1750}
1751
1752/// Unshare a project from the room.
1753async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1754    let project_id = ProjectId::from_proto(message.project_id);
1755    unshare_project_internal(project_id, session.connection_id, &session).await
1756}
1757
1758async fn unshare_project_internal(
1759    project_id: ProjectId,
1760    connection_id: ConnectionId,
1761    session: &Session,
1762) -> Result<()> {
1763    let delete = {
1764        let room_guard = session
1765            .db()
1766            .await
1767            .unshare_project(project_id, connection_id)
1768            .await?;
1769
1770        let (delete, room, guest_connection_ids) = &*room_guard;
1771
1772        let message = proto::UnshareProject {
1773            project_id: project_id.to_proto(),
1774        };
1775
1776        broadcast(
1777            Some(connection_id),
1778            guest_connection_ids.iter().copied(),
1779            |conn_id| session.peer.send(conn_id, message.clone()),
1780        );
1781        if let Some(room) = room {
1782            room_updated(room, &session.peer);
1783        }
1784
1785        *delete
1786    };
1787
1788    if delete {
1789        let db = session.db().await;
1790        db.delete_project(project_id).await?;
1791    }
1792
1793    Ok(())
1794}
1795
1796/// Join someone elses shared project.
1797async fn join_project(
1798    request: proto::JoinProject,
1799    response: Response<proto::JoinProject>,
1800    session: Session,
1801) -> Result<()> {
1802    let project_id = ProjectId::from_proto(request.project_id);
1803
1804    tracing::info!(%project_id, "join project");
1805
1806    let db = session.db().await;
1807    let (project, replica_id) = &mut *db
1808        .join_project(project_id, session.connection_id, session.user_id())
1809        .await?;
1810    drop(db);
1811    tracing::info!(%project_id, "join remote project");
1812    join_project_internal(response, session, project, replica_id)
1813}
1814
1815trait JoinProjectInternalResponse {
1816    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1817}
1818impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1819    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1820        Response::<proto::JoinProject>::send(self, result)
1821    }
1822}
1823
1824fn join_project_internal(
1825    response: impl JoinProjectInternalResponse,
1826    session: Session,
1827    project: &mut Project,
1828    replica_id: &ReplicaId,
1829) -> Result<()> {
1830    let collaborators = project
1831        .collaborators
1832        .iter()
1833        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1834        .map(|collaborator| collaborator.to_proto())
1835        .collect::<Vec<_>>();
1836    let project_id = project.id;
1837    let guest_user_id = session.user_id();
1838
1839    let worktrees = project
1840        .worktrees
1841        .iter()
1842        .map(|(id, worktree)| proto::WorktreeMetadata {
1843            id: *id,
1844            root_name: worktree.root_name.clone(),
1845            visible: worktree.visible,
1846            abs_path: worktree.abs_path.clone(),
1847        })
1848        .collect::<Vec<_>>();
1849
1850    let add_project_collaborator = proto::AddProjectCollaborator {
1851        project_id: project_id.to_proto(),
1852        collaborator: Some(proto::Collaborator {
1853            peer_id: Some(session.connection_id.into()),
1854            replica_id: replica_id.0 as u32,
1855            user_id: guest_user_id.to_proto(),
1856            is_host: false,
1857        }),
1858    };
1859
1860    for collaborator in &collaborators {
1861        session
1862            .peer
1863            .send(
1864                collaborator.peer_id.unwrap().into(),
1865                add_project_collaborator.clone(),
1866            )
1867            .trace_err();
1868    }
1869
1870    // First, we send the metadata associated with each worktree.
1871    response.send(proto::JoinProjectResponse {
1872        project_id: project.id.0 as u64,
1873        worktrees: worktrees.clone(),
1874        replica_id: replica_id.0 as u32,
1875        collaborators: collaborators.clone(),
1876        language_servers: project.language_servers.clone(),
1877        role: project.role.into(),
1878    })?;
1879
1880    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1881        // Stream this worktree's entries.
1882        let message = proto::UpdateWorktree {
1883            project_id: project_id.to_proto(),
1884            worktree_id,
1885            abs_path: worktree.abs_path.clone(),
1886            root_name: worktree.root_name,
1887            updated_entries: worktree.entries,
1888            removed_entries: Default::default(),
1889            scan_id: worktree.scan_id,
1890            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1891            updated_repositories: worktree.repository_entries.into_values().collect(),
1892            removed_repositories: Default::default(),
1893        };
1894        for update in proto::split_worktree_update(message) {
1895            session.peer.send(session.connection_id, update.clone())?;
1896        }
1897
1898        // Stream this worktree's diagnostics.
1899        for summary in worktree.diagnostic_summaries {
1900            session.peer.send(
1901                session.connection_id,
1902                proto::UpdateDiagnosticSummary {
1903                    project_id: project_id.to_proto(),
1904                    worktree_id: worktree.id,
1905                    summary: Some(summary),
1906                },
1907            )?;
1908        }
1909
1910        for settings_file in worktree.settings_files {
1911            session.peer.send(
1912                session.connection_id,
1913                proto::UpdateWorktreeSettings {
1914                    project_id: project_id.to_proto(),
1915                    worktree_id: worktree.id,
1916                    path: settings_file.path,
1917                    content: Some(settings_file.content),
1918                    kind: Some(settings_file.kind.to_proto() as i32),
1919                },
1920            )?;
1921        }
1922    }
1923
1924    for language_server in &project.language_servers {
1925        session.peer.send(
1926            session.connection_id,
1927            proto::UpdateLanguageServer {
1928                project_id: project_id.to_proto(),
1929                language_server_id: language_server.id,
1930                variant: Some(
1931                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1932                        proto::LspDiskBasedDiagnosticsUpdated {},
1933                    ),
1934                ),
1935            },
1936        )?;
1937    }
1938
1939    Ok(())
1940}
1941
1942/// Leave someone elses shared project.
1943async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1944    let sender_id = session.connection_id;
1945    let project_id = ProjectId::from_proto(request.project_id);
1946    let db = session.db().await;
1947
1948    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1949    tracing::info!(
1950        %project_id,
1951        "leave project"
1952    );
1953
1954    project_left(project, &session);
1955    if let Some(room) = room {
1956        room_updated(room, &session.peer);
1957    }
1958
1959    Ok(())
1960}
1961
1962/// Updates other participants with changes to the project
1963async fn update_project(
1964    request: proto::UpdateProject,
1965    response: Response<proto::UpdateProject>,
1966    session: Session,
1967) -> Result<()> {
1968    let project_id = ProjectId::from_proto(request.project_id);
1969    let (room, guest_connection_ids) = &*session
1970        .db()
1971        .await
1972        .update_project(project_id, session.connection_id, &request.worktrees)
1973        .await?;
1974    broadcast(
1975        Some(session.connection_id),
1976        guest_connection_ids.iter().copied(),
1977        |connection_id| {
1978            session
1979                .peer
1980                .forward_send(session.connection_id, connection_id, request.clone())
1981        },
1982    );
1983    if let Some(room) = room {
1984        room_updated(room, &session.peer);
1985    }
1986    response.send(proto::Ack {})?;
1987
1988    Ok(())
1989}
1990
1991/// Updates other participants with changes to the worktree
1992async fn update_worktree(
1993    request: proto::UpdateWorktree,
1994    response: Response<proto::UpdateWorktree>,
1995    session: Session,
1996) -> Result<()> {
1997    let guest_connection_ids = session
1998        .db()
1999        .await
2000        .update_worktree(&request, session.connection_id)
2001        .await?;
2002
2003    broadcast(
2004        Some(session.connection_id),
2005        guest_connection_ids.iter().copied(),
2006        |connection_id| {
2007            session
2008                .peer
2009                .forward_send(session.connection_id, connection_id, request.clone())
2010        },
2011    );
2012    response.send(proto::Ack {})?;
2013    Ok(())
2014}
2015
2016/// Updates other participants with changes to the diagnostics
2017async fn update_diagnostic_summary(
2018    message: proto::UpdateDiagnosticSummary,
2019    session: Session,
2020) -> Result<()> {
2021    let guest_connection_ids = session
2022        .db()
2023        .await
2024        .update_diagnostic_summary(&message, session.connection_id)
2025        .await?;
2026
2027    broadcast(
2028        Some(session.connection_id),
2029        guest_connection_ids.iter().copied(),
2030        |connection_id| {
2031            session
2032                .peer
2033                .forward_send(session.connection_id, connection_id, message.clone())
2034        },
2035    );
2036
2037    Ok(())
2038}
2039
2040/// Updates other participants with changes to the worktree settings
2041async fn update_worktree_settings(
2042    message: proto::UpdateWorktreeSettings,
2043    session: Session,
2044) -> Result<()> {
2045    let guest_connection_ids = session
2046        .db()
2047        .await
2048        .update_worktree_settings(&message, session.connection_id)
2049        .await?;
2050
2051    broadcast(
2052        Some(session.connection_id),
2053        guest_connection_ids.iter().copied(),
2054        |connection_id| {
2055            session
2056                .peer
2057                .forward_send(session.connection_id, connection_id, message.clone())
2058        },
2059    );
2060
2061    Ok(())
2062}
2063
2064/// Notify other participants that a  language server has started.
2065async fn start_language_server(
2066    request: proto::StartLanguageServer,
2067    session: Session,
2068) -> Result<()> {
2069    let guest_connection_ids = session
2070        .db()
2071        .await
2072        .start_language_server(&request, session.connection_id)
2073        .await?;
2074
2075    broadcast(
2076        Some(session.connection_id),
2077        guest_connection_ids.iter().copied(),
2078        |connection_id| {
2079            session
2080                .peer
2081                .forward_send(session.connection_id, connection_id, request.clone())
2082        },
2083    );
2084    Ok(())
2085}
2086
2087/// Notify other participants that a language server has changed.
2088async fn update_language_server(
2089    request: proto::UpdateLanguageServer,
2090    session: Session,
2091) -> Result<()> {
2092    let project_id = ProjectId::from_proto(request.project_id);
2093    let project_connection_ids = session
2094        .db()
2095        .await
2096        .project_connection_ids(project_id, session.connection_id, true)
2097        .await?;
2098    broadcast(
2099        Some(session.connection_id),
2100        project_connection_ids.iter().copied(),
2101        |connection_id| {
2102            session
2103                .peer
2104                .forward_send(session.connection_id, connection_id, request.clone())
2105        },
2106    );
2107    Ok(())
2108}
2109
2110/// forward a project request to the host. These requests should be read only
2111/// as guests are allowed to send them.
2112async fn forward_read_only_project_request<T>(
2113    request: T,
2114    response: Response<T>,
2115    session: Session,
2116) -> Result<()>
2117where
2118    T: EntityMessage + RequestMessage,
2119{
2120    let project_id = ProjectId::from_proto(request.remote_entity_id());
2121    let host_connection_id = session
2122        .db()
2123        .await
2124        .host_for_read_only_project_request(project_id, session.connection_id)
2125        .await?;
2126    let payload = session
2127        .peer
2128        .forward_request(session.connection_id, host_connection_id, request)
2129        .await?;
2130    response.send(payload)?;
2131    Ok(())
2132}
2133
2134async fn forward_find_search_candidates_request(
2135    request: proto::FindSearchCandidates,
2136    response: Response<proto::FindSearchCandidates>,
2137    session: Session,
2138) -> Result<()> {
2139    let project_id = ProjectId::from_proto(request.remote_entity_id());
2140    let host_connection_id = session
2141        .db()
2142        .await
2143        .host_for_read_only_project_request(project_id, session.connection_id)
2144        .await?;
2145    let payload = session
2146        .peer
2147        .forward_request(session.connection_id, host_connection_id, request)
2148        .await?;
2149    response.send(payload)?;
2150    Ok(())
2151}
2152
2153/// forward a project request to the host. These requests are disallowed
2154/// for guests.
2155async fn forward_mutating_project_request<T>(
2156    request: T,
2157    response: Response<T>,
2158    session: Session,
2159) -> Result<()>
2160where
2161    T: EntityMessage + RequestMessage,
2162{
2163    let project_id = ProjectId::from_proto(request.remote_entity_id());
2164
2165    let host_connection_id = session
2166        .db()
2167        .await
2168        .host_for_mutating_project_request(project_id, session.connection_id)
2169        .await?;
2170    let payload = session
2171        .peer
2172        .forward_request(session.connection_id, host_connection_id, request)
2173        .await?;
2174    response.send(payload)?;
2175    Ok(())
2176}
2177
2178/// Notify other participants that a new buffer has been created
2179async fn create_buffer_for_peer(
2180    request: proto::CreateBufferForPeer,
2181    session: Session,
2182) -> Result<()> {
2183    session
2184        .db()
2185        .await
2186        .check_user_is_project_host(
2187            ProjectId::from_proto(request.project_id),
2188            session.connection_id,
2189        )
2190        .await?;
2191    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2192    session
2193        .peer
2194        .forward_send(session.connection_id, peer_id.into(), request)?;
2195    Ok(())
2196}
2197
2198/// Notify other participants that a buffer has been updated. This is
2199/// allowed for guests as long as the update is limited to selections.
2200async fn update_buffer(
2201    request: proto::UpdateBuffer,
2202    response: Response<proto::UpdateBuffer>,
2203    session: Session,
2204) -> Result<()> {
2205    let project_id = ProjectId::from_proto(request.project_id);
2206    let mut capability = Capability::ReadOnly;
2207
2208    for op in request.operations.iter() {
2209        match op.variant {
2210            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2211            Some(_) => capability = Capability::ReadWrite,
2212        }
2213    }
2214
2215    let host = {
2216        let guard = session
2217            .db()
2218            .await
2219            .connections_for_buffer_update(project_id, session.connection_id, capability)
2220            .await?;
2221
2222        let (host, guests) = &*guard;
2223
2224        broadcast(
2225            Some(session.connection_id),
2226            guests.clone(),
2227            |connection_id| {
2228                session
2229                    .peer
2230                    .forward_send(session.connection_id, connection_id, request.clone())
2231            },
2232        );
2233
2234        *host
2235    };
2236
2237    if host != session.connection_id {
2238        session
2239            .peer
2240            .forward_request(session.connection_id, host, request.clone())
2241            .await?;
2242    }
2243
2244    response.send(proto::Ack {})?;
2245    Ok(())
2246}
2247
2248async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2249    let project_id = ProjectId::from_proto(message.project_id);
2250
2251    let operation = message.operation.as_ref().context("invalid operation")?;
2252    let capability = match operation.variant.as_ref() {
2253        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2254            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2255                match buffer_op.variant {
2256                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2257                        Capability::ReadOnly
2258                    }
2259                    _ => Capability::ReadWrite,
2260                }
2261            } else {
2262                Capability::ReadWrite
2263            }
2264        }
2265        Some(_) => Capability::ReadWrite,
2266        None => Capability::ReadOnly,
2267    };
2268
2269    let guard = session
2270        .db()
2271        .await
2272        .connections_for_buffer_update(project_id, session.connection_id, capability)
2273        .await?;
2274
2275    let (host, guests) = &*guard;
2276
2277    broadcast(
2278        Some(session.connection_id),
2279        guests.iter().chain([host]).copied(),
2280        |connection_id| {
2281            session
2282                .peer
2283                .forward_send(session.connection_id, connection_id, message.clone())
2284        },
2285    );
2286
2287    Ok(())
2288}
2289
2290/// Notify other participants that a project has been updated.
2291async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2292    request: T,
2293    session: Session,
2294) -> Result<()> {
2295    let project_id = ProjectId::from_proto(request.remote_entity_id());
2296    let project_connection_ids = session
2297        .db()
2298        .await
2299        .project_connection_ids(project_id, session.connection_id, false)
2300        .await?;
2301
2302    broadcast(
2303        Some(session.connection_id),
2304        project_connection_ids.iter().copied(),
2305        |connection_id| {
2306            session
2307                .peer
2308                .forward_send(session.connection_id, connection_id, request.clone())
2309        },
2310    );
2311    Ok(())
2312}
2313
2314/// Start following another user in a call.
2315async fn follow(
2316    request: proto::Follow,
2317    response: Response<proto::Follow>,
2318    session: Session,
2319) -> Result<()> {
2320    let room_id = RoomId::from_proto(request.room_id);
2321    let project_id = request.project_id.map(ProjectId::from_proto);
2322    let leader_id = request
2323        .leader_id
2324        .ok_or_else(|| anyhow!("invalid leader id"))?
2325        .into();
2326    let follower_id = session.connection_id;
2327
2328    session
2329        .db()
2330        .await
2331        .check_room_participants(room_id, leader_id, session.connection_id)
2332        .await?;
2333
2334    let response_payload = session
2335        .peer
2336        .forward_request(session.connection_id, leader_id, request)
2337        .await?;
2338    response.send(response_payload)?;
2339
2340    if let Some(project_id) = project_id {
2341        let room = session
2342            .db()
2343            .await
2344            .follow(room_id, project_id, leader_id, follower_id)
2345            .await?;
2346        room_updated(&room, &session.peer);
2347    }
2348
2349    Ok(())
2350}
2351
2352/// Stop following another user in a call.
2353async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2354    let room_id = RoomId::from_proto(request.room_id);
2355    let project_id = request.project_id.map(ProjectId::from_proto);
2356    let leader_id = request
2357        .leader_id
2358        .ok_or_else(|| anyhow!("invalid leader id"))?
2359        .into();
2360    let follower_id = session.connection_id;
2361
2362    session
2363        .db()
2364        .await
2365        .check_room_participants(room_id, leader_id, session.connection_id)
2366        .await?;
2367
2368    session
2369        .peer
2370        .forward_send(session.connection_id, leader_id, request)?;
2371
2372    if let Some(project_id) = project_id {
2373        let room = session
2374            .db()
2375            .await
2376            .unfollow(room_id, project_id, leader_id, follower_id)
2377            .await?;
2378        room_updated(&room, &session.peer);
2379    }
2380
2381    Ok(())
2382}
2383
2384/// Notify everyone following you of your current location.
2385async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2386    let room_id = RoomId::from_proto(request.room_id);
2387    let database = session.db.lock().await;
2388
2389    let connection_ids = if let Some(project_id) = request.project_id {
2390        let project_id = ProjectId::from_proto(project_id);
2391        database
2392            .project_connection_ids(project_id, session.connection_id, true)
2393            .await?
2394    } else {
2395        database
2396            .room_connection_ids(room_id, session.connection_id)
2397            .await?
2398    };
2399
2400    // For now, don't send view update messages back to that view's current leader.
2401    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2402        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2403        _ => None,
2404    });
2405
2406    for connection_id in connection_ids.iter().cloned() {
2407        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2408            session
2409                .peer
2410                .forward_send(session.connection_id, connection_id, request.clone())?;
2411        }
2412    }
2413    Ok(())
2414}
2415
2416/// Get public data about users.
2417async fn get_users(
2418    request: proto::GetUsers,
2419    response: Response<proto::GetUsers>,
2420    session: Session,
2421) -> Result<()> {
2422    let user_ids = request
2423        .user_ids
2424        .into_iter()
2425        .map(UserId::from_proto)
2426        .collect();
2427    let users = session
2428        .db()
2429        .await
2430        .get_users_by_ids(user_ids)
2431        .await?
2432        .into_iter()
2433        .map(|user| proto::User {
2434            id: user.id.to_proto(),
2435            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2436            github_login: user.github_login,
2437            email: user.email_address,
2438            name: user.name,
2439        })
2440        .collect();
2441    response.send(proto::UsersResponse { users })?;
2442    Ok(())
2443}
2444
2445/// Search for users (to invite) buy Github login
2446async fn fuzzy_search_users(
2447    request: proto::FuzzySearchUsers,
2448    response: Response<proto::FuzzySearchUsers>,
2449    session: Session,
2450) -> Result<()> {
2451    let query = request.query;
2452    let users = match query.len() {
2453        0 => vec![],
2454        1 | 2 => session
2455            .db()
2456            .await
2457            .get_user_by_github_login(&query)
2458            .await?
2459            .into_iter()
2460            .collect(),
2461        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2462    };
2463    let users = users
2464        .into_iter()
2465        .filter(|user| user.id != session.user_id())
2466        .map(|user| proto::User {
2467            id: user.id.to_proto(),
2468            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2469            github_login: user.github_login,
2470            name: user.name,
2471            email: user.email_address,
2472        })
2473        .collect();
2474    response.send(proto::UsersResponse { users })?;
2475    Ok(())
2476}
2477
2478/// Send a contact request to another user.
2479async fn request_contact(
2480    request: proto::RequestContact,
2481    response: Response<proto::RequestContact>,
2482    session: Session,
2483) -> Result<()> {
2484    let requester_id = session.user_id();
2485    let responder_id = UserId::from_proto(request.responder_id);
2486    if requester_id == responder_id {
2487        return Err(anyhow!("cannot add yourself as a contact"))?;
2488    }
2489
2490    let notifications = session
2491        .db()
2492        .await
2493        .send_contact_request(requester_id, responder_id)
2494        .await?;
2495
2496    // Update outgoing contact requests of requester
2497    let mut update = proto::UpdateContacts::default();
2498    update.outgoing_requests.push(responder_id.to_proto());
2499    for connection_id in session
2500        .connection_pool()
2501        .await
2502        .user_connection_ids(requester_id)
2503    {
2504        session.peer.send(connection_id, update.clone())?;
2505    }
2506
2507    // Update incoming contact requests of responder
2508    let mut update = proto::UpdateContacts::default();
2509    update
2510        .incoming_requests
2511        .push(proto::IncomingContactRequest {
2512            requester_id: requester_id.to_proto(),
2513        });
2514    let connection_pool = session.connection_pool().await;
2515    for connection_id in connection_pool.user_connection_ids(responder_id) {
2516        session.peer.send(connection_id, update.clone())?;
2517    }
2518
2519    send_notifications(&connection_pool, &session.peer, notifications);
2520
2521    response.send(proto::Ack {})?;
2522    Ok(())
2523}
2524
2525/// Accept or decline a contact request
2526async fn respond_to_contact_request(
2527    request: proto::RespondToContactRequest,
2528    response: Response<proto::RespondToContactRequest>,
2529    session: Session,
2530) -> Result<()> {
2531    let responder_id = session.user_id();
2532    let requester_id = UserId::from_proto(request.requester_id);
2533    let db = session.db().await;
2534    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2535        db.dismiss_contact_notification(responder_id, requester_id)
2536            .await?;
2537    } else {
2538        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2539
2540        let notifications = db
2541            .respond_to_contact_request(responder_id, requester_id, accept)
2542            .await?;
2543        let requester_busy = db.is_user_busy(requester_id).await?;
2544        let responder_busy = db.is_user_busy(responder_id).await?;
2545
2546        let pool = session.connection_pool().await;
2547        // Update responder with new contact
2548        let mut update = proto::UpdateContacts::default();
2549        if accept {
2550            update
2551                .contacts
2552                .push(contact_for_user(requester_id, requester_busy, &pool));
2553        }
2554        update
2555            .remove_incoming_requests
2556            .push(requester_id.to_proto());
2557        for connection_id in pool.user_connection_ids(responder_id) {
2558            session.peer.send(connection_id, update.clone())?;
2559        }
2560
2561        // Update requester with new contact
2562        let mut update = proto::UpdateContacts::default();
2563        if accept {
2564            update
2565                .contacts
2566                .push(contact_for_user(responder_id, responder_busy, &pool));
2567        }
2568        update
2569            .remove_outgoing_requests
2570            .push(responder_id.to_proto());
2571
2572        for connection_id in pool.user_connection_ids(requester_id) {
2573            session.peer.send(connection_id, update.clone())?;
2574        }
2575
2576        send_notifications(&pool, &session.peer, notifications);
2577    }
2578
2579    response.send(proto::Ack {})?;
2580    Ok(())
2581}
2582
2583/// Remove a contact.
2584async fn remove_contact(
2585    request: proto::RemoveContact,
2586    response: Response<proto::RemoveContact>,
2587    session: Session,
2588) -> Result<()> {
2589    let requester_id = session.user_id();
2590    let responder_id = UserId::from_proto(request.user_id);
2591    let db = session.db().await;
2592    let (contact_accepted, deleted_notification_id) =
2593        db.remove_contact(requester_id, responder_id).await?;
2594
2595    let pool = session.connection_pool().await;
2596    // Update outgoing contact requests of requester
2597    let mut update = proto::UpdateContacts::default();
2598    if contact_accepted {
2599        update.remove_contacts.push(responder_id.to_proto());
2600    } else {
2601        update
2602            .remove_outgoing_requests
2603            .push(responder_id.to_proto());
2604    }
2605    for connection_id in pool.user_connection_ids(requester_id) {
2606        session.peer.send(connection_id, update.clone())?;
2607    }
2608
2609    // Update incoming contact requests of responder
2610    let mut update = proto::UpdateContacts::default();
2611    if contact_accepted {
2612        update.remove_contacts.push(requester_id.to_proto());
2613    } else {
2614        update
2615            .remove_incoming_requests
2616            .push(requester_id.to_proto());
2617    }
2618    for connection_id in pool.user_connection_ids(responder_id) {
2619        session.peer.send(connection_id, update.clone())?;
2620        if let Some(notification_id) = deleted_notification_id {
2621            session.peer.send(
2622                connection_id,
2623                proto::DeleteNotification {
2624                    notification_id: notification_id.to_proto(),
2625                },
2626            )?;
2627        }
2628    }
2629
2630    response.send(proto::Ack {})?;
2631    Ok(())
2632}
2633
2634fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2635    version.0.minor() < 139
2636}
2637
2638async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2639    let plan = session.current_plan(&session.db().await).await?;
2640
2641    session
2642        .peer
2643        .send(
2644            session.connection_id,
2645            proto::UpdateUserPlan { plan: plan.into() },
2646        )
2647        .trace_err();
2648
2649    Ok(())
2650}
2651
2652async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2653    subscribe_user_to_channels(session.user_id(), &session).await?;
2654    Ok(())
2655}
2656
2657async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2658    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2659    let mut pool = session.connection_pool().await;
2660    for membership in &channels_for_user.channel_memberships {
2661        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2662    }
2663    session.peer.send(
2664        session.connection_id,
2665        build_update_user_channels(&channels_for_user),
2666    )?;
2667    session.peer.send(
2668        session.connection_id,
2669        build_channels_update(channels_for_user),
2670    )?;
2671    Ok(())
2672}
2673
2674/// Creates a new channel.
2675async fn create_channel(
2676    request: proto::CreateChannel,
2677    response: Response<proto::CreateChannel>,
2678    session: Session,
2679) -> Result<()> {
2680    let db = session.db().await;
2681
2682    let parent_id = request.parent_id.map(ChannelId::from_proto);
2683    let (channel, membership) = db
2684        .create_channel(&request.name, parent_id, session.user_id())
2685        .await?;
2686
2687    let root_id = channel.root_id();
2688    let channel = Channel::from_model(channel);
2689
2690    response.send(proto::CreateChannelResponse {
2691        channel: Some(channel.to_proto()),
2692        parent_id: request.parent_id,
2693    })?;
2694
2695    let mut connection_pool = session.connection_pool().await;
2696    if let Some(membership) = membership {
2697        connection_pool.subscribe_to_channel(
2698            membership.user_id,
2699            membership.channel_id,
2700            membership.role,
2701        );
2702        let update = proto::UpdateUserChannels {
2703            channel_memberships: vec![proto::ChannelMembership {
2704                channel_id: membership.channel_id.to_proto(),
2705                role: membership.role.into(),
2706            }],
2707            ..Default::default()
2708        };
2709        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2710            session.peer.send(connection_id, update.clone())?;
2711        }
2712    }
2713
2714    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2715        if !role.can_see_channel(channel.visibility) {
2716            continue;
2717        }
2718
2719        let update = proto::UpdateChannels {
2720            channels: vec![channel.to_proto()],
2721            ..Default::default()
2722        };
2723        session.peer.send(connection_id, update.clone())?;
2724    }
2725
2726    Ok(())
2727}
2728
2729/// Delete a channel
2730async fn delete_channel(
2731    request: proto::DeleteChannel,
2732    response: Response<proto::DeleteChannel>,
2733    session: Session,
2734) -> Result<()> {
2735    let db = session.db().await;
2736
2737    let channel_id = request.channel_id;
2738    let (root_channel, removed_channels) = db
2739        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2740        .await?;
2741    response.send(proto::Ack {})?;
2742
2743    // Notify members of removed channels
2744    let mut update = proto::UpdateChannels::default();
2745    update
2746        .delete_channels
2747        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2748
2749    let connection_pool = session.connection_pool().await;
2750    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2751        session.peer.send(connection_id, update.clone())?;
2752    }
2753
2754    Ok(())
2755}
2756
2757/// Invite someone to join a channel.
2758async fn invite_channel_member(
2759    request: proto::InviteChannelMember,
2760    response: Response<proto::InviteChannelMember>,
2761    session: Session,
2762) -> Result<()> {
2763    let db = session.db().await;
2764    let channel_id = ChannelId::from_proto(request.channel_id);
2765    let invitee_id = UserId::from_proto(request.user_id);
2766    let InviteMemberResult {
2767        channel,
2768        notifications,
2769    } = db
2770        .invite_channel_member(
2771            channel_id,
2772            invitee_id,
2773            session.user_id(),
2774            request.role().into(),
2775        )
2776        .await?;
2777
2778    let update = proto::UpdateChannels {
2779        channel_invitations: vec![channel.to_proto()],
2780        ..Default::default()
2781    };
2782
2783    let connection_pool = session.connection_pool().await;
2784    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2785        session.peer.send(connection_id, update.clone())?;
2786    }
2787
2788    send_notifications(&connection_pool, &session.peer, notifications);
2789
2790    response.send(proto::Ack {})?;
2791    Ok(())
2792}
2793
2794/// remove someone from a channel
2795async fn remove_channel_member(
2796    request: proto::RemoveChannelMember,
2797    response: Response<proto::RemoveChannelMember>,
2798    session: Session,
2799) -> Result<()> {
2800    let db = session.db().await;
2801    let channel_id = ChannelId::from_proto(request.channel_id);
2802    let member_id = UserId::from_proto(request.user_id);
2803
2804    let RemoveChannelMemberResult {
2805        membership_update,
2806        notification_id,
2807    } = db
2808        .remove_channel_member(channel_id, member_id, session.user_id())
2809        .await?;
2810
2811    let mut connection_pool = session.connection_pool().await;
2812    notify_membership_updated(
2813        &mut connection_pool,
2814        membership_update,
2815        member_id,
2816        &session.peer,
2817    );
2818    for connection_id in connection_pool.user_connection_ids(member_id) {
2819        if let Some(notification_id) = notification_id {
2820            session
2821                .peer
2822                .send(
2823                    connection_id,
2824                    proto::DeleteNotification {
2825                        notification_id: notification_id.to_proto(),
2826                    },
2827                )
2828                .trace_err();
2829        }
2830    }
2831
2832    response.send(proto::Ack {})?;
2833    Ok(())
2834}
2835
2836/// Toggle the channel between public and private.
2837/// Care is taken to maintain the invariant that public channels only descend from public channels,
2838/// (though members-only channels can appear at any point in the hierarchy).
2839async fn set_channel_visibility(
2840    request: proto::SetChannelVisibility,
2841    response: Response<proto::SetChannelVisibility>,
2842    session: Session,
2843) -> Result<()> {
2844    let db = session.db().await;
2845    let channel_id = ChannelId::from_proto(request.channel_id);
2846    let visibility = request.visibility().into();
2847
2848    let channel_model = db
2849        .set_channel_visibility(channel_id, visibility, session.user_id())
2850        .await?;
2851    let root_id = channel_model.root_id();
2852    let channel = Channel::from_model(channel_model);
2853
2854    let mut connection_pool = session.connection_pool().await;
2855    for (user_id, role) in connection_pool
2856        .channel_user_ids(root_id)
2857        .collect::<Vec<_>>()
2858        .into_iter()
2859    {
2860        let update = if role.can_see_channel(channel.visibility) {
2861            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2862            proto::UpdateChannels {
2863                channels: vec![channel.to_proto()],
2864                ..Default::default()
2865            }
2866        } else {
2867            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2868            proto::UpdateChannels {
2869                delete_channels: vec![channel.id.to_proto()],
2870                ..Default::default()
2871            }
2872        };
2873
2874        for connection_id in connection_pool.user_connection_ids(user_id) {
2875            session.peer.send(connection_id, update.clone())?;
2876        }
2877    }
2878
2879    response.send(proto::Ack {})?;
2880    Ok(())
2881}
2882
2883/// Alter the role for a user in the channel.
2884async fn set_channel_member_role(
2885    request: proto::SetChannelMemberRole,
2886    response: Response<proto::SetChannelMemberRole>,
2887    session: Session,
2888) -> Result<()> {
2889    let db = session.db().await;
2890    let channel_id = ChannelId::from_proto(request.channel_id);
2891    let member_id = UserId::from_proto(request.user_id);
2892    let result = db
2893        .set_channel_member_role(
2894            channel_id,
2895            session.user_id(),
2896            member_id,
2897            request.role().into(),
2898        )
2899        .await?;
2900
2901    match result {
2902        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2903            let mut connection_pool = session.connection_pool().await;
2904            notify_membership_updated(
2905                &mut connection_pool,
2906                membership_update,
2907                member_id,
2908                &session.peer,
2909            )
2910        }
2911        db::SetMemberRoleResult::InviteUpdated(channel) => {
2912            let update = proto::UpdateChannels {
2913                channel_invitations: vec![channel.to_proto()],
2914                ..Default::default()
2915            };
2916
2917            for connection_id in session
2918                .connection_pool()
2919                .await
2920                .user_connection_ids(member_id)
2921            {
2922                session.peer.send(connection_id, update.clone())?;
2923            }
2924        }
2925    }
2926
2927    response.send(proto::Ack {})?;
2928    Ok(())
2929}
2930
2931/// Change the name of a channel
2932async fn rename_channel(
2933    request: proto::RenameChannel,
2934    response: Response<proto::RenameChannel>,
2935    session: Session,
2936) -> Result<()> {
2937    let db = session.db().await;
2938    let channel_id = ChannelId::from_proto(request.channel_id);
2939    let channel_model = db
2940        .rename_channel(channel_id, session.user_id(), &request.name)
2941        .await?;
2942    let root_id = channel_model.root_id();
2943    let channel = Channel::from_model(channel_model);
2944
2945    response.send(proto::RenameChannelResponse {
2946        channel: Some(channel.to_proto()),
2947    })?;
2948
2949    let connection_pool = session.connection_pool().await;
2950    let update = proto::UpdateChannels {
2951        channels: vec![channel.to_proto()],
2952        ..Default::default()
2953    };
2954    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2955        if role.can_see_channel(channel.visibility) {
2956            session.peer.send(connection_id, update.clone())?;
2957        }
2958    }
2959
2960    Ok(())
2961}
2962
2963/// Move a channel to a new parent.
2964async fn move_channel(
2965    request: proto::MoveChannel,
2966    response: Response<proto::MoveChannel>,
2967    session: Session,
2968) -> Result<()> {
2969    let channel_id = ChannelId::from_proto(request.channel_id);
2970    let to = ChannelId::from_proto(request.to);
2971
2972    let (root_id, channels) = session
2973        .db()
2974        .await
2975        .move_channel(channel_id, to, session.user_id())
2976        .await?;
2977
2978    let connection_pool = session.connection_pool().await;
2979    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2980        let channels = channels
2981            .iter()
2982            .filter_map(|channel| {
2983                if role.can_see_channel(channel.visibility) {
2984                    Some(channel.to_proto())
2985                } else {
2986                    None
2987                }
2988            })
2989            .collect::<Vec<_>>();
2990        if channels.is_empty() {
2991            continue;
2992        }
2993
2994        let update = proto::UpdateChannels {
2995            channels,
2996            ..Default::default()
2997        };
2998
2999        session.peer.send(connection_id, update.clone())?;
3000    }
3001
3002    response.send(Ack {})?;
3003    Ok(())
3004}
3005
3006/// Get the list of channel members
3007async fn get_channel_members(
3008    request: proto::GetChannelMembers,
3009    response: Response<proto::GetChannelMembers>,
3010    session: Session,
3011) -> Result<()> {
3012    let db = session.db().await;
3013    let channel_id = ChannelId::from_proto(request.channel_id);
3014    let limit = if request.limit == 0 {
3015        u16::MAX as u64
3016    } else {
3017        request.limit
3018    };
3019    let (members, users) = db
3020        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3021        .await?;
3022    response.send(proto::GetChannelMembersResponse { members, users })?;
3023    Ok(())
3024}
3025
3026/// Accept or decline a channel invitation.
3027async fn respond_to_channel_invite(
3028    request: proto::RespondToChannelInvite,
3029    response: Response<proto::RespondToChannelInvite>,
3030    session: Session,
3031) -> Result<()> {
3032    let db = session.db().await;
3033    let channel_id = ChannelId::from_proto(request.channel_id);
3034    let RespondToChannelInvite {
3035        membership_update,
3036        notifications,
3037    } = db
3038        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3039        .await?;
3040
3041    let mut connection_pool = session.connection_pool().await;
3042    if let Some(membership_update) = membership_update {
3043        notify_membership_updated(
3044            &mut connection_pool,
3045            membership_update,
3046            session.user_id(),
3047            &session.peer,
3048        );
3049    } else {
3050        let update = proto::UpdateChannels {
3051            remove_channel_invitations: vec![channel_id.to_proto()],
3052            ..Default::default()
3053        };
3054
3055        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3056            session.peer.send(connection_id, update.clone())?;
3057        }
3058    };
3059
3060    send_notifications(&connection_pool, &session.peer, notifications);
3061
3062    response.send(proto::Ack {})?;
3063
3064    Ok(())
3065}
3066
3067/// Join the channels' room
3068async fn join_channel(
3069    request: proto::JoinChannel,
3070    response: Response<proto::JoinChannel>,
3071    session: Session,
3072) -> Result<()> {
3073    let channel_id = ChannelId::from_proto(request.channel_id);
3074    join_channel_internal(channel_id, Box::new(response), session).await
3075}
3076
3077trait JoinChannelInternalResponse {
3078    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3079}
3080impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3081    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3082        Response::<proto::JoinChannel>::send(self, result)
3083    }
3084}
3085impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3086    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3087        Response::<proto::JoinRoom>::send(self, result)
3088    }
3089}
3090
3091async fn join_channel_internal(
3092    channel_id: ChannelId,
3093    response: Box<impl JoinChannelInternalResponse>,
3094    session: Session,
3095) -> Result<()> {
3096    let joined_room = {
3097        let mut db = session.db().await;
3098        // If zed quits without leaving the room, and the user re-opens zed before the
3099        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3100        // room they were in.
3101        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3102            tracing::info!(
3103                stale_connection_id = %connection,
3104                "cleaning up stale connection",
3105            );
3106            drop(db);
3107            leave_room_for_session(&session, connection).await?;
3108            db = session.db().await;
3109        }
3110
3111        let (joined_room, membership_updated, role) = db
3112            .join_channel(channel_id, session.user_id(), session.connection_id)
3113            .await?;
3114
3115        let live_kit_connection_info =
3116            session
3117                .app_state
3118                .livekit_client
3119                .as_ref()
3120                .and_then(|live_kit| {
3121                    let (can_publish, token) = if role == ChannelRole::Guest {
3122                        (
3123                            false,
3124                            live_kit
3125                                .guest_token(
3126                                    &joined_room.room.livekit_room,
3127                                    &session.user_id().to_string(),
3128                                )
3129                                .trace_err()?,
3130                        )
3131                    } else {
3132                        (
3133                            true,
3134                            live_kit
3135                                .room_token(
3136                                    &joined_room.room.livekit_room,
3137                                    &session.user_id().to_string(),
3138                                )
3139                                .trace_err()?,
3140                        )
3141                    };
3142
3143                    Some(LiveKitConnectionInfo {
3144                        server_url: live_kit.url().into(),
3145                        token,
3146                        can_publish,
3147                    })
3148                });
3149
3150        response.send(proto::JoinRoomResponse {
3151            room: Some(joined_room.room.clone()),
3152            channel_id: joined_room
3153                .channel
3154                .as_ref()
3155                .map(|channel| channel.id.to_proto()),
3156            live_kit_connection_info,
3157        })?;
3158
3159        let mut connection_pool = session.connection_pool().await;
3160        if let Some(membership_updated) = membership_updated {
3161            notify_membership_updated(
3162                &mut connection_pool,
3163                membership_updated,
3164                session.user_id(),
3165                &session.peer,
3166            );
3167        }
3168
3169        room_updated(&joined_room.room, &session.peer);
3170
3171        joined_room
3172    };
3173
3174    channel_updated(
3175        &joined_room
3176            .channel
3177            .ok_or_else(|| anyhow!("channel not returned"))?,
3178        &joined_room.room,
3179        &session.peer,
3180        &*session.connection_pool().await,
3181    );
3182
3183    update_user_contacts(session.user_id(), &session).await?;
3184    Ok(())
3185}
3186
3187/// Start editing the channel notes
3188async fn join_channel_buffer(
3189    request: proto::JoinChannelBuffer,
3190    response: Response<proto::JoinChannelBuffer>,
3191    session: Session,
3192) -> Result<()> {
3193    let db = session.db().await;
3194    let channel_id = ChannelId::from_proto(request.channel_id);
3195
3196    let open_response = db
3197        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3198        .await?;
3199
3200    let collaborators = open_response.collaborators.clone();
3201    response.send(open_response)?;
3202
3203    let update = UpdateChannelBufferCollaborators {
3204        channel_id: channel_id.to_proto(),
3205        collaborators: collaborators.clone(),
3206    };
3207    channel_buffer_updated(
3208        session.connection_id,
3209        collaborators
3210            .iter()
3211            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3212        &update,
3213        &session.peer,
3214    );
3215
3216    Ok(())
3217}
3218
3219/// Edit the channel notes
3220async fn update_channel_buffer(
3221    request: proto::UpdateChannelBuffer,
3222    session: Session,
3223) -> Result<()> {
3224    let db = session.db().await;
3225    let channel_id = ChannelId::from_proto(request.channel_id);
3226
3227    let (collaborators, epoch, version) = db
3228        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3229        .await?;
3230
3231    channel_buffer_updated(
3232        session.connection_id,
3233        collaborators.clone(),
3234        &proto::UpdateChannelBuffer {
3235            channel_id: channel_id.to_proto(),
3236            operations: request.operations,
3237        },
3238        &session.peer,
3239    );
3240
3241    let pool = &*session.connection_pool().await;
3242
3243    let non_collaborators =
3244        pool.channel_connection_ids(channel_id)
3245            .filter_map(|(connection_id, _)| {
3246                if collaborators.contains(&connection_id) {
3247                    None
3248                } else {
3249                    Some(connection_id)
3250                }
3251            });
3252
3253    broadcast(None, non_collaborators, |peer_id| {
3254        session.peer.send(
3255            peer_id,
3256            proto::UpdateChannels {
3257                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3258                    channel_id: channel_id.to_proto(),
3259                    epoch: epoch as u64,
3260                    version: version.clone(),
3261                }],
3262                ..Default::default()
3263            },
3264        )
3265    });
3266
3267    Ok(())
3268}
3269
3270/// Rejoin the channel notes after a connection blip
3271async fn rejoin_channel_buffers(
3272    request: proto::RejoinChannelBuffers,
3273    response: Response<proto::RejoinChannelBuffers>,
3274    session: Session,
3275) -> Result<()> {
3276    let db = session.db().await;
3277    let buffers = db
3278        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3279        .await?;
3280
3281    for rejoined_buffer in &buffers {
3282        let collaborators_to_notify = rejoined_buffer
3283            .buffer
3284            .collaborators
3285            .iter()
3286            .filter_map(|c| Some(c.peer_id?.into()));
3287        channel_buffer_updated(
3288            session.connection_id,
3289            collaborators_to_notify,
3290            &proto::UpdateChannelBufferCollaborators {
3291                channel_id: rejoined_buffer.buffer.channel_id,
3292                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3293            },
3294            &session.peer,
3295        );
3296    }
3297
3298    response.send(proto::RejoinChannelBuffersResponse {
3299        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3300    })?;
3301
3302    Ok(())
3303}
3304
3305/// Stop editing the channel notes
3306async fn leave_channel_buffer(
3307    request: proto::LeaveChannelBuffer,
3308    response: Response<proto::LeaveChannelBuffer>,
3309    session: Session,
3310) -> Result<()> {
3311    let db = session.db().await;
3312    let channel_id = ChannelId::from_proto(request.channel_id);
3313
3314    let left_buffer = db
3315        .leave_channel_buffer(channel_id, session.connection_id)
3316        .await?;
3317
3318    response.send(Ack {})?;
3319
3320    channel_buffer_updated(
3321        session.connection_id,
3322        left_buffer.connections,
3323        &proto::UpdateChannelBufferCollaborators {
3324            channel_id: channel_id.to_proto(),
3325            collaborators: left_buffer.collaborators,
3326        },
3327        &session.peer,
3328    );
3329
3330    Ok(())
3331}
3332
3333fn channel_buffer_updated<T: EnvelopedMessage>(
3334    sender_id: ConnectionId,
3335    collaborators: impl IntoIterator<Item = ConnectionId>,
3336    message: &T,
3337    peer: &Peer,
3338) {
3339    broadcast(Some(sender_id), collaborators, |peer_id| {
3340        peer.send(peer_id, message.clone())
3341    });
3342}
3343
3344fn send_notifications(
3345    connection_pool: &ConnectionPool,
3346    peer: &Peer,
3347    notifications: db::NotificationBatch,
3348) {
3349    for (user_id, notification) in notifications {
3350        for connection_id in connection_pool.user_connection_ids(user_id) {
3351            if let Err(error) = peer.send(
3352                connection_id,
3353                proto::AddNotification {
3354                    notification: Some(notification.clone()),
3355                },
3356            ) {
3357                tracing::error!(
3358                    "failed to send notification to {:?} {}",
3359                    connection_id,
3360                    error
3361                );
3362            }
3363        }
3364    }
3365}
3366
3367/// Send a message to the channel
3368async fn send_channel_message(
3369    request: proto::SendChannelMessage,
3370    response: Response<proto::SendChannelMessage>,
3371    session: Session,
3372) -> Result<()> {
3373    // Validate the message body.
3374    let body = request.body.trim().to_string();
3375    if body.len() > MAX_MESSAGE_LEN {
3376        return Err(anyhow!("message is too long"))?;
3377    }
3378    if body.is_empty() {
3379        return Err(anyhow!("message can't be blank"))?;
3380    }
3381
3382    // TODO: adjust mentions if body is trimmed
3383
3384    let timestamp = OffsetDateTime::now_utc();
3385    let nonce = request
3386        .nonce
3387        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3388
3389    let channel_id = ChannelId::from_proto(request.channel_id);
3390    let CreatedChannelMessage {
3391        message_id,
3392        participant_connection_ids,
3393        notifications,
3394    } = session
3395        .db()
3396        .await
3397        .create_channel_message(
3398            channel_id,
3399            session.user_id(),
3400            &body,
3401            &request.mentions,
3402            timestamp,
3403            nonce.clone().into(),
3404            request.reply_to_message_id.map(MessageId::from_proto),
3405        )
3406        .await?;
3407
3408    let message = proto::ChannelMessage {
3409        sender_id: session.user_id().to_proto(),
3410        id: message_id.to_proto(),
3411        body,
3412        mentions: request.mentions,
3413        timestamp: timestamp.unix_timestamp() as u64,
3414        nonce: Some(nonce),
3415        reply_to_message_id: request.reply_to_message_id,
3416        edited_at: None,
3417    };
3418    broadcast(
3419        Some(session.connection_id),
3420        participant_connection_ids.clone(),
3421        |connection| {
3422            session.peer.send(
3423                connection,
3424                proto::ChannelMessageSent {
3425                    channel_id: channel_id.to_proto(),
3426                    message: Some(message.clone()),
3427                },
3428            )
3429        },
3430    );
3431    response.send(proto::SendChannelMessageResponse {
3432        message: Some(message),
3433    })?;
3434
3435    let pool = &*session.connection_pool().await;
3436    let non_participants =
3437        pool.channel_connection_ids(channel_id)
3438            .filter_map(|(connection_id, _)| {
3439                if participant_connection_ids.contains(&connection_id) {
3440                    None
3441                } else {
3442                    Some(connection_id)
3443                }
3444            });
3445    broadcast(None, non_participants, |peer_id| {
3446        session.peer.send(
3447            peer_id,
3448            proto::UpdateChannels {
3449                latest_channel_message_ids: vec![proto::ChannelMessageId {
3450                    channel_id: channel_id.to_proto(),
3451                    message_id: message_id.to_proto(),
3452                }],
3453                ..Default::default()
3454            },
3455        )
3456    });
3457    send_notifications(pool, &session.peer, notifications);
3458
3459    Ok(())
3460}
3461
3462/// Delete a channel message
3463async fn remove_channel_message(
3464    request: proto::RemoveChannelMessage,
3465    response: Response<proto::RemoveChannelMessage>,
3466    session: Session,
3467) -> Result<()> {
3468    let channel_id = ChannelId::from_proto(request.channel_id);
3469    let message_id = MessageId::from_proto(request.message_id);
3470    let (connection_ids, existing_notification_ids) = session
3471        .db()
3472        .await
3473        .remove_channel_message(channel_id, message_id, session.user_id())
3474        .await?;
3475
3476    broadcast(
3477        Some(session.connection_id),
3478        connection_ids,
3479        move |connection| {
3480            session.peer.send(connection, request.clone())?;
3481
3482            for notification_id in &existing_notification_ids {
3483                session.peer.send(
3484                    connection,
3485                    proto::DeleteNotification {
3486                        notification_id: (*notification_id).to_proto(),
3487                    },
3488                )?;
3489            }
3490
3491            Ok(())
3492        },
3493    );
3494    response.send(proto::Ack {})?;
3495    Ok(())
3496}
3497
3498async fn update_channel_message(
3499    request: proto::UpdateChannelMessage,
3500    response: Response<proto::UpdateChannelMessage>,
3501    session: Session,
3502) -> Result<()> {
3503    let channel_id = ChannelId::from_proto(request.channel_id);
3504    let message_id = MessageId::from_proto(request.message_id);
3505    let updated_at = OffsetDateTime::now_utc();
3506    let UpdatedChannelMessage {
3507        message_id,
3508        participant_connection_ids,
3509        notifications,
3510        reply_to_message_id,
3511        timestamp,
3512        deleted_mention_notification_ids,
3513        updated_mention_notifications,
3514    } = session
3515        .db()
3516        .await
3517        .update_channel_message(
3518            channel_id,
3519            message_id,
3520            session.user_id(),
3521            request.body.as_str(),
3522            &request.mentions,
3523            updated_at,
3524        )
3525        .await?;
3526
3527    let nonce = request
3528        .nonce
3529        .clone()
3530        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3531
3532    let message = proto::ChannelMessage {
3533        sender_id: session.user_id().to_proto(),
3534        id: message_id.to_proto(),
3535        body: request.body.clone(),
3536        mentions: request.mentions.clone(),
3537        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3538        nonce: Some(nonce),
3539        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3540        edited_at: Some(updated_at.unix_timestamp() as u64),
3541    };
3542
3543    response.send(proto::Ack {})?;
3544
3545    let pool = &*session.connection_pool().await;
3546    broadcast(
3547        Some(session.connection_id),
3548        participant_connection_ids,
3549        |connection| {
3550            session.peer.send(
3551                connection,
3552                proto::ChannelMessageUpdate {
3553                    channel_id: channel_id.to_proto(),
3554                    message: Some(message.clone()),
3555                },
3556            )?;
3557
3558            for notification_id in &deleted_mention_notification_ids {
3559                session.peer.send(
3560                    connection,
3561                    proto::DeleteNotification {
3562                        notification_id: (*notification_id).to_proto(),
3563                    },
3564                )?;
3565            }
3566
3567            for notification in &updated_mention_notifications {
3568                session.peer.send(
3569                    connection,
3570                    proto::UpdateNotification {
3571                        notification: Some(notification.clone()),
3572                    },
3573                )?;
3574            }
3575
3576            Ok(())
3577        },
3578    );
3579
3580    send_notifications(pool, &session.peer, notifications);
3581
3582    Ok(())
3583}
3584
3585/// Mark a channel message as read
3586async fn acknowledge_channel_message(
3587    request: proto::AckChannelMessage,
3588    session: Session,
3589) -> Result<()> {
3590    let channel_id = ChannelId::from_proto(request.channel_id);
3591    let message_id = MessageId::from_proto(request.message_id);
3592    let notifications = session
3593        .db()
3594        .await
3595        .observe_channel_message(channel_id, session.user_id(), message_id)
3596        .await?;
3597    send_notifications(
3598        &*session.connection_pool().await,
3599        &session.peer,
3600        notifications,
3601    );
3602    Ok(())
3603}
3604
3605/// Mark a buffer version as synced
3606async fn acknowledge_buffer_version(
3607    request: proto::AckBufferOperation,
3608    session: Session,
3609) -> Result<()> {
3610    let buffer_id = BufferId::from_proto(request.buffer_id);
3611    session
3612        .db()
3613        .await
3614        .observe_buffer_version(
3615            buffer_id,
3616            session.user_id(),
3617            request.epoch as i32,
3618            &request.version,
3619        )
3620        .await?;
3621    Ok(())
3622}
3623
3624async fn count_language_model_tokens(
3625    request: proto::CountLanguageModelTokens,
3626    response: Response<proto::CountLanguageModelTokens>,
3627    session: Session,
3628    config: &Config,
3629) -> Result<()> {
3630    authorize_access_to_legacy_llm_endpoints(&session).await?;
3631
3632    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3633        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3634        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3635    };
3636
3637    session
3638        .app_state
3639        .rate_limiter
3640        .check(&*rate_limit, session.user_id())
3641        .await?;
3642
3643    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3644        Some(proto::LanguageModelProvider::Google) => {
3645            let api_key = config
3646                .google_ai_api_key
3647                .as_ref()
3648                .context("no Google AI API key configured on the server")?;
3649            google_ai::count_tokens(
3650                session.http_client.as_ref(),
3651                google_ai::API_URL,
3652                api_key,
3653                serde_json::from_str(&request.request)?,
3654            )
3655            .await?
3656        }
3657        _ => return Err(anyhow!("unsupported provider"))?,
3658    };
3659
3660    response.send(proto::CountLanguageModelTokensResponse {
3661        token_count: result.total_tokens as u32,
3662    })?;
3663
3664    Ok(())
3665}
3666
3667struct ZedProCountLanguageModelTokensRateLimit;
3668
3669impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3670    fn capacity(&self) -> usize {
3671        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3672            .ok()
3673            .and_then(|v| v.parse().ok())
3674            .unwrap_or(600) // Picked arbitrarily
3675    }
3676
3677    fn refill_duration(&self) -> chrono::Duration {
3678        chrono::Duration::hours(1)
3679    }
3680
3681    fn db_name(&self) -> &'static str {
3682        "zed-pro:count-language-model-tokens"
3683    }
3684}
3685
3686struct FreeCountLanguageModelTokensRateLimit;
3687
3688impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3689    fn capacity(&self) -> usize {
3690        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3691            .ok()
3692            .and_then(|v| v.parse().ok())
3693            .unwrap_or(600 / 10) // Picked arbitrarily
3694    }
3695
3696    fn refill_duration(&self) -> chrono::Duration {
3697        chrono::Duration::hours(1)
3698    }
3699
3700    fn db_name(&self) -> &'static str {
3701        "free:count-language-model-tokens"
3702    }
3703}
3704
3705struct ZedProComputeEmbeddingsRateLimit;
3706
3707impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3708    fn capacity(&self) -> usize {
3709        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3710            .ok()
3711            .and_then(|v| v.parse().ok())
3712            .unwrap_or(5000) // Picked arbitrarily
3713    }
3714
3715    fn refill_duration(&self) -> chrono::Duration {
3716        chrono::Duration::hours(1)
3717    }
3718
3719    fn db_name(&self) -> &'static str {
3720        "zed-pro:compute-embeddings"
3721    }
3722}
3723
3724struct FreeComputeEmbeddingsRateLimit;
3725
3726impl RateLimit for FreeComputeEmbeddingsRateLimit {
3727    fn capacity(&self) -> usize {
3728        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3729            .ok()
3730            .and_then(|v| v.parse().ok())
3731            .unwrap_or(5000 / 10) // Picked arbitrarily
3732    }
3733
3734    fn refill_duration(&self) -> chrono::Duration {
3735        chrono::Duration::hours(1)
3736    }
3737
3738    fn db_name(&self) -> &'static str {
3739        "free:compute-embeddings"
3740    }
3741}
3742
3743async fn compute_embeddings(
3744    request: proto::ComputeEmbeddings,
3745    response: Response<proto::ComputeEmbeddings>,
3746    session: Session,
3747    api_key: Option<Arc<str>>,
3748) -> Result<()> {
3749    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3750    authorize_access_to_legacy_llm_endpoints(&session).await?;
3751
3752    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3753        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3754        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3755    };
3756
3757    session
3758        .app_state
3759        .rate_limiter
3760        .check(&*rate_limit, session.user_id())
3761        .await?;
3762
3763    let embeddings = match request.model.as_str() {
3764        "openai/text-embedding-3-small" => {
3765            open_ai::embed(
3766                session.http_client.as_ref(),
3767                OPEN_AI_API_URL,
3768                &api_key,
3769                OpenAiEmbeddingModel::TextEmbedding3Small,
3770                request.texts.iter().map(|text| text.as_str()),
3771            )
3772            .await?
3773        }
3774        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3775    };
3776
3777    let embeddings = request
3778        .texts
3779        .iter()
3780        .map(|text| {
3781            let mut hasher = sha2::Sha256::new();
3782            hasher.update(text.as_bytes());
3783            let result = hasher.finalize();
3784            result.to_vec()
3785        })
3786        .zip(
3787            embeddings
3788                .data
3789                .into_iter()
3790                .map(|embedding| embedding.embedding),
3791        )
3792        .collect::<HashMap<_, _>>();
3793
3794    let db = session.db().await;
3795    db.save_embeddings(&request.model, &embeddings)
3796        .await
3797        .context("failed to save embeddings")
3798        .trace_err();
3799
3800    response.send(proto::ComputeEmbeddingsResponse {
3801        embeddings: embeddings
3802            .into_iter()
3803            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3804            .collect(),
3805    })?;
3806    Ok(())
3807}
3808
3809async fn get_cached_embeddings(
3810    request: proto::GetCachedEmbeddings,
3811    response: Response<proto::GetCachedEmbeddings>,
3812    session: Session,
3813) -> Result<()> {
3814    authorize_access_to_legacy_llm_endpoints(&session).await?;
3815
3816    let db = session.db().await;
3817    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3818
3819    response.send(proto::GetCachedEmbeddingsResponse {
3820        embeddings: embeddings
3821            .into_iter()
3822            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3823            .collect(),
3824    })?;
3825    Ok(())
3826}
3827
3828/// This is leftover from before the LLM service.
3829///
3830/// The endpoints protected by this check will be moved there eventually.
3831async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3832    if session.is_staff() {
3833        Ok(())
3834    } else {
3835        Err(anyhow!("permission denied"))?
3836    }
3837}
3838
3839/// Get a Supermaven API key for the user
3840async fn get_supermaven_api_key(
3841    _request: proto::GetSupermavenApiKey,
3842    response: Response<proto::GetSupermavenApiKey>,
3843    session: Session,
3844) -> Result<()> {
3845    let user_id: String = session.user_id().to_string();
3846    if !session.is_staff() {
3847        return Err(anyhow!("supermaven not enabled for this account"))?;
3848    }
3849
3850    let email = session
3851        .email()
3852        .ok_or_else(|| anyhow!("user must have an email"))?;
3853
3854    let supermaven_admin_api = session
3855        .supermaven_client
3856        .as_ref()
3857        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3858
3859    let result = supermaven_admin_api
3860        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3861        .await?;
3862
3863    response.send(proto::GetSupermavenApiKeyResponse {
3864        api_key: result.api_key,
3865    })?;
3866
3867    Ok(())
3868}
3869
3870/// Start receiving chat updates for a channel
3871async fn join_channel_chat(
3872    request: proto::JoinChannelChat,
3873    response: Response<proto::JoinChannelChat>,
3874    session: Session,
3875) -> Result<()> {
3876    let channel_id = ChannelId::from_proto(request.channel_id);
3877
3878    let db = session.db().await;
3879    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3880        .await?;
3881    let messages = db
3882        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3883        .await?;
3884    response.send(proto::JoinChannelChatResponse {
3885        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3886        messages,
3887    })?;
3888    Ok(())
3889}
3890
3891/// Stop receiving chat updates for a channel
3892async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3893    let channel_id = ChannelId::from_proto(request.channel_id);
3894    session
3895        .db()
3896        .await
3897        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3898        .await?;
3899    Ok(())
3900}
3901
3902/// Retrieve the chat history for a channel
3903async fn get_channel_messages(
3904    request: proto::GetChannelMessages,
3905    response: Response<proto::GetChannelMessages>,
3906    session: Session,
3907) -> Result<()> {
3908    let channel_id = ChannelId::from_proto(request.channel_id);
3909    let messages = session
3910        .db()
3911        .await
3912        .get_channel_messages(
3913            channel_id,
3914            session.user_id(),
3915            MESSAGE_COUNT_PER_PAGE,
3916            Some(MessageId::from_proto(request.before_message_id)),
3917        )
3918        .await?;
3919    response.send(proto::GetChannelMessagesResponse {
3920        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3921        messages,
3922    })?;
3923    Ok(())
3924}
3925
3926/// Retrieve specific chat messages
3927async fn get_channel_messages_by_id(
3928    request: proto::GetChannelMessagesById,
3929    response: Response<proto::GetChannelMessagesById>,
3930    session: Session,
3931) -> Result<()> {
3932    let message_ids = request
3933        .message_ids
3934        .iter()
3935        .map(|id| MessageId::from_proto(*id))
3936        .collect::<Vec<_>>();
3937    let messages = session
3938        .db()
3939        .await
3940        .get_channel_messages_by_id(session.user_id(), &message_ids)
3941        .await?;
3942    response.send(proto::GetChannelMessagesResponse {
3943        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3944        messages,
3945    })?;
3946    Ok(())
3947}
3948
3949/// Retrieve the current users notifications
3950async fn get_notifications(
3951    request: proto::GetNotifications,
3952    response: Response<proto::GetNotifications>,
3953    session: Session,
3954) -> Result<()> {
3955    let notifications = session
3956        .db()
3957        .await
3958        .get_notifications(
3959            session.user_id(),
3960            NOTIFICATION_COUNT_PER_PAGE,
3961            request.before_id.map(db::NotificationId::from_proto),
3962        )
3963        .await?;
3964    response.send(proto::GetNotificationsResponse {
3965        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3966        notifications,
3967    })?;
3968    Ok(())
3969}
3970
3971/// Mark notifications as read
3972async fn mark_notification_as_read(
3973    request: proto::MarkNotificationRead,
3974    response: Response<proto::MarkNotificationRead>,
3975    session: Session,
3976) -> Result<()> {
3977    let database = &session.db().await;
3978    let notifications = database
3979        .mark_notification_as_read_by_id(
3980            session.user_id(),
3981            NotificationId::from_proto(request.notification_id),
3982        )
3983        .await?;
3984    send_notifications(
3985        &*session.connection_pool().await,
3986        &session.peer,
3987        notifications,
3988    );
3989    response.send(proto::Ack {})?;
3990    Ok(())
3991}
3992
3993/// Get the current users information
3994async fn get_private_user_info(
3995    _request: proto::GetPrivateUserInfo,
3996    response: Response<proto::GetPrivateUserInfo>,
3997    session: Session,
3998) -> Result<()> {
3999    let db = session.db().await;
4000
4001    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4002    let user = db
4003        .get_user_by_id(session.user_id())
4004        .await?
4005        .ok_or_else(|| anyhow!("user not found"))?;
4006    let flags = db.get_user_flags(session.user_id()).await?;
4007
4008    response.send(proto::GetPrivateUserInfoResponse {
4009        metrics_id,
4010        staff: user.admin,
4011        flags,
4012        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4013    })?;
4014    Ok(())
4015}
4016
4017/// Accept the terms of service (tos) on behalf of the current user
4018async fn accept_terms_of_service(
4019    _request: proto::AcceptTermsOfService,
4020    response: Response<proto::AcceptTermsOfService>,
4021    session: Session,
4022) -> Result<()> {
4023    let db = session.db().await;
4024
4025    let accepted_tos_at = Utc::now();
4026    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4027        .await?;
4028
4029    response.send(proto::AcceptTermsOfServiceResponse {
4030        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4031    })?;
4032    Ok(())
4033}
4034
4035/// The minimum account age an account must have in order to use the LLM service.
4036const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4037
4038async fn get_llm_api_token(
4039    _request: proto::GetLlmToken,
4040    response: Response<proto::GetLlmToken>,
4041    session: Session,
4042) -> Result<()> {
4043    let db = session.db().await;
4044
4045    let flags = db.get_user_flags(session.user_id()).await?;
4046    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4047    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4048    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4049
4050    if !session.is_staff() && !has_language_models_feature_flag {
4051        Err(anyhow!("permission denied"))?
4052    }
4053
4054    let user_id = session.user_id();
4055    let user = db
4056        .get_user_by_id(user_id)
4057        .await?
4058        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4059
4060    if user.accepted_tos_at.is_none() {
4061        Err(anyhow!("terms of service not accepted"))?
4062    }
4063
4064    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4065
4066    let bypass_account_age_check =
4067        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4068    if !bypass_account_age_check {
4069        let mut account_created_at = user.created_at;
4070        if let Some(github_created_at) = user.github_user_created_at {
4071            account_created_at = account_created_at.min(github_created_at);
4072        }
4073        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4074            Err(anyhow!("account too young"))?
4075        }
4076    }
4077
4078    let billing_preferences = db.get_billing_preferences(user.id).await?;
4079
4080    let token = LlmTokenClaims::create(
4081        &user,
4082        session.is_staff(),
4083        billing_preferences,
4084        has_llm_closed_beta_feature_flag,
4085        has_predict_edits_feature_flag,
4086        has_llm_subscription,
4087        session.current_plan(&db).await?,
4088        session.system_id.clone(),
4089        &session.app_state.config,
4090    )?;
4091    response.send(proto::GetLlmTokenResponse { token })?;
4092    Ok(())
4093}
4094
4095fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4096    let message = match message {
4097        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4098        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4099        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4100        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4101        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4102            code: frame.code.into(),
4103            reason: frame.reason,
4104        })),
4105        // We should never receive a frame while reading the message, according
4106        // to the `tungstenite` maintainers:
4107        //
4108        // > It cannot occur when you read messages from the WebSocket, but it
4109        // > can be used when you want to send the raw frames (e.g. you want to
4110        // > send the frames to the WebSocket without composing the full message first).
4111        // >
4112        // > — https://github.com/snapview/tungstenite-rs/issues/268
4113        TungsteniteMessage::Frame(_) => {
4114            bail!("received an unexpected frame while reading the message")
4115        }
4116    };
4117
4118    Ok(message)
4119}
4120
4121fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4122    match message {
4123        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4124        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4125        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4126        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4127        AxumMessage::Close(frame) => {
4128            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4129                code: frame.code.into(),
4130                reason: frame.reason,
4131            }))
4132        }
4133    }
4134}
4135
4136fn notify_membership_updated(
4137    connection_pool: &mut ConnectionPool,
4138    result: MembershipUpdated,
4139    user_id: UserId,
4140    peer: &Peer,
4141) {
4142    for membership in &result.new_channels.channel_memberships {
4143        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4144    }
4145    for channel_id in &result.removed_channels {
4146        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4147    }
4148
4149    let user_channels_update = proto::UpdateUserChannels {
4150        channel_memberships: result
4151            .new_channels
4152            .channel_memberships
4153            .iter()
4154            .map(|cm| proto::ChannelMembership {
4155                channel_id: cm.channel_id.to_proto(),
4156                role: cm.role.into(),
4157            })
4158            .collect(),
4159        ..Default::default()
4160    };
4161
4162    let mut update = build_channels_update(result.new_channels);
4163    update.delete_channels = result
4164        .removed_channels
4165        .into_iter()
4166        .map(|id| id.to_proto())
4167        .collect();
4168    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4169
4170    for connection_id in connection_pool.user_connection_ids(user_id) {
4171        peer.send(connection_id, user_channels_update.clone())
4172            .trace_err();
4173        peer.send(connection_id, update.clone()).trace_err();
4174    }
4175}
4176
4177fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4178    proto::UpdateUserChannels {
4179        channel_memberships: channels
4180            .channel_memberships
4181            .iter()
4182            .map(|m| proto::ChannelMembership {
4183                channel_id: m.channel_id.to_proto(),
4184                role: m.role.into(),
4185            })
4186            .collect(),
4187        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4188        observed_channel_message_id: channels.observed_channel_messages.clone(),
4189    }
4190}
4191
4192fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4193    let mut update = proto::UpdateChannels::default();
4194
4195    for channel in channels.channels {
4196        update.channels.push(channel.to_proto());
4197    }
4198
4199    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4200    update.latest_channel_message_ids = channels.latest_channel_messages;
4201
4202    for (channel_id, participants) in channels.channel_participants {
4203        update
4204            .channel_participants
4205            .push(proto::ChannelParticipants {
4206                channel_id: channel_id.to_proto(),
4207                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4208            });
4209    }
4210
4211    for channel in channels.invited_channels {
4212        update.channel_invitations.push(channel.to_proto());
4213    }
4214
4215    update
4216}
4217
4218fn build_initial_contacts_update(
4219    contacts: Vec<db::Contact>,
4220    pool: &ConnectionPool,
4221) -> proto::UpdateContacts {
4222    let mut update = proto::UpdateContacts::default();
4223
4224    for contact in contacts {
4225        match contact {
4226            db::Contact::Accepted { user_id, busy } => {
4227                update.contacts.push(contact_for_user(user_id, busy, pool));
4228            }
4229            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4230            db::Contact::Incoming { user_id } => {
4231                update
4232                    .incoming_requests
4233                    .push(proto::IncomingContactRequest {
4234                        requester_id: user_id.to_proto(),
4235                    })
4236            }
4237        }
4238    }
4239
4240    update
4241}
4242
4243fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4244    proto::Contact {
4245        user_id: user_id.to_proto(),
4246        online: pool.is_user_online(user_id),
4247        busy,
4248    }
4249}
4250
4251fn room_updated(room: &proto::Room, peer: &Peer) {
4252    broadcast(
4253        None,
4254        room.participants
4255            .iter()
4256            .filter_map(|participant| Some(participant.peer_id?.into())),
4257        |peer_id| {
4258            peer.send(
4259                peer_id,
4260                proto::RoomUpdated {
4261                    room: Some(room.clone()),
4262                },
4263            )
4264        },
4265    );
4266}
4267
4268fn channel_updated(
4269    channel: &db::channel::Model,
4270    room: &proto::Room,
4271    peer: &Peer,
4272    pool: &ConnectionPool,
4273) {
4274    let participants = room
4275        .participants
4276        .iter()
4277        .map(|p| p.user_id)
4278        .collect::<Vec<_>>();
4279
4280    broadcast(
4281        None,
4282        pool.channel_connection_ids(channel.root_id())
4283            .filter_map(|(channel_id, role)| {
4284                role.can_see_channel(channel.visibility)
4285                    .then_some(channel_id)
4286            }),
4287        |peer_id| {
4288            peer.send(
4289                peer_id,
4290                proto::UpdateChannels {
4291                    channel_participants: vec![proto::ChannelParticipants {
4292                        channel_id: channel.id.to_proto(),
4293                        participant_user_ids: participants.clone(),
4294                    }],
4295                    ..Default::default()
4296                },
4297            )
4298        },
4299    );
4300}
4301
4302async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4303    let db = session.db().await;
4304
4305    let contacts = db.get_contacts(user_id).await?;
4306    let busy = db.is_user_busy(user_id).await?;
4307
4308    let pool = session.connection_pool().await;
4309    let updated_contact = contact_for_user(user_id, busy, &pool);
4310    for contact in contacts {
4311        if let db::Contact::Accepted {
4312            user_id: contact_user_id,
4313            ..
4314        } = contact
4315        {
4316            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4317                session
4318                    .peer
4319                    .send(
4320                        contact_conn_id,
4321                        proto::UpdateContacts {
4322                            contacts: vec![updated_contact.clone()],
4323                            remove_contacts: Default::default(),
4324                            incoming_requests: Default::default(),
4325                            remove_incoming_requests: Default::default(),
4326                            outgoing_requests: Default::default(),
4327                            remove_outgoing_requests: Default::default(),
4328                        },
4329                    )
4330                    .trace_err();
4331            }
4332        }
4333    }
4334    Ok(())
4335}
4336
4337async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4338    let mut contacts_to_update = HashSet::default();
4339
4340    let room_id;
4341    let canceled_calls_to_user_ids;
4342    let livekit_room;
4343    let delete_livekit_room;
4344    let room;
4345    let channel;
4346
4347    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4348        contacts_to_update.insert(session.user_id());
4349
4350        for project in left_room.left_projects.values() {
4351            project_left(project, session);
4352        }
4353
4354        room_id = RoomId::from_proto(left_room.room.id);
4355        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4356        livekit_room = mem::take(&mut left_room.room.livekit_room);
4357        delete_livekit_room = left_room.deleted;
4358        room = mem::take(&mut left_room.room);
4359        channel = mem::take(&mut left_room.channel);
4360
4361        room_updated(&room, &session.peer);
4362    } else {
4363        return Ok(());
4364    }
4365
4366    if let Some(channel) = channel {
4367        channel_updated(
4368            &channel,
4369            &room,
4370            &session.peer,
4371            &*session.connection_pool().await,
4372        );
4373    }
4374
4375    {
4376        let pool = session.connection_pool().await;
4377        for canceled_user_id in canceled_calls_to_user_ids {
4378            for connection_id in pool.user_connection_ids(canceled_user_id) {
4379                session
4380                    .peer
4381                    .send(
4382                        connection_id,
4383                        proto::CallCanceled {
4384                            room_id: room_id.to_proto(),
4385                        },
4386                    )
4387                    .trace_err();
4388            }
4389            contacts_to_update.insert(canceled_user_id);
4390        }
4391    }
4392
4393    for contact_user_id in contacts_to_update {
4394        update_user_contacts(contact_user_id, session).await?;
4395    }
4396
4397    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4398        live_kit
4399            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4400            .await
4401            .trace_err();
4402
4403        if delete_livekit_room {
4404            live_kit.delete_room(livekit_room).await.trace_err();
4405        }
4406    }
4407
4408    Ok(())
4409}
4410
4411async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4412    let left_channel_buffers = session
4413        .db()
4414        .await
4415        .leave_channel_buffers(session.connection_id)
4416        .await?;
4417
4418    for left_buffer in left_channel_buffers {
4419        channel_buffer_updated(
4420            session.connection_id,
4421            left_buffer.connections,
4422            &proto::UpdateChannelBufferCollaborators {
4423                channel_id: left_buffer.channel_id.to_proto(),
4424                collaborators: left_buffer.collaborators,
4425            },
4426            &session.peer,
4427        );
4428    }
4429
4430    Ok(())
4431}
4432
4433fn project_left(project: &db::LeftProject, session: &Session) {
4434    for connection_id in &project.connection_ids {
4435        if project.should_unshare {
4436            session
4437                .peer
4438                .send(
4439                    *connection_id,
4440                    proto::UnshareProject {
4441                        project_id: project.id.to_proto(),
4442                    },
4443                )
4444                .trace_err();
4445        } else {
4446            session
4447                .peer
4448                .send(
4449                    *connection_id,
4450                    proto::RemoveProjectCollaborator {
4451                        project_id: project.id.to_proto(),
4452                        peer_id: Some(session.connection_id.into()),
4453                    },
4454                )
4455                .trace_err();
4456        }
4457    }
4458}
4459
4460pub trait ResultExt {
4461    type Ok;
4462
4463    fn trace_err(self) -> Option<Self::Ok>;
4464}
4465
4466impl<T, E> ResultExt for Result<T, E>
4467where
4468    E: std::fmt::Debug,
4469{
4470    type Ok = T;
4471
4472    #[track_caller]
4473    fn trace_err(self) -> Option<T> {
4474        match self {
4475            Ok(value) => Some(value),
4476            Err(error) => {
4477                tracing::error!("{:?}", error);
4478                None
4479            }
4480        }
4481    }
4482}