rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, Database,
   7        MembershipUpdated, MessageId, MoveChannelResult, ProjectId, RenameChannelResult, RoomId,
   8        ServerId, SetChannelVisibilityResult, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::ConnectionPool;
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use lazy_static::lazy_static;
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{info_span, instrument, Instrument};
  67use util::channel::RELEASE_CHANNEL_NAME;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74
  75lazy_static! {
  76    static ref METRIC_CONNECTIONS: IntGauge =
  77        register_int_gauge!("connections", "number of connections").unwrap();
  78    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  79        "shared_projects",
  80        "number of open projects with one or more guests"
  81    )
  82    .unwrap();
  83}
  84
  85type MessageHandler =
  86    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  87
  88struct Response<R> {
  89    peer: Arc<Peer>,
  90    receipt: Receipt<R>,
  91    responded: Arc<AtomicBool>,
  92}
  93
  94impl<R: RequestMessage> Response<R> {
  95    fn send(self, payload: R::Response) -> Result<()> {
  96        self.responded.store(true, SeqCst);
  97        self.peer.respond(self.receipt, payload)?;
  98        Ok(())
  99    }
 100}
 101
 102#[derive(Clone)]
 103struct Session {
 104    user_id: UserId,
 105    connection_id: ConnectionId,
 106    db: Arc<tokio::sync::Mutex<DbHandle>>,
 107    peer: Arc<Peer>,
 108    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 109    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 110    executor: Executor,
 111}
 112
 113impl Session {
 114    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.db.lock().await;
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        guard
 121    }
 122
 123    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 124        #[cfg(test)]
 125        tokio::task::yield_now().await;
 126        let guard = self.connection_pool.lock();
 127        ConnectionPoolGuard {
 128            guard,
 129            _not_send: PhantomData,
 130        }
 131    }
 132}
 133
 134impl fmt::Debug for Session {
 135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 136        f.debug_struct("Session")
 137            .field("user_id", &self.user_id)
 138            .field("connection_id", &self.connection_id)
 139            .finish()
 140    }
 141}
 142
 143struct DbHandle(Arc<Database>);
 144
 145impl Deref for DbHandle {
 146    type Target = Database;
 147
 148    fn deref(&self) -> &Self::Target {
 149        self.0.as_ref()
 150    }
 151}
 152
 153pub struct Server {
 154    id: parking_lot::Mutex<ServerId>,
 155    peer: Arc<Peer>,
 156    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 157    app_state: Arc<AppState>,
 158    executor: Executor,
 159    handlers: HashMap<TypeId, MessageHandler>,
 160    teardown: watch::Sender<()>,
 161}
 162
 163pub(crate) struct ConnectionPoolGuard<'a> {
 164    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 165    _not_send: PhantomData<Rc<()>>,
 166}
 167
 168#[derive(Serialize)]
 169pub struct ServerSnapshot<'a> {
 170    peer: &'a Peer,
 171    #[serde(serialize_with = "serialize_deref")]
 172    connection_pool: ConnectionPoolGuard<'a>,
 173}
 174
 175pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 176where
 177    S: Serializer,
 178    T: Deref<Target = U>,
 179    U: Serialize,
 180{
 181    Serialize::serialize(value.deref(), serializer)
 182}
 183
 184impl Server {
 185    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 186        let mut server = Self {
 187            id: parking_lot::Mutex::new(id),
 188            peer: Peer::new(id.0 as u32),
 189            app_state,
 190            executor,
 191            connection_pool: Default::default(),
 192            handlers: Default::default(),
 193            teardown: watch::channel(()).0,
 194        };
 195
 196        server
 197            .add_request_handler(ping)
 198            .add_request_handler(create_room)
 199            .add_request_handler(join_room)
 200            .add_request_handler(rejoin_room)
 201            .add_request_handler(leave_room)
 202            .add_request_handler(call)
 203            .add_request_handler(cancel_call)
 204            .add_message_handler(decline_call)
 205            .add_request_handler(update_participant_location)
 206            .add_request_handler(share_project)
 207            .add_message_handler(unshare_project)
 208            .add_request_handler(join_project)
 209            .add_message_handler(leave_project)
 210            .add_request_handler(update_project)
 211            .add_request_handler(update_worktree)
 212            .add_message_handler(start_language_server)
 213            .add_message_handler(update_language_server)
 214            .add_message_handler(update_diagnostic_summary)
 215            .add_message_handler(update_worktree_settings)
 216            .add_message_handler(refresh_inlay_hints)
 217            .add_request_handler(forward_project_request::<proto::GetHover>)
 218            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 220            .add_request_handler(forward_project_request::<proto::GetReferences>)
 221            .add_request_handler(forward_project_request::<proto::SearchProject>)
 222            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 223            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 226            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 227            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 229            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 230            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 232            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 233            .add_request_handler(forward_project_request::<proto::PerformRename>)
 234            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 235            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 236            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 237            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 243            .add_request_handler(forward_project_request::<proto::InlayHints>)
 244            .add_message_handler(create_buffer_for_peer)
 245            .add_request_handler(update_buffer)
 246            .add_message_handler(update_buffer_file)
 247            .add_message_handler(buffer_reloaded)
 248            .add_message_handler(buffer_saved)
 249            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 250            .add_request_handler(get_users)
 251            .add_request_handler(fuzzy_search_users)
 252            .add_request_handler(request_contact)
 253            .add_request_handler(remove_contact)
 254            .add_request_handler(respond_to_contact_request)
 255            .add_request_handler(create_channel)
 256            .add_request_handler(delete_channel)
 257            .add_request_handler(invite_channel_member)
 258            .add_request_handler(remove_channel_member)
 259            .add_request_handler(set_channel_member_role)
 260            .add_request_handler(set_channel_visibility)
 261            .add_request_handler(rename_channel)
 262            .add_request_handler(join_channel_buffer)
 263            .add_request_handler(leave_channel_buffer)
 264            .add_message_handler(update_channel_buffer)
 265            .add_request_handler(rejoin_channel_buffers)
 266            .add_request_handler(get_channel_members)
 267            .add_request_handler(respond_to_channel_invite)
 268            .add_request_handler(join_channel)
 269            .add_request_handler(join_channel_chat)
 270            .add_message_handler(leave_channel_chat)
 271            .add_request_handler(send_channel_message)
 272            .add_request_handler(remove_channel_message)
 273            .add_request_handler(get_channel_messages)
 274            .add_request_handler(link_channel)
 275            .add_request_handler(unlink_channel)
 276            .add_request_handler(move_channel)
 277            .add_request_handler(follow)
 278            .add_message_handler(unfollow)
 279            .add_message_handler(update_followers)
 280            .add_message_handler(update_diff_base)
 281            .add_request_handler(get_private_user_info)
 282            .add_message_handler(acknowledge_channel_message)
 283            .add_message_handler(acknowledge_buffer_version);
 284
 285        Arc::new(server)
 286    }
 287
 288    pub async fn start(&self) -> Result<()> {
 289        let server_id = *self.id.lock();
 290        let app_state = self.app_state.clone();
 291        let peer = self.peer.clone();
 292        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 293        let pool = self.connection_pool.clone();
 294        let live_kit_client = self.app_state.live_kit_client.clone();
 295
 296        let span = info_span!("start server");
 297        self.executor.spawn_detached(
 298            async move {
 299                tracing::info!("waiting for cleanup timeout");
 300                timeout.await;
 301                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 302                if let Some((room_ids, channel_ids)) = app_state
 303                    .db
 304                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 305                    .await
 306                    .trace_err()
 307                {
 308                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 309                    tracing::info!(
 310                        stale_channel_buffer_count = channel_ids.len(),
 311                        "retrieved stale channel buffers"
 312                    );
 313
 314                    for channel_id in channel_ids {
 315                        if let Some(refreshed_channel_buffer) = app_state
 316                            .db
 317                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 318                            .await
 319                            .trace_err()
 320                        {
 321                            for connection_id in refreshed_channel_buffer.connection_ids {
 322                                peer.send(
 323                                    connection_id,
 324                                    proto::UpdateChannelBufferCollaborators {
 325                                        channel_id: channel_id.to_proto(),
 326                                        collaborators: refreshed_channel_buffer
 327                                            .collaborators
 328                                            .clone(),
 329                                    },
 330                                )
 331                                .trace_err();
 332                            }
 333                        }
 334                    }
 335
 336                    for room_id in room_ids {
 337                        let mut contacts_to_update = HashSet::default();
 338                        let mut canceled_calls_to_user_ids = Vec::new();
 339                        let mut live_kit_room = String::new();
 340                        let mut delete_live_kit_room = false;
 341
 342                        if let Some(mut refreshed_room) = app_state
 343                            .db
 344                            .clear_stale_room_participants(room_id, server_id)
 345                            .await
 346                            .trace_err()
 347                        {
 348                            tracing::info!(
 349                                room_id = room_id.0,
 350                                new_participant_count = refreshed_room.room.participants.len(),
 351                                "refreshed room"
 352                            );
 353                            room_updated(&refreshed_room.room, &peer);
 354                            if let Some(channel_id) = refreshed_room.channel_id {
 355                                channel_updated(
 356                                    channel_id,
 357                                    &refreshed_room.room,
 358                                    &refreshed_room.channel_members,
 359                                    &peer,
 360                                    &*pool.lock(),
 361                                );
 362                            }
 363                            contacts_to_update
 364                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 365                            contacts_to_update
 366                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 367                            canceled_calls_to_user_ids =
 368                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 369                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 370                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 371                        }
 372
 373                        {
 374                            let pool = pool.lock();
 375                            for canceled_user_id in canceled_calls_to_user_ids {
 376                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 377                                    peer.send(
 378                                        connection_id,
 379                                        proto::CallCanceled {
 380                                            room_id: room_id.to_proto(),
 381                                        },
 382                                    )
 383                                    .trace_err();
 384                                }
 385                            }
 386                        }
 387
 388                        for user_id in contacts_to_update {
 389                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 390                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 391                            if let Some((busy, contacts)) = busy.zip(contacts) {
 392                                let pool = pool.lock();
 393                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 394                                for contact in contacts {
 395                                    if let db::Contact::Accepted {
 396                                        user_id: contact_user_id,
 397                                        ..
 398                                    } = contact
 399                                    {
 400                                        for contact_conn_id in
 401                                            pool.user_connection_ids(contact_user_id)
 402                                        {
 403                                            peer.send(
 404                                                contact_conn_id,
 405                                                proto::UpdateContacts {
 406                                                    contacts: vec![updated_contact.clone()],
 407                                                    remove_contacts: Default::default(),
 408                                                    incoming_requests: Default::default(),
 409                                                    remove_incoming_requests: Default::default(),
 410                                                    outgoing_requests: Default::default(),
 411                                                    remove_outgoing_requests: Default::default(),
 412                                                },
 413                                            )
 414                                            .trace_err();
 415                                        }
 416                                    }
 417                                }
 418                            }
 419                        }
 420
 421                        if let Some(live_kit) = live_kit_client.as_ref() {
 422                            if delete_live_kit_room {
 423                                live_kit.delete_room(live_kit_room).await.trace_err();
 424                            }
 425                        }
 426                    }
 427                }
 428
 429                app_state
 430                    .db
 431                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 432                    .await
 433                    .trace_err();
 434            }
 435            .instrument(span),
 436        );
 437        Ok(())
 438    }
 439
 440    pub fn teardown(&self) {
 441        self.peer.teardown();
 442        self.connection_pool.lock().reset();
 443        let _ = self.teardown.send(());
 444    }
 445
 446    #[cfg(test)]
 447    pub fn reset(&self, id: ServerId) {
 448        self.teardown();
 449        *self.id.lock() = id;
 450        self.peer.reset(id.0 as u32);
 451    }
 452
 453    #[cfg(test)]
 454    pub fn id(&self) -> ServerId {
 455        *self.id.lock()
 456    }
 457
 458    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 459    where
 460        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 461        Fut: 'static + Send + Future<Output = Result<()>>,
 462        M: EnvelopedMessage,
 463    {
 464        let prev_handler = self.handlers.insert(
 465            TypeId::of::<M>(),
 466            Box::new(move |envelope, session| {
 467                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 468                let span = info_span!(
 469                    "handle message",
 470                    payload_type = envelope.payload_type_name()
 471                );
 472                span.in_scope(|| {
 473                    tracing::info!(
 474                        payload_type = envelope.payload_type_name(),
 475                        "message received"
 476                    );
 477                });
 478                let start_time = Instant::now();
 479                let future = (handler)(*envelope, session);
 480                async move {
 481                    let result = future.await;
 482                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 483                    match result {
 484                        Err(error) => {
 485                            tracing::error!(%error, ?duration_ms, "error handling message")
 486                        }
 487                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 488                    }
 489                }
 490                .instrument(span)
 491                .boxed()
 492            }),
 493        );
 494        if prev_handler.is_some() {
 495            panic!("registered a handler for the same message twice");
 496        }
 497        self
 498    }
 499
 500    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 501    where
 502        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 503        Fut: 'static + Send + Future<Output = Result<()>>,
 504        M: EnvelopedMessage,
 505    {
 506        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 507        self
 508    }
 509
 510    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 511    where
 512        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 513        Fut: Send + Future<Output = Result<()>>,
 514        M: RequestMessage,
 515    {
 516        let handler = Arc::new(handler);
 517        self.add_handler(move |envelope, session| {
 518            let receipt = envelope.receipt();
 519            let handler = handler.clone();
 520            async move {
 521                let peer = session.peer.clone();
 522                let responded = Arc::new(AtomicBool::default());
 523                let response = Response {
 524                    peer: peer.clone(),
 525                    responded: responded.clone(),
 526                    receipt,
 527                };
 528                match (handler)(envelope.payload, response, session).await {
 529                    Ok(()) => {
 530                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 531                            Ok(())
 532                        } else {
 533                            Err(anyhow!("handler did not send a response"))?
 534                        }
 535                    }
 536                    Err(error) => {
 537                        peer.respond_with_error(
 538                            receipt,
 539                            proto::Error {
 540                                message: error.to_string(),
 541                            },
 542                        )?;
 543                        Err(error)
 544                    }
 545                }
 546            }
 547        })
 548    }
 549
 550    pub fn handle_connection(
 551        self: &Arc<Self>,
 552        connection: Connection,
 553        address: String,
 554        user: User,
 555        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 556        executor: Executor,
 557    ) -> impl Future<Output = Result<()>> {
 558        let this = self.clone();
 559        let user_id = user.id;
 560        let login = user.github_login;
 561        let span = info_span!("handle connection", %user_id, %login, %address);
 562        let mut teardown = self.teardown.subscribe();
 563        async move {
 564            let (connection_id, handle_io, mut incoming_rx) = this
 565                .peer
 566                .add_connection(connection, {
 567                    let executor = executor.clone();
 568                    move |duration| executor.sleep(duration)
 569                });
 570
 571            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 572            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 573            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 574
 575            if let Some(send_connection_id) = send_connection_id.take() {
 576                let _ = send_connection_id.send(connection_id);
 577            }
 578
 579            if !user.connected_once {
 580                this.peer.send(connection_id, proto::ShowContacts {})?;
 581                this.app_state.db.set_user_connected_once(user_id, true).await?;
 582            }
 583
 584            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 585                this.app_state.db.get_contacts(user_id),
 586                this.app_state.db.get_channels_for_user(user_id),
 587                this.app_state.db.get_channel_invites_for_user(user_id)
 588            ).await?;
 589
 590            {
 591                let mut pool = this.connection_pool.lock();
 592                pool.add_connection(connection_id, user_id, user.admin);
 593                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 594                this.peer.send(connection_id, build_channels_update(
 595                    channels_for_user,
 596                    channel_invites
 597                ))?;
 598            }
 599
 600            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 601                this.peer.send(connection_id, incoming_call)?;
 602            }
 603
 604            let session = Session {
 605                user_id,
 606                connection_id,
 607                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 608                peer: this.peer.clone(),
 609                connection_pool: this.connection_pool.clone(),
 610                live_kit_client: this.app_state.live_kit_client.clone(),
 611                executor: executor.clone(),
 612            };
 613            update_user_contacts(user_id, &session).await?;
 614
 615            let handle_io = handle_io.fuse();
 616            futures::pin_mut!(handle_io);
 617
 618            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 619            // This prevents deadlocks when e.g., client A performs a request to client B and
 620            // client B performs a request to client A. If both clients stop processing further
 621            // messages until their respective request completes, they won't have a chance to
 622            // respond to the other client's request and cause a deadlock.
 623            //
 624            // This arrangement ensures we will attempt to process earlier messages first, but fall
 625            // back to processing messages arrived later in the spirit of making progress.
 626            let mut foreground_message_handlers = FuturesUnordered::new();
 627            let concurrent_handlers = Arc::new(Semaphore::new(256));
 628            loop {
 629                let next_message = async {
 630                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 631                    let message = incoming_rx.next().await;
 632                    (permit, message)
 633                }.fuse();
 634                futures::pin_mut!(next_message);
 635                futures::select_biased! {
 636                    _ = teardown.changed().fuse() => return Ok(()),
 637                    result = handle_io => {
 638                        if let Err(error) = result {
 639                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 640                        }
 641                        break;
 642                    }
 643                    _ = foreground_message_handlers.next() => {}
 644                    next_message = next_message => {
 645                        let (permit, message) = next_message;
 646                        if let Some(message) = message {
 647                            let type_name = message.payload_type_name();
 648                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 649                            let span_enter = span.enter();
 650                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 651                                let is_background = message.is_background();
 652                                let handle_message = (handler)(message, session.clone());
 653                                drop(span_enter);
 654
 655                                let handle_message = async move {
 656                                    handle_message.await;
 657                                    drop(permit);
 658                                }.instrument(span);
 659                                if is_background {
 660                                    executor.spawn_detached(handle_message);
 661                                } else {
 662                                    foreground_message_handlers.push(handle_message);
 663                                }
 664                            } else {
 665                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 666                            }
 667                        } else {
 668                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 669                            break;
 670                        }
 671                    }
 672                }
 673            }
 674
 675            drop(foreground_message_handlers);
 676            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 677            if let Err(error) = connection_lost(session, teardown, executor).await {
 678                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 679            }
 680
 681            Ok(())
 682        }.instrument(span)
 683    }
 684
 685    pub async fn invite_code_redeemed(
 686        self: &Arc<Self>,
 687        inviter_id: UserId,
 688        invitee_id: UserId,
 689    ) -> Result<()> {
 690        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 691            if let Some(code) = &user.invite_code {
 692                let pool = self.connection_pool.lock();
 693                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 694                for connection_id in pool.user_connection_ids(inviter_id) {
 695                    self.peer.send(
 696                        connection_id,
 697                        proto::UpdateContacts {
 698                            contacts: vec![invitee_contact.clone()],
 699                            ..Default::default()
 700                        },
 701                    )?;
 702                    self.peer.send(
 703                        connection_id,
 704                        proto::UpdateInviteInfo {
 705                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 706                            count: user.invite_count as u32,
 707                        },
 708                    )?;
 709                }
 710            }
 711        }
 712        Ok(())
 713    }
 714
 715    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 716        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 717            if let Some(invite_code) = &user.invite_code {
 718                let pool = self.connection_pool.lock();
 719                for connection_id in pool.user_connection_ids(user_id) {
 720                    self.peer.send(
 721                        connection_id,
 722                        proto::UpdateInviteInfo {
 723                            url: format!(
 724                                "{}{}",
 725                                self.app_state.config.invite_link_prefix, invite_code
 726                            ),
 727                            count: user.invite_count as u32,
 728                        },
 729                    )?;
 730                }
 731            }
 732        }
 733        Ok(())
 734    }
 735
 736    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 737        ServerSnapshot {
 738            connection_pool: ConnectionPoolGuard {
 739                guard: self.connection_pool.lock(),
 740                _not_send: PhantomData,
 741            },
 742            peer: &self.peer,
 743        }
 744    }
 745}
 746
 747impl<'a> Deref for ConnectionPoolGuard<'a> {
 748    type Target = ConnectionPool;
 749
 750    fn deref(&self) -> &Self::Target {
 751        &*self.guard
 752    }
 753}
 754
 755impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 756    fn deref_mut(&mut self) -> &mut Self::Target {
 757        &mut *self.guard
 758    }
 759}
 760
 761impl<'a> Drop for ConnectionPoolGuard<'a> {
 762    fn drop(&mut self) {
 763        #[cfg(test)]
 764        self.check_invariants();
 765    }
 766}
 767
 768fn broadcast<F>(
 769    sender_id: Option<ConnectionId>,
 770    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 771    mut f: F,
 772) where
 773    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 774{
 775    for receiver_id in receiver_ids {
 776        if Some(receiver_id) != sender_id {
 777            if let Err(error) = f(receiver_id) {
 778                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 779            }
 780        }
 781    }
 782}
 783
 784lazy_static! {
 785    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 786}
 787
 788pub struct ProtocolVersion(u32);
 789
 790impl Header for ProtocolVersion {
 791    fn name() -> &'static HeaderName {
 792        &ZED_PROTOCOL_VERSION
 793    }
 794
 795    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 796    where
 797        Self: Sized,
 798        I: Iterator<Item = &'i axum::http::HeaderValue>,
 799    {
 800        let version = values
 801            .next()
 802            .ok_or_else(axum::headers::Error::invalid)?
 803            .to_str()
 804            .map_err(|_| axum::headers::Error::invalid())?
 805            .parse()
 806            .map_err(|_| axum::headers::Error::invalid())?;
 807        Ok(Self(version))
 808    }
 809
 810    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 811        values.extend([self.0.to_string().parse().unwrap()]);
 812    }
 813}
 814
 815pub fn routes(server: Arc<Server>) -> Router<Body> {
 816    Router::new()
 817        .route("/rpc", get(handle_websocket_request))
 818        .layer(
 819            ServiceBuilder::new()
 820                .layer(Extension(server.app_state.clone()))
 821                .layer(middleware::from_fn(auth::validate_header)),
 822        )
 823        .route("/metrics", get(handle_metrics))
 824        .layer(Extension(server))
 825}
 826
 827pub async fn handle_websocket_request(
 828    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 829    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 830    Extension(server): Extension<Arc<Server>>,
 831    Extension(user): Extension<User>,
 832    ws: WebSocketUpgrade,
 833) -> axum::response::Response {
 834    if protocol_version != rpc::PROTOCOL_VERSION {
 835        return (
 836            StatusCode::UPGRADE_REQUIRED,
 837            "client must be upgraded".to_string(),
 838        )
 839            .into_response();
 840    }
 841    let socket_address = socket_address.to_string();
 842    ws.on_upgrade(move |socket| {
 843        use util::ResultExt;
 844        let socket = socket
 845            .map_ok(to_tungstenite_message)
 846            .err_into()
 847            .with(|message| async move { Ok(to_axum_message(message)) });
 848        let connection = Connection::new(Box::pin(socket));
 849        async move {
 850            server
 851                .handle_connection(connection, socket_address, user, None, Executor::Production)
 852                .await
 853                .log_err();
 854        }
 855    })
 856}
 857
 858pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 859    let connections = server
 860        .connection_pool
 861        .lock()
 862        .connections()
 863        .filter(|connection| !connection.admin)
 864        .count();
 865
 866    METRIC_CONNECTIONS.set(connections as _);
 867
 868    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 869    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 870
 871    let encoder = prometheus::TextEncoder::new();
 872    let metric_families = prometheus::gather();
 873    let encoded_metrics = encoder
 874        .encode_to_string(&metric_families)
 875        .map_err(|err| anyhow!("{}", err))?;
 876    Ok(encoded_metrics)
 877}
 878
 879#[instrument(err, skip(executor))]
 880async fn connection_lost(
 881    session: Session,
 882    mut teardown: watch::Receiver<()>,
 883    executor: Executor,
 884) -> Result<()> {
 885    session.peer.disconnect(session.connection_id);
 886    session
 887        .connection_pool()
 888        .await
 889        .remove_connection(session.connection_id)?;
 890
 891    session
 892        .db()
 893        .await
 894        .connection_lost(session.connection_id)
 895        .await
 896        .trace_err();
 897
 898    futures::select_biased! {
 899        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 900            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 901            leave_room_for_session(&session).await.trace_err();
 902            leave_channel_buffers_for_session(&session)
 903                .await
 904                .trace_err();
 905
 906            if !session
 907                .connection_pool()
 908                .await
 909                .is_user_online(session.user_id)
 910            {
 911                let db = session.db().await;
 912                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 913                    room_updated(&room, &session.peer);
 914                }
 915            }
 916
 917            update_user_contacts(session.user_id, &session).await?;
 918        }
 919        _ = teardown.changed().fuse() => {}
 920    }
 921
 922    Ok(())
 923}
 924
 925async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 926    response.send(proto::Ack {})?;
 927    Ok(())
 928}
 929
 930async fn create_room(
 931    _request: proto::CreateRoom,
 932    response: Response<proto::CreateRoom>,
 933    session: Session,
 934) -> Result<()> {
 935    let live_kit_room = nanoid::nanoid!(30);
 936
 937    let live_kit_connection_info = {
 938        let live_kit_room = live_kit_room.clone();
 939        let live_kit = session.live_kit_client.as_ref();
 940
 941        util::async_iife!({
 942            let live_kit = live_kit?;
 943
 944            let token = live_kit
 945                .room_token(&live_kit_room, &session.user_id.to_string())
 946                .trace_err()?;
 947
 948            Some(proto::LiveKitConnectionInfo {
 949                server_url: live_kit.url().into(),
 950                token,
 951                can_publish: true,
 952            })
 953        })
 954    }
 955    .await;
 956
 957    let room = session
 958        .db()
 959        .await
 960        .create_room(
 961            session.user_id,
 962            session.connection_id,
 963            &live_kit_room,
 964            RELEASE_CHANNEL_NAME.as_str(),
 965        )
 966        .await?;
 967
 968    response.send(proto::CreateRoomResponse {
 969        room: Some(room.clone()),
 970        live_kit_connection_info,
 971    })?;
 972
 973    update_user_contacts(session.user_id, &session).await?;
 974    Ok(())
 975}
 976
 977async fn join_room(
 978    request: proto::JoinRoom,
 979    response: Response<proto::JoinRoom>,
 980    session: Session,
 981) -> Result<()> {
 982    let room_id = RoomId::from_proto(request.id);
 983
 984    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 985
 986    if let Some(channel_id) = channel_id {
 987        return join_channel_internal(channel_id, Box::new(response), session).await;
 988    }
 989
 990    let joined_room = {
 991        let room = session
 992            .db()
 993            .await
 994            .join_room(
 995                room_id,
 996                session.user_id,
 997                session.connection_id,
 998                RELEASE_CHANNEL_NAME.as_str(),
 999            )
1000            .await?;
1001        room_updated(&room.room, &session.peer);
1002        room.into_inner()
1003    };
1004
1005    for connection_id in session
1006        .connection_pool()
1007        .await
1008        .user_connection_ids(session.user_id)
1009    {
1010        session
1011            .peer
1012            .send(
1013                connection_id,
1014                proto::CallCanceled {
1015                    room_id: room_id.to_proto(),
1016                },
1017            )
1018            .trace_err();
1019    }
1020
1021    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1022        if let Some(token) = live_kit
1023            .room_token(
1024                &joined_room.room.live_kit_room,
1025                &session.user_id.to_string(),
1026            )
1027            .trace_err()
1028        {
1029            Some(proto::LiveKitConnectionInfo {
1030                server_url: live_kit.url().into(),
1031                token,
1032                can_publish: true,
1033            })
1034        } else {
1035            None
1036        }
1037    } else {
1038        None
1039    };
1040
1041    response.send(proto::JoinRoomResponse {
1042        room: Some(joined_room.room),
1043        channel_id: None,
1044        live_kit_connection_info,
1045    })?;
1046
1047    update_user_contacts(session.user_id, &session).await?;
1048    Ok(())
1049}
1050
1051async fn rejoin_room(
1052    request: proto::RejoinRoom,
1053    response: Response<proto::RejoinRoom>,
1054    session: Session,
1055) -> Result<()> {
1056    let room;
1057    let channel_id;
1058    let channel_members;
1059    {
1060        let mut rejoined_room = session
1061            .db()
1062            .await
1063            .rejoin_room(request, session.user_id, session.connection_id)
1064            .await?;
1065
1066        response.send(proto::RejoinRoomResponse {
1067            room: Some(rejoined_room.room.clone()),
1068            reshared_projects: rejoined_room
1069                .reshared_projects
1070                .iter()
1071                .map(|project| proto::ResharedProject {
1072                    id: project.id.to_proto(),
1073                    collaborators: project
1074                        .collaborators
1075                        .iter()
1076                        .map(|collaborator| collaborator.to_proto())
1077                        .collect(),
1078                })
1079                .collect(),
1080            rejoined_projects: rejoined_room
1081                .rejoined_projects
1082                .iter()
1083                .map(|rejoined_project| proto::RejoinedProject {
1084                    id: rejoined_project.id.to_proto(),
1085                    worktrees: rejoined_project
1086                        .worktrees
1087                        .iter()
1088                        .map(|worktree| proto::WorktreeMetadata {
1089                            id: worktree.id,
1090                            root_name: worktree.root_name.clone(),
1091                            visible: worktree.visible,
1092                            abs_path: worktree.abs_path.clone(),
1093                        })
1094                        .collect(),
1095                    collaborators: rejoined_project
1096                        .collaborators
1097                        .iter()
1098                        .map(|collaborator| collaborator.to_proto())
1099                        .collect(),
1100                    language_servers: rejoined_project.language_servers.clone(),
1101                })
1102                .collect(),
1103        })?;
1104        room_updated(&rejoined_room.room, &session.peer);
1105
1106        for project in &rejoined_room.reshared_projects {
1107            for collaborator in &project.collaborators {
1108                session
1109                    .peer
1110                    .send(
1111                        collaborator.connection_id,
1112                        proto::UpdateProjectCollaborator {
1113                            project_id: project.id.to_proto(),
1114                            old_peer_id: Some(project.old_connection_id.into()),
1115                            new_peer_id: Some(session.connection_id.into()),
1116                        },
1117                    )
1118                    .trace_err();
1119            }
1120
1121            broadcast(
1122                Some(session.connection_id),
1123                project
1124                    .collaborators
1125                    .iter()
1126                    .map(|collaborator| collaborator.connection_id),
1127                |connection_id| {
1128                    session.peer.forward_send(
1129                        session.connection_id,
1130                        connection_id,
1131                        proto::UpdateProject {
1132                            project_id: project.id.to_proto(),
1133                            worktrees: project.worktrees.clone(),
1134                        },
1135                    )
1136                },
1137            );
1138        }
1139
1140        for project in &rejoined_room.rejoined_projects {
1141            for collaborator in &project.collaborators {
1142                session
1143                    .peer
1144                    .send(
1145                        collaborator.connection_id,
1146                        proto::UpdateProjectCollaborator {
1147                            project_id: project.id.to_proto(),
1148                            old_peer_id: Some(project.old_connection_id.into()),
1149                            new_peer_id: Some(session.connection_id.into()),
1150                        },
1151                    )
1152                    .trace_err();
1153            }
1154        }
1155
1156        for project in &mut rejoined_room.rejoined_projects {
1157            for worktree in mem::take(&mut project.worktrees) {
1158                #[cfg(any(test, feature = "test-support"))]
1159                const MAX_CHUNK_SIZE: usize = 2;
1160                #[cfg(not(any(test, feature = "test-support")))]
1161                const MAX_CHUNK_SIZE: usize = 256;
1162
1163                // Stream this worktree's entries.
1164                let message = proto::UpdateWorktree {
1165                    project_id: project.id.to_proto(),
1166                    worktree_id: worktree.id,
1167                    abs_path: worktree.abs_path.clone(),
1168                    root_name: worktree.root_name,
1169                    updated_entries: worktree.updated_entries,
1170                    removed_entries: worktree.removed_entries,
1171                    scan_id: worktree.scan_id,
1172                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1173                    updated_repositories: worktree.updated_repositories,
1174                    removed_repositories: worktree.removed_repositories,
1175                };
1176                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1177                    session.peer.send(session.connection_id, update.clone())?;
1178                }
1179
1180                // Stream this worktree's diagnostics.
1181                for summary in worktree.diagnostic_summaries {
1182                    session.peer.send(
1183                        session.connection_id,
1184                        proto::UpdateDiagnosticSummary {
1185                            project_id: project.id.to_proto(),
1186                            worktree_id: worktree.id,
1187                            summary: Some(summary),
1188                        },
1189                    )?;
1190                }
1191
1192                for settings_file in worktree.settings_files {
1193                    session.peer.send(
1194                        session.connection_id,
1195                        proto::UpdateWorktreeSettings {
1196                            project_id: project.id.to_proto(),
1197                            worktree_id: worktree.id,
1198                            path: settings_file.path,
1199                            content: Some(settings_file.content),
1200                        },
1201                    )?;
1202                }
1203            }
1204
1205            for language_server in &project.language_servers {
1206                session.peer.send(
1207                    session.connection_id,
1208                    proto::UpdateLanguageServer {
1209                        project_id: project.id.to_proto(),
1210                        language_server_id: language_server.id,
1211                        variant: Some(
1212                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1213                                proto::LspDiskBasedDiagnosticsUpdated {},
1214                            ),
1215                        ),
1216                    },
1217                )?;
1218            }
1219        }
1220
1221        let rejoined_room = rejoined_room.into_inner();
1222
1223        room = rejoined_room.room;
1224        channel_id = rejoined_room.channel_id;
1225        channel_members = rejoined_room.channel_members;
1226    }
1227
1228    if let Some(channel_id) = channel_id {
1229        channel_updated(
1230            channel_id,
1231            &room,
1232            &channel_members,
1233            &session.peer,
1234            &*session.connection_pool().await,
1235        );
1236    }
1237
1238    update_user_contacts(session.user_id, &session).await?;
1239    Ok(())
1240}
1241
1242async fn leave_room(
1243    _: proto::LeaveRoom,
1244    response: Response<proto::LeaveRoom>,
1245    session: Session,
1246) -> Result<()> {
1247    leave_room_for_session(&session).await?;
1248    response.send(proto::Ack {})?;
1249    Ok(())
1250}
1251
1252async fn call(
1253    request: proto::Call,
1254    response: Response<proto::Call>,
1255    session: Session,
1256) -> Result<()> {
1257    let room_id = RoomId::from_proto(request.room_id);
1258    let calling_user_id = session.user_id;
1259    let calling_connection_id = session.connection_id;
1260    let called_user_id = UserId::from_proto(request.called_user_id);
1261    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1262    if !session
1263        .db()
1264        .await
1265        .has_contact(calling_user_id, called_user_id)
1266        .await?
1267    {
1268        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1269    }
1270
1271    let incoming_call = {
1272        let (room, incoming_call) = &mut *session
1273            .db()
1274            .await
1275            .call(
1276                room_id,
1277                calling_user_id,
1278                calling_connection_id,
1279                called_user_id,
1280                initial_project_id,
1281            )
1282            .await?;
1283        room_updated(&room, &session.peer);
1284        mem::take(incoming_call)
1285    };
1286    update_user_contacts(called_user_id, &session).await?;
1287
1288    let mut calls = session
1289        .connection_pool()
1290        .await
1291        .user_connection_ids(called_user_id)
1292        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1293        .collect::<FuturesUnordered<_>>();
1294
1295    while let Some(call_response) = calls.next().await {
1296        match call_response.as_ref() {
1297            Ok(_) => {
1298                response.send(proto::Ack {})?;
1299                return Ok(());
1300            }
1301            Err(_) => {
1302                call_response.trace_err();
1303            }
1304        }
1305    }
1306
1307    {
1308        let room = session
1309            .db()
1310            .await
1311            .call_failed(room_id, called_user_id)
1312            .await?;
1313        room_updated(&room, &session.peer);
1314    }
1315    update_user_contacts(called_user_id, &session).await?;
1316
1317    Err(anyhow!("failed to ring user"))?
1318}
1319
1320async fn cancel_call(
1321    request: proto::CancelCall,
1322    response: Response<proto::CancelCall>,
1323    session: Session,
1324) -> Result<()> {
1325    let called_user_id = UserId::from_proto(request.called_user_id);
1326    let room_id = RoomId::from_proto(request.room_id);
1327    {
1328        let room = session
1329            .db()
1330            .await
1331            .cancel_call(room_id, session.connection_id, called_user_id)
1332            .await?;
1333        room_updated(&room, &session.peer);
1334    }
1335
1336    for connection_id in session
1337        .connection_pool()
1338        .await
1339        .user_connection_ids(called_user_id)
1340    {
1341        session
1342            .peer
1343            .send(
1344                connection_id,
1345                proto::CallCanceled {
1346                    room_id: room_id.to_proto(),
1347                },
1348            )
1349            .trace_err();
1350    }
1351    response.send(proto::Ack {})?;
1352
1353    update_user_contacts(called_user_id, &session).await?;
1354    Ok(())
1355}
1356
1357async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1358    let room_id = RoomId::from_proto(message.room_id);
1359    {
1360        let room = session
1361            .db()
1362            .await
1363            .decline_call(Some(room_id), session.user_id)
1364            .await?
1365            .ok_or_else(|| anyhow!("failed to decline call"))?;
1366        room_updated(&room, &session.peer);
1367    }
1368
1369    for connection_id in session
1370        .connection_pool()
1371        .await
1372        .user_connection_ids(session.user_id)
1373    {
1374        session
1375            .peer
1376            .send(
1377                connection_id,
1378                proto::CallCanceled {
1379                    room_id: room_id.to_proto(),
1380                },
1381            )
1382            .trace_err();
1383    }
1384    update_user_contacts(session.user_id, &session).await?;
1385    Ok(())
1386}
1387
1388async fn update_participant_location(
1389    request: proto::UpdateParticipantLocation,
1390    response: Response<proto::UpdateParticipantLocation>,
1391    session: Session,
1392) -> Result<()> {
1393    let room_id = RoomId::from_proto(request.room_id);
1394    let location = request
1395        .location
1396        .ok_or_else(|| anyhow!("invalid location"))?;
1397
1398    let db = session.db().await;
1399    let room = db
1400        .update_room_participant_location(room_id, session.connection_id, location)
1401        .await?;
1402
1403    room_updated(&room, &session.peer);
1404    response.send(proto::Ack {})?;
1405    Ok(())
1406}
1407
1408async fn share_project(
1409    request: proto::ShareProject,
1410    response: Response<proto::ShareProject>,
1411    session: Session,
1412) -> Result<()> {
1413    let (project_id, room) = &*session
1414        .db()
1415        .await
1416        .share_project(
1417            RoomId::from_proto(request.room_id),
1418            session.connection_id,
1419            &request.worktrees,
1420        )
1421        .await?;
1422    response.send(proto::ShareProjectResponse {
1423        project_id: project_id.to_proto(),
1424    })?;
1425    room_updated(&room, &session.peer);
1426
1427    Ok(())
1428}
1429
1430async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1431    let project_id = ProjectId::from_proto(message.project_id);
1432
1433    let (room, guest_connection_ids) = &*session
1434        .db()
1435        .await
1436        .unshare_project(project_id, session.connection_id)
1437        .await?;
1438
1439    broadcast(
1440        Some(session.connection_id),
1441        guest_connection_ids.iter().copied(),
1442        |conn_id| session.peer.send(conn_id, message.clone()),
1443    );
1444    room_updated(&room, &session.peer);
1445
1446    Ok(())
1447}
1448
1449async fn join_project(
1450    request: proto::JoinProject,
1451    response: Response<proto::JoinProject>,
1452    session: Session,
1453) -> Result<()> {
1454    let project_id = ProjectId::from_proto(request.project_id);
1455    let guest_user_id = session.user_id;
1456
1457    tracing::info!(%project_id, "join project");
1458
1459    let (project, replica_id) = &mut *session
1460        .db()
1461        .await
1462        .join_project(project_id, session.connection_id)
1463        .await?;
1464
1465    let collaborators = project
1466        .collaborators
1467        .iter()
1468        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1469        .map(|collaborator| collaborator.to_proto())
1470        .collect::<Vec<_>>();
1471
1472    let worktrees = project
1473        .worktrees
1474        .iter()
1475        .map(|(id, worktree)| proto::WorktreeMetadata {
1476            id: *id,
1477            root_name: worktree.root_name.clone(),
1478            visible: worktree.visible,
1479            abs_path: worktree.abs_path.clone(),
1480        })
1481        .collect::<Vec<_>>();
1482
1483    for collaborator in &collaborators {
1484        session
1485            .peer
1486            .send(
1487                collaborator.peer_id.unwrap().into(),
1488                proto::AddProjectCollaborator {
1489                    project_id: project_id.to_proto(),
1490                    collaborator: Some(proto::Collaborator {
1491                        peer_id: Some(session.connection_id.into()),
1492                        replica_id: replica_id.0 as u32,
1493                        user_id: guest_user_id.to_proto(),
1494                    }),
1495                },
1496            )
1497            .trace_err();
1498    }
1499
1500    // First, we send the metadata associated with each worktree.
1501    response.send(proto::JoinProjectResponse {
1502        worktrees: worktrees.clone(),
1503        replica_id: replica_id.0 as u32,
1504        collaborators: collaborators.clone(),
1505        language_servers: project.language_servers.clone(),
1506    })?;
1507
1508    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1509        #[cfg(any(test, feature = "test-support"))]
1510        const MAX_CHUNK_SIZE: usize = 2;
1511        #[cfg(not(any(test, feature = "test-support")))]
1512        const MAX_CHUNK_SIZE: usize = 256;
1513
1514        // Stream this worktree's entries.
1515        let message = proto::UpdateWorktree {
1516            project_id: project_id.to_proto(),
1517            worktree_id,
1518            abs_path: worktree.abs_path.clone(),
1519            root_name: worktree.root_name,
1520            updated_entries: worktree.entries,
1521            removed_entries: Default::default(),
1522            scan_id: worktree.scan_id,
1523            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1524            updated_repositories: worktree.repository_entries.into_values().collect(),
1525            removed_repositories: Default::default(),
1526        };
1527        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1528            session.peer.send(session.connection_id, update.clone())?;
1529        }
1530
1531        // Stream this worktree's diagnostics.
1532        for summary in worktree.diagnostic_summaries {
1533            session.peer.send(
1534                session.connection_id,
1535                proto::UpdateDiagnosticSummary {
1536                    project_id: project_id.to_proto(),
1537                    worktree_id: worktree.id,
1538                    summary: Some(summary),
1539                },
1540            )?;
1541        }
1542
1543        for settings_file in worktree.settings_files {
1544            session.peer.send(
1545                session.connection_id,
1546                proto::UpdateWorktreeSettings {
1547                    project_id: project_id.to_proto(),
1548                    worktree_id: worktree.id,
1549                    path: settings_file.path,
1550                    content: Some(settings_file.content),
1551                },
1552            )?;
1553        }
1554    }
1555
1556    for language_server in &project.language_servers {
1557        session.peer.send(
1558            session.connection_id,
1559            proto::UpdateLanguageServer {
1560                project_id: project_id.to_proto(),
1561                language_server_id: language_server.id,
1562                variant: Some(
1563                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1564                        proto::LspDiskBasedDiagnosticsUpdated {},
1565                    ),
1566                ),
1567            },
1568        )?;
1569    }
1570
1571    Ok(())
1572}
1573
1574async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1575    let sender_id = session.connection_id;
1576    let project_id = ProjectId::from_proto(request.project_id);
1577
1578    let (room, project) = &*session
1579        .db()
1580        .await
1581        .leave_project(project_id, sender_id)
1582        .await?;
1583    tracing::info!(
1584        %project_id,
1585        host_user_id = %project.host_user_id,
1586        host_connection_id = %project.host_connection_id,
1587        "leave project"
1588    );
1589
1590    project_left(&project, &session);
1591    room_updated(&room, &session.peer);
1592
1593    Ok(())
1594}
1595
1596async fn update_project(
1597    request: proto::UpdateProject,
1598    response: Response<proto::UpdateProject>,
1599    session: Session,
1600) -> Result<()> {
1601    let project_id = ProjectId::from_proto(request.project_id);
1602    let (room, guest_connection_ids) = &*session
1603        .db()
1604        .await
1605        .update_project(project_id, session.connection_id, &request.worktrees)
1606        .await?;
1607    broadcast(
1608        Some(session.connection_id),
1609        guest_connection_ids.iter().copied(),
1610        |connection_id| {
1611            session
1612                .peer
1613                .forward_send(session.connection_id, connection_id, request.clone())
1614        },
1615    );
1616    room_updated(&room, &session.peer);
1617    response.send(proto::Ack {})?;
1618
1619    Ok(())
1620}
1621
1622async fn update_worktree(
1623    request: proto::UpdateWorktree,
1624    response: Response<proto::UpdateWorktree>,
1625    session: Session,
1626) -> Result<()> {
1627    let guest_connection_ids = session
1628        .db()
1629        .await
1630        .update_worktree(&request, session.connection_id)
1631        .await?;
1632
1633    broadcast(
1634        Some(session.connection_id),
1635        guest_connection_ids.iter().copied(),
1636        |connection_id| {
1637            session
1638                .peer
1639                .forward_send(session.connection_id, connection_id, request.clone())
1640        },
1641    );
1642    response.send(proto::Ack {})?;
1643    Ok(())
1644}
1645
1646async fn update_diagnostic_summary(
1647    message: proto::UpdateDiagnosticSummary,
1648    session: Session,
1649) -> Result<()> {
1650    let guest_connection_ids = session
1651        .db()
1652        .await
1653        .update_diagnostic_summary(&message, session.connection_id)
1654        .await?;
1655
1656    broadcast(
1657        Some(session.connection_id),
1658        guest_connection_ids.iter().copied(),
1659        |connection_id| {
1660            session
1661                .peer
1662                .forward_send(session.connection_id, connection_id, message.clone())
1663        },
1664    );
1665
1666    Ok(())
1667}
1668
1669async fn update_worktree_settings(
1670    message: proto::UpdateWorktreeSettings,
1671    session: Session,
1672) -> Result<()> {
1673    let guest_connection_ids = session
1674        .db()
1675        .await
1676        .update_worktree_settings(&message, session.connection_id)
1677        .await?;
1678
1679    broadcast(
1680        Some(session.connection_id),
1681        guest_connection_ids.iter().copied(),
1682        |connection_id| {
1683            session
1684                .peer
1685                .forward_send(session.connection_id, connection_id, message.clone())
1686        },
1687    );
1688
1689    Ok(())
1690}
1691
1692async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1693    broadcast_project_message(request.project_id, request, session).await
1694}
1695
1696async fn start_language_server(
1697    request: proto::StartLanguageServer,
1698    session: Session,
1699) -> Result<()> {
1700    let guest_connection_ids = session
1701        .db()
1702        .await
1703        .start_language_server(&request, session.connection_id)
1704        .await?;
1705
1706    broadcast(
1707        Some(session.connection_id),
1708        guest_connection_ids.iter().copied(),
1709        |connection_id| {
1710            session
1711                .peer
1712                .forward_send(session.connection_id, connection_id, request.clone())
1713        },
1714    );
1715    Ok(())
1716}
1717
1718async fn update_language_server(
1719    request: proto::UpdateLanguageServer,
1720    session: Session,
1721) -> Result<()> {
1722    session.executor.record_backtrace();
1723    let project_id = ProjectId::from_proto(request.project_id);
1724    let project_connection_ids = session
1725        .db()
1726        .await
1727        .project_connection_ids(project_id, session.connection_id)
1728        .await?;
1729    broadcast(
1730        Some(session.connection_id),
1731        project_connection_ids.iter().copied(),
1732        |connection_id| {
1733            session
1734                .peer
1735                .forward_send(session.connection_id, connection_id, request.clone())
1736        },
1737    );
1738    Ok(())
1739}
1740
1741async fn forward_project_request<T>(
1742    request: T,
1743    response: Response<T>,
1744    session: Session,
1745) -> Result<()>
1746where
1747    T: EntityMessage + RequestMessage,
1748{
1749    session.executor.record_backtrace();
1750    let project_id = ProjectId::from_proto(request.remote_entity_id());
1751    let host_connection_id = {
1752        let collaborators = session
1753            .db()
1754            .await
1755            .project_collaborators(project_id, session.connection_id)
1756            .await?;
1757        collaborators
1758            .iter()
1759            .find(|collaborator| collaborator.is_host)
1760            .ok_or_else(|| anyhow!("host not found"))?
1761            .connection_id
1762    };
1763
1764    let payload = session
1765        .peer
1766        .forward_request(session.connection_id, host_connection_id, request)
1767        .await?;
1768
1769    response.send(payload)?;
1770    Ok(())
1771}
1772
1773async fn create_buffer_for_peer(
1774    request: proto::CreateBufferForPeer,
1775    session: Session,
1776) -> Result<()> {
1777    session.executor.record_backtrace();
1778    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1779    session
1780        .peer
1781        .forward_send(session.connection_id, peer_id.into(), request)?;
1782    Ok(())
1783}
1784
1785async fn update_buffer(
1786    request: proto::UpdateBuffer,
1787    response: Response<proto::UpdateBuffer>,
1788    session: Session,
1789) -> Result<()> {
1790    session.executor.record_backtrace();
1791    let project_id = ProjectId::from_proto(request.project_id);
1792    let mut guest_connection_ids;
1793    let mut host_connection_id = None;
1794    {
1795        let collaborators = session
1796            .db()
1797            .await
1798            .project_collaborators(project_id, session.connection_id)
1799            .await?;
1800        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1801        for collaborator in collaborators.iter() {
1802            if collaborator.is_host {
1803                host_connection_id = Some(collaborator.connection_id);
1804            } else {
1805                guest_connection_ids.push(collaborator.connection_id);
1806            }
1807        }
1808    }
1809    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1810
1811    session.executor.record_backtrace();
1812    broadcast(
1813        Some(session.connection_id),
1814        guest_connection_ids,
1815        |connection_id| {
1816            session
1817                .peer
1818                .forward_send(session.connection_id, connection_id, request.clone())
1819        },
1820    );
1821    if host_connection_id != session.connection_id {
1822        session
1823            .peer
1824            .forward_request(session.connection_id, host_connection_id, request.clone())
1825            .await?;
1826    }
1827
1828    response.send(proto::Ack {})?;
1829    Ok(())
1830}
1831
1832async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1833    let project_id = ProjectId::from_proto(request.project_id);
1834    let project_connection_ids = session
1835        .db()
1836        .await
1837        .project_connection_ids(project_id, session.connection_id)
1838        .await?;
1839
1840    broadcast(
1841        Some(session.connection_id),
1842        project_connection_ids.iter().copied(),
1843        |connection_id| {
1844            session
1845                .peer
1846                .forward_send(session.connection_id, connection_id, request.clone())
1847        },
1848    );
1849    Ok(())
1850}
1851
1852async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1853    let project_id = ProjectId::from_proto(request.project_id);
1854    let project_connection_ids = session
1855        .db()
1856        .await
1857        .project_connection_ids(project_id, session.connection_id)
1858        .await?;
1859    broadcast(
1860        Some(session.connection_id),
1861        project_connection_ids.iter().copied(),
1862        |connection_id| {
1863            session
1864                .peer
1865                .forward_send(session.connection_id, connection_id, request.clone())
1866        },
1867    );
1868    Ok(())
1869}
1870
1871async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1872    broadcast_project_message(request.project_id, request, session).await
1873}
1874
1875async fn broadcast_project_message<T: EnvelopedMessage>(
1876    project_id: u64,
1877    request: T,
1878    session: Session,
1879) -> Result<()> {
1880    let project_id = ProjectId::from_proto(project_id);
1881    let project_connection_ids = session
1882        .db()
1883        .await
1884        .project_connection_ids(project_id, session.connection_id)
1885        .await?;
1886    broadcast(
1887        Some(session.connection_id),
1888        project_connection_ids.iter().copied(),
1889        |connection_id| {
1890            session
1891                .peer
1892                .forward_send(session.connection_id, connection_id, request.clone())
1893        },
1894    );
1895    Ok(())
1896}
1897
1898async fn follow(
1899    request: proto::Follow,
1900    response: Response<proto::Follow>,
1901    session: Session,
1902) -> Result<()> {
1903    let room_id = RoomId::from_proto(request.room_id);
1904    let project_id = request.project_id.map(ProjectId::from_proto);
1905    let leader_id = request
1906        .leader_id
1907        .ok_or_else(|| anyhow!("invalid leader id"))?
1908        .into();
1909    let follower_id = session.connection_id;
1910
1911    session
1912        .db()
1913        .await
1914        .check_room_participants(room_id, leader_id, session.connection_id)
1915        .await?;
1916
1917    let response_payload = session
1918        .peer
1919        .forward_request(session.connection_id, leader_id, request)
1920        .await?;
1921    response.send(response_payload)?;
1922
1923    if let Some(project_id) = project_id {
1924        let room = session
1925            .db()
1926            .await
1927            .follow(room_id, project_id, leader_id, follower_id)
1928            .await?;
1929        room_updated(&room, &session.peer);
1930    }
1931
1932    Ok(())
1933}
1934
1935async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1936    let room_id = RoomId::from_proto(request.room_id);
1937    let project_id = request.project_id.map(ProjectId::from_proto);
1938    let leader_id = request
1939        .leader_id
1940        .ok_or_else(|| anyhow!("invalid leader id"))?
1941        .into();
1942    let follower_id = session.connection_id;
1943
1944    session
1945        .db()
1946        .await
1947        .check_room_participants(room_id, leader_id, session.connection_id)
1948        .await?;
1949
1950    session
1951        .peer
1952        .forward_send(session.connection_id, leader_id, request)?;
1953
1954    if let Some(project_id) = project_id {
1955        let room = session
1956            .db()
1957            .await
1958            .unfollow(room_id, project_id, leader_id, follower_id)
1959            .await?;
1960        room_updated(&room, &session.peer);
1961    }
1962
1963    Ok(())
1964}
1965
1966async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1967    let room_id = RoomId::from_proto(request.room_id);
1968    let database = session.db.lock().await;
1969
1970    let connection_ids = if let Some(project_id) = request.project_id {
1971        let project_id = ProjectId::from_proto(project_id);
1972        database
1973            .project_connection_ids(project_id, session.connection_id)
1974            .await?
1975    } else {
1976        database
1977            .room_connection_ids(room_id, session.connection_id)
1978            .await?
1979    };
1980
1981    // For now, don't send view update messages back to that view's current leader.
1982    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1983        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1984        _ => None,
1985    });
1986
1987    for follower_peer_id in request.follower_ids.iter().copied() {
1988        let follower_connection_id = follower_peer_id.into();
1989        if Some(follower_peer_id) != connection_id_to_omit
1990            && connection_ids.contains(&follower_connection_id)
1991        {
1992            session.peer.forward_send(
1993                session.connection_id,
1994                follower_connection_id,
1995                request.clone(),
1996            )?;
1997        }
1998    }
1999    Ok(())
2000}
2001
2002async fn get_users(
2003    request: proto::GetUsers,
2004    response: Response<proto::GetUsers>,
2005    session: Session,
2006) -> Result<()> {
2007    let user_ids = request
2008        .user_ids
2009        .into_iter()
2010        .map(UserId::from_proto)
2011        .collect();
2012    let users = session
2013        .db()
2014        .await
2015        .get_users_by_ids(user_ids)
2016        .await?
2017        .into_iter()
2018        .map(|user| proto::User {
2019            id: user.id.to_proto(),
2020            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2021            github_login: user.github_login,
2022        })
2023        .collect();
2024    response.send(proto::UsersResponse { users })?;
2025    Ok(())
2026}
2027
2028async fn fuzzy_search_users(
2029    request: proto::FuzzySearchUsers,
2030    response: Response<proto::FuzzySearchUsers>,
2031    session: Session,
2032) -> Result<()> {
2033    let query = request.query;
2034    let users = match query.len() {
2035        0 => vec![],
2036        1 | 2 => session
2037            .db()
2038            .await
2039            .get_user_by_github_login(&query)
2040            .await?
2041            .into_iter()
2042            .collect(),
2043        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2044    };
2045    let users = users
2046        .into_iter()
2047        .filter(|user| user.id != session.user_id)
2048        .map(|user| proto::User {
2049            id: user.id.to_proto(),
2050            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2051            github_login: user.github_login,
2052        })
2053        .collect();
2054    response.send(proto::UsersResponse { users })?;
2055    Ok(())
2056}
2057
2058async fn request_contact(
2059    request: proto::RequestContact,
2060    response: Response<proto::RequestContact>,
2061    session: Session,
2062) -> Result<()> {
2063    let requester_id = session.user_id;
2064    let responder_id = UserId::from_proto(request.responder_id);
2065    if requester_id == responder_id {
2066        return Err(anyhow!("cannot add yourself as a contact"))?;
2067    }
2068
2069    session
2070        .db()
2071        .await
2072        .send_contact_request(requester_id, responder_id)
2073        .await?;
2074
2075    // Update outgoing contact requests of requester
2076    let mut update = proto::UpdateContacts::default();
2077    update.outgoing_requests.push(responder_id.to_proto());
2078    for connection_id in session
2079        .connection_pool()
2080        .await
2081        .user_connection_ids(requester_id)
2082    {
2083        session.peer.send(connection_id, update.clone())?;
2084    }
2085
2086    // Update incoming contact requests of responder
2087    let mut update = proto::UpdateContacts::default();
2088    update
2089        .incoming_requests
2090        .push(proto::IncomingContactRequest {
2091            requester_id: requester_id.to_proto(),
2092            should_notify: true,
2093        });
2094    for connection_id in session
2095        .connection_pool()
2096        .await
2097        .user_connection_ids(responder_id)
2098    {
2099        session.peer.send(connection_id, update.clone())?;
2100    }
2101
2102    response.send(proto::Ack {})?;
2103    Ok(())
2104}
2105
2106async fn respond_to_contact_request(
2107    request: proto::RespondToContactRequest,
2108    response: Response<proto::RespondToContactRequest>,
2109    session: Session,
2110) -> Result<()> {
2111    let responder_id = session.user_id;
2112    let requester_id = UserId::from_proto(request.requester_id);
2113    let db = session.db().await;
2114    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2115        db.dismiss_contact_notification(responder_id, requester_id)
2116            .await?;
2117    } else {
2118        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2119
2120        db.respond_to_contact_request(responder_id, requester_id, accept)
2121            .await?;
2122        let requester_busy = db.is_user_busy(requester_id).await?;
2123        let responder_busy = db.is_user_busy(responder_id).await?;
2124
2125        let pool = session.connection_pool().await;
2126        // Update responder with new contact
2127        let mut update = proto::UpdateContacts::default();
2128        if accept {
2129            update
2130                .contacts
2131                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2132        }
2133        update
2134            .remove_incoming_requests
2135            .push(requester_id.to_proto());
2136        for connection_id in pool.user_connection_ids(responder_id) {
2137            session.peer.send(connection_id, update.clone())?;
2138        }
2139
2140        // Update requester with new contact
2141        let mut update = proto::UpdateContacts::default();
2142        if accept {
2143            update
2144                .contacts
2145                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2146        }
2147        update
2148            .remove_outgoing_requests
2149            .push(responder_id.to_proto());
2150        for connection_id in pool.user_connection_ids(requester_id) {
2151            session.peer.send(connection_id, update.clone())?;
2152        }
2153    }
2154
2155    response.send(proto::Ack {})?;
2156    Ok(())
2157}
2158
2159async fn remove_contact(
2160    request: proto::RemoveContact,
2161    response: Response<proto::RemoveContact>,
2162    session: Session,
2163) -> Result<()> {
2164    let requester_id = session.user_id;
2165    let responder_id = UserId::from_proto(request.user_id);
2166    let db = session.db().await;
2167    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2168
2169    let pool = session.connection_pool().await;
2170    // Update outgoing contact requests of requester
2171    let mut update = proto::UpdateContacts::default();
2172    if contact_accepted {
2173        update.remove_contacts.push(responder_id.to_proto());
2174    } else {
2175        update
2176            .remove_outgoing_requests
2177            .push(responder_id.to_proto());
2178    }
2179    for connection_id in pool.user_connection_ids(requester_id) {
2180        session.peer.send(connection_id, update.clone())?;
2181    }
2182
2183    // Update incoming contact requests of responder
2184    let mut update = proto::UpdateContacts::default();
2185    if contact_accepted {
2186        update.remove_contacts.push(requester_id.to_proto());
2187    } else {
2188        update
2189            .remove_incoming_requests
2190            .push(requester_id.to_proto());
2191    }
2192    for connection_id in pool.user_connection_ids(responder_id) {
2193        session.peer.send(connection_id, update.clone())?;
2194    }
2195
2196    response.send(proto::Ack {})?;
2197    Ok(())
2198}
2199
2200async fn create_channel(
2201    request: proto::CreateChannel,
2202    response: Response<proto::CreateChannel>,
2203    session: Session,
2204) -> Result<()> {
2205    let db = session.db().await;
2206
2207    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2208    let CreateChannelResult {
2209        channel,
2210        participants_to_update,
2211    } = db
2212        .create_channel(&request.name, parent_id, session.user_id)
2213        .await?;
2214
2215    response.send(proto::CreateChannelResponse {
2216        channel: Some(channel.to_proto()),
2217        parent_id: request.parent_id,
2218    })?;
2219
2220    let connection_pool = session.connection_pool().await;
2221    for (user_id, channels) in participants_to_update {
2222        let update = build_channels_update(channels, vec![]);
2223        for connection_id in connection_pool.user_connection_ids(user_id) {
2224            if user_id == session.user_id {
2225                continue;
2226            }
2227            session.peer.send(connection_id, update.clone())?;
2228        }
2229    }
2230
2231    Ok(())
2232}
2233
2234async fn delete_channel(
2235    request: proto::DeleteChannel,
2236    response: Response<proto::DeleteChannel>,
2237    session: Session,
2238) -> Result<()> {
2239    let db = session.db().await;
2240
2241    let channel_id = request.channel_id;
2242    let (removed_channels, member_ids) = db
2243        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2244        .await?;
2245    response.send(proto::Ack {})?;
2246
2247    // Notify members of removed channels
2248    let mut update = proto::UpdateChannels::default();
2249    update
2250        .delete_channels
2251        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2252
2253    let connection_pool = session.connection_pool().await;
2254    for member_id in member_ids {
2255        for connection_id in connection_pool.user_connection_ids(member_id) {
2256            session.peer.send(connection_id, update.clone())?;
2257        }
2258    }
2259
2260    Ok(())
2261}
2262
2263async fn invite_channel_member(
2264    request: proto::InviteChannelMember,
2265    response: Response<proto::InviteChannelMember>,
2266    session: Session,
2267) -> Result<()> {
2268    let db = session.db().await;
2269    let channel_id = ChannelId::from_proto(request.channel_id);
2270    let invitee_id = UserId::from_proto(request.user_id);
2271    let channel = db
2272        .invite_channel_member(
2273            channel_id,
2274            invitee_id,
2275            session.user_id,
2276            request.role().into(),
2277        )
2278        .await?;
2279
2280    let update = proto::UpdateChannels {
2281        channel_invitations: vec![channel.to_proto()],
2282        ..Default::default()
2283    };
2284
2285    for connection_id in session
2286        .connection_pool()
2287        .await
2288        .user_connection_ids(invitee_id)
2289    {
2290        session.peer.send(connection_id, update.clone())?;
2291    }
2292
2293    response.send(proto::Ack {})?;
2294    Ok(())
2295}
2296
2297async fn remove_channel_member(
2298    request: proto::RemoveChannelMember,
2299    response: Response<proto::RemoveChannelMember>,
2300    session: Session,
2301) -> Result<()> {
2302    let db = session.db().await;
2303    let channel_id = ChannelId::from_proto(request.channel_id);
2304    let member_id = UserId::from_proto(request.user_id);
2305
2306    let membership_updated = db
2307        .remove_channel_member(channel_id, member_id, session.user_id)
2308        .await?;
2309
2310    dbg!(&membership_updated);
2311
2312    notify_membership_updated(membership_updated, member_id, &session).await?;
2313
2314    response.send(proto::Ack {})?;
2315    Ok(())
2316}
2317
2318async fn set_channel_visibility(
2319    request: proto::SetChannelVisibility,
2320    response: Response<proto::SetChannelVisibility>,
2321    session: Session,
2322) -> Result<()> {
2323    let db = session.db().await;
2324    let channel_id = ChannelId::from_proto(request.channel_id);
2325    let visibility = request.visibility().into();
2326
2327    let SetChannelVisibilityResult {
2328        participants_to_update,
2329        participants_to_remove,
2330    } = db
2331        .set_channel_visibility(channel_id, visibility, session.user_id)
2332        .await?;
2333
2334    let connection_pool = session.connection_pool().await;
2335    for (user_id, channels) in participants_to_update {
2336        let update = build_channels_update(channels, vec![]);
2337        for connection_id in connection_pool.user_connection_ids(user_id) {
2338            session.peer.send(connection_id, update.clone())?;
2339        }
2340    }
2341    for user_id in participants_to_remove {
2342        let update = proto::UpdateChannels {
2343            // for public participants  we only need to remove the current channel
2344            // (not descendants)
2345            // because they can still see any public descendants
2346            delete_channels: vec![channel_id.to_proto()],
2347            ..Default::default()
2348        };
2349        for connection_id in connection_pool.user_connection_ids(user_id) {
2350            session.peer.send(connection_id, update.clone())?;
2351        }
2352    }
2353
2354    response.send(proto::Ack {})?;
2355    Ok(())
2356}
2357
2358async fn set_channel_member_role(
2359    request: proto::SetChannelMemberRole,
2360    response: Response<proto::SetChannelMemberRole>,
2361    session: Session,
2362) -> Result<()> {
2363    let db = session.db().await;
2364    let channel_id = ChannelId::from_proto(request.channel_id);
2365    let member_id = UserId::from_proto(request.user_id);
2366    let result = db
2367        .set_channel_member_role(
2368            channel_id,
2369            session.user_id,
2370            member_id,
2371            request.role().into(),
2372        )
2373        .await?;
2374
2375    match result {
2376        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2377            notify_membership_updated(membership_update, member_id, &session).await?;
2378        }
2379        db::SetMemberRoleResult::InviteUpdated(channel) => {
2380            let update = proto::UpdateChannels {
2381                channel_invitations: vec![channel.to_proto()],
2382                ..Default::default()
2383            };
2384
2385            for connection_id in session
2386                .connection_pool()
2387                .await
2388                .user_connection_ids(member_id)
2389            {
2390                session.peer.send(connection_id, update.clone())?;
2391            }
2392        }
2393    }
2394
2395    response.send(proto::Ack {})?;
2396    Ok(())
2397}
2398
2399async fn rename_channel(
2400    request: proto::RenameChannel,
2401    response: Response<proto::RenameChannel>,
2402    session: Session,
2403) -> Result<()> {
2404    let db = session.db().await;
2405    let channel_id = ChannelId::from_proto(request.channel_id);
2406    let RenameChannelResult {
2407        channel,
2408        participants_to_update,
2409    } = db
2410        .rename_channel(channel_id, session.user_id, &request.name)
2411        .await?;
2412
2413    response.send(proto::RenameChannelResponse {
2414        channel: Some(channel.to_proto()),
2415    })?;
2416
2417    let connection_pool = session.connection_pool().await;
2418    for (user_id, channel) in participants_to_update {
2419        for connection_id in connection_pool.user_connection_ids(user_id) {
2420            let update = proto::UpdateChannels {
2421                channels: vec![channel.to_proto()],
2422                ..Default::default()
2423            };
2424
2425            session.peer.send(connection_id, update.clone())?;
2426        }
2427    }
2428
2429    Ok(())
2430}
2431
2432// TODO: Implement in terms of symlinks
2433// Current behavior of this is more like 'Move root channel'
2434async fn link_channel(
2435    request: proto::LinkChannel,
2436    response: Response<proto::LinkChannel>,
2437    session: Session,
2438) -> Result<()> {
2439    let db = session.db().await;
2440    let channel_id = ChannelId::from_proto(request.channel_id);
2441    let to = ChannelId::from_proto(request.to);
2442
2443    let result = db
2444        .move_channel(channel_id, None, to, session.user_id)
2445        .await?;
2446    drop(db);
2447
2448    notify_channel_moved(result, session).await?;
2449
2450    response.send(Ack {})?;
2451
2452    Ok(())
2453}
2454
2455// TODO: Implement in terms of symlinks
2456async fn unlink_channel(
2457    _request: proto::UnlinkChannel,
2458    _response: Response<proto::UnlinkChannel>,
2459    _session: Session,
2460) -> Result<()> {
2461    Err(anyhow!("unimplemented").into())
2462}
2463
2464async fn move_channel(
2465    request: proto::MoveChannel,
2466    response: Response<proto::MoveChannel>,
2467    session: Session,
2468) -> Result<()> {
2469    let db = session.db().await;
2470    let channel_id = ChannelId::from_proto(request.channel_id);
2471    let from_parent = ChannelId::from_proto(request.from);
2472    let to = ChannelId::from_proto(request.to);
2473
2474    let result = db
2475        .move_channel(channel_id, Some(from_parent), to, session.user_id)
2476        .await?;
2477    drop(db);
2478
2479    notify_channel_moved(result, session).await?;
2480
2481    response.send(Ack {})?;
2482    Ok(())
2483}
2484
2485async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2486    let Some(MoveChannelResult {
2487        participants_to_remove,
2488        participants_to_update,
2489        moved_channels,
2490    }) = result
2491    else {
2492        return Ok(());
2493    };
2494    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2495
2496    let connection_pool = session.connection_pool().await;
2497    for (user_id, channels) in participants_to_update {
2498        let mut update = build_channels_update(channels, vec![]);
2499        update.delete_channels = moved_channels.clone();
2500        for connection_id in connection_pool.user_connection_ids(user_id) {
2501            session.peer.send(connection_id, update.clone())?;
2502        }
2503    }
2504
2505    for user_id in participants_to_remove {
2506        let update = proto::UpdateChannels {
2507            delete_channels: moved_channels.clone(),
2508            ..Default::default()
2509        };
2510        for connection_id in connection_pool.user_connection_ids(user_id) {
2511            session.peer.send(connection_id, update.clone())?;
2512        }
2513    }
2514    Ok(())
2515}
2516
2517async fn get_channel_members(
2518    request: proto::GetChannelMembers,
2519    response: Response<proto::GetChannelMembers>,
2520    session: Session,
2521) -> Result<()> {
2522    let db = session.db().await;
2523    let channel_id = ChannelId::from_proto(request.channel_id);
2524    let members = db
2525        .get_channel_participant_details(channel_id, session.user_id)
2526        .await?;
2527    response.send(proto::GetChannelMembersResponse { members })?;
2528    Ok(())
2529}
2530
2531async fn respond_to_channel_invite(
2532    request: proto::RespondToChannelInvite,
2533    response: Response<proto::RespondToChannelInvite>,
2534    session: Session,
2535) -> Result<()> {
2536    let db = session.db().await;
2537    let channel_id = ChannelId::from_proto(request.channel_id);
2538    let result = db
2539        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2540        .await?;
2541
2542    if let Some(accept_invite_result) = result {
2543        notify_membership_updated(accept_invite_result, session.user_id, &session).await?;
2544    } else {
2545        let update = proto::UpdateChannels {
2546            remove_channel_invitations: vec![channel_id.to_proto()],
2547            ..Default::default()
2548        };
2549
2550        let connection_pool = session.connection_pool().await;
2551        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2552            session.peer.send(connection_id, update.clone())?;
2553        }
2554    };
2555
2556    response.send(proto::Ack {})?;
2557
2558    Ok(())
2559}
2560
2561async fn join_channel(
2562    request: proto::JoinChannel,
2563    response: Response<proto::JoinChannel>,
2564    session: Session,
2565) -> Result<()> {
2566    let channel_id = ChannelId::from_proto(request.channel_id);
2567    join_channel_internal(channel_id, Box::new(response), session).await
2568}
2569
2570trait JoinChannelInternalResponse {
2571    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2572}
2573impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2574    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2575        Response::<proto::JoinChannel>::send(self, result)
2576    }
2577}
2578impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2579    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2580        Response::<proto::JoinRoom>::send(self, result)
2581    }
2582}
2583
2584async fn join_channel_internal(
2585    channel_id: ChannelId,
2586    response: Box<impl JoinChannelInternalResponse>,
2587    session: Session,
2588) -> Result<()> {
2589    let joined_room = {
2590        leave_room_for_session(&session).await?;
2591        let db = session.db().await;
2592
2593        let (joined_room, accept_invite_result, role) = db
2594            .join_channel(
2595                channel_id,
2596                session.user_id,
2597                session.connection_id,
2598                RELEASE_CHANNEL_NAME.as_str(),
2599            )
2600            .await?;
2601
2602        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2603            let (can_publish, token) = if role == ChannelRole::Guest {
2604                (
2605                    false,
2606                    live_kit
2607                        .guest_token(
2608                            &joined_room.room.live_kit_room,
2609                            &session.user_id.to_string(),
2610                        )
2611                        .trace_err()?,
2612                )
2613            } else {
2614                (
2615                    true,
2616                    live_kit
2617                        .room_token(
2618                            &joined_room.room.live_kit_room,
2619                            &session.user_id.to_string(),
2620                        )
2621                        .trace_err()?,
2622                )
2623            };
2624
2625            Some(LiveKitConnectionInfo {
2626                server_url: live_kit.url().into(),
2627                token,
2628                can_publish,
2629            })
2630        });
2631
2632        response.send(proto::JoinRoomResponse {
2633            room: Some(joined_room.room.clone()),
2634            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2635            live_kit_connection_info,
2636        })?;
2637
2638        if let Some(accept_invite_result) = accept_invite_result {
2639            notify_membership_updated(accept_invite_result, session.user_id, &session).await?;
2640        }
2641
2642        room_updated(&joined_room.room, &session.peer);
2643
2644        joined_room
2645    };
2646
2647    channel_updated(
2648        channel_id,
2649        &joined_room.room,
2650        &joined_room.channel_members,
2651        &session.peer,
2652        &*session.connection_pool().await,
2653    );
2654
2655    update_user_contacts(session.user_id, &session).await?;
2656    Ok(())
2657}
2658
2659async fn join_channel_buffer(
2660    request: proto::JoinChannelBuffer,
2661    response: Response<proto::JoinChannelBuffer>,
2662    session: Session,
2663) -> Result<()> {
2664    let db = session.db().await;
2665    let channel_id = ChannelId::from_proto(request.channel_id);
2666
2667    let open_response = db
2668        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2669        .await?;
2670
2671    let collaborators = open_response.collaborators.clone();
2672    response.send(open_response)?;
2673
2674    let update = UpdateChannelBufferCollaborators {
2675        channel_id: channel_id.to_proto(),
2676        collaborators: collaborators.clone(),
2677    };
2678    channel_buffer_updated(
2679        session.connection_id,
2680        collaborators
2681            .iter()
2682            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2683        &update,
2684        &session.peer,
2685    );
2686
2687    Ok(())
2688}
2689
2690async fn update_channel_buffer(
2691    request: proto::UpdateChannelBuffer,
2692    session: Session,
2693) -> Result<()> {
2694    let db = session.db().await;
2695    let channel_id = ChannelId::from_proto(request.channel_id);
2696
2697    let (collaborators, non_collaborators, epoch, version) = db
2698        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2699        .await?;
2700
2701    channel_buffer_updated(
2702        session.connection_id,
2703        collaborators,
2704        &proto::UpdateChannelBuffer {
2705            channel_id: channel_id.to_proto(),
2706            operations: request.operations,
2707        },
2708        &session.peer,
2709    );
2710
2711    let pool = &*session.connection_pool().await;
2712
2713    broadcast(
2714        None,
2715        non_collaborators
2716            .iter()
2717            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2718        |peer_id| {
2719            session.peer.send(
2720                peer_id.into(),
2721                proto::UpdateChannels {
2722                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2723                        channel_id: channel_id.to_proto(),
2724                        epoch: epoch as u64,
2725                        version: version.clone(),
2726                    }],
2727                    ..Default::default()
2728                },
2729            )
2730        },
2731    );
2732
2733    Ok(())
2734}
2735
2736async fn rejoin_channel_buffers(
2737    request: proto::RejoinChannelBuffers,
2738    response: Response<proto::RejoinChannelBuffers>,
2739    session: Session,
2740) -> Result<()> {
2741    let db = session.db().await;
2742    let buffers = db
2743        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2744        .await?;
2745
2746    for rejoined_buffer in &buffers {
2747        let collaborators_to_notify = rejoined_buffer
2748            .buffer
2749            .collaborators
2750            .iter()
2751            .filter_map(|c| Some(c.peer_id?.into()));
2752        channel_buffer_updated(
2753            session.connection_id,
2754            collaborators_to_notify,
2755            &proto::UpdateChannelBufferCollaborators {
2756                channel_id: rejoined_buffer.buffer.channel_id,
2757                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2758            },
2759            &session.peer,
2760        );
2761    }
2762
2763    response.send(proto::RejoinChannelBuffersResponse {
2764        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2765    })?;
2766
2767    Ok(())
2768}
2769
2770async fn leave_channel_buffer(
2771    request: proto::LeaveChannelBuffer,
2772    response: Response<proto::LeaveChannelBuffer>,
2773    session: Session,
2774) -> Result<()> {
2775    let db = session.db().await;
2776    let channel_id = ChannelId::from_proto(request.channel_id);
2777
2778    let left_buffer = db
2779        .leave_channel_buffer(channel_id, session.connection_id)
2780        .await?;
2781
2782    response.send(Ack {})?;
2783
2784    channel_buffer_updated(
2785        session.connection_id,
2786        left_buffer.connections,
2787        &proto::UpdateChannelBufferCollaborators {
2788            channel_id: channel_id.to_proto(),
2789            collaborators: left_buffer.collaborators,
2790        },
2791        &session.peer,
2792    );
2793
2794    Ok(())
2795}
2796
2797fn channel_buffer_updated<T: EnvelopedMessage>(
2798    sender_id: ConnectionId,
2799    collaborators: impl IntoIterator<Item = ConnectionId>,
2800    message: &T,
2801    peer: &Peer,
2802) {
2803    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2804        peer.send(peer_id.into(), message.clone())
2805    });
2806}
2807
2808async fn send_channel_message(
2809    request: proto::SendChannelMessage,
2810    response: Response<proto::SendChannelMessage>,
2811    session: Session,
2812) -> Result<()> {
2813    // Validate the message body.
2814    let body = request.body.trim().to_string();
2815    if body.len() > MAX_MESSAGE_LEN {
2816        return Err(anyhow!("message is too long"))?;
2817    }
2818    if body.is_empty() {
2819        return Err(anyhow!("message can't be blank"))?;
2820    }
2821
2822    let timestamp = OffsetDateTime::now_utc();
2823    let nonce = request
2824        .nonce
2825        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2826
2827    let channel_id = ChannelId::from_proto(request.channel_id);
2828    let (message_id, connection_ids, non_participants) = session
2829        .db()
2830        .await
2831        .create_channel_message(
2832            channel_id,
2833            session.user_id,
2834            &body,
2835            timestamp,
2836            nonce.clone().into(),
2837        )
2838        .await?;
2839    let message = proto::ChannelMessage {
2840        sender_id: session.user_id.to_proto(),
2841        id: message_id.to_proto(),
2842        body,
2843        timestamp: timestamp.unix_timestamp() as u64,
2844        nonce: Some(nonce),
2845    };
2846    broadcast(Some(session.connection_id), connection_ids, |connection| {
2847        session.peer.send(
2848            connection,
2849            proto::ChannelMessageSent {
2850                channel_id: channel_id.to_proto(),
2851                message: Some(message.clone()),
2852            },
2853        )
2854    });
2855    response.send(proto::SendChannelMessageResponse {
2856        message: Some(message),
2857    })?;
2858
2859    let pool = &*session.connection_pool().await;
2860    broadcast(
2861        None,
2862        non_participants
2863            .iter()
2864            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2865        |peer_id| {
2866            session.peer.send(
2867                peer_id.into(),
2868                proto::UpdateChannels {
2869                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2870                        channel_id: channel_id.to_proto(),
2871                        message_id: message_id.to_proto(),
2872                    }],
2873                    ..Default::default()
2874                },
2875            )
2876        },
2877    );
2878
2879    Ok(())
2880}
2881
2882async fn remove_channel_message(
2883    request: proto::RemoveChannelMessage,
2884    response: Response<proto::RemoveChannelMessage>,
2885    session: Session,
2886) -> Result<()> {
2887    let channel_id = ChannelId::from_proto(request.channel_id);
2888    let message_id = MessageId::from_proto(request.message_id);
2889    let connection_ids = session
2890        .db()
2891        .await
2892        .remove_channel_message(channel_id, message_id, session.user_id)
2893        .await?;
2894    broadcast(Some(session.connection_id), connection_ids, |connection| {
2895        session.peer.send(connection, request.clone())
2896    });
2897    response.send(proto::Ack {})?;
2898    Ok(())
2899}
2900
2901async fn acknowledge_channel_message(
2902    request: proto::AckChannelMessage,
2903    session: Session,
2904) -> Result<()> {
2905    let channel_id = ChannelId::from_proto(request.channel_id);
2906    let message_id = MessageId::from_proto(request.message_id);
2907    session
2908        .db()
2909        .await
2910        .observe_channel_message(channel_id, session.user_id, message_id)
2911        .await?;
2912    Ok(())
2913}
2914
2915async fn acknowledge_buffer_version(
2916    request: proto::AckBufferOperation,
2917    session: Session,
2918) -> Result<()> {
2919    let buffer_id = BufferId::from_proto(request.buffer_id);
2920    session
2921        .db()
2922        .await
2923        .observe_buffer_version(
2924            buffer_id,
2925            session.user_id,
2926            request.epoch as i32,
2927            &request.version,
2928        )
2929        .await?;
2930    Ok(())
2931}
2932
2933async fn join_channel_chat(
2934    request: proto::JoinChannelChat,
2935    response: Response<proto::JoinChannelChat>,
2936    session: Session,
2937) -> Result<()> {
2938    let channel_id = ChannelId::from_proto(request.channel_id);
2939
2940    let db = session.db().await;
2941    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2942        .await?;
2943    let messages = db
2944        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2945        .await?;
2946    response.send(proto::JoinChannelChatResponse {
2947        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2948        messages,
2949    })?;
2950    Ok(())
2951}
2952
2953async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2954    let channel_id = ChannelId::from_proto(request.channel_id);
2955    session
2956        .db()
2957        .await
2958        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2959        .await?;
2960    Ok(())
2961}
2962
2963async fn get_channel_messages(
2964    request: proto::GetChannelMessages,
2965    response: Response<proto::GetChannelMessages>,
2966    session: Session,
2967) -> Result<()> {
2968    let channel_id = ChannelId::from_proto(request.channel_id);
2969    let messages = session
2970        .db()
2971        .await
2972        .get_channel_messages(
2973            channel_id,
2974            session.user_id,
2975            MESSAGE_COUNT_PER_PAGE,
2976            Some(MessageId::from_proto(request.before_message_id)),
2977        )
2978        .await?;
2979    response.send(proto::GetChannelMessagesResponse {
2980        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2981        messages,
2982    })?;
2983    Ok(())
2984}
2985
2986async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2987    let project_id = ProjectId::from_proto(request.project_id);
2988    let project_connection_ids = session
2989        .db()
2990        .await
2991        .project_connection_ids(project_id, session.connection_id)
2992        .await?;
2993    broadcast(
2994        Some(session.connection_id),
2995        project_connection_ids.iter().copied(),
2996        |connection_id| {
2997            session
2998                .peer
2999                .forward_send(session.connection_id, connection_id, request.clone())
3000        },
3001    );
3002    Ok(())
3003}
3004
3005async fn get_private_user_info(
3006    _request: proto::GetPrivateUserInfo,
3007    response: Response<proto::GetPrivateUserInfo>,
3008    session: Session,
3009) -> Result<()> {
3010    let db = session.db().await;
3011
3012    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3013    let user = db
3014        .get_user_by_id(session.user_id)
3015        .await?
3016        .ok_or_else(|| anyhow!("user not found"))?;
3017    let flags = db.get_user_flags(session.user_id).await?;
3018
3019    response.send(proto::GetPrivateUserInfoResponse {
3020        metrics_id,
3021        staff: user.admin,
3022        flags,
3023    })?;
3024    Ok(())
3025}
3026
3027fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3028    match message {
3029        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3030        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3031        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3032        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3033        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3034            code: frame.code.into(),
3035            reason: frame.reason,
3036        })),
3037    }
3038}
3039
3040fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3041    match message {
3042        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3043        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3044        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3045        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3046        AxumMessage::Close(frame) => {
3047            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3048                code: frame.code.into(),
3049                reason: frame.reason,
3050            }))
3051        }
3052    }
3053}
3054
3055async fn notify_membership_updated(
3056    result: MembershipUpdated,
3057    user_id: UserId,
3058    session: &Session,
3059) -> Result<()> {
3060    let mut update = build_channels_update(result.new_channels, vec![]);
3061    update.delete_channels = result
3062        .removed_channels
3063        .into_iter()
3064        .map(|id| id.to_proto())
3065        .collect();
3066    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3067
3068    let connection_pool = session.connection_pool().await;
3069    for connection_id in connection_pool.user_connection_ids(user_id) {
3070        session.peer.send(connection_id, update.clone())?;
3071    }
3072    Ok(())
3073}
3074
3075fn build_channels_update(
3076    channels: ChannelsForUser,
3077    channel_invites: Vec<db::Channel>,
3078) -> proto::UpdateChannels {
3079    let mut update = proto::UpdateChannels::default();
3080
3081    for channel in channels.channels.channels {
3082        update.channels.push(proto::Channel {
3083            id: channel.id.to_proto(),
3084            name: channel.name,
3085            visibility: channel.visibility.into(),
3086            role: channel.role.into(),
3087        });
3088    }
3089
3090    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3091    update.unseen_channel_messages = channels.channel_messages;
3092    update.insert_edge = channels.channels.edges;
3093
3094    for (channel_id, participants) in channels.channel_participants {
3095        update
3096            .channel_participants
3097            .push(proto::ChannelParticipants {
3098                channel_id: channel_id.to_proto(),
3099                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3100            });
3101    }
3102
3103    for channel in channel_invites {
3104        update.channel_invitations.push(proto::Channel {
3105            id: channel.id.to_proto(),
3106            name: channel.name,
3107            visibility: channel.visibility.into(),
3108            role: channel.role.into(),
3109        });
3110    }
3111
3112    update
3113}
3114
3115fn build_initial_contacts_update(
3116    contacts: Vec<db::Contact>,
3117    pool: &ConnectionPool,
3118) -> proto::UpdateContacts {
3119    let mut update = proto::UpdateContacts::default();
3120
3121    for contact in contacts {
3122        match contact {
3123            db::Contact::Accepted {
3124                user_id,
3125                should_notify,
3126                busy,
3127            } => {
3128                update
3129                    .contacts
3130                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3131            }
3132            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3133            db::Contact::Incoming {
3134                user_id,
3135                should_notify,
3136            } => update
3137                .incoming_requests
3138                .push(proto::IncomingContactRequest {
3139                    requester_id: user_id.to_proto(),
3140                    should_notify,
3141                }),
3142        }
3143    }
3144
3145    update
3146}
3147
3148fn contact_for_user(
3149    user_id: UserId,
3150    should_notify: bool,
3151    busy: bool,
3152    pool: &ConnectionPool,
3153) -> proto::Contact {
3154    proto::Contact {
3155        user_id: user_id.to_proto(),
3156        online: pool.is_user_online(user_id),
3157        busy,
3158        should_notify,
3159    }
3160}
3161
3162fn room_updated(room: &proto::Room, peer: &Peer) {
3163    broadcast(
3164        None,
3165        room.participants
3166            .iter()
3167            .filter_map(|participant| Some(participant.peer_id?.into())),
3168        |peer_id| {
3169            peer.send(
3170                peer_id.into(),
3171                proto::RoomUpdated {
3172                    room: Some(room.clone()),
3173                },
3174            )
3175        },
3176    );
3177}
3178
3179fn channel_updated(
3180    channel_id: ChannelId,
3181    room: &proto::Room,
3182    channel_members: &[UserId],
3183    peer: &Peer,
3184    pool: &ConnectionPool,
3185) {
3186    let participants = room
3187        .participants
3188        .iter()
3189        .map(|p| p.user_id)
3190        .collect::<Vec<_>>();
3191
3192    broadcast(
3193        None,
3194        channel_members
3195            .iter()
3196            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3197        |peer_id| {
3198            peer.send(
3199                peer_id.into(),
3200                proto::UpdateChannels {
3201                    channel_participants: vec![proto::ChannelParticipants {
3202                        channel_id: channel_id.to_proto(),
3203                        participant_user_ids: participants.clone(),
3204                    }],
3205                    ..Default::default()
3206                },
3207            )
3208        },
3209    );
3210}
3211
3212async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3213    let db = session.db().await;
3214
3215    let contacts = db.get_contacts(user_id).await?;
3216    let busy = db.is_user_busy(user_id).await?;
3217
3218    let pool = session.connection_pool().await;
3219    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3220    for contact in contacts {
3221        if let db::Contact::Accepted {
3222            user_id: contact_user_id,
3223            ..
3224        } = contact
3225        {
3226            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3227                session
3228                    .peer
3229                    .send(
3230                        contact_conn_id,
3231                        proto::UpdateContacts {
3232                            contacts: vec![updated_contact.clone()],
3233                            remove_contacts: Default::default(),
3234                            incoming_requests: Default::default(),
3235                            remove_incoming_requests: Default::default(),
3236                            outgoing_requests: Default::default(),
3237                            remove_outgoing_requests: Default::default(),
3238                        },
3239                    )
3240                    .trace_err();
3241            }
3242        }
3243    }
3244    Ok(())
3245}
3246
3247async fn leave_room_for_session(session: &Session) -> Result<()> {
3248    let mut contacts_to_update = HashSet::default();
3249
3250    let room_id;
3251    let canceled_calls_to_user_ids;
3252    let live_kit_room;
3253    let delete_live_kit_room;
3254    let room;
3255    let channel_members;
3256    let channel_id;
3257
3258    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3259        contacts_to_update.insert(session.user_id);
3260
3261        for project in left_room.left_projects.values() {
3262            project_left(project, session);
3263        }
3264
3265        room_id = RoomId::from_proto(left_room.room.id);
3266        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3267        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3268        delete_live_kit_room = left_room.deleted;
3269        room = mem::take(&mut left_room.room);
3270        channel_members = mem::take(&mut left_room.channel_members);
3271        channel_id = left_room.channel_id;
3272
3273        room_updated(&room, &session.peer);
3274    } else {
3275        return Ok(());
3276    }
3277
3278    if let Some(channel_id) = channel_id {
3279        channel_updated(
3280            channel_id,
3281            &room,
3282            &channel_members,
3283            &session.peer,
3284            &*session.connection_pool().await,
3285        );
3286    }
3287
3288    {
3289        let pool = session.connection_pool().await;
3290        for canceled_user_id in canceled_calls_to_user_ids {
3291            for connection_id in pool.user_connection_ids(canceled_user_id) {
3292                session
3293                    .peer
3294                    .send(
3295                        connection_id,
3296                        proto::CallCanceled {
3297                            room_id: room_id.to_proto(),
3298                        },
3299                    )
3300                    .trace_err();
3301            }
3302            contacts_to_update.insert(canceled_user_id);
3303        }
3304    }
3305
3306    for contact_user_id in contacts_to_update {
3307        update_user_contacts(contact_user_id, &session).await?;
3308    }
3309
3310    if let Some(live_kit) = session.live_kit_client.as_ref() {
3311        live_kit
3312            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3313            .await
3314            .trace_err();
3315
3316        if delete_live_kit_room {
3317            live_kit.delete_room(live_kit_room).await.trace_err();
3318        }
3319    }
3320
3321    Ok(())
3322}
3323
3324async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3325    let left_channel_buffers = session
3326        .db()
3327        .await
3328        .leave_channel_buffers(session.connection_id)
3329        .await?;
3330
3331    for left_buffer in left_channel_buffers {
3332        channel_buffer_updated(
3333            session.connection_id,
3334            left_buffer.connections,
3335            &proto::UpdateChannelBufferCollaborators {
3336                channel_id: left_buffer.channel_id.to_proto(),
3337                collaborators: left_buffer.collaborators,
3338            },
3339            &session.peer,
3340        );
3341    }
3342
3343    Ok(())
3344}
3345
3346fn project_left(project: &db::LeftProject, session: &Session) {
3347    for connection_id in &project.connection_ids {
3348        if project.host_user_id == session.user_id {
3349            session
3350                .peer
3351                .send(
3352                    *connection_id,
3353                    proto::UnshareProject {
3354                        project_id: project.id.to_proto(),
3355                    },
3356                )
3357                .trace_err();
3358        } else {
3359            session
3360                .peer
3361                .send(
3362                    *connection_id,
3363                    proto::RemoveProjectCollaborator {
3364                        project_id: project.id.to_proto(),
3365                        peer_id: Some(session.connection_id.into()),
3366                    },
3367                )
3368                .trace_err();
3369        }
3370    }
3371}
3372
3373pub trait ResultExt {
3374    type Ok;
3375
3376    fn trace_err(self) -> Option<Self::Ok>;
3377}
3378
3379impl<T, E> ResultExt for Result<T, E>
3380where
3381    E: std::fmt::Debug,
3382{
3383    type Ok = T;
3384
3385    fn trace_err(self) -> Option<T> {
3386        match self {
3387            Ok(value) => Some(value),
3388            Err(error) => {
3389                tracing::error!("{:?}", error);
3390                None
3391            }
3392        }
3393    }
3394}