rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   8use crate::{
   9    AppState, Error, Result, auth,
  10    db::{
  11        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  12        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  13        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  14        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  15    },
  16    executor::Executor,
  17};
  18use anyhow::{Context as _, anyhow, bail};
  19use async_tungstenite::tungstenite::{
  20    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  21};
  22use axum::{
  23    Extension, Router, TypedHeader,
  24    body::Body,
  25    extract::{
  26        ConnectInfo, WebSocketUpgrade,
  27        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  28    },
  29    headers::{Header, HeaderName},
  30    http::StatusCode,
  31    middleware,
  32    response::IntoResponse,
  33    routing::get,
  34};
  35use chrono::Utc;
  36use collections::{HashMap, HashSet};
  37pub use connection_pool::{ConnectionPool, ZedVersion};
  38use core::fmt::{self, Debug, Formatter};
  39use reqwest_client::ReqwestClient;
  40use rpc::proto::split_repository_update;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  45    stream::FuturesUnordered,
  46};
  47use prometheus::{IntGauge, register_int_gauge};
  48use rpc::{
  49    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        Arc, OnceLock,
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{Semaphore, watch};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    Instrument,
  76    field::{self},
  77    info_span, instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn update_span(&self, span: &tracing::Span) {
 114        match &self {
 115            Principal::User(user) => {
 116                span.record("user_id", user.id.0);
 117                span.record("login", &user.github_login);
 118            }
 119            Principal::Impersonated { user, admin } => {
 120                span.record("user_id", user.id.0);
 121                span.record("login", &user.github_login);
 122                span.record("impersonator", &admin.github_login);
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Clone)]
 129struct Session {
 130    principal: Principal,
 131    connection_id: ConnectionId,
 132    db: Arc<tokio::sync::Mutex<DbHandle>>,
 133    peer: Arc<Peer>,
 134    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 135    app_state: Arc<AppState>,
 136    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    fn user_id(&self) -> UserId {
 172        match &self.principal {
 173            Principal::User(user) => user.id,
 174            Principal::Impersonated { user, .. } => user.id,
 175        }
 176    }
 177
 178    pub fn email(&self) -> Option<String> {
 179        match &self.principal {
 180            Principal::User(user) => user.email_address.clone(),
 181            Principal::Impersonated { user, .. } => user.email_address.clone(),
 182        }
 183    }
 184}
 185
 186impl Debug for Session {
 187    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 188        let mut result = f.debug_struct("Session");
 189        match &self.principal {
 190            Principal::User(user) => {
 191                result.field("user", &user.github_login);
 192            }
 193            Principal::Impersonated { user, admin } => {
 194                result.field("user", &user.github_login);
 195                result.field("impersonator", &admin.github_login);
 196            }
 197        }
 198        result.field("connection_id", &self.connection_id).finish()
 199    }
 200}
 201
 202struct DbHandle(Arc<Database>);
 203
 204impl Deref for DbHandle {
 205    type Target = Database;
 206
 207    fn deref(&self) -> &Self::Target {
 208        self.0.as_ref()
 209    }
 210}
 211
 212pub struct Server {
 213    id: parking_lot::Mutex<ServerId>,
 214    peer: Arc<Peer>,
 215    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 216    app_state: Arc<AppState>,
 217    handlers: HashMap<TypeId, MessageHandler>,
 218    teardown: watch::Sender<bool>,
 219}
 220
 221pub(crate) struct ConnectionPoolGuard<'a> {
 222    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 223    _not_send: PhantomData<Rc<()>>,
 224}
 225
 226#[derive(Serialize)]
 227pub struct ServerSnapshot<'a> {
 228    peer: &'a Peer,
 229    #[serde(serialize_with = "serialize_deref")]
 230    connection_pool: ConnectionPoolGuard<'a>,
 231}
 232
 233pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 234where
 235    S: Serializer,
 236    T: Deref<Target = U>,
 237    U: Serialize,
 238{
 239    Serialize::serialize(value.deref(), serializer)
 240}
 241
 242impl Server {
 243    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 244        let mut server = Self {
 245            id: parking_lot::Mutex::new(id),
 246            peer: Peer::new(id.0 as u32),
 247            app_state: app_state.clone(),
 248            connection_pool: Default::default(),
 249            handlers: Default::default(),
 250            teardown: watch::channel(false).0,
 251        };
 252
 253        server
 254            .add_request_handler(ping)
 255            .add_request_handler(create_room)
 256            .add_request_handler(join_room)
 257            .add_request_handler(rejoin_room)
 258            .add_request_handler(leave_room)
 259            .add_request_handler(set_room_participant_role)
 260            .add_request_handler(call)
 261            .add_request_handler(cancel_call)
 262            .add_message_handler(decline_call)
 263            .add_request_handler(update_participant_location)
 264            .add_request_handler(share_project)
 265            .add_message_handler(unshare_project)
 266            .add_request_handler(join_project)
 267            .add_message_handler(leave_project)
 268            .add_request_handler(update_project)
 269            .add_request_handler(update_worktree)
 270            .add_request_handler(update_repository)
 271            .add_request_handler(remove_repository)
 272            .add_message_handler(start_language_server)
 273            .add_message_handler(update_language_server)
 274            .add_message_handler(update_diagnostic_summary)
 275            .add_message_handler(update_worktree_settings)
 276            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 277            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 278            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 279            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 280            .add_request_handler(forward_find_search_candidates_request)
 281            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 282            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 283            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 284            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 285            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 286            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 287            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 288            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 289            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 290            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 291            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 293            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 294            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 295            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 296            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 297            .add_request_handler(
 298                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 299            )
 300            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 301            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 303            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 304            .add_request_handler(
 305                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 306            )
 307            .add_request_handler(
 308                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 309            )
 310            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 311            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 312            .add_request_handler(
 313                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 314            )
 315            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 316            .add_request_handler(
 317                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 318            )
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 320            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 321            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 322            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 323            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 324            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 325            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 326            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 327            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 328            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 329            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 331            .add_request_handler(
 332                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 333            )
 334            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 335            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 336            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 337            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 338            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 339            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 340            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 341            .add_message_handler(create_buffer_for_peer)
 342            .add_request_handler(update_buffer)
 343            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 344            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 345            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 346            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 347            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 349            .add_request_handler(get_users)
 350            .add_request_handler(fuzzy_search_users)
 351            .add_request_handler(request_contact)
 352            .add_request_handler(remove_contact)
 353            .add_request_handler(respond_to_contact_request)
 354            .add_message_handler(subscribe_to_channels)
 355            .add_request_handler(create_channel)
 356            .add_request_handler(delete_channel)
 357            .add_request_handler(invite_channel_member)
 358            .add_request_handler(remove_channel_member)
 359            .add_request_handler(set_channel_member_role)
 360            .add_request_handler(set_channel_visibility)
 361            .add_request_handler(rename_channel)
 362            .add_request_handler(join_channel_buffer)
 363            .add_request_handler(leave_channel_buffer)
 364            .add_message_handler(update_channel_buffer)
 365            .add_request_handler(rejoin_channel_buffers)
 366            .add_request_handler(get_channel_members)
 367            .add_request_handler(respond_to_channel_invite)
 368            .add_request_handler(join_channel)
 369            .add_request_handler(join_channel_chat)
 370            .add_message_handler(leave_channel_chat)
 371            .add_request_handler(send_channel_message)
 372            .add_request_handler(remove_channel_message)
 373            .add_request_handler(update_channel_message)
 374            .add_request_handler(get_channel_messages)
 375            .add_request_handler(get_channel_messages_by_id)
 376            .add_request_handler(get_notifications)
 377            .add_request_handler(mark_notification_as_read)
 378            .add_request_handler(move_channel)
 379            .add_request_handler(follow)
 380            .add_message_handler(unfollow)
 381            .add_message_handler(update_followers)
 382            .add_request_handler(get_private_user_info)
 383            .add_request_handler(get_llm_api_token)
 384            .add_request_handler(accept_terms_of_service)
 385            .add_message_handler(acknowledge_channel_message)
 386            .add_message_handler(acknowledge_buffer_version)
 387            .add_request_handler(get_supermaven_api_key)
 388            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 389            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 390            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 391            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 392            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 393            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 394            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 395            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 396            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 397            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 398            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 399            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 400            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 401            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 403            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 404            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 405            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 406            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 407            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 408            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 409            .add_message_handler(update_context);
 410
 411        Arc::new(server)
 412    }
 413
 414    pub async fn start(&self) -> Result<()> {
 415        let server_id = *self.id.lock();
 416        let app_state = self.app_state.clone();
 417        let peer = self.peer.clone();
 418        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 419        let pool = self.connection_pool.clone();
 420        let livekit_client = self.app_state.livekit_client.clone();
 421
 422        let span = info_span!("start server");
 423        self.app_state.executor.spawn_detached(
 424            async move {
 425                tracing::info!("waiting for cleanup timeout");
 426                timeout.await;
 427                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 428                if let Some((room_ids, channel_ids)) = app_state
 429                    .db
 430                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 431                    .await
 432                    .trace_err()
 433                {
 434                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 435                    tracing::info!(
 436                        stale_channel_buffer_count = channel_ids.len(),
 437                        "retrieved stale channel buffers"
 438                    );
 439
 440                    for channel_id in channel_ids {
 441                        if let Some(refreshed_channel_buffer) = app_state
 442                            .db
 443                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 444                            .await
 445                            .trace_err()
 446                        {
 447                            for connection_id in refreshed_channel_buffer.connection_ids {
 448                                peer.send(
 449                                    connection_id,
 450                                    proto::UpdateChannelBufferCollaborators {
 451                                        channel_id: channel_id.to_proto(),
 452                                        collaborators: refreshed_channel_buffer
 453                                            .collaborators
 454                                            .clone(),
 455                                    },
 456                                )
 457                                .trace_err();
 458                            }
 459                        }
 460                    }
 461
 462                    for room_id in room_ids {
 463                        let mut contacts_to_update = HashSet::default();
 464                        let mut canceled_calls_to_user_ids = Vec::new();
 465                        let mut livekit_room = String::new();
 466                        let mut delete_livekit_room = false;
 467
 468                        if let Some(mut refreshed_room) = app_state
 469                            .db
 470                            .clear_stale_room_participants(room_id, server_id)
 471                            .await
 472                            .trace_err()
 473                        {
 474                            tracing::info!(
 475                                room_id = room_id.0,
 476                                new_participant_count = refreshed_room.room.participants.len(),
 477                                "refreshed room"
 478                            );
 479                            room_updated(&refreshed_room.room, &peer);
 480                            if let Some(channel) = refreshed_room.channel.as_ref() {
 481                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 482                            }
 483                            contacts_to_update
 484                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 485                            contacts_to_update
 486                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 487                            canceled_calls_to_user_ids =
 488                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 489                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 490                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 491                        }
 492
 493                        {
 494                            let pool = pool.lock();
 495                            for canceled_user_id in canceled_calls_to_user_ids {
 496                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 497                                    peer.send(
 498                                        connection_id,
 499                                        proto::CallCanceled {
 500                                            room_id: room_id.to_proto(),
 501                                        },
 502                                    )
 503                                    .trace_err();
 504                                }
 505                            }
 506                        }
 507
 508                        for user_id in contacts_to_update {
 509                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 510                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 511                            if let Some((busy, contacts)) = busy.zip(contacts) {
 512                                let pool = pool.lock();
 513                                let updated_contact = contact_for_user(user_id, busy, &pool);
 514                                for contact in contacts {
 515                                    if let db::Contact::Accepted {
 516                                        user_id: contact_user_id,
 517                                        ..
 518                                    } = contact
 519                                    {
 520                                        for contact_conn_id in
 521                                            pool.user_connection_ids(contact_user_id)
 522                                        {
 523                                            peer.send(
 524                                                contact_conn_id,
 525                                                proto::UpdateContacts {
 526                                                    contacts: vec![updated_contact.clone()],
 527                                                    remove_contacts: Default::default(),
 528                                                    incoming_requests: Default::default(),
 529                                                    remove_incoming_requests: Default::default(),
 530                                                    outgoing_requests: Default::default(),
 531                                                    remove_outgoing_requests: Default::default(),
 532                                                },
 533                                            )
 534                                            .trace_err();
 535                                        }
 536                                    }
 537                                }
 538                            }
 539                        }
 540
 541                        if let Some(live_kit) = livekit_client.as_ref() {
 542                            if delete_livekit_room {
 543                                live_kit.delete_room(livekit_room).await.trace_err();
 544                            }
 545                        }
 546                    }
 547                }
 548
 549                app_state
 550                    .db
 551                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 552                    .await
 553                    .trace_err();
 554            }
 555            .instrument(span),
 556        );
 557        Ok(())
 558    }
 559
 560    pub fn teardown(&self) {
 561        self.peer.teardown();
 562        self.connection_pool.lock().reset();
 563        let _ = self.teardown.send(true);
 564    }
 565
 566    #[cfg(test)]
 567    pub fn reset(&self, id: ServerId) {
 568        self.teardown();
 569        *self.id.lock() = id;
 570        self.peer.reset(id.0 as u32);
 571        let _ = self.teardown.send(false);
 572    }
 573
 574    #[cfg(test)]
 575    pub fn id(&self) -> ServerId {
 576        *self.id.lock()
 577    }
 578
 579    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 580    where
 581        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 582        Fut: 'static + Send + Future<Output = Result<()>>,
 583        M: EnvelopedMessage,
 584    {
 585        let prev_handler = self.handlers.insert(
 586            TypeId::of::<M>(),
 587            Box::new(move |envelope, session| {
 588                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 589                let received_at = envelope.received_at;
 590                tracing::info!("message received");
 591                let start_time = Instant::now();
 592                let future = (handler)(*envelope, session);
 593                async move {
 594                    let result = future.await;
 595                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 596                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 597                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 598                    let payload_type = M::NAME;
 599
 600                    match result {
 601                        Err(error) => {
 602                            tracing::error!(
 603                                ?error,
 604                                total_duration_ms,
 605                                processing_duration_ms,
 606                                queue_duration_ms,
 607                                payload_type,
 608                                "error handling message"
 609                            )
 610                        }
 611                        Ok(()) => tracing::info!(
 612                            total_duration_ms,
 613                            processing_duration_ms,
 614                            queue_duration_ms,
 615                            "finished handling message"
 616                        ),
 617                    }
 618                }
 619                .boxed()
 620            }),
 621        );
 622        if prev_handler.is_some() {
 623            panic!("registered a handler for the same message twice");
 624        }
 625        self
 626    }
 627
 628    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 629    where
 630        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 631        Fut: 'static + Send + Future<Output = Result<()>>,
 632        M: EnvelopedMessage,
 633    {
 634        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 635        self
 636    }
 637
 638    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 639    where
 640        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 641        Fut: Send + Future<Output = Result<()>>,
 642        M: RequestMessage,
 643    {
 644        let handler = Arc::new(handler);
 645        self.add_handler(move |envelope, session| {
 646            let receipt = envelope.receipt();
 647            let handler = handler.clone();
 648            async move {
 649                let peer = session.peer.clone();
 650                let responded = Arc::new(AtomicBool::default());
 651                let response = Response {
 652                    peer: peer.clone(),
 653                    responded: responded.clone(),
 654                    receipt,
 655                };
 656                match (handler)(envelope.payload, response, session).await {
 657                    Ok(()) => {
 658                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 659                            Ok(())
 660                        } else {
 661                            Err(anyhow!("handler did not send a response"))?
 662                        }
 663                    }
 664                    Err(error) => {
 665                        let proto_err = match &error {
 666                            Error::Internal(err) => err.to_proto(),
 667                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 668                        };
 669                        peer.respond_with_error(receipt, proto_err)?;
 670                        Err(error)
 671                    }
 672                }
 673            }
 674        })
 675    }
 676
 677    pub fn handle_connection(
 678        self: &Arc<Self>,
 679        connection: Connection,
 680        address: String,
 681        principal: Principal,
 682        zed_version: ZedVersion,
 683        geoip_country_code: Option<String>,
 684        system_id: Option<String>,
 685        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 686        executor: Executor,
 687    ) -> impl Future<Output = ()> + use<> {
 688        let this = self.clone();
 689        let span = info_span!("handle connection", %address,
 690            connection_id=field::Empty,
 691            user_id=field::Empty,
 692            login=field::Empty,
 693            impersonator=field::Empty,
 694            geoip_country_code=field::Empty
 695        );
 696        principal.update_span(&span);
 697        if let Some(country_code) = geoip_country_code.as_ref() {
 698            span.record("geoip_country_code", country_code);
 699        }
 700
 701        let mut teardown = self.teardown.subscribe();
 702        async move {
 703            if *teardown.borrow() {
 704                tracing::error!("server is tearing down");
 705                return
 706            }
 707            let (connection_id, handle_io, mut incoming_rx) = this
 708                .peer
 709                .add_connection(connection, {
 710                    let executor = executor.clone();
 711                    move |duration| executor.sleep(duration)
 712                });
 713            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 714
 715            tracing::info!("connection opened");
 716
 717            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 718            let http_client = match ReqwestClient::user_agent(&user_agent) {
 719                Ok(http_client) => Arc::new(http_client),
 720                Err(error) => {
 721                    tracing::error!(?error, "failed to create HTTP client");
 722                    return;
 723                }
 724            };
 725
 726            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 727                    supermaven_admin_api_key.to_string(),
 728                    http_client.clone(),
 729                )));
 730
 731            let session = Session {
 732                principal: principal.clone(),
 733                connection_id,
 734                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 735                peer: this.peer.clone(),
 736                connection_pool: this.connection_pool.clone(),
 737                app_state: this.app_state.clone(),
 738                geoip_country_code,
 739                system_id,
 740                _executor: executor.clone(),
 741                supermaven_client,
 742            };
 743
 744            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 745                tracing::error!(?error, "failed to send initial client update");
 746                return;
 747            }
 748
 749            let handle_io = handle_io.fuse();
 750            futures::pin_mut!(handle_io);
 751
 752            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 753            // This prevents deadlocks when e.g., client A performs a request to client B and
 754            // client B performs a request to client A. If both clients stop processing further
 755            // messages until their respective request completes, they won't have a chance to
 756            // respond to the other client's request and cause a deadlock.
 757            //
 758            // This arrangement ensures we will attempt to process earlier messages first, but fall
 759            // back to processing messages arrived later in the spirit of making progress.
 760            let mut foreground_message_handlers = FuturesUnordered::new();
 761            let concurrent_handlers = Arc::new(Semaphore::new(256));
 762            loop {
 763                let next_message = async {
 764                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 765                    let message = incoming_rx.next().await;
 766                    (permit, message)
 767                }.fuse();
 768                futures::pin_mut!(next_message);
 769                futures::select_biased! {
 770                    _ = teardown.changed().fuse() => return,
 771                    result = handle_io => {
 772                        if let Err(error) = result {
 773                            tracing::error!(?error, "error handling I/O");
 774                        }
 775                        break;
 776                    }
 777                    _ = foreground_message_handlers.next() => {}
 778                    next_message = next_message => {
 779                        let (permit, message) = next_message;
 780                        if let Some(message) = message {
 781                            let type_name = message.payload_type_name();
 782                            // note: we copy all the fields from the parent span so we can query them in the logs.
 783                            // (https://github.com/tokio-rs/tracing/issues/2670).
 784                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 785                                user_id=field::Empty,
 786                                login=field::Empty,
 787                                impersonator=field::Empty,
 788                            );
 789                            principal.update_span(&span);
 790                            let span_enter = span.enter();
 791                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 792                                let is_background = message.is_background();
 793                                let handle_message = (handler)(message, session.clone());
 794                                drop(span_enter);
 795
 796                                let handle_message = async move {
 797                                    handle_message.await;
 798                                    drop(permit);
 799                                }.instrument(span);
 800                                if is_background {
 801                                    executor.spawn_detached(handle_message);
 802                                } else {
 803                                    foreground_message_handlers.push(handle_message);
 804                                }
 805                            } else {
 806                                tracing::error!("no message handler");
 807                            }
 808                        } else {
 809                            tracing::info!("connection closed");
 810                            break;
 811                        }
 812                    }
 813                }
 814            }
 815
 816            drop(foreground_message_handlers);
 817            tracing::info!("signing out");
 818            if let Err(error) = connection_lost(session, teardown, executor).await {
 819                tracing::error!(?error, "error signing out");
 820            }
 821
 822        }.instrument(span)
 823    }
 824
 825    async fn send_initial_client_update(
 826        &self,
 827        connection_id: ConnectionId,
 828        principal: &Principal,
 829        zed_version: ZedVersion,
 830        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 831        session: &Session,
 832    ) -> Result<()> {
 833        self.peer.send(
 834            connection_id,
 835            proto::Hello {
 836                peer_id: Some(connection_id.into()),
 837            },
 838        )?;
 839        tracing::info!("sent hello message");
 840        if let Some(send_connection_id) = send_connection_id.take() {
 841            let _ = send_connection_id.send(connection_id);
 842        }
 843
 844        match principal {
 845            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 846                if !user.connected_once {
 847                    self.peer.send(connection_id, proto::ShowContacts {})?;
 848                    self.app_state
 849                        .db
 850                        .set_user_connected_once(user.id, true)
 851                        .await?;
 852                }
 853
 854                update_user_plan(user.id, session).await?;
 855
 856                let contacts = self.app_state.db.get_contacts(user.id).await?;
 857
 858                {
 859                    let mut pool = self.connection_pool.lock();
 860                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 861                    self.peer.send(
 862                        connection_id,
 863                        build_initial_contacts_update(contacts, &pool),
 864                    )?;
 865                }
 866
 867                if should_auto_subscribe_to_channels(zed_version) {
 868                    subscribe_user_to_channels(user.id, session).await?;
 869                }
 870
 871                if let Some(incoming_call) =
 872                    self.app_state.db.incoming_call_for_user(user.id).await?
 873                {
 874                    self.peer.send(connection_id, incoming_call)?;
 875                }
 876
 877                update_user_contacts(user.id, session).await?;
 878            }
 879        }
 880
 881        Ok(())
 882    }
 883
 884    pub async fn invite_code_redeemed(
 885        self: &Arc<Self>,
 886        inviter_id: UserId,
 887        invitee_id: UserId,
 888    ) -> Result<()> {
 889        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 890            if let Some(code) = &user.invite_code {
 891                let pool = self.connection_pool.lock();
 892                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 893                for connection_id in pool.user_connection_ids(inviter_id) {
 894                    self.peer.send(
 895                        connection_id,
 896                        proto::UpdateContacts {
 897                            contacts: vec![invitee_contact.clone()],
 898                            ..Default::default()
 899                        },
 900                    )?;
 901                    self.peer.send(
 902                        connection_id,
 903                        proto::UpdateInviteInfo {
 904                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 905                            count: user.invite_count as u32,
 906                        },
 907                    )?;
 908                }
 909            }
 910        }
 911        Ok(())
 912    }
 913
 914    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 915        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 916            if let Some(invite_code) = &user.invite_code {
 917                let pool = self.connection_pool.lock();
 918                for connection_id in pool.user_connection_ids(user_id) {
 919                    self.peer.send(
 920                        connection_id,
 921                        proto::UpdateInviteInfo {
 922                            url: format!(
 923                                "{}{}",
 924                                self.app_state.config.invite_link_prefix, invite_code
 925                            ),
 926                            count: user.invite_count as u32,
 927                        },
 928                    )?;
 929                }
 930            }
 931        }
 932        Ok(())
 933    }
 934
 935    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 936        let user = self
 937            .app_state
 938            .db
 939            .get_user_by_id(user_id)
 940            .await?
 941            .context("user not found")?;
 942
 943        let update_user_plan = make_update_user_plan_message(
 944            &self.app_state.db,
 945            self.app_state.llm_db.clone(),
 946            user_id,
 947            user.admin,
 948        )
 949        .await?;
 950
 951        let pool = self.connection_pool.lock();
 952        for connection_id in pool.user_connection_ids(user_id) {
 953            self.peer
 954                .send(connection_id, update_user_plan.clone())
 955                .trace_err();
 956        }
 957
 958        Ok(())
 959    }
 960
 961    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 962        let pool = self.connection_pool.lock();
 963        for connection_id in pool.user_connection_ids(user_id) {
 964            self.peer
 965                .send(connection_id, proto::RefreshLlmToken {})
 966                .trace_err();
 967        }
 968    }
 969
 970    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 971        ServerSnapshot {
 972            connection_pool: ConnectionPoolGuard {
 973                guard: self.connection_pool.lock(),
 974                _not_send: PhantomData,
 975            },
 976            peer: &self.peer,
 977        }
 978    }
 979}
 980
 981impl Deref for ConnectionPoolGuard<'_> {
 982    type Target = ConnectionPool;
 983
 984    fn deref(&self) -> &Self::Target {
 985        &self.guard
 986    }
 987}
 988
 989impl DerefMut for ConnectionPoolGuard<'_> {
 990    fn deref_mut(&mut self) -> &mut Self::Target {
 991        &mut self.guard
 992    }
 993}
 994
 995impl Drop for ConnectionPoolGuard<'_> {
 996    fn drop(&mut self) {
 997        #[cfg(test)]
 998        self.check_invariants();
 999    }
1000}
1001
1002fn broadcast<F>(
1003    sender_id: Option<ConnectionId>,
1004    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1005    mut f: F,
1006) where
1007    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1008{
1009    for receiver_id in receiver_ids {
1010        if Some(receiver_id) != sender_id {
1011            if let Err(error) = f(receiver_id) {
1012                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1013            }
1014        }
1015    }
1016}
1017
1018pub struct ProtocolVersion(u32);
1019
1020impl Header for ProtocolVersion {
1021    fn name() -> &'static HeaderName {
1022        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1023        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1024    }
1025
1026    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1027    where
1028        Self: Sized,
1029        I: Iterator<Item = &'i axum::http::HeaderValue>,
1030    {
1031        let version = values
1032            .next()
1033            .ok_or_else(axum::headers::Error::invalid)?
1034            .to_str()
1035            .map_err(|_| axum::headers::Error::invalid())?
1036            .parse()
1037            .map_err(|_| axum::headers::Error::invalid())?;
1038        Ok(Self(version))
1039    }
1040
1041    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1042        values.extend([self.0.to_string().parse().unwrap()]);
1043    }
1044}
1045
1046pub struct AppVersionHeader(SemanticVersion);
1047impl Header for AppVersionHeader {
1048    fn name() -> &'static HeaderName {
1049        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1050        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1051    }
1052
1053    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1054    where
1055        Self: Sized,
1056        I: Iterator<Item = &'i axum::http::HeaderValue>,
1057    {
1058        let version = values
1059            .next()
1060            .ok_or_else(axum::headers::Error::invalid)?
1061            .to_str()
1062            .map_err(|_| axum::headers::Error::invalid())?
1063            .parse()
1064            .map_err(|_| axum::headers::Error::invalid())?;
1065        Ok(Self(version))
1066    }
1067
1068    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1069        values.extend([self.0.to_string().parse().unwrap()]);
1070    }
1071}
1072
1073pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1074    Router::new()
1075        .route("/rpc", get(handle_websocket_request))
1076        .layer(
1077            ServiceBuilder::new()
1078                .layer(Extension(server.app_state.clone()))
1079                .layer(middleware::from_fn(auth::validate_header)),
1080        )
1081        .route("/metrics", get(handle_metrics))
1082        .layer(Extension(server))
1083}
1084
1085pub async fn handle_websocket_request(
1086    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089    Extension(server): Extension<Arc<Server>>,
1090    Extension(principal): Extension<Principal>,
1091    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1092    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1093    ws: WebSocketUpgrade,
1094) -> axum::response::Response {
1095    if protocol_version != rpc::PROTOCOL_VERSION {
1096        return (
1097            StatusCode::UPGRADE_REQUIRED,
1098            "client must be upgraded".to_string(),
1099        )
1100            .into_response();
1101    }
1102
1103    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1104        return (
1105            StatusCode::UPGRADE_REQUIRED,
1106            "no version header found".to_string(),
1107        )
1108            .into_response();
1109    };
1110
1111    if !version.can_collaborate() {
1112        return (
1113            StatusCode::UPGRADE_REQUIRED,
1114            "client must be upgraded".to_string(),
1115        )
1116            .into_response();
1117    }
1118
1119    let socket_address = socket_address.to_string();
1120    ws.on_upgrade(move |socket| {
1121        let socket = socket
1122            .map_ok(to_tungstenite_message)
1123            .err_into()
1124            .with(|message| async move { to_axum_message(message) });
1125        let connection = Connection::new(Box::pin(socket));
1126        async move {
1127            server
1128                .handle_connection(
1129                    connection,
1130                    socket_address,
1131                    principal,
1132                    version,
1133                    country_code_header.map(|header| header.to_string()),
1134                    system_id_header.map(|header| header.to_string()),
1135                    None,
1136                    Executor::Production,
1137                )
1138                .await;
1139        }
1140    })
1141}
1142
1143pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1144    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145    let connections_metric = CONNECTIONS_METRIC
1146        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1147
1148    let connections = server
1149        .connection_pool
1150        .lock()
1151        .connections()
1152        .filter(|connection| !connection.admin)
1153        .count();
1154    connections_metric.set(connections as _);
1155
1156    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1157    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1158        register_int_gauge!(
1159            "shared_projects",
1160            "number of open projects with one or more guests"
1161        )
1162        .unwrap()
1163    });
1164
1165    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1166    shared_projects_metric.set(shared_projects as _);
1167
1168    let encoder = prometheus::TextEncoder::new();
1169    let metric_families = prometheus::gather();
1170    let encoded_metrics = encoder
1171        .encode_to_string(&metric_families)
1172        .map_err(|err| anyhow!("{err}"))?;
1173    Ok(encoded_metrics)
1174}
1175
1176#[instrument(err, skip(executor))]
1177async fn connection_lost(
1178    session: Session,
1179    mut teardown: watch::Receiver<bool>,
1180    executor: Executor,
1181) -> Result<()> {
1182    session.peer.disconnect(session.connection_id);
1183    session
1184        .connection_pool()
1185        .await
1186        .remove_connection(session.connection_id)?;
1187
1188    session
1189        .db()
1190        .await
1191        .connection_lost(session.connection_id)
1192        .await
1193        .trace_err();
1194
1195    futures::select_biased! {
1196        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1197
1198            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1199            leave_room_for_session(&session, session.connection_id).await.trace_err();
1200            leave_channel_buffers_for_session(&session)
1201                .await
1202                .trace_err();
1203
1204            if !session
1205                .connection_pool()
1206                .await
1207                .is_user_online(session.user_id())
1208            {
1209                let db = session.db().await;
1210                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1211                    room_updated(&room, &session.peer);
1212                }
1213            }
1214
1215            update_user_contacts(session.user_id(), &session).await?;
1216        },
1217        _ = teardown.changed().fuse() => {}
1218    }
1219
1220    Ok(())
1221}
1222
1223/// Acknowledges a ping from a client, used to keep the connection alive.
1224async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1225    response.send(proto::Ack {})?;
1226    Ok(())
1227}
1228
1229/// Creates a new room for calling (outside of channels)
1230async fn create_room(
1231    _request: proto::CreateRoom,
1232    response: Response<proto::CreateRoom>,
1233    session: Session,
1234) -> Result<()> {
1235    let livekit_room = nanoid::nanoid!(30);
1236
1237    let live_kit_connection_info = util::maybe!(async {
1238        let live_kit = session.app_state.livekit_client.as_ref();
1239        let live_kit = live_kit?;
1240        let user_id = session.user_id().to_string();
1241
1242        let token = live_kit
1243            .room_token(&livekit_room, &user_id.to_string())
1244            .trace_err()?;
1245
1246        Some(proto::LiveKitConnectionInfo {
1247            server_url: live_kit.url().into(),
1248            token,
1249            can_publish: true,
1250        })
1251    })
1252    .await;
1253
1254    let room = session
1255        .db()
1256        .await
1257        .create_room(session.user_id(), session.connection_id, &livekit_room)
1258        .await?;
1259
1260    response.send(proto::CreateRoomResponse {
1261        room: Some(room.clone()),
1262        live_kit_connection_info,
1263    })?;
1264
1265    update_user_contacts(session.user_id(), &session).await?;
1266    Ok(())
1267}
1268
1269/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1270async fn join_room(
1271    request: proto::JoinRoom,
1272    response: Response<proto::JoinRoom>,
1273    session: Session,
1274) -> Result<()> {
1275    let room_id = RoomId::from_proto(request.id);
1276
1277    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1278
1279    if let Some(channel_id) = channel_id {
1280        return join_channel_internal(channel_id, Box::new(response), session).await;
1281    }
1282
1283    let joined_room = {
1284        let room = session
1285            .db()
1286            .await
1287            .join_room(room_id, session.user_id(), session.connection_id)
1288            .await?;
1289        room_updated(&room.room, &session.peer);
1290        room.into_inner()
1291    };
1292
1293    for connection_id in session
1294        .connection_pool()
1295        .await
1296        .user_connection_ids(session.user_id())
1297    {
1298        session
1299            .peer
1300            .send(
1301                connection_id,
1302                proto::CallCanceled {
1303                    room_id: room_id.to_proto(),
1304                },
1305            )
1306            .trace_err();
1307    }
1308
1309    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1310    {
1311        live_kit
1312            .room_token(
1313                &joined_room.room.livekit_room,
1314                &session.user_id().to_string(),
1315            )
1316            .trace_err()
1317            .map(|token| proto::LiveKitConnectionInfo {
1318                server_url: live_kit.url().into(),
1319                token,
1320                can_publish: true,
1321            })
1322    } else {
1323        None
1324    };
1325
1326    response.send(proto::JoinRoomResponse {
1327        room: Some(joined_room.room),
1328        channel_id: None,
1329        live_kit_connection_info,
1330    })?;
1331
1332    update_user_contacts(session.user_id(), &session).await?;
1333    Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338    request: proto::RejoinRoom,
1339    response: Response<proto::RejoinRoom>,
1340    session: Session,
1341) -> Result<()> {
1342    let room;
1343    let channel;
1344    {
1345        let mut rejoined_room = session
1346            .db()
1347            .await
1348            .rejoin_room(request, session.user_id(), session.connection_id)
1349            .await?;
1350
1351        response.send(proto::RejoinRoomResponse {
1352            room: Some(rejoined_room.room.clone()),
1353            reshared_projects: rejoined_room
1354                .reshared_projects
1355                .iter()
1356                .map(|project| proto::ResharedProject {
1357                    id: project.id.to_proto(),
1358                    collaborators: project
1359                        .collaborators
1360                        .iter()
1361                        .map(|collaborator| collaborator.to_proto())
1362                        .collect(),
1363                })
1364                .collect(),
1365            rejoined_projects: rejoined_room
1366                .rejoined_projects
1367                .iter()
1368                .map(|rejoined_project| rejoined_project.to_proto())
1369                .collect(),
1370        })?;
1371        room_updated(&rejoined_room.room, &session.peer);
1372
1373        for project in &rejoined_room.reshared_projects {
1374            for collaborator in &project.collaborators {
1375                session
1376                    .peer
1377                    .send(
1378                        collaborator.connection_id,
1379                        proto::UpdateProjectCollaborator {
1380                            project_id: project.id.to_proto(),
1381                            old_peer_id: Some(project.old_connection_id.into()),
1382                            new_peer_id: Some(session.connection_id.into()),
1383                        },
1384                    )
1385                    .trace_err();
1386            }
1387
1388            broadcast(
1389                Some(session.connection_id),
1390                project
1391                    .collaborators
1392                    .iter()
1393                    .map(|collaborator| collaborator.connection_id),
1394                |connection_id| {
1395                    session.peer.forward_send(
1396                        session.connection_id,
1397                        connection_id,
1398                        proto::UpdateProject {
1399                            project_id: project.id.to_proto(),
1400                            worktrees: project.worktrees.clone(),
1401                        },
1402                    )
1403                },
1404            );
1405        }
1406
1407        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1408
1409        let rejoined_room = rejoined_room.into_inner();
1410
1411        room = rejoined_room.room;
1412        channel = rejoined_room.channel;
1413    }
1414
1415    if let Some(channel) = channel {
1416        channel_updated(
1417            &channel,
1418            &room,
1419            &session.peer,
1420            &*session.connection_pool().await,
1421        );
1422    }
1423
1424    update_user_contacts(session.user_id(), &session).await?;
1425    Ok(())
1426}
1427
1428fn notify_rejoined_projects(
1429    rejoined_projects: &mut Vec<RejoinedProject>,
1430    session: &Session,
1431) -> Result<()> {
1432    for project in rejoined_projects.iter() {
1433        for collaborator in &project.collaborators {
1434            session
1435                .peer
1436                .send(
1437                    collaborator.connection_id,
1438                    proto::UpdateProjectCollaborator {
1439                        project_id: project.id.to_proto(),
1440                        old_peer_id: Some(project.old_connection_id.into()),
1441                        new_peer_id: Some(session.connection_id.into()),
1442                    },
1443                )
1444                .trace_err();
1445        }
1446    }
1447
1448    for project in rejoined_projects {
1449        for worktree in mem::take(&mut project.worktrees) {
1450            // Stream this worktree's entries.
1451            let message = proto::UpdateWorktree {
1452                project_id: project.id.to_proto(),
1453                worktree_id: worktree.id,
1454                abs_path: worktree.abs_path.clone(),
1455                root_name: worktree.root_name,
1456                updated_entries: worktree.updated_entries,
1457                removed_entries: worktree.removed_entries,
1458                scan_id: worktree.scan_id,
1459                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1460                updated_repositories: worktree.updated_repositories,
1461                removed_repositories: worktree.removed_repositories,
1462            };
1463            for update in proto::split_worktree_update(message) {
1464                session.peer.send(session.connection_id, update)?;
1465            }
1466
1467            // Stream this worktree's diagnostics.
1468            for summary in worktree.diagnostic_summaries {
1469                session.peer.send(
1470                    session.connection_id,
1471                    proto::UpdateDiagnosticSummary {
1472                        project_id: project.id.to_proto(),
1473                        worktree_id: worktree.id,
1474                        summary: Some(summary),
1475                    },
1476                )?;
1477            }
1478
1479            for settings_file in worktree.settings_files {
1480                session.peer.send(
1481                    session.connection_id,
1482                    proto::UpdateWorktreeSettings {
1483                        project_id: project.id.to_proto(),
1484                        worktree_id: worktree.id,
1485                        path: settings_file.path,
1486                        content: Some(settings_file.content),
1487                        kind: Some(settings_file.kind.to_proto().into()),
1488                    },
1489                )?;
1490            }
1491        }
1492
1493        for repository in mem::take(&mut project.updated_repositories) {
1494            for update in split_repository_update(repository) {
1495                session.peer.send(session.connection_id, update)?;
1496            }
1497        }
1498
1499        for id in mem::take(&mut project.removed_repositories) {
1500            session.peer.send(
1501                session.connection_id,
1502                proto::RemoveRepository {
1503                    project_id: project.id.to_proto(),
1504                    id,
1505                },
1506            )?;
1507        }
1508    }
1509
1510    Ok(())
1511}
1512
1513/// leave room disconnects from the room.
1514async fn leave_room(
1515    _: proto::LeaveRoom,
1516    response: Response<proto::LeaveRoom>,
1517    session: Session,
1518) -> Result<()> {
1519    leave_room_for_session(&session, session.connection_id).await?;
1520    response.send(proto::Ack {})?;
1521    Ok(())
1522}
1523
1524/// Updates the permissions of someone else in the room.
1525async fn set_room_participant_role(
1526    request: proto::SetRoomParticipantRole,
1527    response: Response<proto::SetRoomParticipantRole>,
1528    session: Session,
1529) -> Result<()> {
1530    let user_id = UserId::from_proto(request.user_id);
1531    let role = ChannelRole::from(request.role());
1532
1533    let (livekit_room, can_publish) = {
1534        let room = session
1535            .db()
1536            .await
1537            .set_room_participant_role(
1538                session.user_id(),
1539                RoomId::from_proto(request.room_id),
1540                user_id,
1541                role,
1542            )
1543            .await?;
1544
1545        let livekit_room = room.livekit_room.clone();
1546        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1547        room_updated(&room, &session.peer);
1548        (livekit_room, can_publish)
1549    };
1550
1551    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1552        live_kit
1553            .update_participant(
1554                livekit_room.clone(),
1555                request.user_id.to_string(),
1556                livekit_api::proto::ParticipantPermission {
1557                    can_subscribe: true,
1558                    can_publish,
1559                    can_publish_data: can_publish,
1560                    hidden: false,
1561                    recorder: false,
1562                },
1563            )
1564            .await
1565            .trace_err();
1566    }
1567
1568    response.send(proto::Ack {})?;
1569    Ok(())
1570}
1571
1572/// Call someone else into the current room
1573async fn call(
1574    request: proto::Call,
1575    response: Response<proto::Call>,
1576    session: Session,
1577) -> Result<()> {
1578    let room_id = RoomId::from_proto(request.room_id);
1579    let calling_user_id = session.user_id();
1580    let calling_connection_id = session.connection_id;
1581    let called_user_id = UserId::from_proto(request.called_user_id);
1582    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1583    if !session
1584        .db()
1585        .await
1586        .has_contact(calling_user_id, called_user_id)
1587        .await?
1588    {
1589        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1590    }
1591
1592    let incoming_call = {
1593        let (room, incoming_call) = &mut *session
1594            .db()
1595            .await
1596            .call(
1597                room_id,
1598                calling_user_id,
1599                calling_connection_id,
1600                called_user_id,
1601                initial_project_id,
1602            )
1603            .await?;
1604        room_updated(room, &session.peer);
1605        mem::take(incoming_call)
1606    };
1607    update_user_contacts(called_user_id, &session).await?;
1608
1609    let mut calls = session
1610        .connection_pool()
1611        .await
1612        .user_connection_ids(called_user_id)
1613        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1614        .collect::<FuturesUnordered<_>>();
1615
1616    while let Some(call_response) = calls.next().await {
1617        match call_response.as_ref() {
1618            Ok(_) => {
1619                response.send(proto::Ack {})?;
1620                return Ok(());
1621            }
1622            Err(_) => {
1623                call_response.trace_err();
1624            }
1625        }
1626    }
1627
1628    {
1629        let room = session
1630            .db()
1631            .await
1632            .call_failed(room_id, called_user_id)
1633            .await?;
1634        room_updated(&room, &session.peer);
1635    }
1636    update_user_contacts(called_user_id, &session).await?;
1637
1638    Err(anyhow!("failed to ring user"))?
1639}
1640
1641/// Cancel an outgoing call.
1642async fn cancel_call(
1643    request: proto::CancelCall,
1644    response: Response<proto::CancelCall>,
1645    session: Session,
1646) -> Result<()> {
1647    let called_user_id = UserId::from_proto(request.called_user_id);
1648    let room_id = RoomId::from_proto(request.room_id);
1649    {
1650        let room = session
1651            .db()
1652            .await
1653            .cancel_call(room_id, session.connection_id, called_user_id)
1654            .await?;
1655        room_updated(&room, &session.peer);
1656    }
1657
1658    for connection_id in session
1659        .connection_pool()
1660        .await
1661        .user_connection_ids(called_user_id)
1662    {
1663        session
1664            .peer
1665            .send(
1666                connection_id,
1667                proto::CallCanceled {
1668                    room_id: room_id.to_proto(),
1669                },
1670            )
1671            .trace_err();
1672    }
1673    response.send(proto::Ack {})?;
1674
1675    update_user_contacts(called_user_id, &session).await?;
1676    Ok(())
1677}
1678
1679/// Decline an incoming call.
1680async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1681    let room_id = RoomId::from_proto(message.room_id);
1682    {
1683        let room = session
1684            .db()
1685            .await
1686            .decline_call(Some(room_id), session.user_id())
1687            .await?
1688            .context("declining call")?;
1689        room_updated(&room, &session.peer);
1690    }
1691
1692    for connection_id in session
1693        .connection_pool()
1694        .await
1695        .user_connection_ids(session.user_id())
1696    {
1697        session
1698            .peer
1699            .send(
1700                connection_id,
1701                proto::CallCanceled {
1702                    room_id: room_id.to_proto(),
1703                },
1704            )
1705            .trace_err();
1706    }
1707    update_user_contacts(session.user_id(), &session).await?;
1708    Ok(())
1709}
1710
1711/// Updates other participants in the room with your current location.
1712async fn update_participant_location(
1713    request: proto::UpdateParticipantLocation,
1714    response: Response<proto::UpdateParticipantLocation>,
1715    session: Session,
1716) -> Result<()> {
1717    let room_id = RoomId::from_proto(request.room_id);
1718    let location = request.location.context("invalid location")?;
1719
1720    let db = session.db().await;
1721    let room = db
1722        .update_room_participant_location(room_id, session.connection_id, location)
1723        .await?;
1724
1725    room_updated(&room, &session.peer);
1726    response.send(proto::Ack {})?;
1727    Ok(())
1728}
1729
1730/// Share a project into the room.
1731async fn share_project(
1732    request: proto::ShareProject,
1733    response: Response<proto::ShareProject>,
1734    session: Session,
1735) -> Result<()> {
1736    let (project_id, room) = &*session
1737        .db()
1738        .await
1739        .share_project(
1740            RoomId::from_proto(request.room_id),
1741            session.connection_id,
1742            &request.worktrees,
1743            request.is_ssh_project,
1744        )
1745        .await?;
1746    response.send(proto::ShareProjectResponse {
1747        project_id: project_id.to_proto(),
1748    })?;
1749    room_updated(room, &session.peer);
1750
1751    Ok(())
1752}
1753
1754/// Unshare a project from the room.
1755async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1756    let project_id = ProjectId::from_proto(message.project_id);
1757    unshare_project_internal(project_id, session.connection_id, &session).await
1758}
1759
1760async fn unshare_project_internal(
1761    project_id: ProjectId,
1762    connection_id: ConnectionId,
1763    session: &Session,
1764) -> Result<()> {
1765    let delete = {
1766        let room_guard = session
1767            .db()
1768            .await
1769            .unshare_project(project_id, connection_id)
1770            .await?;
1771
1772        let (delete, room, guest_connection_ids) = &*room_guard;
1773
1774        let message = proto::UnshareProject {
1775            project_id: project_id.to_proto(),
1776        };
1777
1778        broadcast(
1779            Some(connection_id),
1780            guest_connection_ids.iter().copied(),
1781            |conn_id| session.peer.send(conn_id, message.clone()),
1782        );
1783        if let Some(room) = room {
1784            room_updated(room, &session.peer);
1785        }
1786
1787        *delete
1788    };
1789
1790    if delete {
1791        let db = session.db().await;
1792        db.delete_project(project_id).await?;
1793    }
1794
1795    Ok(())
1796}
1797
1798/// Join someone elses shared project.
1799async fn join_project(
1800    request: proto::JoinProject,
1801    response: Response<proto::JoinProject>,
1802    session: Session,
1803) -> Result<()> {
1804    let project_id = ProjectId::from_proto(request.project_id);
1805
1806    tracing::info!(%project_id, "join project");
1807
1808    let db = session.db().await;
1809    let (project, replica_id) = &mut *db
1810        .join_project(project_id, session.connection_id, session.user_id())
1811        .await?;
1812    drop(db);
1813    tracing::info!(%project_id, "join remote project");
1814    join_project_internal(response, session, project, replica_id)
1815}
1816
1817trait JoinProjectInternalResponse {
1818    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1819}
1820impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1821    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1822        Response::<proto::JoinProject>::send(self, result)
1823    }
1824}
1825
1826fn join_project_internal(
1827    response: impl JoinProjectInternalResponse,
1828    session: Session,
1829    project: &mut Project,
1830    replica_id: &ReplicaId,
1831) -> Result<()> {
1832    let collaborators = project
1833        .collaborators
1834        .iter()
1835        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1836        .map(|collaborator| collaborator.to_proto())
1837        .collect::<Vec<_>>();
1838    let project_id = project.id;
1839    let guest_user_id = session.user_id();
1840
1841    let worktrees = project
1842        .worktrees
1843        .iter()
1844        .map(|(id, worktree)| proto::WorktreeMetadata {
1845            id: *id,
1846            root_name: worktree.root_name.clone(),
1847            visible: worktree.visible,
1848            abs_path: worktree.abs_path.clone(),
1849        })
1850        .collect::<Vec<_>>();
1851
1852    let add_project_collaborator = proto::AddProjectCollaborator {
1853        project_id: project_id.to_proto(),
1854        collaborator: Some(proto::Collaborator {
1855            peer_id: Some(session.connection_id.into()),
1856            replica_id: replica_id.0 as u32,
1857            user_id: guest_user_id.to_proto(),
1858            is_host: false,
1859        }),
1860    };
1861
1862    for collaborator in &collaborators {
1863        session
1864            .peer
1865            .send(
1866                collaborator.peer_id.unwrap().into(),
1867                add_project_collaborator.clone(),
1868            )
1869            .trace_err();
1870    }
1871
1872    // First, we send the metadata associated with each worktree.
1873    response.send(proto::JoinProjectResponse {
1874        project_id: project.id.0 as u64,
1875        worktrees: worktrees.clone(),
1876        replica_id: replica_id.0 as u32,
1877        collaborators: collaborators.clone(),
1878        language_servers: project.language_servers.clone(),
1879        role: project.role.into(),
1880    })?;
1881
1882    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1883        // Stream this worktree's entries.
1884        let message = proto::UpdateWorktree {
1885            project_id: project_id.to_proto(),
1886            worktree_id,
1887            abs_path: worktree.abs_path.clone(),
1888            root_name: worktree.root_name,
1889            updated_entries: worktree.entries,
1890            removed_entries: Default::default(),
1891            scan_id: worktree.scan_id,
1892            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1893            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1894            removed_repositories: Default::default(),
1895        };
1896        for update in proto::split_worktree_update(message) {
1897            session.peer.send(session.connection_id, update.clone())?;
1898        }
1899
1900        // Stream this worktree's diagnostics.
1901        for summary in worktree.diagnostic_summaries {
1902            session.peer.send(
1903                session.connection_id,
1904                proto::UpdateDiagnosticSummary {
1905                    project_id: project_id.to_proto(),
1906                    worktree_id: worktree.id,
1907                    summary: Some(summary),
1908                },
1909            )?;
1910        }
1911
1912        for settings_file in worktree.settings_files {
1913            session.peer.send(
1914                session.connection_id,
1915                proto::UpdateWorktreeSettings {
1916                    project_id: project_id.to_proto(),
1917                    worktree_id: worktree.id,
1918                    path: settings_file.path,
1919                    content: Some(settings_file.content),
1920                    kind: Some(settings_file.kind.to_proto() as i32),
1921                },
1922            )?;
1923        }
1924    }
1925
1926    for repository in mem::take(&mut project.repositories) {
1927        for update in split_repository_update(repository) {
1928            session.peer.send(session.connection_id, update)?;
1929        }
1930    }
1931
1932    for language_server in &project.language_servers {
1933        session.peer.send(
1934            session.connection_id,
1935            proto::UpdateLanguageServer {
1936                project_id: project_id.to_proto(),
1937                language_server_id: language_server.id,
1938                variant: Some(
1939                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1940                        proto::LspDiskBasedDiagnosticsUpdated {},
1941                    ),
1942                ),
1943            },
1944        )?;
1945    }
1946
1947    Ok(())
1948}
1949
1950/// Leave someone elses shared project.
1951async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1952    let sender_id = session.connection_id;
1953    let project_id = ProjectId::from_proto(request.project_id);
1954    let db = session.db().await;
1955
1956    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1957    tracing::info!(
1958        %project_id,
1959        "leave project"
1960    );
1961
1962    project_left(project, &session);
1963    if let Some(room) = room {
1964        room_updated(room, &session.peer);
1965    }
1966
1967    Ok(())
1968}
1969
1970/// Updates other participants with changes to the project
1971async fn update_project(
1972    request: proto::UpdateProject,
1973    response: Response<proto::UpdateProject>,
1974    session: Session,
1975) -> Result<()> {
1976    let project_id = ProjectId::from_proto(request.project_id);
1977    let (room, guest_connection_ids) = &*session
1978        .db()
1979        .await
1980        .update_project(project_id, session.connection_id, &request.worktrees)
1981        .await?;
1982    broadcast(
1983        Some(session.connection_id),
1984        guest_connection_ids.iter().copied(),
1985        |connection_id| {
1986            session
1987                .peer
1988                .forward_send(session.connection_id, connection_id, request.clone())
1989        },
1990    );
1991    if let Some(room) = room {
1992        room_updated(room, &session.peer);
1993    }
1994    response.send(proto::Ack {})?;
1995
1996    Ok(())
1997}
1998
1999/// Updates other participants with changes to the worktree
2000async fn update_worktree(
2001    request: proto::UpdateWorktree,
2002    response: Response<proto::UpdateWorktree>,
2003    session: Session,
2004) -> Result<()> {
2005    let guest_connection_ids = session
2006        .db()
2007        .await
2008        .update_worktree(&request, session.connection_id)
2009        .await?;
2010
2011    broadcast(
2012        Some(session.connection_id),
2013        guest_connection_ids.iter().copied(),
2014        |connection_id| {
2015            session
2016                .peer
2017                .forward_send(session.connection_id, connection_id, request.clone())
2018        },
2019    );
2020    response.send(proto::Ack {})?;
2021    Ok(())
2022}
2023
2024async fn update_repository(
2025    request: proto::UpdateRepository,
2026    response: Response<proto::UpdateRepository>,
2027    session: Session,
2028) -> Result<()> {
2029    let guest_connection_ids = session
2030        .db()
2031        .await
2032        .update_repository(&request, session.connection_id)
2033        .await?;
2034
2035    broadcast(
2036        Some(session.connection_id),
2037        guest_connection_ids.iter().copied(),
2038        |connection_id| {
2039            session
2040                .peer
2041                .forward_send(session.connection_id, connection_id, request.clone())
2042        },
2043    );
2044    response.send(proto::Ack {})?;
2045    Ok(())
2046}
2047
2048async fn remove_repository(
2049    request: proto::RemoveRepository,
2050    response: Response<proto::RemoveRepository>,
2051    session: Session,
2052) -> Result<()> {
2053    let guest_connection_ids = session
2054        .db()
2055        .await
2056        .remove_repository(&request, session.connection_id)
2057        .await?;
2058
2059    broadcast(
2060        Some(session.connection_id),
2061        guest_connection_ids.iter().copied(),
2062        |connection_id| {
2063            session
2064                .peer
2065                .forward_send(session.connection_id, connection_id, request.clone())
2066        },
2067    );
2068    response.send(proto::Ack {})?;
2069    Ok(())
2070}
2071
2072/// Updates other participants with changes to the diagnostics
2073async fn update_diagnostic_summary(
2074    message: proto::UpdateDiagnosticSummary,
2075    session: Session,
2076) -> Result<()> {
2077    let guest_connection_ids = session
2078        .db()
2079        .await
2080        .update_diagnostic_summary(&message, session.connection_id)
2081        .await?;
2082
2083    broadcast(
2084        Some(session.connection_id),
2085        guest_connection_ids.iter().copied(),
2086        |connection_id| {
2087            session
2088                .peer
2089                .forward_send(session.connection_id, connection_id, message.clone())
2090        },
2091    );
2092
2093    Ok(())
2094}
2095
2096/// Updates other participants with changes to the worktree settings
2097async fn update_worktree_settings(
2098    message: proto::UpdateWorktreeSettings,
2099    session: Session,
2100) -> Result<()> {
2101    let guest_connection_ids = session
2102        .db()
2103        .await
2104        .update_worktree_settings(&message, session.connection_id)
2105        .await?;
2106
2107    broadcast(
2108        Some(session.connection_id),
2109        guest_connection_ids.iter().copied(),
2110        |connection_id| {
2111            session
2112                .peer
2113                .forward_send(session.connection_id, connection_id, message.clone())
2114        },
2115    );
2116
2117    Ok(())
2118}
2119
2120/// Notify other participants that a language server has started.
2121async fn start_language_server(
2122    request: proto::StartLanguageServer,
2123    session: Session,
2124) -> Result<()> {
2125    let guest_connection_ids = session
2126        .db()
2127        .await
2128        .start_language_server(&request, session.connection_id)
2129        .await?;
2130
2131    broadcast(
2132        Some(session.connection_id),
2133        guest_connection_ids.iter().copied(),
2134        |connection_id| {
2135            session
2136                .peer
2137                .forward_send(session.connection_id, connection_id, request.clone())
2138        },
2139    );
2140    Ok(())
2141}
2142
2143/// Notify other participants that a language server has changed.
2144async fn update_language_server(
2145    request: proto::UpdateLanguageServer,
2146    session: Session,
2147) -> Result<()> {
2148    let project_id = ProjectId::from_proto(request.project_id);
2149    let project_connection_ids = session
2150        .db()
2151        .await
2152        .project_connection_ids(project_id, session.connection_id, true)
2153        .await?;
2154    broadcast(
2155        Some(session.connection_id),
2156        project_connection_ids.iter().copied(),
2157        |connection_id| {
2158            session
2159                .peer
2160                .forward_send(session.connection_id, connection_id, request.clone())
2161        },
2162    );
2163    Ok(())
2164}
2165
2166/// forward a project request to the host. These requests should be read only
2167/// as guests are allowed to send them.
2168async fn forward_read_only_project_request<T>(
2169    request: T,
2170    response: Response<T>,
2171    session: Session,
2172) -> Result<()>
2173where
2174    T: EntityMessage + RequestMessage,
2175{
2176    let project_id = ProjectId::from_proto(request.remote_entity_id());
2177    let host_connection_id = session
2178        .db()
2179        .await
2180        .host_for_read_only_project_request(project_id, session.connection_id)
2181        .await?;
2182    let payload = session
2183        .peer
2184        .forward_request(session.connection_id, host_connection_id, request)
2185        .await?;
2186    response.send(payload)?;
2187    Ok(())
2188}
2189
2190async fn forward_find_search_candidates_request(
2191    request: proto::FindSearchCandidates,
2192    response: Response<proto::FindSearchCandidates>,
2193    session: Session,
2194) -> Result<()> {
2195    let project_id = ProjectId::from_proto(request.remote_entity_id());
2196    let host_connection_id = session
2197        .db()
2198        .await
2199        .host_for_read_only_project_request(project_id, session.connection_id)
2200        .await?;
2201    let payload = session
2202        .peer
2203        .forward_request(session.connection_id, host_connection_id, request)
2204        .await?;
2205    response.send(payload)?;
2206    Ok(())
2207}
2208
2209/// forward a project request to the host. These requests are disallowed
2210/// for guests.
2211async fn forward_mutating_project_request<T>(
2212    request: T,
2213    response: Response<T>,
2214    session: Session,
2215) -> Result<()>
2216where
2217    T: EntityMessage + RequestMessage,
2218{
2219    let project_id = ProjectId::from_proto(request.remote_entity_id());
2220
2221    let host_connection_id = session
2222        .db()
2223        .await
2224        .host_for_mutating_project_request(project_id, session.connection_id)
2225        .await?;
2226    let payload = session
2227        .peer
2228        .forward_request(session.connection_id, host_connection_id, request)
2229        .await?;
2230    response.send(payload)?;
2231    Ok(())
2232}
2233
2234/// Notify other participants that a new buffer has been created
2235async fn create_buffer_for_peer(
2236    request: proto::CreateBufferForPeer,
2237    session: Session,
2238) -> Result<()> {
2239    session
2240        .db()
2241        .await
2242        .check_user_is_project_host(
2243            ProjectId::from_proto(request.project_id),
2244            session.connection_id,
2245        )
2246        .await?;
2247    let peer_id = request.peer_id.context("invalid peer id")?;
2248    session
2249        .peer
2250        .forward_send(session.connection_id, peer_id.into(), request)?;
2251    Ok(())
2252}
2253
2254/// Notify other participants that a buffer has been updated. This is
2255/// allowed for guests as long as the update is limited to selections.
2256async fn update_buffer(
2257    request: proto::UpdateBuffer,
2258    response: Response<proto::UpdateBuffer>,
2259    session: Session,
2260) -> Result<()> {
2261    let project_id = ProjectId::from_proto(request.project_id);
2262    let mut capability = Capability::ReadOnly;
2263
2264    for op in request.operations.iter() {
2265        match op.variant {
2266            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2267            Some(_) => capability = Capability::ReadWrite,
2268        }
2269    }
2270
2271    let host = {
2272        let guard = session
2273            .db()
2274            .await
2275            .connections_for_buffer_update(project_id, session.connection_id, capability)
2276            .await?;
2277
2278        let (host, guests) = &*guard;
2279
2280        broadcast(
2281            Some(session.connection_id),
2282            guests.clone(),
2283            |connection_id| {
2284                session
2285                    .peer
2286                    .forward_send(session.connection_id, connection_id, request.clone())
2287            },
2288        );
2289
2290        *host
2291    };
2292
2293    if host != session.connection_id {
2294        session
2295            .peer
2296            .forward_request(session.connection_id, host, request.clone())
2297            .await?;
2298    }
2299
2300    response.send(proto::Ack {})?;
2301    Ok(())
2302}
2303
2304async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2305    let project_id = ProjectId::from_proto(message.project_id);
2306
2307    let operation = message.operation.as_ref().context("invalid operation")?;
2308    let capability = match operation.variant.as_ref() {
2309        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2310            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2311                match buffer_op.variant {
2312                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2313                        Capability::ReadOnly
2314                    }
2315                    _ => Capability::ReadWrite,
2316                }
2317            } else {
2318                Capability::ReadWrite
2319            }
2320        }
2321        Some(_) => Capability::ReadWrite,
2322        None => Capability::ReadOnly,
2323    };
2324
2325    let guard = session
2326        .db()
2327        .await
2328        .connections_for_buffer_update(project_id, session.connection_id, capability)
2329        .await?;
2330
2331    let (host, guests) = &*guard;
2332
2333    broadcast(
2334        Some(session.connection_id),
2335        guests.iter().chain([host]).copied(),
2336        |connection_id| {
2337            session
2338                .peer
2339                .forward_send(session.connection_id, connection_id, message.clone())
2340        },
2341    );
2342
2343    Ok(())
2344}
2345
2346/// Notify other participants that a project has been updated.
2347async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2348    request: T,
2349    session: Session,
2350) -> Result<()> {
2351    let project_id = ProjectId::from_proto(request.remote_entity_id());
2352    let project_connection_ids = session
2353        .db()
2354        .await
2355        .project_connection_ids(project_id, session.connection_id, false)
2356        .await?;
2357
2358    broadcast(
2359        Some(session.connection_id),
2360        project_connection_ids.iter().copied(),
2361        |connection_id| {
2362            session
2363                .peer
2364                .forward_send(session.connection_id, connection_id, request.clone())
2365        },
2366    );
2367    Ok(())
2368}
2369
2370/// Start following another user in a call.
2371async fn follow(
2372    request: proto::Follow,
2373    response: Response<proto::Follow>,
2374    session: Session,
2375) -> Result<()> {
2376    let room_id = RoomId::from_proto(request.room_id);
2377    let project_id = request.project_id.map(ProjectId::from_proto);
2378    let leader_id = request.leader_id.context("invalid leader id")?.into();
2379    let follower_id = session.connection_id;
2380
2381    session
2382        .db()
2383        .await
2384        .check_room_participants(room_id, leader_id, session.connection_id)
2385        .await?;
2386
2387    let response_payload = session
2388        .peer
2389        .forward_request(session.connection_id, leader_id, request)
2390        .await?;
2391    response.send(response_payload)?;
2392
2393    if let Some(project_id) = project_id {
2394        let room = session
2395            .db()
2396            .await
2397            .follow(room_id, project_id, leader_id, follower_id)
2398            .await?;
2399        room_updated(&room, &session.peer);
2400    }
2401
2402    Ok(())
2403}
2404
2405/// Stop following another user in a call.
2406async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2407    let room_id = RoomId::from_proto(request.room_id);
2408    let project_id = request.project_id.map(ProjectId::from_proto);
2409    let leader_id = request.leader_id.context("invalid leader id")?.into();
2410    let follower_id = session.connection_id;
2411
2412    session
2413        .db()
2414        .await
2415        .check_room_participants(room_id, leader_id, session.connection_id)
2416        .await?;
2417
2418    session
2419        .peer
2420        .forward_send(session.connection_id, leader_id, request)?;
2421
2422    if let Some(project_id) = project_id {
2423        let room = session
2424            .db()
2425            .await
2426            .unfollow(room_id, project_id, leader_id, follower_id)
2427            .await?;
2428        room_updated(&room, &session.peer);
2429    }
2430
2431    Ok(())
2432}
2433
2434/// Notify everyone following you of your current location.
2435async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2436    let room_id = RoomId::from_proto(request.room_id);
2437    let database = session.db.lock().await;
2438
2439    let connection_ids = if let Some(project_id) = request.project_id {
2440        let project_id = ProjectId::from_proto(project_id);
2441        database
2442            .project_connection_ids(project_id, session.connection_id, true)
2443            .await?
2444    } else {
2445        database
2446            .room_connection_ids(room_id, session.connection_id)
2447            .await?
2448    };
2449
2450    // For now, don't send view update messages back to that view's current leader.
2451    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2452        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2453        _ => None,
2454    });
2455
2456    for connection_id in connection_ids.iter().cloned() {
2457        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2458            session
2459                .peer
2460                .forward_send(session.connection_id, connection_id, request.clone())?;
2461        }
2462    }
2463    Ok(())
2464}
2465
2466/// Get public data about users.
2467async fn get_users(
2468    request: proto::GetUsers,
2469    response: Response<proto::GetUsers>,
2470    session: Session,
2471) -> Result<()> {
2472    let user_ids = request
2473        .user_ids
2474        .into_iter()
2475        .map(UserId::from_proto)
2476        .collect();
2477    let users = session
2478        .db()
2479        .await
2480        .get_users_by_ids(user_ids)
2481        .await?
2482        .into_iter()
2483        .map(|user| proto::User {
2484            id: user.id.to_proto(),
2485            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2486            github_login: user.github_login,
2487            email: user.email_address,
2488            name: user.name,
2489        })
2490        .collect();
2491    response.send(proto::UsersResponse { users })?;
2492    Ok(())
2493}
2494
2495/// Search for users (to invite) buy Github login
2496async fn fuzzy_search_users(
2497    request: proto::FuzzySearchUsers,
2498    response: Response<proto::FuzzySearchUsers>,
2499    session: Session,
2500) -> Result<()> {
2501    let query = request.query;
2502    let users = match query.len() {
2503        0 => vec![],
2504        1 | 2 => session
2505            .db()
2506            .await
2507            .get_user_by_github_login(&query)
2508            .await?
2509            .into_iter()
2510            .collect(),
2511        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2512    };
2513    let users = users
2514        .into_iter()
2515        .filter(|user| user.id != session.user_id())
2516        .map(|user| proto::User {
2517            id: user.id.to_proto(),
2518            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2519            github_login: user.github_login,
2520            name: user.name,
2521            email: user.email_address,
2522        })
2523        .collect();
2524    response.send(proto::UsersResponse { users })?;
2525    Ok(())
2526}
2527
2528/// Send a contact request to another user.
2529async fn request_contact(
2530    request: proto::RequestContact,
2531    response: Response<proto::RequestContact>,
2532    session: Session,
2533) -> Result<()> {
2534    let requester_id = session.user_id();
2535    let responder_id = UserId::from_proto(request.responder_id);
2536    if requester_id == responder_id {
2537        return Err(anyhow!("cannot add yourself as a contact"))?;
2538    }
2539
2540    let notifications = session
2541        .db()
2542        .await
2543        .send_contact_request(requester_id, responder_id)
2544        .await?;
2545
2546    // Update outgoing contact requests of requester
2547    let mut update = proto::UpdateContacts::default();
2548    update.outgoing_requests.push(responder_id.to_proto());
2549    for connection_id in session
2550        .connection_pool()
2551        .await
2552        .user_connection_ids(requester_id)
2553    {
2554        session.peer.send(connection_id, update.clone())?;
2555    }
2556
2557    // Update incoming contact requests of responder
2558    let mut update = proto::UpdateContacts::default();
2559    update
2560        .incoming_requests
2561        .push(proto::IncomingContactRequest {
2562            requester_id: requester_id.to_proto(),
2563        });
2564    let connection_pool = session.connection_pool().await;
2565    for connection_id in connection_pool.user_connection_ids(responder_id) {
2566        session.peer.send(connection_id, update.clone())?;
2567    }
2568
2569    send_notifications(&connection_pool, &session.peer, notifications);
2570
2571    response.send(proto::Ack {})?;
2572    Ok(())
2573}
2574
2575/// Accept or decline a contact request
2576async fn respond_to_contact_request(
2577    request: proto::RespondToContactRequest,
2578    response: Response<proto::RespondToContactRequest>,
2579    session: Session,
2580) -> Result<()> {
2581    let responder_id = session.user_id();
2582    let requester_id = UserId::from_proto(request.requester_id);
2583    let db = session.db().await;
2584    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2585        db.dismiss_contact_notification(responder_id, requester_id)
2586            .await?;
2587    } else {
2588        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2589
2590        let notifications = db
2591            .respond_to_contact_request(responder_id, requester_id, accept)
2592            .await?;
2593        let requester_busy = db.is_user_busy(requester_id).await?;
2594        let responder_busy = db.is_user_busy(responder_id).await?;
2595
2596        let pool = session.connection_pool().await;
2597        // Update responder with new contact
2598        let mut update = proto::UpdateContacts::default();
2599        if accept {
2600            update
2601                .contacts
2602                .push(contact_for_user(requester_id, requester_busy, &pool));
2603        }
2604        update
2605            .remove_incoming_requests
2606            .push(requester_id.to_proto());
2607        for connection_id in pool.user_connection_ids(responder_id) {
2608            session.peer.send(connection_id, update.clone())?;
2609        }
2610
2611        // Update requester with new contact
2612        let mut update = proto::UpdateContacts::default();
2613        if accept {
2614            update
2615                .contacts
2616                .push(contact_for_user(responder_id, responder_busy, &pool));
2617        }
2618        update
2619            .remove_outgoing_requests
2620            .push(responder_id.to_proto());
2621
2622        for connection_id in pool.user_connection_ids(requester_id) {
2623            session.peer.send(connection_id, update.clone())?;
2624        }
2625
2626        send_notifications(&pool, &session.peer, notifications);
2627    }
2628
2629    response.send(proto::Ack {})?;
2630    Ok(())
2631}
2632
2633/// Remove a contact.
2634async fn remove_contact(
2635    request: proto::RemoveContact,
2636    response: Response<proto::RemoveContact>,
2637    session: Session,
2638) -> Result<()> {
2639    let requester_id = session.user_id();
2640    let responder_id = UserId::from_proto(request.user_id);
2641    let db = session.db().await;
2642    let (contact_accepted, deleted_notification_id) =
2643        db.remove_contact(requester_id, responder_id).await?;
2644
2645    let pool = session.connection_pool().await;
2646    // Update outgoing contact requests of requester
2647    let mut update = proto::UpdateContacts::default();
2648    if contact_accepted {
2649        update.remove_contacts.push(responder_id.to_proto());
2650    } else {
2651        update
2652            .remove_outgoing_requests
2653            .push(responder_id.to_proto());
2654    }
2655    for connection_id in pool.user_connection_ids(requester_id) {
2656        session.peer.send(connection_id, update.clone())?;
2657    }
2658
2659    // Update incoming contact requests of responder
2660    let mut update = proto::UpdateContacts::default();
2661    if contact_accepted {
2662        update.remove_contacts.push(requester_id.to_proto());
2663    } else {
2664        update
2665            .remove_incoming_requests
2666            .push(requester_id.to_proto());
2667    }
2668    for connection_id in pool.user_connection_ids(responder_id) {
2669        session.peer.send(connection_id, update.clone())?;
2670        if let Some(notification_id) = deleted_notification_id {
2671            session.peer.send(
2672                connection_id,
2673                proto::DeleteNotification {
2674                    notification_id: notification_id.to_proto(),
2675                },
2676            )?;
2677        }
2678    }
2679
2680    response.send(proto::Ack {})?;
2681    Ok(())
2682}
2683
2684fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2685    version.0.minor() < 139
2686}
2687
2688async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2689    if is_staff {
2690        return Ok(proto::Plan::ZedPro);
2691    }
2692
2693    let subscription = db.get_active_billing_subscription(user_id).await?;
2694    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2695
2696    let plan = if let Some(subscription_kind) = subscription_kind {
2697        match subscription_kind {
2698            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2699            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2700            SubscriptionKind::ZedFree => proto::Plan::Free,
2701        }
2702    } else {
2703        proto::Plan::Free
2704    };
2705
2706    Ok(plan)
2707}
2708
2709async fn make_update_user_plan_message(
2710    db: &Arc<Database>,
2711    llm_db: Option<Arc<LlmDatabase>>,
2712    user_id: UserId,
2713    is_staff: bool,
2714) -> Result<proto::UpdateUserPlan> {
2715    let feature_flags = db.get_user_flags(user_id).await?;
2716    let plan = current_plan(db, user_id, is_staff).await?;
2717    let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2718    let billing_preferences = db.get_billing_preferences(user_id).await?;
2719    let user = db.get_user_by_id(user_id).await?;
2720
2721    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2722        let subscription = db.get_active_billing_subscription(user_id).await?;
2723
2724        let subscription_period =
2725            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2726
2727        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2728            llm_db
2729                .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2730                .await?
2731        } else {
2732            None
2733        };
2734
2735        (subscription_period, usage)
2736    } else {
2737        (None, None)
2738    };
2739
2740    // Calculate account_too_young
2741    let account_too_young = if matches!(plan, proto::Plan::ZedPro) {
2742        // If they have paid, then we allow them to use all of the features
2743        false
2744    } else if let Some(user) = user {
2745        // If we have access to the profile age, we use that
2746        chrono::Utc::now().naive_utc() - user.account_created_at() < MIN_ACCOUNT_AGE_FOR_LLM_USE
2747    } else {
2748        // Default to false otherwise
2749        false
2750    };
2751
2752    Ok(proto::UpdateUserPlan {
2753        plan: plan.into(),
2754        trial_started_at: billing_customer
2755            .and_then(|billing_customer| billing_customer.trial_started_at)
2756            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2757        is_usage_based_billing_enabled: if is_staff {
2758            Some(true)
2759        } else {
2760            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2761        },
2762        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2763            proto::SubscriptionPeriod {
2764                started_at: started_at.timestamp() as u64,
2765                ended_at: ended_at.timestamp() as u64,
2766            }
2767        }),
2768        account_too_young: Some(account_too_young),
2769        usage: usage.map(|usage| {
2770            let plan = match plan {
2771                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2772                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2773                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2774            };
2775
2776            let model_requests_limit = match plan.model_requests_limit() {
2777                zed_llm_client::UsageLimit::Limited(limit) => {
2778                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2779                        && feature_flags
2780                            .iter()
2781                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2782                    {
2783                        1_000
2784                    } else {
2785                        limit
2786                    };
2787
2788                    zed_llm_client::UsageLimit::Limited(limit)
2789                }
2790                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2791            };
2792
2793            proto::SubscriptionUsage {
2794                model_requests_usage_amount: usage.model_requests as u32,
2795                model_requests_usage_limit: Some(proto::UsageLimit {
2796                    variant: Some(match model_requests_limit {
2797                        zed_llm_client::UsageLimit::Limited(limit) => {
2798                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2799                                limit: limit as u32,
2800                            })
2801                        }
2802                        zed_llm_client::UsageLimit::Unlimited => {
2803                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2804                        }
2805                    }),
2806                }),
2807                edit_predictions_usage_amount: usage.edit_predictions as u32,
2808                edit_predictions_usage_limit: Some(proto::UsageLimit {
2809                    variant: Some(match plan.edit_predictions_limit() {
2810                        zed_llm_client::UsageLimit::Limited(limit) => {
2811                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2812                                limit: limit as u32,
2813                            })
2814                        }
2815                        zed_llm_client::UsageLimit::Unlimited => {
2816                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2817                        }
2818                    }),
2819                }),
2820            }
2821        }),
2822    })
2823}
2824
2825async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2826    let db = session.db().await;
2827
2828    let update_user_plan = make_update_user_plan_message(
2829        &db.0,
2830        session.app_state.llm_db.clone(),
2831        user_id,
2832        session.is_staff(),
2833    )
2834    .await?;
2835
2836    session
2837        .peer
2838        .send(session.connection_id, update_user_plan)
2839        .trace_err();
2840
2841    Ok(())
2842}
2843
2844async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2845    subscribe_user_to_channels(session.user_id(), &session).await?;
2846    Ok(())
2847}
2848
2849async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2850    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2851    let mut pool = session.connection_pool().await;
2852    for membership in &channels_for_user.channel_memberships {
2853        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2854    }
2855    session.peer.send(
2856        session.connection_id,
2857        build_update_user_channels(&channels_for_user),
2858    )?;
2859    session.peer.send(
2860        session.connection_id,
2861        build_channels_update(channels_for_user),
2862    )?;
2863    Ok(())
2864}
2865
2866/// Creates a new channel.
2867async fn create_channel(
2868    request: proto::CreateChannel,
2869    response: Response<proto::CreateChannel>,
2870    session: Session,
2871) -> Result<()> {
2872    let db = session.db().await;
2873
2874    let parent_id = request.parent_id.map(ChannelId::from_proto);
2875    let (channel, membership) = db
2876        .create_channel(&request.name, parent_id, session.user_id())
2877        .await?;
2878
2879    let root_id = channel.root_id();
2880    let channel = Channel::from_model(channel);
2881
2882    response.send(proto::CreateChannelResponse {
2883        channel: Some(channel.to_proto()),
2884        parent_id: request.parent_id,
2885    })?;
2886
2887    let mut connection_pool = session.connection_pool().await;
2888    if let Some(membership) = membership {
2889        connection_pool.subscribe_to_channel(
2890            membership.user_id,
2891            membership.channel_id,
2892            membership.role,
2893        );
2894        let update = proto::UpdateUserChannels {
2895            channel_memberships: vec![proto::ChannelMembership {
2896                channel_id: membership.channel_id.to_proto(),
2897                role: membership.role.into(),
2898            }],
2899            ..Default::default()
2900        };
2901        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2902            session.peer.send(connection_id, update.clone())?;
2903        }
2904    }
2905
2906    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2907        if !role.can_see_channel(channel.visibility) {
2908            continue;
2909        }
2910
2911        let update = proto::UpdateChannels {
2912            channels: vec![channel.to_proto()],
2913            ..Default::default()
2914        };
2915        session.peer.send(connection_id, update.clone())?;
2916    }
2917
2918    Ok(())
2919}
2920
2921/// Delete a channel
2922async fn delete_channel(
2923    request: proto::DeleteChannel,
2924    response: Response<proto::DeleteChannel>,
2925    session: Session,
2926) -> Result<()> {
2927    let db = session.db().await;
2928
2929    let channel_id = request.channel_id;
2930    let (root_channel, removed_channels) = db
2931        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2932        .await?;
2933    response.send(proto::Ack {})?;
2934
2935    // Notify members of removed channels
2936    let mut update = proto::UpdateChannels::default();
2937    update
2938        .delete_channels
2939        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2940
2941    let connection_pool = session.connection_pool().await;
2942    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2943        session.peer.send(connection_id, update.clone())?;
2944    }
2945
2946    Ok(())
2947}
2948
2949/// Invite someone to join a channel.
2950async fn invite_channel_member(
2951    request: proto::InviteChannelMember,
2952    response: Response<proto::InviteChannelMember>,
2953    session: Session,
2954) -> Result<()> {
2955    let db = session.db().await;
2956    let channel_id = ChannelId::from_proto(request.channel_id);
2957    let invitee_id = UserId::from_proto(request.user_id);
2958    let InviteMemberResult {
2959        channel,
2960        notifications,
2961    } = db
2962        .invite_channel_member(
2963            channel_id,
2964            invitee_id,
2965            session.user_id(),
2966            request.role().into(),
2967        )
2968        .await?;
2969
2970    let update = proto::UpdateChannels {
2971        channel_invitations: vec![channel.to_proto()],
2972        ..Default::default()
2973    };
2974
2975    let connection_pool = session.connection_pool().await;
2976    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2977        session.peer.send(connection_id, update.clone())?;
2978    }
2979
2980    send_notifications(&connection_pool, &session.peer, notifications);
2981
2982    response.send(proto::Ack {})?;
2983    Ok(())
2984}
2985
2986/// remove someone from a channel
2987async fn remove_channel_member(
2988    request: proto::RemoveChannelMember,
2989    response: Response<proto::RemoveChannelMember>,
2990    session: Session,
2991) -> Result<()> {
2992    let db = session.db().await;
2993    let channel_id = ChannelId::from_proto(request.channel_id);
2994    let member_id = UserId::from_proto(request.user_id);
2995
2996    let RemoveChannelMemberResult {
2997        membership_update,
2998        notification_id,
2999    } = db
3000        .remove_channel_member(channel_id, member_id, session.user_id())
3001        .await?;
3002
3003    let mut connection_pool = session.connection_pool().await;
3004    notify_membership_updated(
3005        &mut connection_pool,
3006        membership_update,
3007        member_id,
3008        &session.peer,
3009    );
3010    for connection_id in connection_pool.user_connection_ids(member_id) {
3011        if let Some(notification_id) = notification_id {
3012            session
3013                .peer
3014                .send(
3015                    connection_id,
3016                    proto::DeleteNotification {
3017                        notification_id: notification_id.to_proto(),
3018                    },
3019                )
3020                .trace_err();
3021        }
3022    }
3023
3024    response.send(proto::Ack {})?;
3025    Ok(())
3026}
3027
3028/// Toggle the channel between public and private.
3029/// Care is taken to maintain the invariant that public channels only descend from public channels,
3030/// (though members-only channels can appear at any point in the hierarchy).
3031async fn set_channel_visibility(
3032    request: proto::SetChannelVisibility,
3033    response: Response<proto::SetChannelVisibility>,
3034    session: Session,
3035) -> Result<()> {
3036    let db = session.db().await;
3037    let channel_id = ChannelId::from_proto(request.channel_id);
3038    let visibility = request.visibility().into();
3039
3040    let channel_model = db
3041        .set_channel_visibility(channel_id, visibility, session.user_id())
3042        .await?;
3043    let root_id = channel_model.root_id();
3044    let channel = Channel::from_model(channel_model);
3045
3046    let mut connection_pool = session.connection_pool().await;
3047    for (user_id, role) in connection_pool
3048        .channel_user_ids(root_id)
3049        .collect::<Vec<_>>()
3050        .into_iter()
3051    {
3052        let update = if role.can_see_channel(channel.visibility) {
3053            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3054            proto::UpdateChannels {
3055                channels: vec![channel.to_proto()],
3056                ..Default::default()
3057            }
3058        } else {
3059            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3060            proto::UpdateChannels {
3061                delete_channels: vec![channel.id.to_proto()],
3062                ..Default::default()
3063            }
3064        };
3065
3066        for connection_id in connection_pool.user_connection_ids(user_id) {
3067            session.peer.send(connection_id, update.clone())?;
3068        }
3069    }
3070
3071    response.send(proto::Ack {})?;
3072    Ok(())
3073}
3074
3075/// Alter the role for a user in the channel.
3076async fn set_channel_member_role(
3077    request: proto::SetChannelMemberRole,
3078    response: Response<proto::SetChannelMemberRole>,
3079    session: Session,
3080) -> Result<()> {
3081    let db = session.db().await;
3082    let channel_id = ChannelId::from_proto(request.channel_id);
3083    let member_id = UserId::from_proto(request.user_id);
3084    let result = db
3085        .set_channel_member_role(
3086            channel_id,
3087            session.user_id(),
3088            member_id,
3089            request.role().into(),
3090        )
3091        .await?;
3092
3093    match result {
3094        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3095            let mut connection_pool = session.connection_pool().await;
3096            notify_membership_updated(
3097                &mut connection_pool,
3098                membership_update,
3099                member_id,
3100                &session.peer,
3101            )
3102        }
3103        db::SetMemberRoleResult::InviteUpdated(channel) => {
3104            let update = proto::UpdateChannels {
3105                channel_invitations: vec![channel.to_proto()],
3106                ..Default::default()
3107            };
3108
3109            for connection_id in session
3110                .connection_pool()
3111                .await
3112                .user_connection_ids(member_id)
3113            {
3114                session.peer.send(connection_id, update.clone())?;
3115            }
3116        }
3117    }
3118
3119    response.send(proto::Ack {})?;
3120    Ok(())
3121}
3122
3123/// Change the name of a channel
3124async fn rename_channel(
3125    request: proto::RenameChannel,
3126    response: Response<proto::RenameChannel>,
3127    session: Session,
3128) -> Result<()> {
3129    let db = session.db().await;
3130    let channel_id = ChannelId::from_proto(request.channel_id);
3131    let channel_model = db
3132        .rename_channel(channel_id, session.user_id(), &request.name)
3133        .await?;
3134    let root_id = channel_model.root_id();
3135    let channel = Channel::from_model(channel_model);
3136
3137    response.send(proto::RenameChannelResponse {
3138        channel: Some(channel.to_proto()),
3139    })?;
3140
3141    let connection_pool = session.connection_pool().await;
3142    let update = proto::UpdateChannels {
3143        channels: vec![channel.to_proto()],
3144        ..Default::default()
3145    };
3146    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3147        if role.can_see_channel(channel.visibility) {
3148            session.peer.send(connection_id, update.clone())?;
3149        }
3150    }
3151
3152    Ok(())
3153}
3154
3155/// Move a channel to a new parent.
3156async fn move_channel(
3157    request: proto::MoveChannel,
3158    response: Response<proto::MoveChannel>,
3159    session: Session,
3160) -> Result<()> {
3161    let channel_id = ChannelId::from_proto(request.channel_id);
3162    let to = ChannelId::from_proto(request.to);
3163
3164    let (root_id, channels) = session
3165        .db()
3166        .await
3167        .move_channel(channel_id, to, session.user_id())
3168        .await?;
3169
3170    let connection_pool = session.connection_pool().await;
3171    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3172        let channels = channels
3173            .iter()
3174            .filter_map(|channel| {
3175                if role.can_see_channel(channel.visibility) {
3176                    Some(channel.to_proto())
3177                } else {
3178                    None
3179                }
3180            })
3181            .collect::<Vec<_>>();
3182        if channels.is_empty() {
3183            continue;
3184        }
3185
3186        let update = proto::UpdateChannels {
3187            channels,
3188            ..Default::default()
3189        };
3190
3191        session.peer.send(connection_id, update.clone())?;
3192    }
3193
3194    response.send(Ack {})?;
3195    Ok(())
3196}
3197
3198/// Get the list of channel members
3199async fn get_channel_members(
3200    request: proto::GetChannelMembers,
3201    response: Response<proto::GetChannelMembers>,
3202    session: Session,
3203) -> Result<()> {
3204    let db = session.db().await;
3205    let channel_id = ChannelId::from_proto(request.channel_id);
3206    let limit = if request.limit == 0 {
3207        u16::MAX as u64
3208    } else {
3209        request.limit
3210    };
3211    let (members, users) = db
3212        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3213        .await?;
3214    response.send(proto::GetChannelMembersResponse { members, users })?;
3215    Ok(())
3216}
3217
3218/// Accept or decline a channel invitation.
3219async fn respond_to_channel_invite(
3220    request: proto::RespondToChannelInvite,
3221    response: Response<proto::RespondToChannelInvite>,
3222    session: Session,
3223) -> Result<()> {
3224    let db = session.db().await;
3225    let channel_id = ChannelId::from_proto(request.channel_id);
3226    let RespondToChannelInvite {
3227        membership_update,
3228        notifications,
3229    } = db
3230        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3231        .await?;
3232
3233    let mut connection_pool = session.connection_pool().await;
3234    if let Some(membership_update) = membership_update {
3235        notify_membership_updated(
3236            &mut connection_pool,
3237            membership_update,
3238            session.user_id(),
3239            &session.peer,
3240        );
3241    } else {
3242        let update = proto::UpdateChannels {
3243            remove_channel_invitations: vec![channel_id.to_proto()],
3244            ..Default::default()
3245        };
3246
3247        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3248            session.peer.send(connection_id, update.clone())?;
3249        }
3250    };
3251
3252    send_notifications(&connection_pool, &session.peer, notifications);
3253
3254    response.send(proto::Ack {})?;
3255
3256    Ok(())
3257}
3258
3259/// Join the channels' room
3260async fn join_channel(
3261    request: proto::JoinChannel,
3262    response: Response<proto::JoinChannel>,
3263    session: Session,
3264) -> Result<()> {
3265    let channel_id = ChannelId::from_proto(request.channel_id);
3266    join_channel_internal(channel_id, Box::new(response), session).await
3267}
3268
3269trait JoinChannelInternalResponse {
3270    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3271}
3272impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3273    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3274        Response::<proto::JoinChannel>::send(self, result)
3275    }
3276}
3277impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3278    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3279        Response::<proto::JoinRoom>::send(self, result)
3280    }
3281}
3282
3283async fn join_channel_internal(
3284    channel_id: ChannelId,
3285    response: Box<impl JoinChannelInternalResponse>,
3286    session: Session,
3287) -> Result<()> {
3288    let joined_room = {
3289        let mut db = session.db().await;
3290        // If zed quits without leaving the room, and the user re-opens zed before the
3291        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3292        // room they were in.
3293        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3294            tracing::info!(
3295                stale_connection_id = %connection,
3296                "cleaning up stale connection",
3297            );
3298            drop(db);
3299            leave_room_for_session(&session, connection).await?;
3300            db = session.db().await;
3301        }
3302
3303        let (joined_room, membership_updated, role) = db
3304            .join_channel(channel_id, session.user_id(), session.connection_id)
3305            .await?;
3306
3307        let live_kit_connection_info =
3308            session
3309                .app_state
3310                .livekit_client
3311                .as_ref()
3312                .and_then(|live_kit| {
3313                    let (can_publish, token) = if role == ChannelRole::Guest {
3314                        (
3315                            false,
3316                            live_kit
3317                                .guest_token(
3318                                    &joined_room.room.livekit_room,
3319                                    &session.user_id().to_string(),
3320                                )
3321                                .trace_err()?,
3322                        )
3323                    } else {
3324                        (
3325                            true,
3326                            live_kit
3327                                .room_token(
3328                                    &joined_room.room.livekit_room,
3329                                    &session.user_id().to_string(),
3330                                )
3331                                .trace_err()?,
3332                        )
3333                    };
3334
3335                    Some(LiveKitConnectionInfo {
3336                        server_url: live_kit.url().into(),
3337                        token,
3338                        can_publish,
3339                    })
3340                });
3341
3342        response.send(proto::JoinRoomResponse {
3343            room: Some(joined_room.room.clone()),
3344            channel_id: joined_room
3345                .channel
3346                .as_ref()
3347                .map(|channel| channel.id.to_proto()),
3348            live_kit_connection_info,
3349        })?;
3350
3351        let mut connection_pool = session.connection_pool().await;
3352        if let Some(membership_updated) = membership_updated {
3353            notify_membership_updated(
3354                &mut connection_pool,
3355                membership_updated,
3356                session.user_id(),
3357                &session.peer,
3358            );
3359        }
3360
3361        room_updated(&joined_room.room, &session.peer);
3362
3363        joined_room
3364    };
3365
3366    channel_updated(
3367        &joined_room.channel.context("channel not returned")?,
3368        &joined_room.room,
3369        &session.peer,
3370        &*session.connection_pool().await,
3371    );
3372
3373    update_user_contacts(session.user_id(), &session).await?;
3374    Ok(())
3375}
3376
3377/// Start editing the channel notes
3378async fn join_channel_buffer(
3379    request: proto::JoinChannelBuffer,
3380    response: Response<proto::JoinChannelBuffer>,
3381    session: Session,
3382) -> Result<()> {
3383    let db = session.db().await;
3384    let channel_id = ChannelId::from_proto(request.channel_id);
3385
3386    let open_response = db
3387        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3388        .await?;
3389
3390    let collaborators = open_response.collaborators.clone();
3391    response.send(open_response)?;
3392
3393    let update = UpdateChannelBufferCollaborators {
3394        channel_id: channel_id.to_proto(),
3395        collaborators: collaborators.clone(),
3396    };
3397    channel_buffer_updated(
3398        session.connection_id,
3399        collaborators
3400            .iter()
3401            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3402        &update,
3403        &session.peer,
3404    );
3405
3406    Ok(())
3407}
3408
3409/// Edit the channel notes
3410async fn update_channel_buffer(
3411    request: proto::UpdateChannelBuffer,
3412    session: Session,
3413) -> Result<()> {
3414    let db = session.db().await;
3415    let channel_id = ChannelId::from_proto(request.channel_id);
3416
3417    let (collaborators, epoch, version) = db
3418        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3419        .await?;
3420
3421    channel_buffer_updated(
3422        session.connection_id,
3423        collaborators.clone(),
3424        &proto::UpdateChannelBuffer {
3425            channel_id: channel_id.to_proto(),
3426            operations: request.operations,
3427        },
3428        &session.peer,
3429    );
3430
3431    let pool = &*session.connection_pool().await;
3432
3433    let non_collaborators =
3434        pool.channel_connection_ids(channel_id)
3435            .filter_map(|(connection_id, _)| {
3436                if collaborators.contains(&connection_id) {
3437                    None
3438                } else {
3439                    Some(connection_id)
3440                }
3441            });
3442
3443    broadcast(None, non_collaborators, |peer_id| {
3444        session.peer.send(
3445            peer_id,
3446            proto::UpdateChannels {
3447                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3448                    channel_id: channel_id.to_proto(),
3449                    epoch: epoch as u64,
3450                    version: version.clone(),
3451                }],
3452                ..Default::default()
3453            },
3454        )
3455    });
3456
3457    Ok(())
3458}
3459
3460/// Rejoin the channel notes after a connection blip
3461async fn rejoin_channel_buffers(
3462    request: proto::RejoinChannelBuffers,
3463    response: Response<proto::RejoinChannelBuffers>,
3464    session: Session,
3465) -> Result<()> {
3466    let db = session.db().await;
3467    let buffers = db
3468        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3469        .await?;
3470
3471    for rejoined_buffer in &buffers {
3472        let collaborators_to_notify = rejoined_buffer
3473            .buffer
3474            .collaborators
3475            .iter()
3476            .filter_map(|c| Some(c.peer_id?.into()));
3477        channel_buffer_updated(
3478            session.connection_id,
3479            collaborators_to_notify,
3480            &proto::UpdateChannelBufferCollaborators {
3481                channel_id: rejoined_buffer.buffer.channel_id,
3482                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3483            },
3484            &session.peer,
3485        );
3486    }
3487
3488    response.send(proto::RejoinChannelBuffersResponse {
3489        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3490    })?;
3491
3492    Ok(())
3493}
3494
3495/// Stop editing the channel notes
3496async fn leave_channel_buffer(
3497    request: proto::LeaveChannelBuffer,
3498    response: Response<proto::LeaveChannelBuffer>,
3499    session: Session,
3500) -> Result<()> {
3501    let db = session.db().await;
3502    let channel_id = ChannelId::from_proto(request.channel_id);
3503
3504    let left_buffer = db
3505        .leave_channel_buffer(channel_id, session.connection_id)
3506        .await?;
3507
3508    response.send(Ack {})?;
3509
3510    channel_buffer_updated(
3511        session.connection_id,
3512        left_buffer.connections,
3513        &proto::UpdateChannelBufferCollaborators {
3514            channel_id: channel_id.to_proto(),
3515            collaborators: left_buffer.collaborators,
3516        },
3517        &session.peer,
3518    );
3519
3520    Ok(())
3521}
3522
3523fn channel_buffer_updated<T: EnvelopedMessage>(
3524    sender_id: ConnectionId,
3525    collaborators: impl IntoIterator<Item = ConnectionId>,
3526    message: &T,
3527    peer: &Peer,
3528) {
3529    broadcast(Some(sender_id), collaborators, |peer_id| {
3530        peer.send(peer_id, message.clone())
3531    });
3532}
3533
3534fn send_notifications(
3535    connection_pool: &ConnectionPool,
3536    peer: &Peer,
3537    notifications: db::NotificationBatch,
3538) {
3539    for (user_id, notification) in notifications {
3540        for connection_id in connection_pool.user_connection_ids(user_id) {
3541            if let Err(error) = peer.send(
3542                connection_id,
3543                proto::AddNotification {
3544                    notification: Some(notification.clone()),
3545                },
3546            ) {
3547                tracing::error!(
3548                    "failed to send notification to {:?} {}",
3549                    connection_id,
3550                    error
3551                );
3552            }
3553        }
3554    }
3555}
3556
3557/// Send a message to the channel
3558async fn send_channel_message(
3559    request: proto::SendChannelMessage,
3560    response: Response<proto::SendChannelMessage>,
3561    session: Session,
3562) -> Result<()> {
3563    // Validate the message body.
3564    let body = request.body.trim().to_string();
3565    if body.len() > MAX_MESSAGE_LEN {
3566        return Err(anyhow!("message is too long"))?;
3567    }
3568    if body.is_empty() {
3569        return Err(anyhow!("message can't be blank"))?;
3570    }
3571
3572    // TODO: adjust mentions if body is trimmed
3573
3574    let timestamp = OffsetDateTime::now_utc();
3575    let nonce = request.nonce.context("nonce can't be blank")?;
3576
3577    let channel_id = ChannelId::from_proto(request.channel_id);
3578    let CreatedChannelMessage {
3579        message_id,
3580        participant_connection_ids,
3581        notifications,
3582    } = session
3583        .db()
3584        .await
3585        .create_channel_message(
3586            channel_id,
3587            session.user_id(),
3588            &body,
3589            &request.mentions,
3590            timestamp,
3591            nonce.clone().into(),
3592            request.reply_to_message_id.map(MessageId::from_proto),
3593        )
3594        .await?;
3595
3596    let message = proto::ChannelMessage {
3597        sender_id: session.user_id().to_proto(),
3598        id: message_id.to_proto(),
3599        body,
3600        mentions: request.mentions,
3601        timestamp: timestamp.unix_timestamp() as u64,
3602        nonce: Some(nonce),
3603        reply_to_message_id: request.reply_to_message_id,
3604        edited_at: None,
3605    };
3606    broadcast(
3607        Some(session.connection_id),
3608        participant_connection_ids.clone(),
3609        |connection| {
3610            session.peer.send(
3611                connection,
3612                proto::ChannelMessageSent {
3613                    channel_id: channel_id.to_proto(),
3614                    message: Some(message.clone()),
3615                },
3616            )
3617        },
3618    );
3619    response.send(proto::SendChannelMessageResponse {
3620        message: Some(message),
3621    })?;
3622
3623    let pool = &*session.connection_pool().await;
3624    let non_participants =
3625        pool.channel_connection_ids(channel_id)
3626            .filter_map(|(connection_id, _)| {
3627                if participant_connection_ids.contains(&connection_id) {
3628                    None
3629                } else {
3630                    Some(connection_id)
3631                }
3632            });
3633    broadcast(None, non_participants, |peer_id| {
3634        session.peer.send(
3635            peer_id,
3636            proto::UpdateChannels {
3637                latest_channel_message_ids: vec![proto::ChannelMessageId {
3638                    channel_id: channel_id.to_proto(),
3639                    message_id: message_id.to_proto(),
3640                }],
3641                ..Default::default()
3642            },
3643        )
3644    });
3645    send_notifications(pool, &session.peer, notifications);
3646
3647    Ok(())
3648}
3649
3650/// Delete a channel message
3651async fn remove_channel_message(
3652    request: proto::RemoveChannelMessage,
3653    response: Response<proto::RemoveChannelMessage>,
3654    session: Session,
3655) -> Result<()> {
3656    let channel_id = ChannelId::from_proto(request.channel_id);
3657    let message_id = MessageId::from_proto(request.message_id);
3658    let (connection_ids, existing_notification_ids) = session
3659        .db()
3660        .await
3661        .remove_channel_message(channel_id, message_id, session.user_id())
3662        .await?;
3663
3664    broadcast(
3665        Some(session.connection_id),
3666        connection_ids,
3667        move |connection| {
3668            session.peer.send(connection, request.clone())?;
3669
3670            for notification_id in &existing_notification_ids {
3671                session.peer.send(
3672                    connection,
3673                    proto::DeleteNotification {
3674                        notification_id: (*notification_id).to_proto(),
3675                    },
3676                )?;
3677            }
3678
3679            Ok(())
3680        },
3681    );
3682    response.send(proto::Ack {})?;
3683    Ok(())
3684}
3685
3686async fn update_channel_message(
3687    request: proto::UpdateChannelMessage,
3688    response: Response<proto::UpdateChannelMessage>,
3689    session: Session,
3690) -> Result<()> {
3691    let channel_id = ChannelId::from_proto(request.channel_id);
3692    let message_id = MessageId::from_proto(request.message_id);
3693    let updated_at = OffsetDateTime::now_utc();
3694    let UpdatedChannelMessage {
3695        message_id,
3696        participant_connection_ids,
3697        notifications,
3698        reply_to_message_id,
3699        timestamp,
3700        deleted_mention_notification_ids,
3701        updated_mention_notifications,
3702    } = session
3703        .db()
3704        .await
3705        .update_channel_message(
3706            channel_id,
3707            message_id,
3708            session.user_id(),
3709            request.body.as_str(),
3710            &request.mentions,
3711            updated_at,
3712        )
3713        .await?;
3714
3715    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3716
3717    let message = proto::ChannelMessage {
3718        sender_id: session.user_id().to_proto(),
3719        id: message_id.to_proto(),
3720        body: request.body.clone(),
3721        mentions: request.mentions.clone(),
3722        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3723        nonce: Some(nonce),
3724        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3725        edited_at: Some(updated_at.unix_timestamp() as u64),
3726    };
3727
3728    response.send(proto::Ack {})?;
3729
3730    let pool = &*session.connection_pool().await;
3731    broadcast(
3732        Some(session.connection_id),
3733        participant_connection_ids,
3734        |connection| {
3735            session.peer.send(
3736                connection,
3737                proto::ChannelMessageUpdate {
3738                    channel_id: channel_id.to_proto(),
3739                    message: Some(message.clone()),
3740                },
3741            )?;
3742
3743            for notification_id in &deleted_mention_notification_ids {
3744                session.peer.send(
3745                    connection,
3746                    proto::DeleteNotification {
3747                        notification_id: (*notification_id).to_proto(),
3748                    },
3749                )?;
3750            }
3751
3752            for notification in &updated_mention_notifications {
3753                session.peer.send(
3754                    connection,
3755                    proto::UpdateNotification {
3756                        notification: Some(notification.clone()),
3757                    },
3758                )?;
3759            }
3760
3761            Ok(())
3762        },
3763    );
3764
3765    send_notifications(pool, &session.peer, notifications);
3766
3767    Ok(())
3768}
3769
3770/// Mark a channel message as read
3771async fn acknowledge_channel_message(
3772    request: proto::AckChannelMessage,
3773    session: Session,
3774) -> Result<()> {
3775    let channel_id = ChannelId::from_proto(request.channel_id);
3776    let message_id = MessageId::from_proto(request.message_id);
3777    let notifications = session
3778        .db()
3779        .await
3780        .observe_channel_message(channel_id, session.user_id(), message_id)
3781        .await?;
3782    send_notifications(
3783        &*session.connection_pool().await,
3784        &session.peer,
3785        notifications,
3786    );
3787    Ok(())
3788}
3789
3790/// Mark a buffer version as synced
3791async fn acknowledge_buffer_version(
3792    request: proto::AckBufferOperation,
3793    session: Session,
3794) -> Result<()> {
3795    let buffer_id = BufferId::from_proto(request.buffer_id);
3796    session
3797        .db()
3798        .await
3799        .observe_buffer_version(
3800            buffer_id,
3801            session.user_id(),
3802            request.epoch as i32,
3803            &request.version,
3804        )
3805        .await?;
3806    Ok(())
3807}
3808
3809/// Get a Supermaven API key for the user
3810async fn get_supermaven_api_key(
3811    _request: proto::GetSupermavenApiKey,
3812    response: Response<proto::GetSupermavenApiKey>,
3813    session: Session,
3814) -> Result<()> {
3815    let user_id: String = session.user_id().to_string();
3816    if !session.is_staff() {
3817        return Err(anyhow!("supermaven not enabled for this account"))?;
3818    }
3819
3820    let email = session.email().context("user must have an email")?;
3821
3822    let supermaven_admin_api = session
3823        .supermaven_client
3824        .as_ref()
3825        .context("supermaven not configured")?;
3826
3827    let result = supermaven_admin_api
3828        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3829        .await?;
3830
3831    response.send(proto::GetSupermavenApiKeyResponse {
3832        api_key: result.api_key,
3833    })?;
3834
3835    Ok(())
3836}
3837
3838/// Start receiving chat updates for a channel
3839async fn join_channel_chat(
3840    request: proto::JoinChannelChat,
3841    response: Response<proto::JoinChannelChat>,
3842    session: Session,
3843) -> Result<()> {
3844    let channel_id = ChannelId::from_proto(request.channel_id);
3845
3846    let db = session.db().await;
3847    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3848        .await?;
3849    let messages = db
3850        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3851        .await?;
3852    response.send(proto::JoinChannelChatResponse {
3853        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3854        messages,
3855    })?;
3856    Ok(())
3857}
3858
3859/// Stop receiving chat updates for a channel
3860async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3861    let channel_id = ChannelId::from_proto(request.channel_id);
3862    session
3863        .db()
3864        .await
3865        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3866        .await?;
3867    Ok(())
3868}
3869
3870/// Retrieve the chat history for a channel
3871async fn get_channel_messages(
3872    request: proto::GetChannelMessages,
3873    response: Response<proto::GetChannelMessages>,
3874    session: Session,
3875) -> Result<()> {
3876    let channel_id = ChannelId::from_proto(request.channel_id);
3877    let messages = session
3878        .db()
3879        .await
3880        .get_channel_messages(
3881            channel_id,
3882            session.user_id(),
3883            MESSAGE_COUNT_PER_PAGE,
3884            Some(MessageId::from_proto(request.before_message_id)),
3885        )
3886        .await?;
3887    response.send(proto::GetChannelMessagesResponse {
3888        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3889        messages,
3890    })?;
3891    Ok(())
3892}
3893
3894/// Retrieve specific chat messages
3895async fn get_channel_messages_by_id(
3896    request: proto::GetChannelMessagesById,
3897    response: Response<proto::GetChannelMessagesById>,
3898    session: Session,
3899) -> Result<()> {
3900    let message_ids = request
3901        .message_ids
3902        .iter()
3903        .map(|id| MessageId::from_proto(*id))
3904        .collect::<Vec<_>>();
3905    let messages = session
3906        .db()
3907        .await
3908        .get_channel_messages_by_id(session.user_id(), &message_ids)
3909        .await?;
3910    response.send(proto::GetChannelMessagesResponse {
3911        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3912        messages,
3913    })?;
3914    Ok(())
3915}
3916
3917/// Retrieve the current users notifications
3918async fn get_notifications(
3919    request: proto::GetNotifications,
3920    response: Response<proto::GetNotifications>,
3921    session: Session,
3922) -> Result<()> {
3923    let notifications = session
3924        .db()
3925        .await
3926        .get_notifications(
3927            session.user_id(),
3928            NOTIFICATION_COUNT_PER_PAGE,
3929            request.before_id.map(db::NotificationId::from_proto),
3930        )
3931        .await?;
3932    response.send(proto::GetNotificationsResponse {
3933        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3934        notifications,
3935    })?;
3936    Ok(())
3937}
3938
3939/// Mark notifications as read
3940async fn mark_notification_as_read(
3941    request: proto::MarkNotificationRead,
3942    response: Response<proto::MarkNotificationRead>,
3943    session: Session,
3944) -> Result<()> {
3945    let database = &session.db().await;
3946    let notifications = database
3947        .mark_notification_as_read_by_id(
3948            session.user_id(),
3949            NotificationId::from_proto(request.notification_id),
3950        )
3951        .await?;
3952    send_notifications(
3953        &*session.connection_pool().await,
3954        &session.peer,
3955        notifications,
3956    );
3957    response.send(proto::Ack {})?;
3958    Ok(())
3959}
3960
3961/// Get the current users information
3962async fn get_private_user_info(
3963    _request: proto::GetPrivateUserInfo,
3964    response: Response<proto::GetPrivateUserInfo>,
3965    session: Session,
3966) -> Result<()> {
3967    let db = session.db().await;
3968
3969    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3970    let user = db
3971        .get_user_by_id(session.user_id())
3972        .await?
3973        .context("user not found")?;
3974    let flags = db.get_user_flags(session.user_id()).await?;
3975
3976    response.send(proto::GetPrivateUserInfoResponse {
3977        metrics_id,
3978        staff: user.admin,
3979        flags,
3980        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3981    })?;
3982    Ok(())
3983}
3984
3985/// Accept the terms of service (tos) on behalf of the current user
3986async fn accept_terms_of_service(
3987    _request: proto::AcceptTermsOfService,
3988    response: Response<proto::AcceptTermsOfService>,
3989    session: Session,
3990) -> Result<()> {
3991    let db = session.db().await;
3992
3993    let accepted_tos_at = Utc::now();
3994    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3995        .await?;
3996
3997    response.send(proto::AcceptTermsOfServiceResponse {
3998        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3999    })?;
4000    Ok(())
4001}
4002
4003/// The minimum account age an account must have in order to use the LLM service.
4004pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4005
4006async fn get_llm_api_token(
4007    _request: proto::GetLlmToken,
4008    response: Response<proto::GetLlmToken>,
4009    session: Session,
4010) -> Result<()> {
4011    let db = session.db().await;
4012
4013    let flags = db.get_user_flags(session.user_id()).await?;
4014
4015    let user_id = session.user_id();
4016    let user = db
4017        .get_user_by_id(user_id)
4018        .await?
4019        .with_context(|| format!("user {user_id} not found"))?;
4020
4021    if user.accepted_tos_at.is_none() {
4022        Err(anyhow!("terms of service not accepted"))?
4023    }
4024
4025    let stripe_client = session
4026        .app_state
4027        .stripe_client
4028        .as_ref()
4029        .context("failed to retrieve Stripe client")?;
4030
4031    let stripe_billing = session
4032        .app_state
4033        .stripe_billing
4034        .as_ref()
4035        .context("failed to retrieve Stripe billing object")?;
4036
4037    let billing_customer =
4038        if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
4039            billing_customer
4040        } else {
4041            let customer_id = stripe_billing
4042                .find_or_create_customer_by_email(user.email_address.as_deref())
4043                .await?;
4044
4045            find_or_create_billing_customer(
4046                &session.app_state,
4047                &stripe_client,
4048                stripe::Expandable::Id(customer_id),
4049            )
4050            .await?
4051            .context("billing customer not found")?
4052        };
4053
4054    let billing_subscription =
4055        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4056            billing_subscription
4057        } else {
4058            let stripe_customer_id = billing_customer
4059                .stripe_customer_id
4060                .parse::<stripe::CustomerId>()
4061                .context("failed to parse Stripe customer ID from database")?;
4062
4063            let stripe_subscription = stripe_billing
4064                .subscribe_to_zed_free(stripe_customer_id)
4065                .await?;
4066
4067            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4068                billing_customer_id: billing_customer.id,
4069                kind: Some(SubscriptionKind::ZedFree),
4070                stripe_subscription_id: stripe_subscription.id.to_string(),
4071                stripe_subscription_status: stripe_subscription.status.into(),
4072                stripe_cancellation_reason: None,
4073                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4074                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4075            })
4076            .await?
4077        };
4078
4079    let billing_preferences = db.get_billing_preferences(user.id).await?;
4080
4081    let token = LlmTokenClaims::create(
4082        &user,
4083        session.is_staff(),
4084        billing_preferences,
4085        &flags,
4086        billing_subscription,
4087        session.system_id.clone(),
4088        &session.app_state.config,
4089    )?;
4090    response.send(proto::GetLlmTokenResponse { token })?;
4091    Ok(())
4092}
4093
4094fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4095    let message = match message {
4096        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4097        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4098        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4099        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4100        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4101            code: frame.code.into(),
4102            reason: frame.reason.as_str().to_owned().into(),
4103        })),
4104        // We should never receive a frame while reading the message, according
4105        // to the `tungstenite` maintainers:
4106        //
4107        // > It cannot occur when you read messages from the WebSocket, but it
4108        // > can be used when you want to send the raw frames (e.g. you want to
4109        // > send the frames to the WebSocket without composing the full message first).
4110        // >
4111        // > — https://github.com/snapview/tungstenite-rs/issues/268
4112        TungsteniteMessage::Frame(_) => {
4113            bail!("received an unexpected frame while reading the message")
4114        }
4115    };
4116
4117    Ok(message)
4118}
4119
4120fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4121    match message {
4122        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4123        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4124        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4125        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4126        AxumMessage::Close(frame) => {
4127            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4128                code: frame.code.into(),
4129                reason: frame.reason.as_ref().into(),
4130            }))
4131        }
4132    }
4133}
4134
4135fn notify_membership_updated(
4136    connection_pool: &mut ConnectionPool,
4137    result: MembershipUpdated,
4138    user_id: UserId,
4139    peer: &Peer,
4140) {
4141    for membership in &result.new_channels.channel_memberships {
4142        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4143    }
4144    for channel_id in &result.removed_channels {
4145        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4146    }
4147
4148    let user_channels_update = proto::UpdateUserChannels {
4149        channel_memberships: result
4150            .new_channels
4151            .channel_memberships
4152            .iter()
4153            .map(|cm| proto::ChannelMembership {
4154                channel_id: cm.channel_id.to_proto(),
4155                role: cm.role.into(),
4156            })
4157            .collect(),
4158        ..Default::default()
4159    };
4160
4161    let mut update = build_channels_update(result.new_channels);
4162    update.delete_channels = result
4163        .removed_channels
4164        .into_iter()
4165        .map(|id| id.to_proto())
4166        .collect();
4167    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4168
4169    for connection_id in connection_pool.user_connection_ids(user_id) {
4170        peer.send(connection_id, user_channels_update.clone())
4171            .trace_err();
4172        peer.send(connection_id, update.clone()).trace_err();
4173    }
4174}
4175
4176fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4177    proto::UpdateUserChannels {
4178        channel_memberships: channels
4179            .channel_memberships
4180            .iter()
4181            .map(|m| proto::ChannelMembership {
4182                channel_id: m.channel_id.to_proto(),
4183                role: m.role.into(),
4184            })
4185            .collect(),
4186        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4187        observed_channel_message_id: channels.observed_channel_messages.clone(),
4188    }
4189}
4190
4191fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4192    let mut update = proto::UpdateChannels::default();
4193
4194    for channel in channels.channels {
4195        update.channels.push(channel.to_proto());
4196    }
4197
4198    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4199    update.latest_channel_message_ids = channels.latest_channel_messages;
4200
4201    for (channel_id, participants) in channels.channel_participants {
4202        update
4203            .channel_participants
4204            .push(proto::ChannelParticipants {
4205                channel_id: channel_id.to_proto(),
4206                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4207            });
4208    }
4209
4210    for channel in channels.invited_channels {
4211        update.channel_invitations.push(channel.to_proto());
4212    }
4213
4214    update
4215}
4216
4217fn build_initial_contacts_update(
4218    contacts: Vec<db::Contact>,
4219    pool: &ConnectionPool,
4220) -> proto::UpdateContacts {
4221    let mut update = proto::UpdateContacts::default();
4222
4223    for contact in contacts {
4224        match contact {
4225            db::Contact::Accepted { user_id, busy } => {
4226                update.contacts.push(contact_for_user(user_id, busy, pool));
4227            }
4228            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4229            db::Contact::Incoming { user_id } => {
4230                update
4231                    .incoming_requests
4232                    .push(proto::IncomingContactRequest {
4233                        requester_id: user_id.to_proto(),
4234                    })
4235            }
4236        }
4237    }
4238
4239    update
4240}
4241
4242fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4243    proto::Contact {
4244        user_id: user_id.to_proto(),
4245        online: pool.is_user_online(user_id),
4246        busy,
4247    }
4248}
4249
4250fn room_updated(room: &proto::Room, peer: &Peer) {
4251    broadcast(
4252        None,
4253        room.participants
4254            .iter()
4255            .filter_map(|participant| Some(participant.peer_id?.into())),
4256        |peer_id| {
4257            peer.send(
4258                peer_id,
4259                proto::RoomUpdated {
4260                    room: Some(room.clone()),
4261                },
4262            )
4263        },
4264    );
4265}
4266
4267fn channel_updated(
4268    channel: &db::channel::Model,
4269    room: &proto::Room,
4270    peer: &Peer,
4271    pool: &ConnectionPool,
4272) {
4273    let participants = room
4274        .participants
4275        .iter()
4276        .map(|p| p.user_id)
4277        .collect::<Vec<_>>();
4278
4279    broadcast(
4280        None,
4281        pool.channel_connection_ids(channel.root_id())
4282            .filter_map(|(channel_id, role)| {
4283                role.can_see_channel(channel.visibility)
4284                    .then_some(channel_id)
4285            }),
4286        |peer_id| {
4287            peer.send(
4288                peer_id,
4289                proto::UpdateChannels {
4290                    channel_participants: vec![proto::ChannelParticipants {
4291                        channel_id: channel.id.to_proto(),
4292                        participant_user_ids: participants.clone(),
4293                    }],
4294                    ..Default::default()
4295                },
4296            )
4297        },
4298    );
4299}
4300
4301async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4302    let db = session.db().await;
4303
4304    let contacts = db.get_contacts(user_id).await?;
4305    let busy = db.is_user_busy(user_id).await?;
4306
4307    let pool = session.connection_pool().await;
4308    let updated_contact = contact_for_user(user_id, busy, &pool);
4309    for contact in contacts {
4310        if let db::Contact::Accepted {
4311            user_id: contact_user_id,
4312            ..
4313        } = contact
4314        {
4315            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4316                session
4317                    .peer
4318                    .send(
4319                        contact_conn_id,
4320                        proto::UpdateContacts {
4321                            contacts: vec![updated_contact.clone()],
4322                            remove_contacts: Default::default(),
4323                            incoming_requests: Default::default(),
4324                            remove_incoming_requests: Default::default(),
4325                            outgoing_requests: Default::default(),
4326                            remove_outgoing_requests: Default::default(),
4327                        },
4328                    )
4329                    .trace_err();
4330            }
4331        }
4332    }
4333    Ok(())
4334}
4335
4336async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4337    let mut contacts_to_update = HashSet::default();
4338
4339    let room_id;
4340    let canceled_calls_to_user_ids;
4341    let livekit_room;
4342    let delete_livekit_room;
4343    let room;
4344    let channel;
4345
4346    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4347        contacts_to_update.insert(session.user_id());
4348
4349        for project in left_room.left_projects.values() {
4350            project_left(project, session);
4351        }
4352
4353        room_id = RoomId::from_proto(left_room.room.id);
4354        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4355        livekit_room = mem::take(&mut left_room.room.livekit_room);
4356        delete_livekit_room = left_room.deleted;
4357        room = mem::take(&mut left_room.room);
4358        channel = mem::take(&mut left_room.channel);
4359
4360        room_updated(&room, &session.peer);
4361    } else {
4362        return Ok(());
4363    }
4364
4365    if let Some(channel) = channel {
4366        channel_updated(
4367            &channel,
4368            &room,
4369            &session.peer,
4370            &*session.connection_pool().await,
4371        );
4372    }
4373
4374    {
4375        let pool = session.connection_pool().await;
4376        for canceled_user_id in canceled_calls_to_user_ids {
4377            for connection_id in pool.user_connection_ids(canceled_user_id) {
4378                session
4379                    .peer
4380                    .send(
4381                        connection_id,
4382                        proto::CallCanceled {
4383                            room_id: room_id.to_proto(),
4384                        },
4385                    )
4386                    .trace_err();
4387            }
4388            contacts_to_update.insert(canceled_user_id);
4389        }
4390    }
4391
4392    for contact_user_id in contacts_to_update {
4393        update_user_contacts(contact_user_id, session).await?;
4394    }
4395
4396    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4397        live_kit
4398            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4399            .await
4400            .trace_err();
4401
4402        if delete_livekit_room {
4403            live_kit.delete_room(livekit_room).await.trace_err();
4404        }
4405    }
4406
4407    Ok(())
4408}
4409
4410async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4411    let left_channel_buffers = session
4412        .db()
4413        .await
4414        .leave_channel_buffers(session.connection_id)
4415        .await?;
4416
4417    for left_buffer in left_channel_buffers {
4418        channel_buffer_updated(
4419            session.connection_id,
4420            left_buffer.connections,
4421            &proto::UpdateChannelBufferCollaborators {
4422                channel_id: left_buffer.channel_id.to_proto(),
4423                collaborators: left_buffer.collaborators,
4424            },
4425            &session.peer,
4426        );
4427    }
4428
4429    Ok(())
4430}
4431
4432fn project_left(project: &db::LeftProject, session: &Session) {
4433    for connection_id in &project.connection_ids {
4434        if project.should_unshare {
4435            session
4436                .peer
4437                .send(
4438                    *connection_id,
4439                    proto::UnshareProject {
4440                        project_id: project.id.to_proto(),
4441                    },
4442                )
4443                .trace_err();
4444        } else {
4445            session
4446                .peer
4447                .send(
4448                    *connection_id,
4449                    proto::RemoveProjectCollaborator {
4450                        project_id: project.id.to_proto(),
4451                        peer_id: Some(session.connection_id.into()),
4452                    },
4453                )
4454                .trace_err();
4455        }
4456    }
4457}
4458
4459pub trait ResultExt {
4460    type Ok;
4461
4462    fn trace_err(self) -> Option<Self::Ok>;
4463}
4464
4465impl<T, E> ResultExt for Result<T, E>
4466where
4467    E: std::fmt::Debug,
4468{
4469    type Ok = T;
4470
4471    #[track_caller]
4472    fn trace_err(self) -> Option<T> {
4473        match self {
4474            Ok(value) => Some(value),
4475            Err(error) => {
4476                tracing::error!("{:?}", error);
4477                None
4478            }
4479        }
4480    }
4481}