rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 332            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 338            .add_request_handler(
 339                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 340            )
 341            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 342            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 345            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 347            .add_message_handler(create_buffer_for_peer)
 348            .add_request_handler(update_buffer)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 354            .add_request_handler(get_users)
 355            .add_request_handler(fuzzy_search_users)
 356            .add_request_handler(request_contact)
 357            .add_request_handler(remove_contact)
 358            .add_request_handler(respond_to_contact_request)
 359            .add_message_handler(subscribe_to_channels)
 360            .add_request_handler(create_channel)
 361            .add_request_handler(delete_channel)
 362            .add_request_handler(invite_channel_member)
 363            .add_request_handler(remove_channel_member)
 364            .add_request_handler(set_channel_member_role)
 365            .add_request_handler(set_channel_visibility)
 366            .add_request_handler(rename_channel)
 367            .add_request_handler(join_channel_buffer)
 368            .add_request_handler(leave_channel_buffer)
 369            .add_message_handler(update_channel_buffer)
 370            .add_request_handler(rejoin_channel_buffers)
 371            .add_request_handler(get_channel_members)
 372            .add_request_handler(respond_to_channel_invite)
 373            .add_request_handler(join_channel)
 374            .add_request_handler(join_channel_chat)
 375            .add_message_handler(leave_channel_chat)
 376            .add_request_handler(send_channel_message)
 377            .add_request_handler(remove_channel_message)
 378            .add_request_handler(update_channel_message)
 379            .add_request_handler(get_channel_messages)
 380            .add_request_handler(get_channel_messages_by_id)
 381            .add_request_handler(get_notifications)
 382            .add_request_handler(mark_notification_as_read)
 383            .add_request_handler(move_channel)
 384            .add_request_handler(follow)
 385            .add_message_handler(unfollow)
 386            .add_message_handler(update_followers)
 387            .add_request_handler(get_private_user_info)
 388            .add_request_handler(get_llm_api_token)
 389            .add_request_handler(accept_terms_of_service)
 390            .add_message_handler(acknowledge_channel_message)
 391            .add_message_handler(acknowledge_buffer_version)
 392            .add_request_handler(get_supermaven_api_key)
 393            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 395            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Push>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Pull>)
 398            .add_request_handler(forward_mutating_project_request::<proto::Fetch>)
 399            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 401            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 402            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 404            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 405            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 406            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 407            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 408            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 409            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 410            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 412            .add_message_handler(update_context)
 413            .add_request_handler({
 414                let app_state = app_state.clone();
 415                move |request, response, session| {
 416                    let app_state = app_state.clone();
 417                    async move {
 418                        count_language_model_tokens(request, response, session, &app_state.config)
 419                            .await
 420                    }
 421                }
 422            })
 423            .add_request_handler(get_cached_embeddings)
 424            .add_request_handler({
 425                let app_state = app_state.clone();
 426                move |request, response, session| {
 427                    compute_embeddings(
 428                        request,
 429                        response,
 430                        session,
 431                        app_state.config.openai_api_key.clone(),
 432                    )
 433                }
 434            });
 435
 436        Arc::new(server)
 437    }
 438
 439    pub async fn start(&self) -> Result<()> {
 440        let server_id = *self.id.lock();
 441        let app_state = self.app_state.clone();
 442        let peer = self.peer.clone();
 443        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 444        let pool = self.connection_pool.clone();
 445        let livekit_client = self.app_state.livekit_client.clone();
 446
 447        let span = info_span!("start server");
 448        self.app_state.executor.spawn_detached(
 449            async move {
 450                tracing::info!("waiting for cleanup timeout");
 451                timeout.await;
 452                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 453                if let Some((room_ids, channel_ids)) = app_state
 454                    .db
 455                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 456                    .await
 457                    .trace_err()
 458                {
 459                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 460                    tracing::info!(
 461                        stale_channel_buffer_count = channel_ids.len(),
 462                        "retrieved stale channel buffers"
 463                    );
 464
 465                    for channel_id in channel_ids {
 466                        if let Some(refreshed_channel_buffer) = app_state
 467                            .db
 468                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 469                            .await
 470                            .trace_err()
 471                        {
 472                            for connection_id in refreshed_channel_buffer.connection_ids {
 473                                peer.send(
 474                                    connection_id,
 475                                    proto::UpdateChannelBufferCollaborators {
 476                                        channel_id: channel_id.to_proto(),
 477                                        collaborators: refreshed_channel_buffer
 478                                            .collaborators
 479                                            .clone(),
 480                                    },
 481                                )
 482                                .trace_err();
 483                            }
 484                        }
 485                    }
 486
 487                    for room_id in room_ids {
 488                        let mut contacts_to_update = HashSet::default();
 489                        let mut canceled_calls_to_user_ids = Vec::new();
 490                        let mut livekit_room = String::new();
 491                        let mut delete_livekit_room = false;
 492
 493                        if let Some(mut refreshed_room) = app_state
 494                            .db
 495                            .clear_stale_room_participants(room_id, server_id)
 496                            .await
 497                            .trace_err()
 498                        {
 499                            tracing::info!(
 500                                room_id = room_id.0,
 501                                new_participant_count = refreshed_room.room.participants.len(),
 502                                "refreshed room"
 503                            );
 504                            room_updated(&refreshed_room.room, &peer);
 505                            if let Some(channel) = refreshed_room.channel.as_ref() {
 506                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 507                            }
 508                            contacts_to_update
 509                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 510                            contacts_to_update
 511                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 512                            canceled_calls_to_user_ids =
 513                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 514                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 515                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 516                        }
 517
 518                        {
 519                            let pool = pool.lock();
 520                            for canceled_user_id in canceled_calls_to_user_ids {
 521                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 522                                    peer.send(
 523                                        connection_id,
 524                                        proto::CallCanceled {
 525                                            room_id: room_id.to_proto(),
 526                                        },
 527                                    )
 528                                    .trace_err();
 529                                }
 530                            }
 531                        }
 532
 533                        for user_id in contacts_to_update {
 534                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 535                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 536                            if let Some((busy, contacts)) = busy.zip(contacts) {
 537                                let pool = pool.lock();
 538                                let updated_contact = contact_for_user(user_id, busy, &pool);
 539                                for contact in contacts {
 540                                    if let db::Contact::Accepted {
 541                                        user_id: contact_user_id,
 542                                        ..
 543                                    } = contact
 544                                    {
 545                                        for contact_conn_id in
 546                                            pool.user_connection_ids(contact_user_id)
 547                                        {
 548                                            peer.send(
 549                                                contact_conn_id,
 550                                                proto::UpdateContacts {
 551                                                    contacts: vec![updated_contact.clone()],
 552                                                    remove_contacts: Default::default(),
 553                                                    incoming_requests: Default::default(),
 554                                                    remove_incoming_requests: Default::default(),
 555                                                    outgoing_requests: Default::default(),
 556                                                    remove_outgoing_requests: Default::default(),
 557                                                },
 558                                            )
 559                                            .trace_err();
 560                                        }
 561                                    }
 562                                }
 563                            }
 564                        }
 565
 566                        if let Some(live_kit) = livekit_client.as_ref() {
 567                            if delete_livekit_room {
 568                                live_kit.delete_room(livekit_room).await.trace_err();
 569                            }
 570                        }
 571                    }
 572                }
 573
 574                app_state
 575                    .db
 576                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 577                    .await
 578                    .trace_err();
 579            }
 580            .instrument(span),
 581        );
 582        Ok(())
 583    }
 584
 585    pub fn teardown(&self) {
 586        self.peer.teardown();
 587        self.connection_pool.lock().reset();
 588        let _ = self.teardown.send(true);
 589    }
 590
 591    #[cfg(test)]
 592    pub fn reset(&self, id: ServerId) {
 593        self.teardown();
 594        *self.id.lock() = id;
 595        self.peer.reset(id.0 as u32);
 596        let _ = self.teardown.send(false);
 597    }
 598
 599    #[cfg(test)]
 600    pub fn id(&self) -> ServerId {
 601        *self.id.lock()
 602    }
 603
 604    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 605    where
 606        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 607        Fut: 'static + Send + Future<Output = Result<()>>,
 608        M: EnvelopedMessage,
 609    {
 610        let prev_handler = self.handlers.insert(
 611            TypeId::of::<M>(),
 612            Box::new(move |envelope, session| {
 613                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 614                let received_at = envelope.received_at;
 615                tracing::info!("message received");
 616                let start_time = Instant::now();
 617                let future = (handler)(*envelope, session);
 618                async move {
 619                    let result = future.await;
 620                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 621                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 622                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 623                    let payload_type = M::NAME;
 624
 625                    match result {
 626                        Err(error) => {
 627                            tracing::error!(
 628                                ?error,
 629                                total_duration_ms,
 630                                processing_duration_ms,
 631                                queue_duration_ms,
 632                                payload_type,
 633                                "error handling message"
 634                            )
 635                        }
 636                        Ok(()) => tracing::info!(
 637                            total_duration_ms,
 638                            processing_duration_ms,
 639                            queue_duration_ms,
 640                            "finished handling message"
 641                        ),
 642                    }
 643                }
 644                .boxed()
 645            }),
 646        );
 647        if prev_handler.is_some() {
 648            panic!("registered a handler for the same message twice");
 649        }
 650        self
 651    }
 652
 653    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 654    where
 655        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 656        Fut: 'static + Send + Future<Output = Result<()>>,
 657        M: EnvelopedMessage,
 658    {
 659        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 660        self
 661    }
 662
 663    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 664    where
 665        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 666        Fut: Send + Future<Output = Result<()>>,
 667        M: RequestMessage,
 668    {
 669        let handler = Arc::new(handler);
 670        self.add_handler(move |envelope, session| {
 671            let receipt = envelope.receipt();
 672            let handler = handler.clone();
 673            async move {
 674                let peer = session.peer.clone();
 675                let responded = Arc::new(AtomicBool::default());
 676                let response = Response {
 677                    peer: peer.clone(),
 678                    responded: responded.clone(),
 679                    receipt,
 680                };
 681                match (handler)(envelope.payload, response, session).await {
 682                    Ok(()) => {
 683                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 684                            Ok(())
 685                        } else {
 686                            Err(anyhow!("handler did not send a response"))?
 687                        }
 688                    }
 689                    Err(error) => {
 690                        let proto_err = match &error {
 691                            Error::Internal(err) => err.to_proto(),
 692                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 693                        };
 694                        peer.respond_with_error(receipt, proto_err)?;
 695                        Err(error)
 696                    }
 697                }
 698            }
 699        })
 700    }
 701
 702    #[allow(clippy::too_many_arguments)]
 703    pub fn handle_connection(
 704        self: &Arc<Self>,
 705        connection: Connection,
 706        address: String,
 707        principal: Principal,
 708        zed_version: ZedVersion,
 709        geoip_country_code: Option<String>,
 710        system_id: Option<String>,
 711        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 712        executor: Executor,
 713    ) -> impl Future<Output = ()> {
 714        let this = self.clone();
 715        let span = info_span!("handle connection", %address,
 716            connection_id=field::Empty,
 717            user_id=field::Empty,
 718            login=field::Empty,
 719            impersonator=field::Empty,
 720            geoip_country_code=field::Empty
 721        );
 722        principal.update_span(&span);
 723        if let Some(country_code) = geoip_country_code.as_ref() {
 724            span.record("geoip_country_code", country_code);
 725        }
 726
 727        let mut teardown = self.teardown.subscribe();
 728        async move {
 729            if *teardown.borrow() {
 730                tracing::error!("server is tearing down");
 731                return
 732            }
 733            let (connection_id, handle_io, mut incoming_rx) = this
 734                .peer
 735                .add_connection(connection, {
 736                    let executor = executor.clone();
 737                    move |duration| executor.sleep(duration)
 738                });
 739            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 740
 741            tracing::info!("connection opened");
 742
 743            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 744            let http_client = match ReqwestClient::user_agent(&user_agent) {
 745                Ok(http_client) => Arc::new(http_client),
 746                Err(error) => {
 747                    tracing::error!(?error, "failed to create HTTP client");
 748                    return;
 749                }
 750            };
 751
 752            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 753                    supermaven_admin_api_key.to_string(),
 754                    http_client.clone(),
 755                )));
 756
 757            let session = Session {
 758                principal: principal.clone(),
 759                connection_id,
 760                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 761                peer: this.peer.clone(),
 762                connection_pool: this.connection_pool.clone(),
 763                app_state: this.app_state.clone(),
 764                http_client,
 765                geoip_country_code,
 766                system_id,
 767                _executor: executor.clone(),
 768                supermaven_client,
 769            };
 770
 771            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 772                tracing::error!(?error, "failed to send initial client update");
 773                return;
 774            }
 775
 776            let handle_io = handle_io.fuse();
 777            futures::pin_mut!(handle_io);
 778
 779            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 780            // This prevents deadlocks when e.g., client A performs a request to client B and
 781            // client B performs a request to client A. If both clients stop processing further
 782            // messages until their respective request completes, they won't have a chance to
 783            // respond to the other client's request and cause a deadlock.
 784            //
 785            // This arrangement ensures we will attempt to process earlier messages first, but fall
 786            // back to processing messages arrived later in the spirit of making progress.
 787            let mut foreground_message_handlers = FuturesUnordered::new();
 788            let concurrent_handlers = Arc::new(Semaphore::new(256));
 789            loop {
 790                let next_message = async {
 791                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 792                    let message = incoming_rx.next().await;
 793                    (permit, message)
 794                }.fuse();
 795                futures::pin_mut!(next_message);
 796                futures::select_biased! {
 797                    _ = teardown.changed().fuse() => return,
 798                    result = handle_io => {
 799                        if let Err(error) = result {
 800                            tracing::error!(?error, "error handling I/O");
 801                        }
 802                        break;
 803                    }
 804                    _ = foreground_message_handlers.next() => {}
 805                    next_message = next_message => {
 806                        let (permit, message) = next_message;
 807                        if let Some(message) = message {
 808                            let type_name = message.payload_type_name();
 809                            // note: we copy all the fields from the parent span so we can query them in the logs.
 810                            // (https://github.com/tokio-rs/tracing/issues/2670).
 811                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 812                                user_id=field::Empty,
 813                                login=field::Empty,
 814                                impersonator=field::Empty,
 815                            );
 816                            principal.update_span(&span);
 817                            let span_enter = span.enter();
 818                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 819                                let is_background = message.is_background();
 820                                let handle_message = (handler)(message, session.clone());
 821                                drop(span_enter);
 822
 823                                let handle_message = async move {
 824                                    handle_message.await;
 825                                    drop(permit);
 826                                }.instrument(span);
 827                                if is_background {
 828                                    executor.spawn_detached(handle_message);
 829                                } else {
 830                                    foreground_message_handlers.push(handle_message);
 831                                }
 832                            } else {
 833                                tracing::error!("no message handler");
 834                            }
 835                        } else {
 836                            tracing::info!("connection closed");
 837                            break;
 838                        }
 839                    }
 840                }
 841            }
 842
 843            drop(foreground_message_handlers);
 844            tracing::info!("signing out");
 845            if let Err(error) = connection_lost(session, teardown, executor).await {
 846                tracing::error!(?error, "error signing out");
 847            }
 848
 849        }.instrument(span)
 850    }
 851
 852    async fn send_initial_client_update(
 853        &self,
 854        connection_id: ConnectionId,
 855        principal: &Principal,
 856        zed_version: ZedVersion,
 857        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 858        session: &Session,
 859    ) -> Result<()> {
 860        self.peer.send(
 861            connection_id,
 862            proto::Hello {
 863                peer_id: Some(connection_id.into()),
 864            },
 865        )?;
 866        tracing::info!("sent hello message");
 867        if let Some(send_connection_id) = send_connection_id.take() {
 868            let _ = send_connection_id.send(connection_id);
 869        }
 870
 871        match principal {
 872            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 873                if !user.connected_once {
 874                    self.peer.send(connection_id, proto::ShowContacts {})?;
 875                    self.app_state
 876                        .db
 877                        .set_user_connected_once(user.id, true)
 878                        .await?;
 879                }
 880
 881                update_user_plan(user.id, session).await?;
 882
 883                let contacts = self.app_state.db.get_contacts(user.id).await?;
 884
 885                {
 886                    let mut pool = self.connection_pool.lock();
 887                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 888                    self.peer.send(
 889                        connection_id,
 890                        build_initial_contacts_update(contacts, &pool),
 891                    )?;
 892                }
 893
 894                if should_auto_subscribe_to_channels(zed_version) {
 895                    subscribe_user_to_channels(user.id, session).await?;
 896                }
 897
 898                if let Some(incoming_call) =
 899                    self.app_state.db.incoming_call_for_user(user.id).await?
 900                {
 901                    self.peer.send(connection_id, incoming_call)?;
 902                }
 903
 904                update_user_contacts(user.id, session).await?;
 905            }
 906        }
 907
 908        Ok(())
 909    }
 910
 911    pub async fn invite_code_redeemed(
 912        self: &Arc<Self>,
 913        inviter_id: UserId,
 914        invitee_id: UserId,
 915    ) -> Result<()> {
 916        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 917            if let Some(code) = &user.invite_code {
 918                let pool = self.connection_pool.lock();
 919                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 920                for connection_id in pool.user_connection_ids(inviter_id) {
 921                    self.peer.send(
 922                        connection_id,
 923                        proto::UpdateContacts {
 924                            contacts: vec![invitee_contact.clone()],
 925                            ..Default::default()
 926                        },
 927                    )?;
 928                    self.peer.send(
 929                        connection_id,
 930                        proto::UpdateInviteInfo {
 931                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 932                            count: user.invite_count as u32,
 933                        },
 934                    )?;
 935                }
 936            }
 937        }
 938        Ok(())
 939    }
 940
 941    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 942        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 943            if let Some(invite_code) = &user.invite_code {
 944                let pool = self.connection_pool.lock();
 945                for connection_id in pool.user_connection_ids(user_id) {
 946                    self.peer.send(
 947                        connection_id,
 948                        proto::UpdateInviteInfo {
 949                            url: format!(
 950                                "{}{}",
 951                                self.app_state.config.invite_link_prefix, invite_code
 952                            ),
 953                            count: user.invite_count as u32,
 954                        },
 955                    )?;
 956                }
 957            }
 958        }
 959        Ok(())
 960    }
 961
 962    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 963        let pool = self.connection_pool.lock();
 964        for connection_id in pool.user_connection_ids(user_id) {
 965            self.peer
 966                .send(connection_id, proto::RefreshLlmToken {})
 967                .trace_err();
 968        }
 969    }
 970
 971    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 972        ServerSnapshot {
 973            connection_pool: ConnectionPoolGuard {
 974                guard: self.connection_pool.lock(),
 975                _not_send: PhantomData,
 976            },
 977            peer: &self.peer,
 978        }
 979    }
 980}
 981
 982impl Deref for ConnectionPoolGuard<'_> {
 983    type Target = ConnectionPool;
 984
 985    fn deref(&self) -> &Self::Target {
 986        &self.guard
 987    }
 988}
 989
 990impl DerefMut for ConnectionPoolGuard<'_> {
 991    fn deref_mut(&mut self) -> &mut Self::Target {
 992        &mut self.guard
 993    }
 994}
 995
 996impl Drop for ConnectionPoolGuard<'_> {
 997    fn drop(&mut self) {
 998        #[cfg(test)]
 999        self.check_invariants();
1000    }
1001}
1002
1003fn broadcast<F>(
1004    sender_id: Option<ConnectionId>,
1005    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1006    mut f: F,
1007) where
1008    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1009{
1010    for receiver_id in receiver_ids {
1011        if Some(receiver_id) != sender_id {
1012            if let Err(error) = f(receiver_id) {
1013                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1014            }
1015        }
1016    }
1017}
1018
1019pub struct ProtocolVersion(u32);
1020
1021impl Header for ProtocolVersion {
1022    fn name() -> &'static HeaderName {
1023        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1024        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1025    }
1026
1027    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1028    where
1029        Self: Sized,
1030        I: Iterator<Item = &'i axum::http::HeaderValue>,
1031    {
1032        let version = values
1033            .next()
1034            .ok_or_else(axum::headers::Error::invalid)?
1035            .to_str()
1036            .map_err(|_| axum::headers::Error::invalid())?
1037            .parse()
1038            .map_err(|_| axum::headers::Error::invalid())?;
1039        Ok(Self(version))
1040    }
1041
1042    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1043        values.extend([self.0.to_string().parse().unwrap()]);
1044    }
1045}
1046
1047pub struct AppVersionHeader(SemanticVersion);
1048impl Header for AppVersionHeader {
1049    fn name() -> &'static HeaderName {
1050        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1051        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1052    }
1053
1054    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1055    where
1056        Self: Sized,
1057        I: Iterator<Item = &'i axum::http::HeaderValue>,
1058    {
1059        let version = values
1060            .next()
1061            .ok_or_else(axum::headers::Error::invalid)?
1062            .to_str()
1063            .map_err(|_| axum::headers::Error::invalid())?
1064            .parse()
1065            .map_err(|_| axum::headers::Error::invalid())?;
1066        Ok(Self(version))
1067    }
1068
1069    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1070        values.extend([self.0.to_string().parse().unwrap()]);
1071    }
1072}
1073
1074pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1075    Router::new()
1076        .route("/rpc", get(handle_websocket_request))
1077        .layer(
1078            ServiceBuilder::new()
1079                .layer(Extension(server.app_state.clone()))
1080                .layer(middleware::from_fn(auth::validate_header)),
1081        )
1082        .route("/metrics", get(handle_metrics))
1083        .layer(Extension(server))
1084}
1085
1086#[allow(clippy::too_many_arguments)]
1087pub async fn handle_websocket_request(
1088    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1089    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1090    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1091    Extension(server): Extension<Arc<Server>>,
1092    Extension(principal): Extension<Principal>,
1093    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1094    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1095    ws: WebSocketUpgrade,
1096) -> axum::response::Response {
1097    if protocol_version != rpc::PROTOCOL_VERSION {
1098        return (
1099            StatusCode::UPGRADE_REQUIRED,
1100            "client must be upgraded".to_string(),
1101        )
1102            .into_response();
1103    }
1104
1105    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1106        return (
1107            StatusCode::UPGRADE_REQUIRED,
1108            "no version header found".to_string(),
1109        )
1110            .into_response();
1111    };
1112
1113    if !version.can_collaborate() {
1114        return (
1115            StatusCode::UPGRADE_REQUIRED,
1116            "client must be upgraded".to_string(),
1117        )
1118            .into_response();
1119    }
1120
1121    let socket_address = socket_address.to_string();
1122    ws.on_upgrade(move |socket| {
1123        let socket = socket
1124            .map_ok(to_tungstenite_message)
1125            .err_into()
1126            .with(|message| async move { to_axum_message(message) });
1127        let connection = Connection::new(Box::pin(socket));
1128        async move {
1129            server
1130                .handle_connection(
1131                    connection,
1132                    socket_address,
1133                    principal,
1134                    version,
1135                    country_code_header.map(|header| header.to_string()),
1136                    system_id_header.map(|header| header.to_string()),
1137                    None,
1138                    Executor::Production,
1139                )
1140                .await;
1141        }
1142    })
1143}
1144
1145pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1146    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1147    let connections_metric = CONNECTIONS_METRIC
1148        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1149
1150    let connections = server
1151        .connection_pool
1152        .lock()
1153        .connections()
1154        .filter(|connection| !connection.admin)
1155        .count();
1156    connections_metric.set(connections as _);
1157
1158    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1159    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1160        register_int_gauge!(
1161            "shared_projects",
1162            "number of open projects with one or more guests"
1163        )
1164        .unwrap()
1165    });
1166
1167    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1168    shared_projects_metric.set(shared_projects as _);
1169
1170    let encoder = prometheus::TextEncoder::new();
1171    let metric_families = prometheus::gather();
1172    let encoded_metrics = encoder
1173        .encode_to_string(&metric_families)
1174        .map_err(|err| anyhow!("{}", err))?;
1175    Ok(encoded_metrics)
1176}
1177
1178#[instrument(err, skip(executor))]
1179async fn connection_lost(
1180    session: Session,
1181    mut teardown: watch::Receiver<bool>,
1182    executor: Executor,
1183) -> Result<()> {
1184    session.peer.disconnect(session.connection_id);
1185    session
1186        .connection_pool()
1187        .await
1188        .remove_connection(session.connection_id)?;
1189
1190    session
1191        .db()
1192        .await
1193        .connection_lost(session.connection_id)
1194        .await
1195        .trace_err();
1196
1197    futures::select_biased! {
1198        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1199
1200            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1201            leave_room_for_session(&session, session.connection_id).await.trace_err();
1202            leave_channel_buffers_for_session(&session)
1203                .await
1204                .trace_err();
1205
1206            if !session
1207                .connection_pool()
1208                .await
1209                .is_user_online(session.user_id())
1210            {
1211                let db = session.db().await;
1212                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1213                    room_updated(&room, &session.peer);
1214                }
1215            }
1216
1217            update_user_contacts(session.user_id(), &session).await?;
1218        },
1219        _ = teardown.changed().fuse() => {}
1220    }
1221
1222    Ok(())
1223}
1224
1225/// Acknowledges a ping from a client, used to keep the connection alive.
1226async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1227    response.send(proto::Ack {})?;
1228    Ok(())
1229}
1230
1231/// Creates a new room for calling (outside of channels)
1232async fn create_room(
1233    _request: proto::CreateRoom,
1234    response: Response<proto::CreateRoom>,
1235    session: Session,
1236) -> Result<()> {
1237    let livekit_room = nanoid::nanoid!(30);
1238
1239    let live_kit_connection_info = util::maybe!(async {
1240        let live_kit = session.app_state.livekit_client.as_ref();
1241        let live_kit = live_kit?;
1242        let user_id = session.user_id().to_string();
1243
1244        let token = live_kit
1245            .room_token(&livekit_room, &user_id.to_string())
1246            .trace_err()?;
1247
1248        Some(proto::LiveKitConnectionInfo {
1249            server_url: live_kit.url().into(),
1250            token,
1251            can_publish: true,
1252        })
1253    })
1254    .await;
1255
1256    let room = session
1257        .db()
1258        .await
1259        .create_room(session.user_id(), session.connection_id, &livekit_room)
1260        .await?;
1261
1262    response.send(proto::CreateRoomResponse {
1263        room: Some(room.clone()),
1264        live_kit_connection_info,
1265    })?;
1266
1267    update_user_contacts(session.user_id(), &session).await?;
1268    Ok(())
1269}
1270
1271/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1272async fn join_room(
1273    request: proto::JoinRoom,
1274    response: Response<proto::JoinRoom>,
1275    session: Session,
1276) -> Result<()> {
1277    let room_id = RoomId::from_proto(request.id);
1278
1279    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1280
1281    if let Some(channel_id) = channel_id {
1282        return join_channel_internal(channel_id, Box::new(response), session).await;
1283    }
1284
1285    let joined_room = {
1286        let room = session
1287            .db()
1288            .await
1289            .join_room(room_id, session.user_id(), session.connection_id)
1290            .await?;
1291        room_updated(&room.room, &session.peer);
1292        room.into_inner()
1293    };
1294
1295    for connection_id in session
1296        .connection_pool()
1297        .await
1298        .user_connection_ids(session.user_id())
1299    {
1300        session
1301            .peer
1302            .send(
1303                connection_id,
1304                proto::CallCanceled {
1305                    room_id: room_id.to_proto(),
1306                },
1307            )
1308            .trace_err();
1309    }
1310
1311    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1312    {
1313        live_kit
1314            .room_token(
1315                &joined_room.room.livekit_room,
1316                &session.user_id().to_string(),
1317            )
1318            .trace_err()
1319            .map(|token| proto::LiveKitConnectionInfo {
1320                server_url: live_kit.url().into(),
1321                token,
1322                can_publish: true,
1323            })
1324    } else {
1325        None
1326    };
1327
1328    response.send(proto::JoinRoomResponse {
1329        room: Some(joined_room.room),
1330        channel_id: None,
1331        live_kit_connection_info,
1332    })?;
1333
1334    update_user_contacts(session.user_id(), &session).await?;
1335    Ok(())
1336}
1337
1338/// Rejoin room is used to reconnect to a room after connection errors.
1339async fn rejoin_room(
1340    request: proto::RejoinRoom,
1341    response: Response<proto::RejoinRoom>,
1342    session: Session,
1343) -> Result<()> {
1344    let room;
1345    let channel;
1346    {
1347        let mut rejoined_room = session
1348            .db()
1349            .await
1350            .rejoin_room(request, session.user_id(), session.connection_id)
1351            .await?;
1352
1353        response.send(proto::RejoinRoomResponse {
1354            room: Some(rejoined_room.room.clone()),
1355            reshared_projects: rejoined_room
1356                .reshared_projects
1357                .iter()
1358                .map(|project| proto::ResharedProject {
1359                    id: project.id.to_proto(),
1360                    collaborators: project
1361                        .collaborators
1362                        .iter()
1363                        .map(|collaborator| collaborator.to_proto())
1364                        .collect(),
1365                })
1366                .collect(),
1367            rejoined_projects: rejoined_room
1368                .rejoined_projects
1369                .iter()
1370                .map(|rejoined_project| rejoined_project.to_proto())
1371                .collect(),
1372        })?;
1373        room_updated(&rejoined_room.room, &session.peer);
1374
1375        for project in &rejoined_room.reshared_projects {
1376            for collaborator in &project.collaborators {
1377                session
1378                    .peer
1379                    .send(
1380                        collaborator.connection_id,
1381                        proto::UpdateProjectCollaborator {
1382                            project_id: project.id.to_proto(),
1383                            old_peer_id: Some(project.old_connection_id.into()),
1384                            new_peer_id: Some(session.connection_id.into()),
1385                        },
1386                    )
1387                    .trace_err();
1388            }
1389
1390            broadcast(
1391                Some(session.connection_id),
1392                project
1393                    .collaborators
1394                    .iter()
1395                    .map(|collaborator| collaborator.connection_id),
1396                |connection_id| {
1397                    session.peer.forward_send(
1398                        session.connection_id,
1399                        connection_id,
1400                        proto::UpdateProject {
1401                            project_id: project.id.to_proto(),
1402                            worktrees: project.worktrees.clone(),
1403                        },
1404                    )
1405                },
1406            );
1407        }
1408
1409        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1410
1411        let rejoined_room = rejoined_room.into_inner();
1412
1413        room = rejoined_room.room;
1414        channel = rejoined_room.channel;
1415    }
1416
1417    if let Some(channel) = channel {
1418        channel_updated(
1419            &channel,
1420            &room,
1421            &session.peer,
1422            &*session.connection_pool().await,
1423        );
1424    }
1425
1426    update_user_contacts(session.user_id(), &session).await?;
1427    Ok(())
1428}
1429
1430fn notify_rejoined_projects(
1431    rejoined_projects: &mut Vec<RejoinedProject>,
1432    session: &Session,
1433) -> Result<()> {
1434    for project in rejoined_projects.iter() {
1435        for collaborator in &project.collaborators {
1436            session
1437                .peer
1438                .send(
1439                    collaborator.connection_id,
1440                    proto::UpdateProjectCollaborator {
1441                        project_id: project.id.to_proto(),
1442                        old_peer_id: Some(project.old_connection_id.into()),
1443                        new_peer_id: Some(session.connection_id.into()),
1444                    },
1445                )
1446                .trace_err();
1447        }
1448    }
1449
1450    for project in rejoined_projects {
1451        for worktree in mem::take(&mut project.worktrees) {
1452            // Stream this worktree's entries.
1453            let message = proto::UpdateWorktree {
1454                project_id: project.id.to_proto(),
1455                worktree_id: worktree.id,
1456                abs_path: worktree.abs_path.clone(),
1457                root_name: worktree.root_name,
1458                updated_entries: worktree.updated_entries,
1459                removed_entries: worktree.removed_entries,
1460                scan_id: worktree.scan_id,
1461                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1462                updated_repositories: worktree.updated_repositories,
1463                removed_repositories: worktree.removed_repositories,
1464            };
1465            for update in proto::split_worktree_update(message) {
1466                session.peer.send(session.connection_id, update.clone())?;
1467            }
1468
1469            // Stream this worktree's diagnostics.
1470            for summary in worktree.diagnostic_summaries {
1471                session.peer.send(
1472                    session.connection_id,
1473                    proto::UpdateDiagnosticSummary {
1474                        project_id: project.id.to_proto(),
1475                        worktree_id: worktree.id,
1476                        summary: Some(summary),
1477                    },
1478                )?;
1479            }
1480
1481            for settings_file in worktree.settings_files {
1482                session.peer.send(
1483                    session.connection_id,
1484                    proto::UpdateWorktreeSettings {
1485                        project_id: project.id.to_proto(),
1486                        worktree_id: worktree.id,
1487                        path: settings_file.path,
1488                        content: Some(settings_file.content),
1489                        kind: Some(settings_file.kind.to_proto().into()),
1490                    },
1491                )?;
1492            }
1493        }
1494
1495        for language_server in &project.language_servers {
1496            session.peer.send(
1497                session.connection_id,
1498                proto::UpdateLanguageServer {
1499                    project_id: project.id.to_proto(),
1500                    language_server_id: language_server.id,
1501                    variant: Some(
1502                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1503                            proto::LspDiskBasedDiagnosticsUpdated {},
1504                        ),
1505                    ),
1506                },
1507            )?;
1508        }
1509    }
1510    Ok(())
1511}
1512
1513/// leave room disconnects from the room.
1514async fn leave_room(
1515    _: proto::LeaveRoom,
1516    response: Response<proto::LeaveRoom>,
1517    session: Session,
1518) -> Result<()> {
1519    leave_room_for_session(&session, session.connection_id).await?;
1520    response.send(proto::Ack {})?;
1521    Ok(())
1522}
1523
1524/// Updates the permissions of someone else in the room.
1525async fn set_room_participant_role(
1526    request: proto::SetRoomParticipantRole,
1527    response: Response<proto::SetRoomParticipantRole>,
1528    session: Session,
1529) -> Result<()> {
1530    let user_id = UserId::from_proto(request.user_id);
1531    let role = ChannelRole::from(request.role());
1532
1533    let (livekit_room, can_publish) = {
1534        let room = session
1535            .db()
1536            .await
1537            .set_room_participant_role(
1538                session.user_id(),
1539                RoomId::from_proto(request.room_id),
1540                user_id,
1541                role,
1542            )
1543            .await?;
1544
1545        let livekit_room = room.livekit_room.clone();
1546        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1547        room_updated(&room, &session.peer);
1548        (livekit_room, can_publish)
1549    };
1550
1551    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1552        live_kit
1553            .update_participant(
1554                livekit_room.clone(),
1555                request.user_id.to_string(),
1556                livekit_api::proto::ParticipantPermission {
1557                    can_subscribe: true,
1558                    can_publish,
1559                    can_publish_data: can_publish,
1560                    hidden: false,
1561                    recorder: false,
1562                },
1563            )
1564            .await
1565            .trace_err();
1566    }
1567
1568    response.send(proto::Ack {})?;
1569    Ok(())
1570}
1571
1572/// Call someone else into the current room
1573async fn call(
1574    request: proto::Call,
1575    response: Response<proto::Call>,
1576    session: Session,
1577) -> Result<()> {
1578    let room_id = RoomId::from_proto(request.room_id);
1579    let calling_user_id = session.user_id();
1580    let calling_connection_id = session.connection_id;
1581    let called_user_id = UserId::from_proto(request.called_user_id);
1582    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1583    if !session
1584        .db()
1585        .await
1586        .has_contact(calling_user_id, called_user_id)
1587        .await?
1588    {
1589        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1590    }
1591
1592    let incoming_call = {
1593        let (room, incoming_call) = &mut *session
1594            .db()
1595            .await
1596            .call(
1597                room_id,
1598                calling_user_id,
1599                calling_connection_id,
1600                called_user_id,
1601                initial_project_id,
1602            )
1603            .await?;
1604        room_updated(room, &session.peer);
1605        mem::take(incoming_call)
1606    };
1607    update_user_contacts(called_user_id, &session).await?;
1608
1609    let mut calls = session
1610        .connection_pool()
1611        .await
1612        .user_connection_ids(called_user_id)
1613        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1614        .collect::<FuturesUnordered<_>>();
1615
1616    while let Some(call_response) = calls.next().await {
1617        match call_response.as_ref() {
1618            Ok(_) => {
1619                response.send(proto::Ack {})?;
1620                return Ok(());
1621            }
1622            Err(_) => {
1623                call_response.trace_err();
1624            }
1625        }
1626    }
1627
1628    {
1629        let room = session
1630            .db()
1631            .await
1632            .call_failed(room_id, called_user_id)
1633            .await?;
1634        room_updated(&room, &session.peer);
1635    }
1636    update_user_contacts(called_user_id, &session).await?;
1637
1638    Err(anyhow!("failed to ring user"))?
1639}
1640
1641/// Cancel an outgoing call.
1642async fn cancel_call(
1643    request: proto::CancelCall,
1644    response: Response<proto::CancelCall>,
1645    session: Session,
1646) -> Result<()> {
1647    let called_user_id = UserId::from_proto(request.called_user_id);
1648    let room_id = RoomId::from_proto(request.room_id);
1649    {
1650        let room = session
1651            .db()
1652            .await
1653            .cancel_call(room_id, session.connection_id, called_user_id)
1654            .await?;
1655        room_updated(&room, &session.peer);
1656    }
1657
1658    for connection_id in session
1659        .connection_pool()
1660        .await
1661        .user_connection_ids(called_user_id)
1662    {
1663        session
1664            .peer
1665            .send(
1666                connection_id,
1667                proto::CallCanceled {
1668                    room_id: room_id.to_proto(),
1669                },
1670            )
1671            .trace_err();
1672    }
1673    response.send(proto::Ack {})?;
1674
1675    update_user_contacts(called_user_id, &session).await?;
1676    Ok(())
1677}
1678
1679/// Decline an incoming call.
1680async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1681    let room_id = RoomId::from_proto(message.room_id);
1682    {
1683        let room = session
1684            .db()
1685            .await
1686            .decline_call(Some(room_id), session.user_id())
1687            .await?
1688            .ok_or_else(|| anyhow!("failed to decline call"))?;
1689        room_updated(&room, &session.peer);
1690    }
1691
1692    for connection_id in session
1693        .connection_pool()
1694        .await
1695        .user_connection_ids(session.user_id())
1696    {
1697        session
1698            .peer
1699            .send(
1700                connection_id,
1701                proto::CallCanceled {
1702                    room_id: room_id.to_proto(),
1703                },
1704            )
1705            .trace_err();
1706    }
1707    update_user_contacts(session.user_id(), &session).await?;
1708    Ok(())
1709}
1710
1711/// Updates other participants in the room with your current location.
1712async fn update_participant_location(
1713    request: proto::UpdateParticipantLocation,
1714    response: Response<proto::UpdateParticipantLocation>,
1715    session: Session,
1716) -> Result<()> {
1717    let room_id = RoomId::from_proto(request.room_id);
1718    let location = request
1719        .location
1720        .ok_or_else(|| anyhow!("invalid location"))?;
1721
1722    let db = session.db().await;
1723    let room = db
1724        .update_room_participant_location(room_id, session.connection_id, location)
1725        .await?;
1726
1727    room_updated(&room, &session.peer);
1728    response.send(proto::Ack {})?;
1729    Ok(())
1730}
1731
1732/// Share a project into the room.
1733async fn share_project(
1734    request: proto::ShareProject,
1735    response: Response<proto::ShareProject>,
1736    session: Session,
1737) -> Result<()> {
1738    let (project_id, room) = &*session
1739        .db()
1740        .await
1741        .share_project(
1742            RoomId::from_proto(request.room_id),
1743            session.connection_id,
1744            &request.worktrees,
1745            request.is_ssh_project,
1746        )
1747        .await?;
1748    response.send(proto::ShareProjectResponse {
1749        project_id: project_id.to_proto(),
1750    })?;
1751    room_updated(room, &session.peer);
1752
1753    Ok(())
1754}
1755
1756/// Unshare a project from the room.
1757async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1758    let project_id = ProjectId::from_proto(message.project_id);
1759    unshare_project_internal(project_id, session.connection_id, &session).await
1760}
1761
1762async fn unshare_project_internal(
1763    project_id: ProjectId,
1764    connection_id: ConnectionId,
1765    session: &Session,
1766) -> Result<()> {
1767    let delete = {
1768        let room_guard = session
1769            .db()
1770            .await
1771            .unshare_project(project_id, connection_id)
1772            .await?;
1773
1774        let (delete, room, guest_connection_ids) = &*room_guard;
1775
1776        let message = proto::UnshareProject {
1777            project_id: project_id.to_proto(),
1778        };
1779
1780        broadcast(
1781            Some(connection_id),
1782            guest_connection_ids.iter().copied(),
1783            |conn_id| session.peer.send(conn_id, message.clone()),
1784        );
1785        if let Some(room) = room {
1786            room_updated(room, &session.peer);
1787        }
1788
1789        *delete
1790    };
1791
1792    if delete {
1793        let db = session.db().await;
1794        db.delete_project(project_id).await?;
1795    }
1796
1797    Ok(())
1798}
1799
1800/// Join someone elses shared project.
1801async fn join_project(
1802    request: proto::JoinProject,
1803    response: Response<proto::JoinProject>,
1804    session: Session,
1805) -> Result<()> {
1806    let project_id = ProjectId::from_proto(request.project_id);
1807
1808    tracing::info!(%project_id, "join project");
1809
1810    let db = session.db().await;
1811    let (project, replica_id) = &mut *db
1812        .join_project(project_id, session.connection_id, session.user_id())
1813        .await?;
1814    drop(db);
1815    tracing::info!(%project_id, "join remote project");
1816    join_project_internal(response, session, project, replica_id)
1817}
1818
1819trait JoinProjectInternalResponse {
1820    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1821}
1822impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1823    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1824        Response::<proto::JoinProject>::send(self, result)
1825    }
1826}
1827
1828fn join_project_internal(
1829    response: impl JoinProjectInternalResponse,
1830    session: Session,
1831    project: &mut Project,
1832    replica_id: &ReplicaId,
1833) -> Result<()> {
1834    let collaborators = project
1835        .collaborators
1836        .iter()
1837        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1838        .map(|collaborator| collaborator.to_proto())
1839        .collect::<Vec<_>>();
1840    let project_id = project.id;
1841    let guest_user_id = session.user_id();
1842
1843    let worktrees = project
1844        .worktrees
1845        .iter()
1846        .map(|(id, worktree)| proto::WorktreeMetadata {
1847            id: *id,
1848            root_name: worktree.root_name.clone(),
1849            visible: worktree.visible,
1850            abs_path: worktree.abs_path.clone(),
1851        })
1852        .collect::<Vec<_>>();
1853
1854    let add_project_collaborator = proto::AddProjectCollaborator {
1855        project_id: project_id.to_proto(),
1856        collaborator: Some(proto::Collaborator {
1857            peer_id: Some(session.connection_id.into()),
1858            replica_id: replica_id.0 as u32,
1859            user_id: guest_user_id.to_proto(),
1860            is_host: false,
1861        }),
1862    };
1863
1864    for collaborator in &collaborators {
1865        session
1866            .peer
1867            .send(
1868                collaborator.peer_id.unwrap().into(),
1869                add_project_collaborator.clone(),
1870            )
1871            .trace_err();
1872    }
1873
1874    // First, we send the metadata associated with each worktree.
1875    response.send(proto::JoinProjectResponse {
1876        project_id: project.id.0 as u64,
1877        worktrees: worktrees.clone(),
1878        replica_id: replica_id.0 as u32,
1879        collaborators: collaborators.clone(),
1880        language_servers: project.language_servers.clone(),
1881        role: project.role.into(),
1882    })?;
1883
1884    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1885        // Stream this worktree's entries.
1886        let message = proto::UpdateWorktree {
1887            project_id: project_id.to_proto(),
1888            worktree_id,
1889            abs_path: worktree.abs_path.clone(),
1890            root_name: worktree.root_name,
1891            updated_entries: worktree.entries,
1892            removed_entries: Default::default(),
1893            scan_id: worktree.scan_id,
1894            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1895            updated_repositories: worktree.repository_entries.into_values().collect(),
1896            removed_repositories: Default::default(),
1897        };
1898        for update in proto::split_worktree_update(message) {
1899            session.peer.send(session.connection_id, update.clone())?;
1900        }
1901
1902        // Stream this worktree's diagnostics.
1903        for summary in worktree.diagnostic_summaries {
1904            session.peer.send(
1905                session.connection_id,
1906                proto::UpdateDiagnosticSummary {
1907                    project_id: project_id.to_proto(),
1908                    worktree_id: worktree.id,
1909                    summary: Some(summary),
1910                },
1911            )?;
1912        }
1913
1914        for settings_file in worktree.settings_files {
1915            session.peer.send(
1916                session.connection_id,
1917                proto::UpdateWorktreeSettings {
1918                    project_id: project_id.to_proto(),
1919                    worktree_id: worktree.id,
1920                    path: settings_file.path,
1921                    content: Some(settings_file.content),
1922                    kind: Some(settings_file.kind.to_proto() as i32),
1923                },
1924            )?;
1925        }
1926    }
1927
1928    for language_server in &project.language_servers {
1929        session.peer.send(
1930            session.connection_id,
1931            proto::UpdateLanguageServer {
1932                project_id: project_id.to_proto(),
1933                language_server_id: language_server.id,
1934                variant: Some(
1935                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1936                        proto::LspDiskBasedDiagnosticsUpdated {},
1937                    ),
1938                ),
1939            },
1940        )?;
1941    }
1942
1943    Ok(())
1944}
1945
1946/// Leave someone elses shared project.
1947async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1948    let sender_id = session.connection_id;
1949    let project_id = ProjectId::from_proto(request.project_id);
1950    let db = session.db().await;
1951
1952    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1953    tracing::info!(
1954        %project_id,
1955        "leave project"
1956    );
1957
1958    project_left(project, &session);
1959    if let Some(room) = room {
1960        room_updated(room, &session.peer);
1961    }
1962
1963    Ok(())
1964}
1965
1966/// Updates other participants with changes to the project
1967async fn update_project(
1968    request: proto::UpdateProject,
1969    response: Response<proto::UpdateProject>,
1970    session: Session,
1971) -> Result<()> {
1972    let project_id = ProjectId::from_proto(request.project_id);
1973    let (room, guest_connection_ids) = &*session
1974        .db()
1975        .await
1976        .update_project(project_id, session.connection_id, &request.worktrees)
1977        .await?;
1978    broadcast(
1979        Some(session.connection_id),
1980        guest_connection_ids.iter().copied(),
1981        |connection_id| {
1982            session
1983                .peer
1984                .forward_send(session.connection_id, connection_id, request.clone())
1985        },
1986    );
1987    if let Some(room) = room {
1988        room_updated(room, &session.peer);
1989    }
1990    response.send(proto::Ack {})?;
1991
1992    Ok(())
1993}
1994
1995/// Updates other participants with changes to the worktree
1996async fn update_worktree(
1997    request: proto::UpdateWorktree,
1998    response: Response<proto::UpdateWorktree>,
1999    session: Session,
2000) -> Result<()> {
2001    let guest_connection_ids = session
2002        .db()
2003        .await
2004        .update_worktree(&request, session.connection_id)
2005        .await?;
2006
2007    broadcast(
2008        Some(session.connection_id),
2009        guest_connection_ids.iter().copied(),
2010        |connection_id| {
2011            session
2012                .peer
2013                .forward_send(session.connection_id, connection_id, request.clone())
2014        },
2015    );
2016    response.send(proto::Ack {})?;
2017    Ok(())
2018}
2019
2020/// Updates other participants with changes to the diagnostics
2021async fn update_diagnostic_summary(
2022    message: proto::UpdateDiagnosticSummary,
2023    session: Session,
2024) -> Result<()> {
2025    let guest_connection_ids = session
2026        .db()
2027        .await
2028        .update_diagnostic_summary(&message, session.connection_id)
2029        .await?;
2030
2031    broadcast(
2032        Some(session.connection_id),
2033        guest_connection_ids.iter().copied(),
2034        |connection_id| {
2035            session
2036                .peer
2037                .forward_send(session.connection_id, connection_id, message.clone())
2038        },
2039    );
2040
2041    Ok(())
2042}
2043
2044/// Updates other participants with changes to the worktree settings
2045async fn update_worktree_settings(
2046    message: proto::UpdateWorktreeSettings,
2047    session: Session,
2048) -> Result<()> {
2049    let guest_connection_ids = session
2050        .db()
2051        .await
2052        .update_worktree_settings(&message, session.connection_id)
2053        .await?;
2054
2055    broadcast(
2056        Some(session.connection_id),
2057        guest_connection_ids.iter().copied(),
2058        |connection_id| {
2059            session
2060                .peer
2061                .forward_send(session.connection_id, connection_id, message.clone())
2062        },
2063    );
2064
2065    Ok(())
2066}
2067
2068/// Notify other participants that a  language server has started.
2069async fn start_language_server(
2070    request: proto::StartLanguageServer,
2071    session: Session,
2072) -> Result<()> {
2073    let guest_connection_ids = session
2074        .db()
2075        .await
2076        .start_language_server(&request, session.connection_id)
2077        .await?;
2078
2079    broadcast(
2080        Some(session.connection_id),
2081        guest_connection_ids.iter().copied(),
2082        |connection_id| {
2083            session
2084                .peer
2085                .forward_send(session.connection_id, connection_id, request.clone())
2086        },
2087    );
2088    Ok(())
2089}
2090
2091/// Notify other participants that a language server has changed.
2092async fn update_language_server(
2093    request: proto::UpdateLanguageServer,
2094    session: Session,
2095) -> Result<()> {
2096    let project_id = ProjectId::from_proto(request.project_id);
2097    let project_connection_ids = session
2098        .db()
2099        .await
2100        .project_connection_ids(project_id, session.connection_id, true)
2101        .await?;
2102    broadcast(
2103        Some(session.connection_id),
2104        project_connection_ids.iter().copied(),
2105        |connection_id| {
2106            session
2107                .peer
2108                .forward_send(session.connection_id, connection_id, request.clone())
2109        },
2110    );
2111    Ok(())
2112}
2113
2114/// forward a project request to the host. These requests should be read only
2115/// as guests are allowed to send them.
2116async fn forward_read_only_project_request<T>(
2117    request: T,
2118    response: Response<T>,
2119    session: Session,
2120) -> Result<()>
2121where
2122    T: EntityMessage + RequestMessage,
2123{
2124    let project_id = ProjectId::from_proto(request.remote_entity_id());
2125    let host_connection_id = session
2126        .db()
2127        .await
2128        .host_for_read_only_project_request(project_id, session.connection_id)
2129        .await?;
2130    let payload = session
2131        .peer
2132        .forward_request(session.connection_id, host_connection_id, request)
2133        .await?;
2134    response.send(payload)?;
2135    Ok(())
2136}
2137
2138async fn forward_find_search_candidates_request(
2139    request: proto::FindSearchCandidates,
2140    response: Response<proto::FindSearchCandidates>,
2141    session: Session,
2142) -> Result<()> {
2143    let project_id = ProjectId::from_proto(request.remote_entity_id());
2144    let host_connection_id = session
2145        .db()
2146        .await
2147        .host_for_read_only_project_request(project_id, session.connection_id)
2148        .await?;
2149    let payload = session
2150        .peer
2151        .forward_request(session.connection_id, host_connection_id, request)
2152        .await?;
2153    response.send(payload)?;
2154    Ok(())
2155}
2156
2157/// forward a project request to the host. These requests are disallowed
2158/// for guests.
2159async fn forward_mutating_project_request<T>(
2160    request: T,
2161    response: Response<T>,
2162    session: Session,
2163) -> Result<()>
2164where
2165    T: EntityMessage + RequestMessage,
2166{
2167    let project_id = ProjectId::from_proto(request.remote_entity_id());
2168
2169    let host_connection_id = session
2170        .db()
2171        .await
2172        .host_for_mutating_project_request(project_id, session.connection_id)
2173        .await?;
2174    let payload = session
2175        .peer
2176        .forward_request(session.connection_id, host_connection_id, request)
2177        .await?;
2178    response.send(payload)?;
2179    Ok(())
2180}
2181
2182/// Notify other participants that a new buffer has been created
2183async fn create_buffer_for_peer(
2184    request: proto::CreateBufferForPeer,
2185    session: Session,
2186) -> Result<()> {
2187    session
2188        .db()
2189        .await
2190        .check_user_is_project_host(
2191            ProjectId::from_proto(request.project_id),
2192            session.connection_id,
2193        )
2194        .await?;
2195    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2196    session
2197        .peer
2198        .forward_send(session.connection_id, peer_id.into(), request)?;
2199    Ok(())
2200}
2201
2202/// Notify other participants that a buffer has been updated. This is
2203/// allowed for guests as long as the update is limited to selections.
2204async fn update_buffer(
2205    request: proto::UpdateBuffer,
2206    response: Response<proto::UpdateBuffer>,
2207    session: Session,
2208) -> Result<()> {
2209    let project_id = ProjectId::from_proto(request.project_id);
2210    let mut capability = Capability::ReadOnly;
2211
2212    for op in request.operations.iter() {
2213        match op.variant {
2214            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2215            Some(_) => capability = Capability::ReadWrite,
2216        }
2217    }
2218
2219    let host = {
2220        let guard = session
2221            .db()
2222            .await
2223            .connections_for_buffer_update(project_id, session.connection_id, capability)
2224            .await?;
2225
2226        let (host, guests) = &*guard;
2227
2228        broadcast(
2229            Some(session.connection_id),
2230            guests.clone(),
2231            |connection_id| {
2232                session
2233                    .peer
2234                    .forward_send(session.connection_id, connection_id, request.clone())
2235            },
2236        );
2237
2238        *host
2239    };
2240
2241    if host != session.connection_id {
2242        session
2243            .peer
2244            .forward_request(session.connection_id, host, request.clone())
2245            .await?;
2246    }
2247
2248    response.send(proto::Ack {})?;
2249    Ok(())
2250}
2251
2252async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2253    let project_id = ProjectId::from_proto(message.project_id);
2254
2255    let operation = message.operation.as_ref().context("invalid operation")?;
2256    let capability = match operation.variant.as_ref() {
2257        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2258            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2259                match buffer_op.variant {
2260                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2261                        Capability::ReadOnly
2262                    }
2263                    _ => Capability::ReadWrite,
2264                }
2265            } else {
2266                Capability::ReadWrite
2267            }
2268        }
2269        Some(_) => Capability::ReadWrite,
2270        None => Capability::ReadOnly,
2271    };
2272
2273    let guard = session
2274        .db()
2275        .await
2276        .connections_for_buffer_update(project_id, session.connection_id, capability)
2277        .await?;
2278
2279    let (host, guests) = &*guard;
2280
2281    broadcast(
2282        Some(session.connection_id),
2283        guests.iter().chain([host]).copied(),
2284        |connection_id| {
2285            session
2286                .peer
2287                .forward_send(session.connection_id, connection_id, message.clone())
2288        },
2289    );
2290
2291    Ok(())
2292}
2293
2294/// Notify other participants that a project has been updated.
2295async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2296    request: T,
2297    session: Session,
2298) -> Result<()> {
2299    let project_id = ProjectId::from_proto(request.remote_entity_id());
2300    let project_connection_ids = session
2301        .db()
2302        .await
2303        .project_connection_ids(project_id, session.connection_id, false)
2304        .await?;
2305
2306    broadcast(
2307        Some(session.connection_id),
2308        project_connection_ids.iter().copied(),
2309        |connection_id| {
2310            session
2311                .peer
2312                .forward_send(session.connection_id, connection_id, request.clone())
2313        },
2314    );
2315    Ok(())
2316}
2317
2318/// Start following another user in a call.
2319async fn follow(
2320    request: proto::Follow,
2321    response: Response<proto::Follow>,
2322    session: Session,
2323) -> Result<()> {
2324    let room_id = RoomId::from_proto(request.room_id);
2325    let project_id = request.project_id.map(ProjectId::from_proto);
2326    let leader_id = request
2327        .leader_id
2328        .ok_or_else(|| anyhow!("invalid leader id"))?
2329        .into();
2330    let follower_id = session.connection_id;
2331
2332    session
2333        .db()
2334        .await
2335        .check_room_participants(room_id, leader_id, session.connection_id)
2336        .await?;
2337
2338    let response_payload = session
2339        .peer
2340        .forward_request(session.connection_id, leader_id, request)
2341        .await?;
2342    response.send(response_payload)?;
2343
2344    if let Some(project_id) = project_id {
2345        let room = session
2346            .db()
2347            .await
2348            .follow(room_id, project_id, leader_id, follower_id)
2349            .await?;
2350        room_updated(&room, &session.peer);
2351    }
2352
2353    Ok(())
2354}
2355
2356/// Stop following another user in a call.
2357async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2358    let room_id = RoomId::from_proto(request.room_id);
2359    let project_id = request.project_id.map(ProjectId::from_proto);
2360    let leader_id = request
2361        .leader_id
2362        .ok_or_else(|| anyhow!("invalid leader id"))?
2363        .into();
2364    let follower_id = session.connection_id;
2365
2366    session
2367        .db()
2368        .await
2369        .check_room_participants(room_id, leader_id, session.connection_id)
2370        .await?;
2371
2372    session
2373        .peer
2374        .forward_send(session.connection_id, leader_id, request)?;
2375
2376    if let Some(project_id) = project_id {
2377        let room = session
2378            .db()
2379            .await
2380            .unfollow(room_id, project_id, leader_id, follower_id)
2381            .await?;
2382        room_updated(&room, &session.peer);
2383    }
2384
2385    Ok(())
2386}
2387
2388/// Notify everyone following you of your current location.
2389async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2390    let room_id = RoomId::from_proto(request.room_id);
2391    let database = session.db.lock().await;
2392
2393    let connection_ids = if let Some(project_id) = request.project_id {
2394        let project_id = ProjectId::from_proto(project_id);
2395        database
2396            .project_connection_ids(project_id, session.connection_id, true)
2397            .await?
2398    } else {
2399        database
2400            .room_connection_ids(room_id, session.connection_id)
2401            .await?
2402    };
2403
2404    // For now, don't send view update messages back to that view's current leader.
2405    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2406        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2407        _ => None,
2408    });
2409
2410    for connection_id in connection_ids.iter().cloned() {
2411        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2412            session
2413                .peer
2414                .forward_send(session.connection_id, connection_id, request.clone())?;
2415        }
2416    }
2417    Ok(())
2418}
2419
2420/// Get public data about users.
2421async fn get_users(
2422    request: proto::GetUsers,
2423    response: Response<proto::GetUsers>,
2424    session: Session,
2425) -> Result<()> {
2426    let user_ids = request
2427        .user_ids
2428        .into_iter()
2429        .map(UserId::from_proto)
2430        .collect();
2431    let users = session
2432        .db()
2433        .await
2434        .get_users_by_ids(user_ids)
2435        .await?
2436        .into_iter()
2437        .map(|user| proto::User {
2438            id: user.id.to_proto(),
2439            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2440            github_login: user.github_login,
2441            email: user.email_address,
2442            name: user.name,
2443        })
2444        .collect();
2445    response.send(proto::UsersResponse { users })?;
2446    Ok(())
2447}
2448
2449/// Search for users (to invite) buy Github login
2450async fn fuzzy_search_users(
2451    request: proto::FuzzySearchUsers,
2452    response: Response<proto::FuzzySearchUsers>,
2453    session: Session,
2454) -> Result<()> {
2455    let query = request.query;
2456    let users = match query.len() {
2457        0 => vec![],
2458        1 | 2 => session
2459            .db()
2460            .await
2461            .get_user_by_github_login(&query)
2462            .await?
2463            .into_iter()
2464            .collect(),
2465        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2466    };
2467    let users = users
2468        .into_iter()
2469        .filter(|user| user.id != session.user_id())
2470        .map(|user| proto::User {
2471            id: user.id.to_proto(),
2472            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2473            github_login: user.github_login,
2474            name: user.name,
2475            email: user.email_address,
2476        })
2477        .collect();
2478    response.send(proto::UsersResponse { users })?;
2479    Ok(())
2480}
2481
2482/// Send a contact request to another user.
2483async fn request_contact(
2484    request: proto::RequestContact,
2485    response: Response<proto::RequestContact>,
2486    session: Session,
2487) -> Result<()> {
2488    let requester_id = session.user_id();
2489    let responder_id = UserId::from_proto(request.responder_id);
2490    if requester_id == responder_id {
2491        return Err(anyhow!("cannot add yourself as a contact"))?;
2492    }
2493
2494    let notifications = session
2495        .db()
2496        .await
2497        .send_contact_request(requester_id, responder_id)
2498        .await?;
2499
2500    // Update outgoing contact requests of requester
2501    let mut update = proto::UpdateContacts::default();
2502    update.outgoing_requests.push(responder_id.to_proto());
2503    for connection_id in session
2504        .connection_pool()
2505        .await
2506        .user_connection_ids(requester_id)
2507    {
2508        session.peer.send(connection_id, update.clone())?;
2509    }
2510
2511    // Update incoming contact requests of responder
2512    let mut update = proto::UpdateContacts::default();
2513    update
2514        .incoming_requests
2515        .push(proto::IncomingContactRequest {
2516            requester_id: requester_id.to_proto(),
2517        });
2518    let connection_pool = session.connection_pool().await;
2519    for connection_id in connection_pool.user_connection_ids(responder_id) {
2520        session.peer.send(connection_id, update.clone())?;
2521    }
2522
2523    send_notifications(&connection_pool, &session.peer, notifications);
2524
2525    response.send(proto::Ack {})?;
2526    Ok(())
2527}
2528
2529/// Accept or decline a contact request
2530async fn respond_to_contact_request(
2531    request: proto::RespondToContactRequest,
2532    response: Response<proto::RespondToContactRequest>,
2533    session: Session,
2534) -> Result<()> {
2535    let responder_id = session.user_id();
2536    let requester_id = UserId::from_proto(request.requester_id);
2537    let db = session.db().await;
2538    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2539        db.dismiss_contact_notification(responder_id, requester_id)
2540            .await?;
2541    } else {
2542        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2543
2544        let notifications = db
2545            .respond_to_contact_request(responder_id, requester_id, accept)
2546            .await?;
2547        let requester_busy = db.is_user_busy(requester_id).await?;
2548        let responder_busy = db.is_user_busy(responder_id).await?;
2549
2550        let pool = session.connection_pool().await;
2551        // Update responder with new contact
2552        let mut update = proto::UpdateContacts::default();
2553        if accept {
2554            update
2555                .contacts
2556                .push(contact_for_user(requester_id, requester_busy, &pool));
2557        }
2558        update
2559            .remove_incoming_requests
2560            .push(requester_id.to_proto());
2561        for connection_id in pool.user_connection_ids(responder_id) {
2562            session.peer.send(connection_id, update.clone())?;
2563        }
2564
2565        // Update requester with new contact
2566        let mut update = proto::UpdateContacts::default();
2567        if accept {
2568            update
2569                .contacts
2570                .push(contact_for_user(responder_id, responder_busy, &pool));
2571        }
2572        update
2573            .remove_outgoing_requests
2574            .push(responder_id.to_proto());
2575
2576        for connection_id in pool.user_connection_ids(requester_id) {
2577            session.peer.send(connection_id, update.clone())?;
2578        }
2579
2580        send_notifications(&pool, &session.peer, notifications);
2581    }
2582
2583    response.send(proto::Ack {})?;
2584    Ok(())
2585}
2586
2587/// Remove a contact.
2588async fn remove_contact(
2589    request: proto::RemoveContact,
2590    response: Response<proto::RemoveContact>,
2591    session: Session,
2592) -> Result<()> {
2593    let requester_id = session.user_id();
2594    let responder_id = UserId::from_proto(request.user_id);
2595    let db = session.db().await;
2596    let (contact_accepted, deleted_notification_id) =
2597        db.remove_contact(requester_id, responder_id).await?;
2598
2599    let pool = session.connection_pool().await;
2600    // Update outgoing contact requests of requester
2601    let mut update = proto::UpdateContacts::default();
2602    if contact_accepted {
2603        update.remove_contacts.push(responder_id.to_proto());
2604    } else {
2605        update
2606            .remove_outgoing_requests
2607            .push(responder_id.to_proto());
2608    }
2609    for connection_id in pool.user_connection_ids(requester_id) {
2610        session.peer.send(connection_id, update.clone())?;
2611    }
2612
2613    // Update incoming contact requests of responder
2614    let mut update = proto::UpdateContacts::default();
2615    if contact_accepted {
2616        update.remove_contacts.push(requester_id.to_proto());
2617    } else {
2618        update
2619            .remove_incoming_requests
2620            .push(requester_id.to_proto());
2621    }
2622    for connection_id in pool.user_connection_ids(responder_id) {
2623        session.peer.send(connection_id, update.clone())?;
2624        if let Some(notification_id) = deleted_notification_id {
2625            session.peer.send(
2626                connection_id,
2627                proto::DeleteNotification {
2628                    notification_id: notification_id.to_proto(),
2629                },
2630            )?;
2631        }
2632    }
2633
2634    response.send(proto::Ack {})?;
2635    Ok(())
2636}
2637
2638fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2639    version.0.minor() < 139
2640}
2641
2642async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2643    let plan = session.current_plan(&session.db().await).await?;
2644
2645    session
2646        .peer
2647        .send(
2648            session.connection_id,
2649            proto::UpdateUserPlan { plan: plan.into() },
2650        )
2651        .trace_err();
2652
2653    Ok(())
2654}
2655
2656async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2657    subscribe_user_to_channels(session.user_id(), &session).await?;
2658    Ok(())
2659}
2660
2661async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2662    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2663    let mut pool = session.connection_pool().await;
2664    for membership in &channels_for_user.channel_memberships {
2665        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2666    }
2667    session.peer.send(
2668        session.connection_id,
2669        build_update_user_channels(&channels_for_user),
2670    )?;
2671    session.peer.send(
2672        session.connection_id,
2673        build_channels_update(channels_for_user),
2674    )?;
2675    Ok(())
2676}
2677
2678/// Creates a new channel.
2679async fn create_channel(
2680    request: proto::CreateChannel,
2681    response: Response<proto::CreateChannel>,
2682    session: Session,
2683) -> Result<()> {
2684    let db = session.db().await;
2685
2686    let parent_id = request.parent_id.map(ChannelId::from_proto);
2687    let (channel, membership) = db
2688        .create_channel(&request.name, parent_id, session.user_id())
2689        .await?;
2690
2691    let root_id = channel.root_id();
2692    let channel = Channel::from_model(channel);
2693
2694    response.send(proto::CreateChannelResponse {
2695        channel: Some(channel.to_proto()),
2696        parent_id: request.parent_id,
2697    })?;
2698
2699    let mut connection_pool = session.connection_pool().await;
2700    if let Some(membership) = membership {
2701        connection_pool.subscribe_to_channel(
2702            membership.user_id,
2703            membership.channel_id,
2704            membership.role,
2705        );
2706        let update = proto::UpdateUserChannels {
2707            channel_memberships: vec![proto::ChannelMembership {
2708                channel_id: membership.channel_id.to_proto(),
2709                role: membership.role.into(),
2710            }],
2711            ..Default::default()
2712        };
2713        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2714            session.peer.send(connection_id, update.clone())?;
2715        }
2716    }
2717
2718    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2719        if !role.can_see_channel(channel.visibility) {
2720            continue;
2721        }
2722
2723        let update = proto::UpdateChannels {
2724            channels: vec![channel.to_proto()],
2725            ..Default::default()
2726        };
2727        session.peer.send(connection_id, update.clone())?;
2728    }
2729
2730    Ok(())
2731}
2732
2733/// Delete a channel
2734async fn delete_channel(
2735    request: proto::DeleteChannel,
2736    response: Response<proto::DeleteChannel>,
2737    session: Session,
2738) -> Result<()> {
2739    let db = session.db().await;
2740
2741    let channel_id = request.channel_id;
2742    let (root_channel, removed_channels) = db
2743        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2744        .await?;
2745    response.send(proto::Ack {})?;
2746
2747    // Notify members of removed channels
2748    let mut update = proto::UpdateChannels::default();
2749    update
2750        .delete_channels
2751        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2752
2753    let connection_pool = session.connection_pool().await;
2754    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2755        session.peer.send(connection_id, update.clone())?;
2756    }
2757
2758    Ok(())
2759}
2760
2761/// Invite someone to join a channel.
2762async fn invite_channel_member(
2763    request: proto::InviteChannelMember,
2764    response: Response<proto::InviteChannelMember>,
2765    session: Session,
2766) -> Result<()> {
2767    let db = session.db().await;
2768    let channel_id = ChannelId::from_proto(request.channel_id);
2769    let invitee_id = UserId::from_proto(request.user_id);
2770    let InviteMemberResult {
2771        channel,
2772        notifications,
2773    } = db
2774        .invite_channel_member(
2775            channel_id,
2776            invitee_id,
2777            session.user_id(),
2778            request.role().into(),
2779        )
2780        .await?;
2781
2782    let update = proto::UpdateChannels {
2783        channel_invitations: vec![channel.to_proto()],
2784        ..Default::default()
2785    };
2786
2787    let connection_pool = session.connection_pool().await;
2788    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2789        session.peer.send(connection_id, update.clone())?;
2790    }
2791
2792    send_notifications(&connection_pool, &session.peer, notifications);
2793
2794    response.send(proto::Ack {})?;
2795    Ok(())
2796}
2797
2798/// remove someone from a channel
2799async fn remove_channel_member(
2800    request: proto::RemoveChannelMember,
2801    response: Response<proto::RemoveChannelMember>,
2802    session: Session,
2803) -> Result<()> {
2804    let db = session.db().await;
2805    let channel_id = ChannelId::from_proto(request.channel_id);
2806    let member_id = UserId::from_proto(request.user_id);
2807
2808    let RemoveChannelMemberResult {
2809        membership_update,
2810        notification_id,
2811    } = db
2812        .remove_channel_member(channel_id, member_id, session.user_id())
2813        .await?;
2814
2815    let mut connection_pool = session.connection_pool().await;
2816    notify_membership_updated(
2817        &mut connection_pool,
2818        membership_update,
2819        member_id,
2820        &session.peer,
2821    );
2822    for connection_id in connection_pool.user_connection_ids(member_id) {
2823        if let Some(notification_id) = notification_id {
2824            session
2825                .peer
2826                .send(
2827                    connection_id,
2828                    proto::DeleteNotification {
2829                        notification_id: notification_id.to_proto(),
2830                    },
2831                )
2832                .trace_err();
2833        }
2834    }
2835
2836    response.send(proto::Ack {})?;
2837    Ok(())
2838}
2839
2840/// Toggle the channel between public and private.
2841/// Care is taken to maintain the invariant that public channels only descend from public channels,
2842/// (though members-only channels can appear at any point in the hierarchy).
2843async fn set_channel_visibility(
2844    request: proto::SetChannelVisibility,
2845    response: Response<proto::SetChannelVisibility>,
2846    session: Session,
2847) -> Result<()> {
2848    let db = session.db().await;
2849    let channel_id = ChannelId::from_proto(request.channel_id);
2850    let visibility = request.visibility().into();
2851
2852    let channel_model = db
2853        .set_channel_visibility(channel_id, visibility, session.user_id())
2854        .await?;
2855    let root_id = channel_model.root_id();
2856    let channel = Channel::from_model(channel_model);
2857
2858    let mut connection_pool = session.connection_pool().await;
2859    for (user_id, role) in connection_pool
2860        .channel_user_ids(root_id)
2861        .collect::<Vec<_>>()
2862        .into_iter()
2863    {
2864        let update = if role.can_see_channel(channel.visibility) {
2865            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2866            proto::UpdateChannels {
2867                channels: vec![channel.to_proto()],
2868                ..Default::default()
2869            }
2870        } else {
2871            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2872            proto::UpdateChannels {
2873                delete_channels: vec![channel.id.to_proto()],
2874                ..Default::default()
2875            }
2876        };
2877
2878        for connection_id in connection_pool.user_connection_ids(user_id) {
2879            session.peer.send(connection_id, update.clone())?;
2880        }
2881    }
2882
2883    response.send(proto::Ack {})?;
2884    Ok(())
2885}
2886
2887/// Alter the role for a user in the channel.
2888async fn set_channel_member_role(
2889    request: proto::SetChannelMemberRole,
2890    response: Response<proto::SetChannelMemberRole>,
2891    session: Session,
2892) -> Result<()> {
2893    let db = session.db().await;
2894    let channel_id = ChannelId::from_proto(request.channel_id);
2895    let member_id = UserId::from_proto(request.user_id);
2896    let result = db
2897        .set_channel_member_role(
2898            channel_id,
2899            session.user_id(),
2900            member_id,
2901            request.role().into(),
2902        )
2903        .await?;
2904
2905    match result {
2906        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2907            let mut connection_pool = session.connection_pool().await;
2908            notify_membership_updated(
2909                &mut connection_pool,
2910                membership_update,
2911                member_id,
2912                &session.peer,
2913            )
2914        }
2915        db::SetMemberRoleResult::InviteUpdated(channel) => {
2916            let update = proto::UpdateChannels {
2917                channel_invitations: vec![channel.to_proto()],
2918                ..Default::default()
2919            };
2920
2921            for connection_id in session
2922                .connection_pool()
2923                .await
2924                .user_connection_ids(member_id)
2925            {
2926                session.peer.send(connection_id, update.clone())?;
2927            }
2928        }
2929    }
2930
2931    response.send(proto::Ack {})?;
2932    Ok(())
2933}
2934
2935/// Change the name of a channel
2936async fn rename_channel(
2937    request: proto::RenameChannel,
2938    response: Response<proto::RenameChannel>,
2939    session: Session,
2940) -> Result<()> {
2941    let db = session.db().await;
2942    let channel_id = ChannelId::from_proto(request.channel_id);
2943    let channel_model = db
2944        .rename_channel(channel_id, session.user_id(), &request.name)
2945        .await?;
2946    let root_id = channel_model.root_id();
2947    let channel = Channel::from_model(channel_model);
2948
2949    response.send(proto::RenameChannelResponse {
2950        channel: Some(channel.to_proto()),
2951    })?;
2952
2953    let connection_pool = session.connection_pool().await;
2954    let update = proto::UpdateChannels {
2955        channels: vec![channel.to_proto()],
2956        ..Default::default()
2957    };
2958    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2959        if role.can_see_channel(channel.visibility) {
2960            session.peer.send(connection_id, update.clone())?;
2961        }
2962    }
2963
2964    Ok(())
2965}
2966
2967/// Move a channel to a new parent.
2968async fn move_channel(
2969    request: proto::MoveChannel,
2970    response: Response<proto::MoveChannel>,
2971    session: Session,
2972) -> Result<()> {
2973    let channel_id = ChannelId::from_proto(request.channel_id);
2974    let to = ChannelId::from_proto(request.to);
2975
2976    let (root_id, channels) = session
2977        .db()
2978        .await
2979        .move_channel(channel_id, to, session.user_id())
2980        .await?;
2981
2982    let connection_pool = session.connection_pool().await;
2983    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2984        let channels = channels
2985            .iter()
2986            .filter_map(|channel| {
2987                if role.can_see_channel(channel.visibility) {
2988                    Some(channel.to_proto())
2989                } else {
2990                    None
2991                }
2992            })
2993            .collect::<Vec<_>>();
2994        if channels.is_empty() {
2995            continue;
2996        }
2997
2998        let update = proto::UpdateChannels {
2999            channels,
3000            ..Default::default()
3001        };
3002
3003        session.peer.send(connection_id, update.clone())?;
3004    }
3005
3006    response.send(Ack {})?;
3007    Ok(())
3008}
3009
3010/// Get the list of channel members
3011async fn get_channel_members(
3012    request: proto::GetChannelMembers,
3013    response: Response<proto::GetChannelMembers>,
3014    session: Session,
3015) -> Result<()> {
3016    let db = session.db().await;
3017    let channel_id = ChannelId::from_proto(request.channel_id);
3018    let limit = if request.limit == 0 {
3019        u16::MAX as u64
3020    } else {
3021        request.limit
3022    };
3023    let (members, users) = db
3024        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3025        .await?;
3026    response.send(proto::GetChannelMembersResponse { members, users })?;
3027    Ok(())
3028}
3029
3030/// Accept or decline a channel invitation.
3031async fn respond_to_channel_invite(
3032    request: proto::RespondToChannelInvite,
3033    response: Response<proto::RespondToChannelInvite>,
3034    session: Session,
3035) -> Result<()> {
3036    let db = session.db().await;
3037    let channel_id = ChannelId::from_proto(request.channel_id);
3038    let RespondToChannelInvite {
3039        membership_update,
3040        notifications,
3041    } = db
3042        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3043        .await?;
3044
3045    let mut connection_pool = session.connection_pool().await;
3046    if let Some(membership_update) = membership_update {
3047        notify_membership_updated(
3048            &mut connection_pool,
3049            membership_update,
3050            session.user_id(),
3051            &session.peer,
3052        );
3053    } else {
3054        let update = proto::UpdateChannels {
3055            remove_channel_invitations: vec![channel_id.to_proto()],
3056            ..Default::default()
3057        };
3058
3059        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3060            session.peer.send(connection_id, update.clone())?;
3061        }
3062    };
3063
3064    send_notifications(&connection_pool, &session.peer, notifications);
3065
3066    response.send(proto::Ack {})?;
3067
3068    Ok(())
3069}
3070
3071/// Join the channels' room
3072async fn join_channel(
3073    request: proto::JoinChannel,
3074    response: Response<proto::JoinChannel>,
3075    session: Session,
3076) -> Result<()> {
3077    let channel_id = ChannelId::from_proto(request.channel_id);
3078    join_channel_internal(channel_id, Box::new(response), session).await
3079}
3080
3081trait JoinChannelInternalResponse {
3082    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3083}
3084impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3085    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3086        Response::<proto::JoinChannel>::send(self, result)
3087    }
3088}
3089impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3090    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3091        Response::<proto::JoinRoom>::send(self, result)
3092    }
3093}
3094
3095async fn join_channel_internal(
3096    channel_id: ChannelId,
3097    response: Box<impl JoinChannelInternalResponse>,
3098    session: Session,
3099) -> Result<()> {
3100    let joined_room = {
3101        let mut db = session.db().await;
3102        // If zed quits without leaving the room, and the user re-opens zed before the
3103        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3104        // room they were in.
3105        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3106            tracing::info!(
3107                stale_connection_id = %connection,
3108                "cleaning up stale connection",
3109            );
3110            drop(db);
3111            leave_room_for_session(&session, connection).await?;
3112            db = session.db().await;
3113        }
3114
3115        let (joined_room, membership_updated, role) = db
3116            .join_channel(channel_id, session.user_id(), session.connection_id)
3117            .await?;
3118
3119        let live_kit_connection_info =
3120            session
3121                .app_state
3122                .livekit_client
3123                .as_ref()
3124                .and_then(|live_kit| {
3125                    let (can_publish, token) = if role == ChannelRole::Guest {
3126                        (
3127                            false,
3128                            live_kit
3129                                .guest_token(
3130                                    &joined_room.room.livekit_room,
3131                                    &session.user_id().to_string(),
3132                                )
3133                                .trace_err()?,
3134                        )
3135                    } else {
3136                        (
3137                            true,
3138                            live_kit
3139                                .room_token(
3140                                    &joined_room.room.livekit_room,
3141                                    &session.user_id().to_string(),
3142                                )
3143                                .trace_err()?,
3144                        )
3145                    };
3146
3147                    Some(LiveKitConnectionInfo {
3148                        server_url: live_kit.url().into(),
3149                        token,
3150                        can_publish,
3151                    })
3152                });
3153
3154        response.send(proto::JoinRoomResponse {
3155            room: Some(joined_room.room.clone()),
3156            channel_id: joined_room
3157                .channel
3158                .as_ref()
3159                .map(|channel| channel.id.to_proto()),
3160            live_kit_connection_info,
3161        })?;
3162
3163        let mut connection_pool = session.connection_pool().await;
3164        if let Some(membership_updated) = membership_updated {
3165            notify_membership_updated(
3166                &mut connection_pool,
3167                membership_updated,
3168                session.user_id(),
3169                &session.peer,
3170            );
3171        }
3172
3173        room_updated(&joined_room.room, &session.peer);
3174
3175        joined_room
3176    };
3177
3178    channel_updated(
3179        &joined_room
3180            .channel
3181            .ok_or_else(|| anyhow!("channel not returned"))?,
3182        &joined_room.room,
3183        &session.peer,
3184        &*session.connection_pool().await,
3185    );
3186
3187    update_user_contacts(session.user_id(), &session).await?;
3188    Ok(())
3189}
3190
3191/// Start editing the channel notes
3192async fn join_channel_buffer(
3193    request: proto::JoinChannelBuffer,
3194    response: Response<proto::JoinChannelBuffer>,
3195    session: Session,
3196) -> Result<()> {
3197    let db = session.db().await;
3198    let channel_id = ChannelId::from_proto(request.channel_id);
3199
3200    let open_response = db
3201        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3202        .await?;
3203
3204    let collaborators = open_response.collaborators.clone();
3205    response.send(open_response)?;
3206
3207    let update = UpdateChannelBufferCollaborators {
3208        channel_id: channel_id.to_proto(),
3209        collaborators: collaborators.clone(),
3210    };
3211    channel_buffer_updated(
3212        session.connection_id,
3213        collaborators
3214            .iter()
3215            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3216        &update,
3217        &session.peer,
3218    );
3219
3220    Ok(())
3221}
3222
3223/// Edit the channel notes
3224async fn update_channel_buffer(
3225    request: proto::UpdateChannelBuffer,
3226    session: Session,
3227) -> Result<()> {
3228    let db = session.db().await;
3229    let channel_id = ChannelId::from_proto(request.channel_id);
3230
3231    let (collaborators, epoch, version) = db
3232        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3233        .await?;
3234
3235    channel_buffer_updated(
3236        session.connection_id,
3237        collaborators.clone(),
3238        &proto::UpdateChannelBuffer {
3239            channel_id: channel_id.to_proto(),
3240            operations: request.operations,
3241        },
3242        &session.peer,
3243    );
3244
3245    let pool = &*session.connection_pool().await;
3246
3247    let non_collaborators =
3248        pool.channel_connection_ids(channel_id)
3249            .filter_map(|(connection_id, _)| {
3250                if collaborators.contains(&connection_id) {
3251                    None
3252                } else {
3253                    Some(connection_id)
3254                }
3255            });
3256
3257    broadcast(None, non_collaborators, |peer_id| {
3258        session.peer.send(
3259            peer_id,
3260            proto::UpdateChannels {
3261                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3262                    channel_id: channel_id.to_proto(),
3263                    epoch: epoch as u64,
3264                    version: version.clone(),
3265                }],
3266                ..Default::default()
3267            },
3268        )
3269    });
3270
3271    Ok(())
3272}
3273
3274/// Rejoin the channel notes after a connection blip
3275async fn rejoin_channel_buffers(
3276    request: proto::RejoinChannelBuffers,
3277    response: Response<proto::RejoinChannelBuffers>,
3278    session: Session,
3279) -> Result<()> {
3280    let db = session.db().await;
3281    let buffers = db
3282        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3283        .await?;
3284
3285    for rejoined_buffer in &buffers {
3286        let collaborators_to_notify = rejoined_buffer
3287            .buffer
3288            .collaborators
3289            .iter()
3290            .filter_map(|c| Some(c.peer_id?.into()));
3291        channel_buffer_updated(
3292            session.connection_id,
3293            collaborators_to_notify,
3294            &proto::UpdateChannelBufferCollaborators {
3295                channel_id: rejoined_buffer.buffer.channel_id,
3296                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3297            },
3298            &session.peer,
3299        );
3300    }
3301
3302    response.send(proto::RejoinChannelBuffersResponse {
3303        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3304    })?;
3305
3306    Ok(())
3307}
3308
3309/// Stop editing the channel notes
3310async fn leave_channel_buffer(
3311    request: proto::LeaveChannelBuffer,
3312    response: Response<proto::LeaveChannelBuffer>,
3313    session: Session,
3314) -> Result<()> {
3315    let db = session.db().await;
3316    let channel_id = ChannelId::from_proto(request.channel_id);
3317
3318    let left_buffer = db
3319        .leave_channel_buffer(channel_id, session.connection_id)
3320        .await?;
3321
3322    response.send(Ack {})?;
3323
3324    channel_buffer_updated(
3325        session.connection_id,
3326        left_buffer.connections,
3327        &proto::UpdateChannelBufferCollaborators {
3328            channel_id: channel_id.to_proto(),
3329            collaborators: left_buffer.collaborators,
3330        },
3331        &session.peer,
3332    );
3333
3334    Ok(())
3335}
3336
3337fn channel_buffer_updated<T: EnvelopedMessage>(
3338    sender_id: ConnectionId,
3339    collaborators: impl IntoIterator<Item = ConnectionId>,
3340    message: &T,
3341    peer: &Peer,
3342) {
3343    broadcast(Some(sender_id), collaborators, |peer_id| {
3344        peer.send(peer_id, message.clone())
3345    });
3346}
3347
3348fn send_notifications(
3349    connection_pool: &ConnectionPool,
3350    peer: &Peer,
3351    notifications: db::NotificationBatch,
3352) {
3353    for (user_id, notification) in notifications {
3354        for connection_id in connection_pool.user_connection_ids(user_id) {
3355            if let Err(error) = peer.send(
3356                connection_id,
3357                proto::AddNotification {
3358                    notification: Some(notification.clone()),
3359                },
3360            ) {
3361                tracing::error!(
3362                    "failed to send notification to {:?} {}",
3363                    connection_id,
3364                    error
3365                );
3366            }
3367        }
3368    }
3369}
3370
3371/// Send a message to the channel
3372async fn send_channel_message(
3373    request: proto::SendChannelMessage,
3374    response: Response<proto::SendChannelMessage>,
3375    session: Session,
3376) -> Result<()> {
3377    // Validate the message body.
3378    let body = request.body.trim().to_string();
3379    if body.len() > MAX_MESSAGE_LEN {
3380        return Err(anyhow!("message is too long"))?;
3381    }
3382    if body.is_empty() {
3383        return Err(anyhow!("message can't be blank"))?;
3384    }
3385
3386    // TODO: adjust mentions if body is trimmed
3387
3388    let timestamp = OffsetDateTime::now_utc();
3389    let nonce = request
3390        .nonce
3391        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3392
3393    let channel_id = ChannelId::from_proto(request.channel_id);
3394    let CreatedChannelMessage {
3395        message_id,
3396        participant_connection_ids,
3397        notifications,
3398    } = session
3399        .db()
3400        .await
3401        .create_channel_message(
3402            channel_id,
3403            session.user_id(),
3404            &body,
3405            &request.mentions,
3406            timestamp,
3407            nonce.clone().into(),
3408            request.reply_to_message_id.map(MessageId::from_proto),
3409        )
3410        .await?;
3411
3412    let message = proto::ChannelMessage {
3413        sender_id: session.user_id().to_proto(),
3414        id: message_id.to_proto(),
3415        body,
3416        mentions: request.mentions,
3417        timestamp: timestamp.unix_timestamp() as u64,
3418        nonce: Some(nonce),
3419        reply_to_message_id: request.reply_to_message_id,
3420        edited_at: None,
3421    };
3422    broadcast(
3423        Some(session.connection_id),
3424        participant_connection_ids.clone(),
3425        |connection| {
3426            session.peer.send(
3427                connection,
3428                proto::ChannelMessageSent {
3429                    channel_id: channel_id.to_proto(),
3430                    message: Some(message.clone()),
3431                },
3432            )
3433        },
3434    );
3435    response.send(proto::SendChannelMessageResponse {
3436        message: Some(message),
3437    })?;
3438
3439    let pool = &*session.connection_pool().await;
3440    let non_participants =
3441        pool.channel_connection_ids(channel_id)
3442            .filter_map(|(connection_id, _)| {
3443                if participant_connection_ids.contains(&connection_id) {
3444                    None
3445                } else {
3446                    Some(connection_id)
3447                }
3448            });
3449    broadcast(None, non_participants, |peer_id| {
3450        session.peer.send(
3451            peer_id,
3452            proto::UpdateChannels {
3453                latest_channel_message_ids: vec![proto::ChannelMessageId {
3454                    channel_id: channel_id.to_proto(),
3455                    message_id: message_id.to_proto(),
3456                }],
3457                ..Default::default()
3458            },
3459        )
3460    });
3461    send_notifications(pool, &session.peer, notifications);
3462
3463    Ok(())
3464}
3465
3466/// Delete a channel message
3467async fn remove_channel_message(
3468    request: proto::RemoveChannelMessage,
3469    response: Response<proto::RemoveChannelMessage>,
3470    session: Session,
3471) -> Result<()> {
3472    let channel_id = ChannelId::from_proto(request.channel_id);
3473    let message_id = MessageId::from_proto(request.message_id);
3474    let (connection_ids, existing_notification_ids) = session
3475        .db()
3476        .await
3477        .remove_channel_message(channel_id, message_id, session.user_id())
3478        .await?;
3479
3480    broadcast(
3481        Some(session.connection_id),
3482        connection_ids,
3483        move |connection| {
3484            session.peer.send(connection, request.clone())?;
3485
3486            for notification_id in &existing_notification_ids {
3487                session.peer.send(
3488                    connection,
3489                    proto::DeleteNotification {
3490                        notification_id: (*notification_id).to_proto(),
3491                    },
3492                )?;
3493            }
3494
3495            Ok(())
3496        },
3497    );
3498    response.send(proto::Ack {})?;
3499    Ok(())
3500}
3501
3502async fn update_channel_message(
3503    request: proto::UpdateChannelMessage,
3504    response: Response<proto::UpdateChannelMessage>,
3505    session: Session,
3506) -> Result<()> {
3507    let channel_id = ChannelId::from_proto(request.channel_id);
3508    let message_id = MessageId::from_proto(request.message_id);
3509    let updated_at = OffsetDateTime::now_utc();
3510    let UpdatedChannelMessage {
3511        message_id,
3512        participant_connection_ids,
3513        notifications,
3514        reply_to_message_id,
3515        timestamp,
3516        deleted_mention_notification_ids,
3517        updated_mention_notifications,
3518    } = session
3519        .db()
3520        .await
3521        .update_channel_message(
3522            channel_id,
3523            message_id,
3524            session.user_id(),
3525            request.body.as_str(),
3526            &request.mentions,
3527            updated_at,
3528        )
3529        .await?;
3530
3531    let nonce = request
3532        .nonce
3533        .clone()
3534        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3535
3536    let message = proto::ChannelMessage {
3537        sender_id: session.user_id().to_proto(),
3538        id: message_id.to_proto(),
3539        body: request.body.clone(),
3540        mentions: request.mentions.clone(),
3541        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3542        nonce: Some(nonce),
3543        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3544        edited_at: Some(updated_at.unix_timestamp() as u64),
3545    };
3546
3547    response.send(proto::Ack {})?;
3548
3549    let pool = &*session.connection_pool().await;
3550    broadcast(
3551        Some(session.connection_id),
3552        participant_connection_ids,
3553        |connection| {
3554            session.peer.send(
3555                connection,
3556                proto::ChannelMessageUpdate {
3557                    channel_id: channel_id.to_proto(),
3558                    message: Some(message.clone()),
3559                },
3560            )?;
3561
3562            for notification_id in &deleted_mention_notification_ids {
3563                session.peer.send(
3564                    connection,
3565                    proto::DeleteNotification {
3566                        notification_id: (*notification_id).to_proto(),
3567                    },
3568                )?;
3569            }
3570
3571            for notification in &updated_mention_notifications {
3572                session.peer.send(
3573                    connection,
3574                    proto::UpdateNotification {
3575                        notification: Some(notification.clone()),
3576                    },
3577                )?;
3578            }
3579
3580            Ok(())
3581        },
3582    );
3583
3584    send_notifications(pool, &session.peer, notifications);
3585
3586    Ok(())
3587}
3588
3589/// Mark a channel message as read
3590async fn acknowledge_channel_message(
3591    request: proto::AckChannelMessage,
3592    session: Session,
3593) -> Result<()> {
3594    let channel_id = ChannelId::from_proto(request.channel_id);
3595    let message_id = MessageId::from_proto(request.message_id);
3596    let notifications = session
3597        .db()
3598        .await
3599        .observe_channel_message(channel_id, session.user_id(), message_id)
3600        .await?;
3601    send_notifications(
3602        &*session.connection_pool().await,
3603        &session.peer,
3604        notifications,
3605    );
3606    Ok(())
3607}
3608
3609/// Mark a buffer version as synced
3610async fn acknowledge_buffer_version(
3611    request: proto::AckBufferOperation,
3612    session: Session,
3613) -> Result<()> {
3614    let buffer_id = BufferId::from_proto(request.buffer_id);
3615    session
3616        .db()
3617        .await
3618        .observe_buffer_version(
3619            buffer_id,
3620            session.user_id(),
3621            request.epoch as i32,
3622            &request.version,
3623        )
3624        .await?;
3625    Ok(())
3626}
3627
3628async fn count_language_model_tokens(
3629    request: proto::CountLanguageModelTokens,
3630    response: Response<proto::CountLanguageModelTokens>,
3631    session: Session,
3632    config: &Config,
3633) -> Result<()> {
3634    authorize_access_to_legacy_llm_endpoints(&session).await?;
3635
3636    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3637        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3638        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3639    };
3640
3641    session
3642        .app_state
3643        .rate_limiter
3644        .check(&*rate_limit, session.user_id())
3645        .await?;
3646
3647    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3648        Some(proto::LanguageModelProvider::Google) => {
3649            let api_key = config
3650                .google_ai_api_key
3651                .as_ref()
3652                .context("no Google AI API key configured on the server")?;
3653            google_ai::count_tokens(
3654                session.http_client.as_ref(),
3655                google_ai::API_URL,
3656                api_key,
3657                serde_json::from_str(&request.request)?,
3658            )
3659            .await?
3660        }
3661        _ => return Err(anyhow!("unsupported provider"))?,
3662    };
3663
3664    response.send(proto::CountLanguageModelTokensResponse {
3665        token_count: result.total_tokens as u32,
3666    })?;
3667
3668    Ok(())
3669}
3670
3671struct ZedProCountLanguageModelTokensRateLimit;
3672
3673impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3674    fn capacity(&self) -> usize {
3675        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3676            .ok()
3677            .and_then(|v| v.parse().ok())
3678            .unwrap_or(600) // Picked arbitrarily
3679    }
3680
3681    fn refill_duration(&self) -> chrono::Duration {
3682        chrono::Duration::hours(1)
3683    }
3684
3685    fn db_name(&self) -> &'static str {
3686        "zed-pro:count-language-model-tokens"
3687    }
3688}
3689
3690struct FreeCountLanguageModelTokensRateLimit;
3691
3692impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3693    fn capacity(&self) -> usize {
3694        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3695            .ok()
3696            .and_then(|v| v.parse().ok())
3697            .unwrap_or(600 / 10) // Picked arbitrarily
3698    }
3699
3700    fn refill_duration(&self) -> chrono::Duration {
3701        chrono::Duration::hours(1)
3702    }
3703
3704    fn db_name(&self) -> &'static str {
3705        "free:count-language-model-tokens"
3706    }
3707}
3708
3709struct ZedProComputeEmbeddingsRateLimit;
3710
3711impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3712    fn capacity(&self) -> usize {
3713        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3714            .ok()
3715            .and_then(|v| v.parse().ok())
3716            .unwrap_or(5000) // Picked arbitrarily
3717    }
3718
3719    fn refill_duration(&self) -> chrono::Duration {
3720        chrono::Duration::hours(1)
3721    }
3722
3723    fn db_name(&self) -> &'static str {
3724        "zed-pro:compute-embeddings"
3725    }
3726}
3727
3728struct FreeComputeEmbeddingsRateLimit;
3729
3730impl RateLimit for FreeComputeEmbeddingsRateLimit {
3731    fn capacity(&self) -> usize {
3732        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3733            .ok()
3734            .and_then(|v| v.parse().ok())
3735            .unwrap_or(5000 / 10) // Picked arbitrarily
3736    }
3737
3738    fn refill_duration(&self) -> chrono::Duration {
3739        chrono::Duration::hours(1)
3740    }
3741
3742    fn db_name(&self) -> &'static str {
3743        "free:compute-embeddings"
3744    }
3745}
3746
3747async fn compute_embeddings(
3748    request: proto::ComputeEmbeddings,
3749    response: Response<proto::ComputeEmbeddings>,
3750    session: Session,
3751    api_key: Option<Arc<str>>,
3752) -> Result<()> {
3753    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3754    authorize_access_to_legacy_llm_endpoints(&session).await?;
3755
3756    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3757        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3758        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3759    };
3760
3761    session
3762        .app_state
3763        .rate_limiter
3764        .check(&*rate_limit, session.user_id())
3765        .await?;
3766
3767    let embeddings = match request.model.as_str() {
3768        "openai/text-embedding-3-small" => {
3769            open_ai::embed(
3770                session.http_client.as_ref(),
3771                OPEN_AI_API_URL,
3772                &api_key,
3773                OpenAiEmbeddingModel::TextEmbedding3Small,
3774                request.texts.iter().map(|text| text.as_str()),
3775            )
3776            .await?
3777        }
3778        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3779    };
3780
3781    let embeddings = request
3782        .texts
3783        .iter()
3784        .map(|text| {
3785            let mut hasher = sha2::Sha256::new();
3786            hasher.update(text.as_bytes());
3787            let result = hasher.finalize();
3788            result.to_vec()
3789        })
3790        .zip(
3791            embeddings
3792                .data
3793                .into_iter()
3794                .map(|embedding| embedding.embedding),
3795        )
3796        .collect::<HashMap<_, _>>();
3797
3798    let db = session.db().await;
3799    db.save_embeddings(&request.model, &embeddings)
3800        .await
3801        .context("failed to save embeddings")
3802        .trace_err();
3803
3804    response.send(proto::ComputeEmbeddingsResponse {
3805        embeddings: embeddings
3806            .into_iter()
3807            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3808            .collect(),
3809    })?;
3810    Ok(())
3811}
3812
3813async fn get_cached_embeddings(
3814    request: proto::GetCachedEmbeddings,
3815    response: Response<proto::GetCachedEmbeddings>,
3816    session: Session,
3817) -> Result<()> {
3818    authorize_access_to_legacy_llm_endpoints(&session).await?;
3819
3820    let db = session.db().await;
3821    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3822
3823    response.send(proto::GetCachedEmbeddingsResponse {
3824        embeddings: embeddings
3825            .into_iter()
3826            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3827            .collect(),
3828    })?;
3829    Ok(())
3830}
3831
3832/// This is leftover from before the LLM service.
3833///
3834/// The endpoints protected by this check will be moved there eventually.
3835async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3836    if session.is_staff() {
3837        Ok(())
3838    } else {
3839        Err(anyhow!("permission denied"))?
3840    }
3841}
3842
3843/// Get a Supermaven API key for the user
3844async fn get_supermaven_api_key(
3845    _request: proto::GetSupermavenApiKey,
3846    response: Response<proto::GetSupermavenApiKey>,
3847    session: Session,
3848) -> Result<()> {
3849    let user_id: String = session.user_id().to_string();
3850    if !session.is_staff() {
3851        return Err(anyhow!("supermaven not enabled for this account"))?;
3852    }
3853
3854    let email = session
3855        .email()
3856        .ok_or_else(|| anyhow!("user must have an email"))?;
3857
3858    let supermaven_admin_api = session
3859        .supermaven_client
3860        .as_ref()
3861        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3862
3863    let result = supermaven_admin_api
3864        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3865        .await?;
3866
3867    response.send(proto::GetSupermavenApiKeyResponse {
3868        api_key: result.api_key,
3869    })?;
3870
3871    Ok(())
3872}
3873
3874/// Start receiving chat updates for a channel
3875async fn join_channel_chat(
3876    request: proto::JoinChannelChat,
3877    response: Response<proto::JoinChannelChat>,
3878    session: Session,
3879) -> Result<()> {
3880    let channel_id = ChannelId::from_proto(request.channel_id);
3881
3882    let db = session.db().await;
3883    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3884        .await?;
3885    let messages = db
3886        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3887        .await?;
3888    response.send(proto::JoinChannelChatResponse {
3889        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3890        messages,
3891    })?;
3892    Ok(())
3893}
3894
3895/// Stop receiving chat updates for a channel
3896async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3897    let channel_id = ChannelId::from_proto(request.channel_id);
3898    session
3899        .db()
3900        .await
3901        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3902        .await?;
3903    Ok(())
3904}
3905
3906/// Retrieve the chat history for a channel
3907async fn get_channel_messages(
3908    request: proto::GetChannelMessages,
3909    response: Response<proto::GetChannelMessages>,
3910    session: Session,
3911) -> Result<()> {
3912    let channel_id = ChannelId::from_proto(request.channel_id);
3913    let messages = session
3914        .db()
3915        .await
3916        .get_channel_messages(
3917            channel_id,
3918            session.user_id(),
3919            MESSAGE_COUNT_PER_PAGE,
3920            Some(MessageId::from_proto(request.before_message_id)),
3921        )
3922        .await?;
3923    response.send(proto::GetChannelMessagesResponse {
3924        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3925        messages,
3926    })?;
3927    Ok(())
3928}
3929
3930/// Retrieve specific chat messages
3931async fn get_channel_messages_by_id(
3932    request: proto::GetChannelMessagesById,
3933    response: Response<proto::GetChannelMessagesById>,
3934    session: Session,
3935) -> Result<()> {
3936    let message_ids = request
3937        .message_ids
3938        .iter()
3939        .map(|id| MessageId::from_proto(*id))
3940        .collect::<Vec<_>>();
3941    let messages = session
3942        .db()
3943        .await
3944        .get_channel_messages_by_id(session.user_id(), &message_ids)
3945        .await?;
3946    response.send(proto::GetChannelMessagesResponse {
3947        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3948        messages,
3949    })?;
3950    Ok(())
3951}
3952
3953/// Retrieve the current users notifications
3954async fn get_notifications(
3955    request: proto::GetNotifications,
3956    response: Response<proto::GetNotifications>,
3957    session: Session,
3958) -> Result<()> {
3959    let notifications = session
3960        .db()
3961        .await
3962        .get_notifications(
3963            session.user_id(),
3964            NOTIFICATION_COUNT_PER_PAGE,
3965            request.before_id.map(db::NotificationId::from_proto),
3966        )
3967        .await?;
3968    response.send(proto::GetNotificationsResponse {
3969        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3970        notifications,
3971    })?;
3972    Ok(())
3973}
3974
3975/// Mark notifications as read
3976async fn mark_notification_as_read(
3977    request: proto::MarkNotificationRead,
3978    response: Response<proto::MarkNotificationRead>,
3979    session: Session,
3980) -> Result<()> {
3981    let database = &session.db().await;
3982    let notifications = database
3983        .mark_notification_as_read_by_id(
3984            session.user_id(),
3985            NotificationId::from_proto(request.notification_id),
3986        )
3987        .await?;
3988    send_notifications(
3989        &*session.connection_pool().await,
3990        &session.peer,
3991        notifications,
3992    );
3993    response.send(proto::Ack {})?;
3994    Ok(())
3995}
3996
3997/// Get the current users information
3998async fn get_private_user_info(
3999    _request: proto::GetPrivateUserInfo,
4000    response: Response<proto::GetPrivateUserInfo>,
4001    session: Session,
4002) -> Result<()> {
4003    let db = session.db().await;
4004
4005    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4006    let user = db
4007        .get_user_by_id(session.user_id())
4008        .await?
4009        .ok_or_else(|| anyhow!("user not found"))?;
4010    let flags = db.get_user_flags(session.user_id()).await?;
4011
4012    response.send(proto::GetPrivateUserInfoResponse {
4013        metrics_id,
4014        staff: user.admin,
4015        flags,
4016        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4017    })?;
4018    Ok(())
4019}
4020
4021/// Accept the terms of service (tos) on behalf of the current user
4022async fn accept_terms_of_service(
4023    _request: proto::AcceptTermsOfService,
4024    response: Response<proto::AcceptTermsOfService>,
4025    session: Session,
4026) -> Result<()> {
4027    let db = session.db().await;
4028
4029    let accepted_tos_at = Utc::now();
4030    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4031        .await?;
4032
4033    response.send(proto::AcceptTermsOfServiceResponse {
4034        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4035    })?;
4036    Ok(())
4037}
4038
4039/// The minimum account age an account must have in order to use the LLM service.
4040const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4041
4042async fn get_llm_api_token(
4043    _request: proto::GetLlmToken,
4044    response: Response<proto::GetLlmToken>,
4045    session: Session,
4046) -> Result<()> {
4047    let db = session.db().await;
4048
4049    let flags = db.get_user_flags(session.user_id()).await?;
4050    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4051    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4052    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4053
4054    if !session.is_staff() && !has_language_models_feature_flag {
4055        Err(anyhow!("permission denied"))?
4056    }
4057
4058    let user_id = session.user_id();
4059    let user = db
4060        .get_user_by_id(user_id)
4061        .await?
4062        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4063
4064    if user.accepted_tos_at.is_none() {
4065        Err(anyhow!("terms of service not accepted"))?
4066    }
4067
4068    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4069
4070    let bypass_account_age_check =
4071        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4072    if !bypass_account_age_check {
4073        let mut account_created_at = user.created_at;
4074        if let Some(github_created_at) = user.github_user_created_at {
4075            account_created_at = account_created_at.min(github_created_at);
4076        }
4077        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4078            Err(anyhow!("account too young"))?
4079        }
4080    }
4081
4082    let billing_preferences = db.get_billing_preferences(user.id).await?;
4083
4084    let token = LlmTokenClaims::create(
4085        &user,
4086        session.is_staff(),
4087        billing_preferences,
4088        has_llm_closed_beta_feature_flag,
4089        has_predict_edits_feature_flag,
4090        has_llm_subscription,
4091        session.current_plan(&db).await?,
4092        session.system_id.clone(),
4093        &session.app_state.config,
4094    )?;
4095    response.send(proto::GetLlmTokenResponse { token })?;
4096    Ok(())
4097}
4098
4099fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4100    let message = match message {
4101        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4102        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4103        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4104        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4105        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4106            code: frame.code.into(),
4107            reason: frame.reason,
4108        })),
4109        // We should never receive a frame while reading the message, according
4110        // to the `tungstenite` maintainers:
4111        //
4112        // > It cannot occur when you read messages from the WebSocket, but it
4113        // > can be used when you want to send the raw frames (e.g. you want to
4114        // > send the frames to the WebSocket without composing the full message first).
4115        // >
4116        // > — https://github.com/snapview/tungstenite-rs/issues/268
4117        TungsteniteMessage::Frame(_) => {
4118            bail!("received an unexpected frame while reading the message")
4119        }
4120    };
4121
4122    Ok(message)
4123}
4124
4125fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4126    match message {
4127        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4128        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4129        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4130        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4131        AxumMessage::Close(frame) => {
4132            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4133                code: frame.code.into(),
4134                reason: frame.reason,
4135            }))
4136        }
4137    }
4138}
4139
4140fn notify_membership_updated(
4141    connection_pool: &mut ConnectionPool,
4142    result: MembershipUpdated,
4143    user_id: UserId,
4144    peer: &Peer,
4145) {
4146    for membership in &result.new_channels.channel_memberships {
4147        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4148    }
4149    for channel_id in &result.removed_channels {
4150        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4151    }
4152
4153    let user_channels_update = proto::UpdateUserChannels {
4154        channel_memberships: result
4155            .new_channels
4156            .channel_memberships
4157            .iter()
4158            .map(|cm| proto::ChannelMembership {
4159                channel_id: cm.channel_id.to_proto(),
4160                role: cm.role.into(),
4161            })
4162            .collect(),
4163        ..Default::default()
4164    };
4165
4166    let mut update = build_channels_update(result.new_channels);
4167    update.delete_channels = result
4168        .removed_channels
4169        .into_iter()
4170        .map(|id| id.to_proto())
4171        .collect();
4172    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4173
4174    for connection_id in connection_pool.user_connection_ids(user_id) {
4175        peer.send(connection_id, user_channels_update.clone())
4176            .trace_err();
4177        peer.send(connection_id, update.clone()).trace_err();
4178    }
4179}
4180
4181fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4182    proto::UpdateUserChannels {
4183        channel_memberships: channels
4184            .channel_memberships
4185            .iter()
4186            .map(|m| proto::ChannelMembership {
4187                channel_id: m.channel_id.to_proto(),
4188                role: m.role.into(),
4189            })
4190            .collect(),
4191        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4192        observed_channel_message_id: channels.observed_channel_messages.clone(),
4193    }
4194}
4195
4196fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4197    let mut update = proto::UpdateChannels::default();
4198
4199    for channel in channels.channels {
4200        update.channels.push(channel.to_proto());
4201    }
4202
4203    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4204    update.latest_channel_message_ids = channels.latest_channel_messages;
4205
4206    for (channel_id, participants) in channels.channel_participants {
4207        update
4208            .channel_participants
4209            .push(proto::ChannelParticipants {
4210                channel_id: channel_id.to_proto(),
4211                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4212            });
4213    }
4214
4215    for channel in channels.invited_channels {
4216        update.channel_invitations.push(channel.to_proto());
4217    }
4218
4219    update
4220}
4221
4222fn build_initial_contacts_update(
4223    contacts: Vec<db::Contact>,
4224    pool: &ConnectionPool,
4225) -> proto::UpdateContacts {
4226    let mut update = proto::UpdateContacts::default();
4227
4228    for contact in contacts {
4229        match contact {
4230            db::Contact::Accepted { user_id, busy } => {
4231                update.contacts.push(contact_for_user(user_id, busy, pool));
4232            }
4233            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4234            db::Contact::Incoming { user_id } => {
4235                update
4236                    .incoming_requests
4237                    .push(proto::IncomingContactRequest {
4238                        requester_id: user_id.to_proto(),
4239                    })
4240            }
4241        }
4242    }
4243
4244    update
4245}
4246
4247fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4248    proto::Contact {
4249        user_id: user_id.to_proto(),
4250        online: pool.is_user_online(user_id),
4251        busy,
4252    }
4253}
4254
4255fn room_updated(room: &proto::Room, peer: &Peer) {
4256    broadcast(
4257        None,
4258        room.participants
4259            .iter()
4260            .filter_map(|participant| Some(participant.peer_id?.into())),
4261        |peer_id| {
4262            peer.send(
4263                peer_id,
4264                proto::RoomUpdated {
4265                    room: Some(room.clone()),
4266                },
4267            )
4268        },
4269    );
4270}
4271
4272fn channel_updated(
4273    channel: &db::channel::Model,
4274    room: &proto::Room,
4275    peer: &Peer,
4276    pool: &ConnectionPool,
4277) {
4278    let participants = room
4279        .participants
4280        .iter()
4281        .map(|p| p.user_id)
4282        .collect::<Vec<_>>();
4283
4284    broadcast(
4285        None,
4286        pool.channel_connection_ids(channel.root_id())
4287            .filter_map(|(channel_id, role)| {
4288                role.can_see_channel(channel.visibility)
4289                    .then_some(channel_id)
4290            }),
4291        |peer_id| {
4292            peer.send(
4293                peer_id,
4294                proto::UpdateChannels {
4295                    channel_participants: vec![proto::ChannelParticipants {
4296                        channel_id: channel.id.to_proto(),
4297                        participant_user_ids: participants.clone(),
4298                    }],
4299                    ..Default::default()
4300                },
4301            )
4302        },
4303    );
4304}
4305
4306async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4307    let db = session.db().await;
4308
4309    let contacts = db.get_contacts(user_id).await?;
4310    let busy = db.is_user_busy(user_id).await?;
4311
4312    let pool = session.connection_pool().await;
4313    let updated_contact = contact_for_user(user_id, busy, &pool);
4314    for contact in contacts {
4315        if let db::Contact::Accepted {
4316            user_id: contact_user_id,
4317            ..
4318        } = contact
4319        {
4320            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4321                session
4322                    .peer
4323                    .send(
4324                        contact_conn_id,
4325                        proto::UpdateContacts {
4326                            contacts: vec![updated_contact.clone()],
4327                            remove_contacts: Default::default(),
4328                            incoming_requests: Default::default(),
4329                            remove_incoming_requests: Default::default(),
4330                            outgoing_requests: Default::default(),
4331                            remove_outgoing_requests: Default::default(),
4332                        },
4333                    )
4334                    .trace_err();
4335            }
4336        }
4337    }
4338    Ok(())
4339}
4340
4341async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4342    let mut contacts_to_update = HashSet::default();
4343
4344    let room_id;
4345    let canceled_calls_to_user_ids;
4346    let livekit_room;
4347    let delete_livekit_room;
4348    let room;
4349    let channel;
4350
4351    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4352        contacts_to_update.insert(session.user_id());
4353
4354        for project in left_room.left_projects.values() {
4355            project_left(project, session);
4356        }
4357
4358        room_id = RoomId::from_proto(left_room.room.id);
4359        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4360        livekit_room = mem::take(&mut left_room.room.livekit_room);
4361        delete_livekit_room = left_room.deleted;
4362        room = mem::take(&mut left_room.room);
4363        channel = mem::take(&mut left_room.channel);
4364
4365        room_updated(&room, &session.peer);
4366    } else {
4367        return Ok(());
4368    }
4369
4370    if let Some(channel) = channel {
4371        channel_updated(
4372            &channel,
4373            &room,
4374            &session.peer,
4375            &*session.connection_pool().await,
4376        );
4377    }
4378
4379    {
4380        let pool = session.connection_pool().await;
4381        for canceled_user_id in canceled_calls_to_user_ids {
4382            for connection_id in pool.user_connection_ids(canceled_user_id) {
4383                session
4384                    .peer
4385                    .send(
4386                        connection_id,
4387                        proto::CallCanceled {
4388                            room_id: room_id.to_proto(),
4389                        },
4390                    )
4391                    .trace_err();
4392            }
4393            contacts_to_update.insert(canceled_user_id);
4394        }
4395    }
4396
4397    for contact_user_id in contacts_to_update {
4398        update_user_contacts(contact_user_id, session).await?;
4399    }
4400
4401    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4402        live_kit
4403            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4404            .await
4405            .trace_err();
4406
4407        if delete_livekit_room {
4408            live_kit.delete_room(livekit_room).await.trace_err();
4409        }
4410    }
4411
4412    Ok(())
4413}
4414
4415async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4416    let left_channel_buffers = session
4417        .db()
4418        .await
4419        .leave_channel_buffers(session.connection_id)
4420        .await?;
4421
4422    for left_buffer in left_channel_buffers {
4423        channel_buffer_updated(
4424            session.connection_id,
4425            left_buffer.connections,
4426            &proto::UpdateChannelBufferCollaborators {
4427                channel_id: left_buffer.channel_id.to_proto(),
4428                collaborators: left_buffer.collaborators,
4429            },
4430            &session.peer,
4431        );
4432    }
4433
4434    Ok(())
4435}
4436
4437fn project_left(project: &db::LeftProject, session: &Session) {
4438    for connection_id in &project.connection_ids {
4439        if project.should_unshare {
4440            session
4441                .peer
4442                .send(
4443                    *connection_id,
4444                    proto::UnshareProject {
4445                        project_id: project.id.to_proto(),
4446                    },
4447                )
4448                .trace_err();
4449        } else {
4450            session
4451                .peer
4452                .send(
4453                    *connection_id,
4454                    proto::RemoveProjectCollaborator {
4455                        project_id: project.id.to_proto(),
4456                        peer_id: Some(session.connection_id.into()),
4457                    },
4458                )
4459                .trace_err();
4460        }
4461    }
4462}
4463
4464pub trait ResultExt {
4465    type Ok;
4466
4467    fn trace_err(self) -> Option<Self::Ok>;
4468}
4469
4470impl<T, E> ResultExt for Result<T, E>
4471where
4472    E: std::fmt::Debug,
4473{
4474    type Ok = T;
4475
4476    #[track_caller]
4477    fn trace_err(self) -> Option<T> {
4478        match self {
4479            Ok(value) => Some(value),
4480            Err(error) => {
4481                tracing::error!("{:?}", error);
4482                None
4483            }
4484        }
4485    }
4486}