rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelsForUser, CreateChannelResult, Database, MessageId,
   7        MoveChannelResult, ProjectId, RenameChannelResult, RoomId, ServerId,
   8        SetChannelVisibilityResult, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::ConnectionPool;
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use lazy_static::lazy_static;
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{info_span, instrument, Instrument};
  67use util::channel::RELEASE_CHANNEL_NAME;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74
  75lazy_static! {
  76    static ref METRIC_CONNECTIONS: IntGauge =
  77        register_int_gauge!("connections", "number of connections").unwrap();
  78    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  79        "shared_projects",
  80        "number of open projects with one or more guests"
  81    )
  82    .unwrap();
  83}
  84
  85type MessageHandler =
  86    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  87
  88struct Response<R> {
  89    peer: Arc<Peer>,
  90    receipt: Receipt<R>,
  91    responded: Arc<AtomicBool>,
  92}
  93
  94impl<R: RequestMessage> Response<R> {
  95    fn send(self, payload: R::Response) -> Result<()> {
  96        self.responded.store(true, SeqCst);
  97        self.peer.respond(self.receipt, payload)?;
  98        Ok(())
  99    }
 100}
 101
 102#[derive(Clone)]
 103struct Session {
 104    user_id: UserId,
 105    connection_id: ConnectionId,
 106    db: Arc<tokio::sync::Mutex<DbHandle>>,
 107    peer: Arc<Peer>,
 108    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 109    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 110    executor: Executor,
 111}
 112
 113impl Session {
 114    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.db.lock().await;
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        guard
 121    }
 122
 123    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 124        #[cfg(test)]
 125        tokio::task::yield_now().await;
 126        let guard = self.connection_pool.lock();
 127        ConnectionPoolGuard {
 128            guard,
 129            _not_send: PhantomData,
 130        }
 131    }
 132}
 133
 134impl fmt::Debug for Session {
 135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 136        f.debug_struct("Session")
 137            .field("user_id", &self.user_id)
 138            .field("connection_id", &self.connection_id)
 139            .finish()
 140    }
 141}
 142
 143struct DbHandle(Arc<Database>);
 144
 145impl Deref for DbHandle {
 146    type Target = Database;
 147
 148    fn deref(&self) -> &Self::Target {
 149        self.0.as_ref()
 150    }
 151}
 152
 153pub struct Server {
 154    id: parking_lot::Mutex<ServerId>,
 155    peer: Arc<Peer>,
 156    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 157    app_state: Arc<AppState>,
 158    executor: Executor,
 159    handlers: HashMap<TypeId, MessageHandler>,
 160    teardown: watch::Sender<()>,
 161}
 162
 163pub(crate) struct ConnectionPoolGuard<'a> {
 164    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 165    _not_send: PhantomData<Rc<()>>,
 166}
 167
 168#[derive(Serialize)]
 169pub struct ServerSnapshot<'a> {
 170    peer: &'a Peer,
 171    #[serde(serialize_with = "serialize_deref")]
 172    connection_pool: ConnectionPoolGuard<'a>,
 173}
 174
 175pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 176where
 177    S: Serializer,
 178    T: Deref<Target = U>,
 179    U: Serialize,
 180{
 181    Serialize::serialize(value.deref(), serializer)
 182}
 183
 184impl Server {
 185    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 186        let mut server = Self {
 187            id: parking_lot::Mutex::new(id),
 188            peer: Peer::new(id.0 as u32),
 189            app_state,
 190            executor,
 191            connection_pool: Default::default(),
 192            handlers: Default::default(),
 193            teardown: watch::channel(()).0,
 194        };
 195
 196        server
 197            .add_request_handler(ping)
 198            .add_request_handler(create_room)
 199            .add_request_handler(join_room)
 200            .add_request_handler(rejoin_room)
 201            .add_request_handler(leave_room)
 202            .add_request_handler(call)
 203            .add_request_handler(cancel_call)
 204            .add_message_handler(decline_call)
 205            .add_request_handler(update_participant_location)
 206            .add_request_handler(share_project)
 207            .add_message_handler(unshare_project)
 208            .add_request_handler(join_project)
 209            .add_message_handler(leave_project)
 210            .add_request_handler(update_project)
 211            .add_request_handler(update_worktree)
 212            .add_message_handler(start_language_server)
 213            .add_message_handler(update_language_server)
 214            .add_message_handler(update_diagnostic_summary)
 215            .add_message_handler(update_worktree_settings)
 216            .add_message_handler(refresh_inlay_hints)
 217            .add_request_handler(forward_project_request::<proto::GetHover>)
 218            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 220            .add_request_handler(forward_project_request::<proto::GetReferences>)
 221            .add_request_handler(forward_project_request::<proto::SearchProject>)
 222            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 223            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 226            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 227            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 229            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 230            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 232            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 233            .add_request_handler(forward_project_request::<proto::PerformRename>)
 234            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 235            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 236            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 237            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 243            .add_request_handler(forward_project_request::<proto::InlayHints>)
 244            .add_message_handler(create_buffer_for_peer)
 245            .add_request_handler(update_buffer)
 246            .add_message_handler(update_buffer_file)
 247            .add_message_handler(buffer_reloaded)
 248            .add_message_handler(buffer_saved)
 249            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 250            .add_request_handler(get_users)
 251            .add_request_handler(fuzzy_search_users)
 252            .add_request_handler(request_contact)
 253            .add_request_handler(remove_contact)
 254            .add_request_handler(respond_to_contact_request)
 255            .add_request_handler(create_channel)
 256            .add_request_handler(delete_channel)
 257            .add_request_handler(invite_channel_member)
 258            .add_request_handler(remove_channel_member)
 259            .add_request_handler(set_channel_member_role)
 260            .add_request_handler(set_channel_visibility)
 261            .add_request_handler(rename_channel)
 262            .add_request_handler(join_channel_buffer)
 263            .add_request_handler(leave_channel_buffer)
 264            .add_message_handler(update_channel_buffer)
 265            .add_request_handler(rejoin_channel_buffers)
 266            .add_request_handler(get_channel_members)
 267            .add_request_handler(respond_to_channel_invite)
 268            .add_request_handler(join_channel)
 269            .add_request_handler(join_channel_chat)
 270            .add_message_handler(leave_channel_chat)
 271            .add_request_handler(send_channel_message)
 272            .add_request_handler(remove_channel_message)
 273            .add_request_handler(get_channel_messages)
 274            .add_request_handler(link_channel)
 275            .add_request_handler(unlink_channel)
 276            .add_request_handler(move_channel)
 277            .add_request_handler(follow)
 278            .add_message_handler(unfollow)
 279            .add_message_handler(update_followers)
 280            .add_message_handler(update_diff_base)
 281            .add_request_handler(get_private_user_info)
 282            .add_message_handler(acknowledge_channel_message)
 283            .add_message_handler(acknowledge_buffer_version);
 284
 285        Arc::new(server)
 286    }
 287
 288    pub async fn start(&self) -> Result<()> {
 289        let server_id = *self.id.lock();
 290        let app_state = self.app_state.clone();
 291        let peer = self.peer.clone();
 292        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 293        let pool = self.connection_pool.clone();
 294        let live_kit_client = self.app_state.live_kit_client.clone();
 295
 296        let span = info_span!("start server");
 297        self.executor.spawn_detached(
 298            async move {
 299                tracing::info!("waiting for cleanup timeout");
 300                timeout.await;
 301                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 302                if let Some((room_ids, channel_ids)) = app_state
 303                    .db
 304                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 305                    .await
 306                    .trace_err()
 307                {
 308                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 309                    tracing::info!(
 310                        stale_channel_buffer_count = channel_ids.len(),
 311                        "retrieved stale channel buffers"
 312                    );
 313
 314                    for channel_id in channel_ids {
 315                        if let Some(refreshed_channel_buffer) = app_state
 316                            .db
 317                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 318                            .await
 319                            .trace_err()
 320                        {
 321                            for connection_id in refreshed_channel_buffer.connection_ids {
 322                                peer.send(
 323                                    connection_id,
 324                                    proto::UpdateChannelBufferCollaborators {
 325                                        channel_id: channel_id.to_proto(),
 326                                        collaborators: refreshed_channel_buffer
 327                                            .collaborators
 328                                            .clone(),
 329                                    },
 330                                )
 331                                .trace_err();
 332                            }
 333                        }
 334                    }
 335
 336                    for room_id in room_ids {
 337                        let mut contacts_to_update = HashSet::default();
 338                        let mut canceled_calls_to_user_ids = Vec::new();
 339                        let mut live_kit_room = String::new();
 340                        let mut delete_live_kit_room = false;
 341
 342                        if let Some(mut refreshed_room) = app_state
 343                            .db
 344                            .clear_stale_room_participants(room_id, server_id)
 345                            .await
 346                            .trace_err()
 347                        {
 348                            tracing::info!(
 349                                room_id = room_id.0,
 350                                new_participant_count = refreshed_room.room.participants.len(),
 351                                "refreshed room"
 352                            );
 353                            room_updated(&refreshed_room.room, &peer);
 354                            if let Some(channel_id) = refreshed_room.channel_id {
 355                                channel_updated(
 356                                    channel_id,
 357                                    &refreshed_room.room,
 358                                    &refreshed_room.channel_members,
 359                                    &peer,
 360                                    &*pool.lock(),
 361                                );
 362                            }
 363                            contacts_to_update
 364                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 365                            contacts_to_update
 366                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 367                            canceled_calls_to_user_ids =
 368                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 369                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 370                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 371                        }
 372
 373                        {
 374                            let pool = pool.lock();
 375                            for canceled_user_id in canceled_calls_to_user_ids {
 376                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 377                                    peer.send(
 378                                        connection_id,
 379                                        proto::CallCanceled {
 380                                            room_id: room_id.to_proto(),
 381                                        },
 382                                    )
 383                                    .trace_err();
 384                                }
 385                            }
 386                        }
 387
 388                        for user_id in contacts_to_update {
 389                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 390                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 391                            if let Some((busy, contacts)) = busy.zip(contacts) {
 392                                let pool = pool.lock();
 393                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 394                                for contact in contacts {
 395                                    if let db::Contact::Accepted {
 396                                        user_id: contact_user_id,
 397                                        ..
 398                                    } = contact
 399                                    {
 400                                        for contact_conn_id in
 401                                            pool.user_connection_ids(contact_user_id)
 402                                        {
 403                                            peer.send(
 404                                                contact_conn_id,
 405                                                proto::UpdateContacts {
 406                                                    contacts: vec![updated_contact.clone()],
 407                                                    remove_contacts: Default::default(),
 408                                                    incoming_requests: Default::default(),
 409                                                    remove_incoming_requests: Default::default(),
 410                                                    outgoing_requests: Default::default(),
 411                                                    remove_outgoing_requests: Default::default(),
 412                                                },
 413                                            )
 414                                            .trace_err();
 415                                        }
 416                                    }
 417                                }
 418                            }
 419                        }
 420
 421                        if let Some(live_kit) = live_kit_client.as_ref() {
 422                            if delete_live_kit_room {
 423                                live_kit.delete_room(live_kit_room).await.trace_err();
 424                            }
 425                        }
 426                    }
 427                }
 428
 429                app_state
 430                    .db
 431                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 432                    .await
 433                    .trace_err();
 434            }
 435            .instrument(span),
 436        );
 437        Ok(())
 438    }
 439
 440    pub fn teardown(&self) {
 441        self.peer.teardown();
 442        self.connection_pool.lock().reset();
 443        let _ = self.teardown.send(());
 444    }
 445
 446    #[cfg(test)]
 447    pub fn reset(&self, id: ServerId) {
 448        self.teardown();
 449        *self.id.lock() = id;
 450        self.peer.reset(id.0 as u32);
 451    }
 452
 453    #[cfg(test)]
 454    pub fn id(&self) -> ServerId {
 455        *self.id.lock()
 456    }
 457
 458    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 459    where
 460        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 461        Fut: 'static + Send + Future<Output = Result<()>>,
 462        M: EnvelopedMessage,
 463    {
 464        let prev_handler = self.handlers.insert(
 465            TypeId::of::<M>(),
 466            Box::new(move |envelope, session| {
 467                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 468                let span = info_span!(
 469                    "handle message",
 470                    payload_type = envelope.payload_type_name()
 471                );
 472                span.in_scope(|| {
 473                    tracing::info!(
 474                        payload_type = envelope.payload_type_name(),
 475                        "message received"
 476                    );
 477                });
 478                let start_time = Instant::now();
 479                let future = (handler)(*envelope, session);
 480                async move {
 481                    let result = future.await;
 482                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 483                    match result {
 484                        Err(error) => {
 485                            tracing::error!(%error, ?duration_ms, "error handling message")
 486                        }
 487                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 488                    }
 489                }
 490                .instrument(span)
 491                .boxed()
 492            }),
 493        );
 494        if prev_handler.is_some() {
 495            panic!("registered a handler for the same message twice");
 496        }
 497        self
 498    }
 499
 500    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 501    where
 502        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 503        Fut: 'static + Send + Future<Output = Result<()>>,
 504        M: EnvelopedMessage,
 505    {
 506        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 507        self
 508    }
 509
 510    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 511    where
 512        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 513        Fut: Send + Future<Output = Result<()>>,
 514        M: RequestMessage,
 515    {
 516        let handler = Arc::new(handler);
 517        self.add_handler(move |envelope, session| {
 518            let receipt = envelope.receipt();
 519            let handler = handler.clone();
 520            async move {
 521                let peer = session.peer.clone();
 522                let responded = Arc::new(AtomicBool::default());
 523                let response = Response {
 524                    peer: peer.clone(),
 525                    responded: responded.clone(),
 526                    receipt,
 527                };
 528                match (handler)(envelope.payload, response, session).await {
 529                    Ok(()) => {
 530                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 531                            Ok(())
 532                        } else {
 533                            Err(anyhow!("handler did not send a response"))?
 534                        }
 535                    }
 536                    Err(error) => {
 537                        peer.respond_with_error(
 538                            receipt,
 539                            proto::Error {
 540                                message: error.to_string(),
 541                            },
 542                        )?;
 543                        Err(error)
 544                    }
 545                }
 546            }
 547        })
 548    }
 549
 550    pub fn handle_connection(
 551        self: &Arc<Self>,
 552        connection: Connection,
 553        address: String,
 554        user: User,
 555        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 556        executor: Executor,
 557    ) -> impl Future<Output = Result<()>> {
 558        let this = self.clone();
 559        let user_id = user.id;
 560        let login = user.github_login;
 561        let span = info_span!("handle connection", %user_id, %login, %address);
 562        let mut teardown = self.teardown.subscribe();
 563        async move {
 564            let (connection_id, handle_io, mut incoming_rx) = this
 565                .peer
 566                .add_connection(connection, {
 567                    let executor = executor.clone();
 568                    move |duration| executor.sleep(duration)
 569                });
 570
 571            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 572            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 573            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 574
 575            if let Some(send_connection_id) = send_connection_id.take() {
 576                let _ = send_connection_id.send(connection_id);
 577            }
 578
 579            if !user.connected_once {
 580                this.peer.send(connection_id, proto::ShowContacts {})?;
 581                this.app_state.db.set_user_connected_once(user_id, true).await?;
 582            }
 583
 584            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 585                this.app_state.db.get_contacts(user_id),
 586                this.app_state.db.get_channels_for_user(user_id),
 587                this.app_state.db.get_channel_invites_for_user(user_id)
 588            ).await?;
 589
 590            {
 591                let mut pool = this.connection_pool.lock();
 592                pool.add_connection(connection_id, user_id, user.admin);
 593                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 594                this.peer.send(connection_id, build_channels_update(
 595                    channels_for_user,
 596                    channel_invites
 597                ))?;
 598            }
 599
 600            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 601                this.peer.send(connection_id, incoming_call)?;
 602            }
 603
 604            let session = Session {
 605                user_id,
 606                connection_id,
 607                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 608                peer: this.peer.clone(),
 609                connection_pool: this.connection_pool.clone(),
 610                live_kit_client: this.app_state.live_kit_client.clone(),
 611                executor: executor.clone(),
 612            };
 613            update_user_contacts(user_id, &session).await?;
 614
 615            let handle_io = handle_io.fuse();
 616            futures::pin_mut!(handle_io);
 617
 618            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 619            // This prevents deadlocks when e.g., client A performs a request to client B and
 620            // client B performs a request to client A. If both clients stop processing further
 621            // messages until their respective request completes, they won't have a chance to
 622            // respond to the other client's request and cause a deadlock.
 623            //
 624            // This arrangement ensures we will attempt to process earlier messages first, but fall
 625            // back to processing messages arrived later in the spirit of making progress.
 626            let mut foreground_message_handlers = FuturesUnordered::new();
 627            let concurrent_handlers = Arc::new(Semaphore::new(256));
 628            loop {
 629                let next_message = async {
 630                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 631                    let message = incoming_rx.next().await;
 632                    (permit, message)
 633                }.fuse();
 634                futures::pin_mut!(next_message);
 635                futures::select_biased! {
 636                    _ = teardown.changed().fuse() => return Ok(()),
 637                    result = handle_io => {
 638                        if let Err(error) = result {
 639                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 640                        }
 641                        break;
 642                    }
 643                    _ = foreground_message_handlers.next() => {}
 644                    next_message = next_message => {
 645                        let (permit, message) = next_message;
 646                        if let Some(message) = message {
 647                            let type_name = message.payload_type_name();
 648                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 649                            let span_enter = span.enter();
 650                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 651                                let is_background = message.is_background();
 652                                let handle_message = (handler)(message, session.clone());
 653                                drop(span_enter);
 654
 655                                let handle_message = async move {
 656                                    handle_message.await;
 657                                    drop(permit);
 658                                }.instrument(span);
 659                                if is_background {
 660                                    executor.spawn_detached(handle_message);
 661                                } else {
 662                                    foreground_message_handlers.push(handle_message);
 663                                }
 664                            } else {
 665                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 666                            }
 667                        } else {
 668                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 669                            break;
 670                        }
 671                    }
 672                }
 673            }
 674
 675            drop(foreground_message_handlers);
 676            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 677            if let Err(error) = connection_lost(session, teardown, executor).await {
 678                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 679            }
 680
 681            Ok(())
 682        }.instrument(span)
 683    }
 684
 685    pub async fn invite_code_redeemed(
 686        self: &Arc<Self>,
 687        inviter_id: UserId,
 688        invitee_id: UserId,
 689    ) -> Result<()> {
 690        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 691            if let Some(code) = &user.invite_code {
 692                let pool = self.connection_pool.lock();
 693                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 694                for connection_id in pool.user_connection_ids(inviter_id) {
 695                    self.peer.send(
 696                        connection_id,
 697                        proto::UpdateContacts {
 698                            contacts: vec![invitee_contact.clone()],
 699                            ..Default::default()
 700                        },
 701                    )?;
 702                    self.peer.send(
 703                        connection_id,
 704                        proto::UpdateInviteInfo {
 705                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 706                            count: user.invite_count as u32,
 707                        },
 708                    )?;
 709                }
 710            }
 711        }
 712        Ok(())
 713    }
 714
 715    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 716        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 717            if let Some(invite_code) = &user.invite_code {
 718                let pool = self.connection_pool.lock();
 719                for connection_id in pool.user_connection_ids(user_id) {
 720                    self.peer.send(
 721                        connection_id,
 722                        proto::UpdateInviteInfo {
 723                            url: format!(
 724                                "{}{}",
 725                                self.app_state.config.invite_link_prefix, invite_code
 726                            ),
 727                            count: user.invite_count as u32,
 728                        },
 729                    )?;
 730                }
 731            }
 732        }
 733        Ok(())
 734    }
 735
 736    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 737        ServerSnapshot {
 738            connection_pool: ConnectionPoolGuard {
 739                guard: self.connection_pool.lock(),
 740                _not_send: PhantomData,
 741            },
 742            peer: &self.peer,
 743        }
 744    }
 745}
 746
 747impl<'a> Deref for ConnectionPoolGuard<'a> {
 748    type Target = ConnectionPool;
 749
 750    fn deref(&self) -> &Self::Target {
 751        &*self.guard
 752    }
 753}
 754
 755impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 756    fn deref_mut(&mut self) -> &mut Self::Target {
 757        &mut *self.guard
 758    }
 759}
 760
 761impl<'a> Drop for ConnectionPoolGuard<'a> {
 762    fn drop(&mut self) {
 763        #[cfg(test)]
 764        self.check_invariants();
 765    }
 766}
 767
 768fn broadcast<F>(
 769    sender_id: Option<ConnectionId>,
 770    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 771    mut f: F,
 772) where
 773    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 774{
 775    for receiver_id in receiver_ids {
 776        if Some(receiver_id) != sender_id {
 777            if let Err(error) = f(receiver_id) {
 778                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 779            }
 780        }
 781    }
 782}
 783
 784lazy_static! {
 785    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 786}
 787
 788pub struct ProtocolVersion(u32);
 789
 790impl Header for ProtocolVersion {
 791    fn name() -> &'static HeaderName {
 792        &ZED_PROTOCOL_VERSION
 793    }
 794
 795    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 796    where
 797        Self: Sized,
 798        I: Iterator<Item = &'i axum::http::HeaderValue>,
 799    {
 800        let version = values
 801            .next()
 802            .ok_or_else(axum::headers::Error::invalid)?
 803            .to_str()
 804            .map_err(|_| axum::headers::Error::invalid())?
 805            .parse()
 806            .map_err(|_| axum::headers::Error::invalid())?;
 807        Ok(Self(version))
 808    }
 809
 810    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 811        values.extend([self.0.to_string().parse().unwrap()]);
 812    }
 813}
 814
 815pub fn routes(server: Arc<Server>) -> Router<Body> {
 816    Router::new()
 817        .route("/rpc", get(handle_websocket_request))
 818        .layer(
 819            ServiceBuilder::new()
 820                .layer(Extension(server.app_state.clone()))
 821                .layer(middleware::from_fn(auth::validate_header)),
 822        )
 823        .route("/metrics", get(handle_metrics))
 824        .layer(Extension(server))
 825}
 826
 827pub async fn handle_websocket_request(
 828    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 829    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 830    Extension(server): Extension<Arc<Server>>,
 831    Extension(user): Extension<User>,
 832    ws: WebSocketUpgrade,
 833) -> axum::response::Response {
 834    if protocol_version != rpc::PROTOCOL_VERSION {
 835        return (
 836            StatusCode::UPGRADE_REQUIRED,
 837            "client must be upgraded".to_string(),
 838        )
 839            .into_response();
 840    }
 841    let socket_address = socket_address.to_string();
 842    ws.on_upgrade(move |socket| {
 843        use util::ResultExt;
 844        let socket = socket
 845            .map_ok(to_tungstenite_message)
 846            .err_into()
 847            .with(|message| async move { Ok(to_axum_message(message)) });
 848        let connection = Connection::new(Box::pin(socket));
 849        async move {
 850            server
 851                .handle_connection(connection, socket_address, user, None, Executor::Production)
 852                .await
 853                .log_err();
 854        }
 855    })
 856}
 857
 858pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 859    let connections = server
 860        .connection_pool
 861        .lock()
 862        .connections()
 863        .filter(|connection| !connection.admin)
 864        .count();
 865
 866    METRIC_CONNECTIONS.set(connections as _);
 867
 868    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 869    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 870
 871    let encoder = prometheus::TextEncoder::new();
 872    let metric_families = prometheus::gather();
 873    let encoded_metrics = encoder
 874        .encode_to_string(&metric_families)
 875        .map_err(|err| anyhow!("{}", err))?;
 876    Ok(encoded_metrics)
 877}
 878
 879#[instrument(err, skip(executor))]
 880async fn connection_lost(
 881    session: Session,
 882    mut teardown: watch::Receiver<()>,
 883    executor: Executor,
 884) -> Result<()> {
 885    session.peer.disconnect(session.connection_id);
 886    session
 887        .connection_pool()
 888        .await
 889        .remove_connection(session.connection_id)?;
 890
 891    session
 892        .db()
 893        .await
 894        .connection_lost(session.connection_id)
 895        .await
 896        .trace_err();
 897
 898    futures::select_biased! {
 899        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 900            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 901            leave_room_for_session(&session).await.trace_err();
 902            leave_channel_buffers_for_session(&session)
 903                .await
 904                .trace_err();
 905
 906            if !session
 907                .connection_pool()
 908                .await
 909                .is_user_online(session.user_id)
 910            {
 911                let db = session.db().await;
 912                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 913                    room_updated(&room, &session.peer);
 914                }
 915            }
 916
 917            update_user_contacts(session.user_id, &session).await?;
 918        }
 919        _ = teardown.changed().fuse() => {}
 920    }
 921
 922    Ok(())
 923}
 924
 925async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 926    response.send(proto::Ack {})?;
 927    Ok(())
 928}
 929
 930async fn create_room(
 931    _request: proto::CreateRoom,
 932    response: Response<proto::CreateRoom>,
 933    session: Session,
 934) -> Result<()> {
 935    let live_kit_room = nanoid::nanoid!(30);
 936
 937    let live_kit_connection_info = {
 938        let live_kit_room = live_kit_room.clone();
 939        let live_kit = session.live_kit_client.as_ref();
 940
 941        util::async_iife!({
 942            let live_kit = live_kit?;
 943
 944            let token = live_kit
 945                .room_token(&live_kit_room, &session.user_id.to_string())
 946                .trace_err()?;
 947
 948            Some(proto::LiveKitConnectionInfo {
 949                server_url: live_kit.url().into(),
 950                token,
 951            })
 952        })
 953    }
 954    .await;
 955
 956    let room = session
 957        .db()
 958        .await
 959        .create_room(
 960            session.user_id,
 961            session.connection_id,
 962            &live_kit_room,
 963            RELEASE_CHANNEL_NAME.as_str(),
 964        )
 965        .await?;
 966
 967    response.send(proto::CreateRoomResponse {
 968        room: Some(room.clone()),
 969        live_kit_connection_info,
 970    })?;
 971
 972    update_user_contacts(session.user_id, &session).await?;
 973    Ok(())
 974}
 975
 976async fn join_room(
 977    request: proto::JoinRoom,
 978    response: Response<proto::JoinRoom>,
 979    session: Session,
 980) -> Result<()> {
 981    let room_id = RoomId::from_proto(request.id);
 982
 983    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 984
 985    if let Some(channel_id) = channel_id {
 986        return join_channel_internal(channel_id, Box::new(response), session).await;
 987    }
 988
 989    let joined_room = {
 990        let room = session
 991            .db()
 992            .await
 993            .join_room(
 994                room_id,
 995                session.user_id,
 996                session.connection_id,
 997                RELEASE_CHANNEL_NAME.as_str(),
 998            )
 999            .await?;
1000        room_updated(&room.room, &session.peer);
1001        room.into_inner()
1002    };
1003
1004    for connection_id in session
1005        .connection_pool()
1006        .await
1007        .user_connection_ids(session.user_id)
1008    {
1009        session
1010            .peer
1011            .send(
1012                connection_id,
1013                proto::CallCanceled {
1014                    room_id: room_id.to_proto(),
1015                },
1016            )
1017            .trace_err();
1018    }
1019
1020    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1021        if let Some(token) = live_kit
1022            .room_token(
1023                &joined_room.room.live_kit_room,
1024                &session.user_id.to_string(),
1025            )
1026            .trace_err()
1027        {
1028            Some(proto::LiveKitConnectionInfo {
1029                server_url: live_kit.url().into(),
1030                token,
1031            })
1032        } else {
1033            None
1034        }
1035    } else {
1036        None
1037    };
1038
1039    response.send(proto::JoinRoomResponse {
1040        room: Some(joined_room.room),
1041        channel_id: None,
1042        live_kit_connection_info,
1043    })?;
1044
1045    update_user_contacts(session.user_id, &session).await?;
1046    Ok(())
1047}
1048
1049async fn rejoin_room(
1050    request: proto::RejoinRoom,
1051    response: Response<proto::RejoinRoom>,
1052    session: Session,
1053) -> Result<()> {
1054    let room;
1055    let channel_id;
1056    let channel_members;
1057    {
1058        let mut rejoined_room = session
1059            .db()
1060            .await
1061            .rejoin_room(request, session.user_id, session.connection_id)
1062            .await?;
1063
1064        response.send(proto::RejoinRoomResponse {
1065            room: Some(rejoined_room.room.clone()),
1066            reshared_projects: rejoined_room
1067                .reshared_projects
1068                .iter()
1069                .map(|project| proto::ResharedProject {
1070                    id: project.id.to_proto(),
1071                    collaborators: project
1072                        .collaborators
1073                        .iter()
1074                        .map(|collaborator| collaborator.to_proto())
1075                        .collect(),
1076                })
1077                .collect(),
1078            rejoined_projects: rejoined_room
1079                .rejoined_projects
1080                .iter()
1081                .map(|rejoined_project| proto::RejoinedProject {
1082                    id: rejoined_project.id.to_proto(),
1083                    worktrees: rejoined_project
1084                        .worktrees
1085                        .iter()
1086                        .map(|worktree| proto::WorktreeMetadata {
1087                            id: worktree.id,
1088                            root_name: worktree.root_name.clone(),
1089                            visible: worktree.visible,
1090                            abs_path: worktree.abs_path.clone(),
1091                        })
1092                        .collect(),
1093                    collaborators: rejoined_project
1094                        .collaborators
1095                        .iter()
1096                        .map(|collaborator| collaborator.to_proto())
1097                        .collect(),
1098                    language_servers: rejoined_project.language_servers.clone(),
1099                })
1100                .collect(),
1101        })?;
1102        room_updated(&rejoined_room.room, &session.peer);
1103
1104        for project in &rejoined_room.reshared_projects {
1105            for collaborator in &project.collaborators {
1106                session
1107                    .peer
1108                    .send(
1109                        collaborator.connection_id,
1110                        proto::UpdateProjectCollaborator {
1111                            project_id: project.id.to_proto(),
1112                            old_peer_id: Some(project.old_connection_id.into()),
1113                            new_peer_id: Some(session.connection_id.into()),
1114                        },
1115                    )
1116                    .trace_err();
1117            }
1118
1119            broadcast(
1120                Some(session.connection_id),
1121                project
1122                    .collaborators
1123                    .iter()
1124                    .map(|collaborator| collaborator.connection_id),
1125                |connection_id| {
1126                    session.peer.forward_send(
1127                        session.connection_id,
1128                        connection_id,
1129                        proto::UpdateProject {
1130                            project_id: project.id.to_proto(),
1131                            worktrees: project.worktrees.clone(),
1132                        },
1133                    )
1134                },
1135            );
1136        }
1137
1138        for project in &rejoined_room.rejoined_projects {
1139            for collaborator in &project.collaborators {
1140                session
1141                    .peer
1142                    .send(
1143                        collaborator.connection_id,
1144                        proto::UpdateProjectCollaborator {
1145                            project_id: project.id.to_proto(),
1146                            old_peer_id: Some(project.old_connection_id.into()),
1147                            new_peer_id: Some(session.connection_id.into()),
1148                        },
1149                    )
1150                    .trace_err();
1151            }
1152        }
1153
1154        for project in &mut rejoined_room.rejoined_projects {
1155            for worktree in mem::take(&mut project.worktrees) {
1156                #[cfg(any(test, feature = "test-support"))]
1157                const MAX_CHUNK_SIZE: usize = 2;
1158                #[cfg(not(any(test, feature = "test-support")))]
1159                const MAX_CHUNK_SIZE: usize = 256;
1160
1161                // Stream this worktree's entries.
1162                let message = proto::UpdateWorktree {
1163                    project_id: project.id.to_proto(),
1164                    worktree_id: worktree.id,
1165                    abs_path: worktree.abs_path.clone(),
1166                    root_name: worktree.root_name,
1167                    updated_entries: worktree.updated_entries,
1168                    removed_entries: worktree.removed_entries,
1169                    scan_id: worktree.scan_id,
1170                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1171                    updated_repositories: worktree.updated_repositories,
1172                    removed_repositories: worktree.removed_repositories,
1173                };
1174                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1175                    session.peer.send(session.connection_id, update.clone())?;
1176                }
1177
1178                // Stream this worktree's diagnostics.
1179                for summary in worktree.diagnostic_summaries {
1180                    session.peer.send(
1181                        session.connection_id,
1182                        proto::UpdateDiagnosticSummary {
1183                            project_id: project.id.to_proto(),
1184                            worktree_id: worktree.id,
1185                            summary: Some(summary),
1186                        },
1187                    )?;
1188                }
1189
1190                for settings_file in worktree.settings_files {
1191                    session.peer.send(
1192                        session.connection_id,
1193                        proto::UpdateWorktreeSettings {
1194                            project_id: project.id.to_proto(),
1195                            worktree_id: worktree.id,
1196                            path: settings_file.path,
1197                            content: Some(settings_file.content),
1198                        },
1199                    )?;
1200                }
1201            }
1202
1203            for language_server in &project.language_servers {
1204                session.peer.send(
1205                    session.connection_id,
1206                    proto::UpdateLanguageServer {
1207                        project_id: project.id.to_proto(),
1208                        language_server_id: language_server.id,
1209                        variant: Some(
1210                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1211                                proto::LspDiskBasedDiagnosticsUpdated {},
1212                            ),
1213                        ),
1214                    },
1215                )?;
1216            }
1217        }
1218
1219        let rejoined_room = rejoined_room.into_inner();
1220
1221        room = rejoined_room.room;
1222        channel_id = rejoined_room.channel_id;
1223        channel_members = rejoined_room.channel_members;
1224    }
1225
1226    if let Some(channel_id) = channel_id {
1227        channel_updated(
1228            channel_id,
1229            &room,
1230            &channel_members,
1231            &session.peer,
1232            &*session.connection_pool().await,
1233        );
1234    }
1235
1236    update_user_contacts(session.user_id, &session).await?;
1237    Ok(())
1238}
1239
1240async fn leave_room(
1241    _: proto::LeaveRoom,
1242    response: Response<proto::LeaveRoom>,
1243    session: Session,
1244) -> Result<()> {
1245    leave_room_for_session(&session).await?;
1246    response.send(proto::Ack {})?;
1247    Ok(())
1248}
1249
1250async fn call(
1251    request: proto::Call,
1252    response: Response<proto::Call>,
1253    session: Session,
1254) -> Result<()> {
1255    let room_id = RoomId::from_proto(request.room_id);
1256    let calling_user_id = session.user_id;
1257    let calling_connection_id = session.connection_id;
1258    let called_user_id = UserId::from_proto(request.called_user_id);
1259    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1260    if !session
1261        .db()
1262        .await
1263        .has_contact(calling_user_id, called_user_id)
1264        .await?
1265    {
1266        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1267    }
1268
1269    let incoming_call = {
1270        let (room, incoming_call) = &mut *session
1271            .db()
1272            .await
1273            .call(
1274                room_id,
1275                calling_user_id,
1276                calling_connection_id,
1277                called_user_id,
1278                initial_project_id,
1279            )
1280            .await?;
1281        room_updated(&room, &session.peer);
1282        mem::take(incoming_call)
1283    };
1284    update_user_contacts(called_user_id, &session).await?;
1285
1286    let mut calls = session
1287        .connection_pool()
1288        .await
1289        .user_connection_ids(called_user_id)
1290        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1291        .collect::<FuturesUnordered<_>>();
1292
1293    while let Some(call_response) = calls.next().await {
1294        match call_response.as_ref() {
1295            Ok(_) => {
1296                response.send(proto::Ack {})?;
1297                return Ok(());
1298            }
1299            Err(_) => {
1300                call_response.trace_err();
1301            }
1302        }
1303    }
1304
1305    {
1306        let room = session
1307            .db()
1308            .await
1309            .call_failed(room_id, called_user_id)
1310            .await?;
1311        room_updated(&room, &session.peer);
1312    }
1313    update_user_contacts(called_user_id, &session).await?;
1314
1315    Err(anyhow!("failed to ring user"))?
1316}
1317
1318async fn cancel_call(
1319    request: proto::CancelCall,
1320    response: Response<proto::CancelCall>,
1321    session: Session,
1322) -> Result<()> {
1323    let called_user_id = UserId::from_proto(request.called_user_id);
1324    let room_id = RoomId::from_proto(request.room_id);
1325    {
1326        let room = session
1327            .db()
1328            .await
1329            .cancel_call(room_id, session.connection_id, called_user_id)
1330            .await?;
1331        room_updated(&room, &session.peer);
1332    }
1333
1334    for connection_id in session
1335        .connection_pool()
1336        .await
1337        .user_connection_ids(called_user_id)
1338    {
1339        session
1340            .peer
1341            .send(
1342                connection_id,
1343                proto::CallCanceled {
1344                    room_id: room_id.to_proto(),
1345                },
1346            )
1347            .trace_err();
1348    }
1349    response.send(proto::Ack {})?;
1350
1351    update_user_contacts(called_user_id, &session).await?;
1352    Ok(())
1353}
1354
1355async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1356    let room_id = RoomId::from_proto(message.room_id);
1357    {
1358        let room = session
1359            .db()
1360            .await
1361            .decline_call(Some(room_id), session.user_id)
1362            .await?
1363            .ok_or_else(|| anyhow!("failed to decline call"))?;
1364        room_updated(&room, &session.peer);
1365    }
1366
1367    for connection_id in session
1368        .connection_pool()
1369        .await
1370        .user_connection_ids(session.user_id)
1371    {
1372        session
1373            .peer
1374            .send(
1375                connection_id,
1376                proto::CallCanceled {
1377                    room_id: room_id.to_proto(),
1378                },
1379            )
1380            .trace_err();
1381    }
1382    update_user_contacts(session.user_id, &session).await?;
1383    Ok(())
1384}
1385
1386async fn update_participant_location(
1387    request: proto::UpdateParticipantLocation,
1388    response: Response<proto::UpdateParticipantLocation>,
1389    session: Session,
1390) -> Result<()> {
1391    let room_id = RoomId::from_proto(request.room_id);
1392    let location = request
1393        .location
1394        .ok_or_else(|| anyhow!("invalid location"))?;
1395
1396    let db = session.db().await;
1397    let room = db
1398        .update_room_participant_location(room_id, session.connection_id, location)
1399        .await?;
1400
1401    room_updated(&room, &session.peer);
1402    response.send(proto::Ack {})?;
1403    Ok(())
1404}
1405
1406async fn share_project(
1407    request: proto::ShareProject,
1408    response: Response<proto::ShareProject>,
1409    session: Session,
1410) -> Result<()> {
1411    let (project_id, room) = &*session
1412        .db()
1413        .await
1414        .share_project(
1415            RoomId::from_proto(request.room_id),
1416            session.connection_id,
1417            &request.worktrees,
1418        )
1419        .await?;
1420    response.send(proto::ShareProjectResponse {
1421        project_id: project_id.to_proto(),
1422    })?;
1423    room_updated(&room, &session.peer);
1424
1425    Ok(())
1426}
1427
1428async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1429    let project_id = ProjectId::from_proto(message.project_id);
1430
1431    let (room, guest_connection_ids) = &*session
1432        .db()
1433        .await
1434        .unshare_project(project_id, session.connection_id)
1435        .await?;
1436
1437    broadcast(
1438        Some(session.connection_id),
1439        guest_connection_ids.iter().copied(),
1440        |conn_id| session.peer.send(conn_id, message.clone()),
1441    );
1442    room_updated(&room, &session.peer);
1443
1444    Ok(())
1445}
1446
1447async fn join_project(
1448    request: proto::JoinProject,
1449    response: Response<proto::JoinProject>,
1450    session: Session,
1451) -> Result<()> {
1452    let project_id = ProjectId::from_proto(request.project_id);
1453    let guest_user_id = session.user_id;
1454
1455    tracing::info!(%project_id, "join project");
1456
1457    let (project, replica_id) = &mut *session
1458        .db()
1459        .await
1460        .join_project(project_id, session.connection_id)
1461        .await?;
1462
1463    let collaborators = project
1464        .collaborators
1465        .iter()
1466        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1467        .map(|collaborator| collaborator.to_proto())
1468        .collect::<Vec<_>>();
1469
1470    let worktrees = project
1471        .worktrees
1472        .iter()
1473        .map(|(id, worktree)| proto::WorktreeMetadata {
1474            id: *id,
1475            root_name: worktree.root_name.clone(),
1476            visible: worktree.visible,
1477            abs_path: worktree.abs_path.clone(),
1478        })
1479        .collect::<Vec<_>>();
1480
1481    for collaborator in &collaborators {
1482        session
1483            .peer
1484            .send(
1485                collaborator.peer_id.unwrap().into(),
1486                proto::AddProjectCollaborator {
1487                    project_id: project_id.to_proto(),
1488                    collaborator: Some(proto::Collaborator {
1489                        peer_id: Some(session.connection_id.into()),
1490                        replica_id: replica_id.0 as u32,
1491                        user_id: guest_user_id.to_proto(),
1492                    }),
1493                },
1494            )
1495            .trace_err();
1496    }
1497
1498    // First, we send the metadata associated with each worktree.
1499    response.send(proto::JoinProjectResponse {
1500        worktrees: worktrees.clone(),
1501        replica_id: replica_id.0 as u32,
1502        collaborators: collaborators.clone(),
1503        language_servers: project.language_servers.clone(),
1504    })?;
1505
1506    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1507        #[cfg(any(test, feature = "test-support"))]
1508        const MAX_CHUNK_SIZE: usize = 2;
1509        #[cfg(not(any(test, feature = "test-support")))]
1510        const MAX_CHUNK_SIZE: usize = 256;
1511
1512        // Stream this worktree's entries.
1513        let message = proto::UpdateWorktree {
1514            project_id: project_id.to_proto(),
1515            worktree_id,
1516            abs_path: worktree.abs_path.clone(),
1517            root_name: worktree.root_name,
1518            updated_entries: worktree.entries,
1519            removed_entries: Default::default(),
1520            scan_id: worktree.scan_id,
1521            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1522            updated_repositories: worktree.repository_entries.into_values().collect(),
1523            removed_repositories: Default::default(),
1524        };
1525        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1526            session.peer.send(session.connection_id, update.clone())?;
1527        }
1528
1529        // Stream this worktree's diagnostics.
1530        for summary in worktree.diagnostic_summaries {
1531            session.peer.send(
1532                session.connection_id,
1533                proto::UpdateDiagnosticSummary {
1534                    project_id: project_id.to_proto(),
1535                    worktree_id: worktree.id,
1536                    summary: Some(summary),
1537                },
1538            )?;
1539        }
1540
1541        for settings_file in worktree.settings_files {
1542            session.peer.send(
1543                session.connection_id,
1544                proto::UpdateWorktreeSettings {
1545                    project_id: project_id.to_proto(),
1546                    worktree_id: worktree.id,
1547                    path: settings_file.path,
1548                    content: Some(settings_file.content),
1549                },
1550            )?;
1551        }
1552    }
1553
1554    for language_server in &project.language_servers {
1555        session.peer.send(
1556            session.connection_id,
1557            proto::UpdateLanguageServer {
1558                project_id: project_id.to_proto(),
1559                language_server_id: language_server.id,
1560                variant: Some(
1561                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1562                        proto::LspDiskBasedDiagnosticsUpdated {},
1563                    ),
1564                ),
1565            },
1566        )?;
1567    }
1568
1569    Ok(())
1570}
1571
1572async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1573    let sender_id = session.connection_id;
1574    let project_id = ProjectId::from_proto(request.project_id);
1575
1576    let (room, project) = &*session
1577        .db()
1578        .await
1579        .leave_project(project_id, sender_id)
1580        .await?;
1581    tracing::info!(
1582        %project_id,
1583        host_user_id = %project.host_user_id,
1584        host_connection_id = %project.host_connection_id,
1585        "leave project"
1586    );
1587
1588    project_left(&project, &session);
1589    room_updated(&room, &session.peer);
1590
1591    Ok(())
1592}
1593
1594async fn update_project(
1595    request: proto::UpdateProject,
1596    response: Response<proto::UpdateProject>,
1597    session: Session,
1598) -> Result<()> {
1599    let project_id = ProjectId::from_proto(request.project_id);
1600    let (room, guest_connection_ids) = &*session
1601        .db()
1602        .await
1603        .update_project(project_id, session.connection_id, &request.worktrees)
1604        .await?;
1605    broadcast(
1606        Some(session.connection_id),
1607        guest_connection_ids.iter().copied(),
1608        |connection_id| {
1609            session
1610                .peer
1611                .forward_send(session.connection_id, connection_id, request.clone())
1612        },
1613    );
1614    room_updated(&room, &session.peer);
1615    response.send(proto::Ack {})?;
1616
1617    Ok(())
1618}
1619
1620async fn update_worktree(
1621    request: proto::UpdateWorktree,
1622    response: Response<proto::UpdateWorktree>,
1623    session: Session,
1624) -> Result<()> {
1625    let guest_connection_ids = session
1626        .db()
1627        .await
1628        .update_worktree(&request, session.connection_id)
1629        .await?;
1630
1631    broadcast(
1632        Some(session.connection_id),
1633        guest_connection_ids.iter().copied(),
1634        |connection_id| {
1635            session
1636                .peer
1637                .forward_send(session.connection_id, connection_id, request.clone())
1638        },
1639    );
1640    response.send(proto::Ack {})?;
1641    Ok(())
1642}
1643
1644async fn update_diagnostic_summary(
1645    message: proto::UpdateDiagnosticSummary,
1646    session: Session,
1647) -> Result<()> {
1648    let guest_connection_ids = session
1649        .db()
1650        .await
1651        .update_diagnostic_summary(&message, session.connection_id)
1652        .await?;
1653
1654    broadcast(
1655        Some(session.connection_id),
1656        guest_connection_ids.iter().copied(),
1657        |connection_id| {
1658            session
1659                .peer
1660                .forward_send(session.connection_id, connection_id, message.clone())
1661        },
1662    );
1663
1664    Ok(())
1665}
1666
1667async fn update_worktree_settings(
1668    message: proto::UpdateWorktreeSettings,
1669    session: Session,
1670) -> Result<()> {
1671    let guest_connection_ids = session
1672        .db()
1673        .await
1674        .update_worktree_settings(&message, session.connection_id)
1675        .await?;
1676
1677    broadcast(
1678        Some(session.connection_id),
1679        guest_connection_ids.iter().copied(),
1680        |connection_id| {
1681            session
1682                .peer
1683                .forward_send(session.connection_id, connection_id, message.clone())
1684        },
1685    );
1686
1687    Ok(())
1688}
1689
1690async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1691    broadcast_project_message(request.project_id, request, session).await
1692}
1693
1694async fn start_language_server(
1695    request: proto::StartLanguageServer,
1696    session: Session,
1697) -> Result<()> {
1698    let guest_connection_ids = session
1699        .db()
1700        .await
1701        .start_language_server(&request, session.connection_id)
1702        .await?;
1703
1704    broadcast(
1705        Some(session.connection_id),
1706        guest_connection_ids.iter().copied(),
1707        |connection_id| {
1708            session
1709                .peer
1710                .forward_send(session.connection_id, connection_id, request.clone())
1711        },
1712    );
1713    Ok(())
1714}
1715
1716async fn update_language_server(
1717    request: proto::UpdateLanguageServer,
1718    session: Session,
1719) -> Result<()> {
1720    session.executor.record_backtrace();
1721    let project_id = ProjectId::from_proto(request.project_id);
1722    let project_connection_ids = session
1723        .db()
1724        .await
1725        .project_connection_ids(project_id, session.connection_id)
1726        .await?;
1727    broadcast(
1728        Some(session.connection_id),
1729        project_connection_ids.iter().copied(),
1730        |connection_id| {
1731            session
1732                .peer
1733                .forward_send(session.connection_id, connection_id, request.clone())
1734        },
1735    );
1736    Ok(())
1737}
1738
1739async fn forward_project_request<T>(
1740    request: T,
1741    response: Response<T>,
1742    session: Session,
1743) -> Result<()>
1744where
1745    T: EntityMessage + RequestMessage,
1746{
1747    session.executor.record_backtrace();
1748    let project_id = ProjectId::from_proto(request.remote_entity_id());
1749    let host_connection_id = {
1750        let collaborators = session
1751            .db()
1752            .await
1753            .project_collaborators(project_id, session.connection_id)
1754            .await?;
1755        collaborators
1756            .iter()
1757            .find(|collaborator| collaborator.is_host)
1758            .ok_or_else(|| anyhow!("host not found"))?
1759            .connection_id
1760    };
1761
1762    let payload = session
1763        .peer
1764        .forward_request(session.connection_id, host_connection_id, request)
1765        .await?;
1766
1767    response.send(payload)?;
1768    Ok(())
1769}
1770
1771async fn create_buffer_for_peer(
1772    request: proto::CreateBufferForPeer,
1773    session: Session,
1774) -> Result<()> {
1775    session.executor.record_backtrace();
1776    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1777    session
1778        .peer
1779        .forward_send(session.connection_id, peer_id.into(), request)?;
1780    Ok(())
1781}
1782
1783async fn update_buffer(
1784    request: proto::UpdateBuffer,
1785    response: Response<proto::UpdateBuffer>,
1786    session: Session,
1787) -> Result<()> {
1788    session.executor.record_backtrace();
1789    let project_id = ProjectId::from_proto(request.project_id);
1790    let mut guest_connection_ids;
1791    let mut host_connection_id = None;
1792    {
1793        let collaborators = session
1794            .db()
1795            .await
1796            .project_collaborators(project_id, session.connection_id)
1797            .await?;
1798        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1799        for collaborator in collaborators.iter() {
1800            if collaborator.is_host {
1801                host_connection_id = Some(collaborator.connection_id);
1802            } else {
1803                guest_connection_ids.push(collaborator.connection_id);
1804            }
1805        }
1806    }
1807    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1808
1809    session.executor.record_backtrace();
1810    broadcast(
1811        Some(session.connection_id),
1812        guest_connection_ids,
1813        |connection_id| {
1814            session
1815                .peer
1816                .forward_send(session.connection_id, connection_id, request.clone())
1817        },
1818    );
1819    if host_connection_id != session.connection_id {
1820        session
1821            .peer
1822            .forward_request(session.connection_id, host_connection_id, request.clone())
1823            .await?;
1824    }
1825
1826    response.send(proto::Ack {})?;
1827    Ok(())
1828}
1829
1830async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1831    let project_id = ProjectId::from_proto(request.project_id);
1832    let project_connection_ids = session
1833        .db()
1834        .await
1835        .project_connection_ids(project_id, session.connection_id)
1836        .await?;
1837
1838    broadcast(
1839        Some(session.connection_id),
1840        project_connection_ids.iter().copied(),
1841        |connection_id| {
1842            session
1843                .peer
1844                .forward_send(session.connection_id, connection_id, request.clone())
1845        },
1846    );
1847    Ok(())
1848}
1849
1850async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1851    let project_id = ProjectId::from_proto(request.project_id);
1852    let project_connection_ids = session
1853        .db()
1854        .await
1855        .project_connection_ids(project_id, session.connection_id)
1856        .await?;
1857    broadcast(
1858        Some(session.connection_id),
1859        project_connection_ids.iter().copied(),
1860        |connection_id| {
1861            session
1862                .peer
1863                .forward_send(session.connection_id, connection_id, request.clone())
1864        },
1865    );
1866    Ok(())
1867}
1868
1869async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1870    broadcast_project_message(request.project_id, request, session).await
1871}
1872
1873async fn broadcast_project_message<T: EnvelopedMessage>(
1874    project_id: u64,
1875    request: T,
1876    session: Session,
1877) -> Result<()> {
1878    let project_id = ProjectId::from_proto(project_id);
1879    let project_connection_ids = session
1880        .db()
1881        .await
1882        .project_connection_ids(project_id, session.connection_id)
1883        .await?;
1884    broadcast(
1885        Some(session.connection_id),
1886        project_connection_ids.iter().copied(),
1887        |connection_id| {
1888            session
1889                .peer
1890                .forward_send(session.connection_id, connection_id, request.clone())
1891        },
1892    );
1893    Ok(())
1894}
1895
1896async fn follow(
1897    request: proto::Follow,
1898    response: Response<proto::Follow>,
1899    session: Session,
1900) -> Result<()> {
1901    let room_id = RoomId::from_proto(request.room_id);
1902    let project_id = request.project_id.map(ProjectId::from_proto);
1903    let leader_id = request
1904        .leader_id
1905        .ok_or_else(|| anyhow!("invalid leader id"))?
1906        .into();
1907    let follower_id = session.connection_id;
1908
1909    session
1910        .db()
1911        .await
1912        .check_room_participants(room_id, leader_id, session.connection_id)
1913        .await?;
1914
1915    let response_payload = session
1916        .peer
1917        .forward_request(session.connection_id, leader_id, request)
1918        .await?;
1919    response.send(response_payload)?;
1920
1921    if let Some(project_id) = project_id {
1922        let room = session
1923            .db()
1924            .await
1925            .follow(room_id, project_id, leader_id, follower_id)
1926            .await?;
1927        room_updated(&room, &session.peer);
1928    }
1929
1930    Ok(())
1931}
1932
1933async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1934    let room_id = RoomId::from_proto(request.room_id);
1935    let project_id = request.project_id.map(ProjectId::from_proto);
1936    let leader_id = request
1937        .leader_id
1938        .ok_or_else(|| anyhow!("invalid leader id"))?
1939        .into();
1940    let follower_id = session.connection_id;
1941
1942    session
1943        .db()
1944        .await
1945        .check_room_participants(room_id, leader_id, session.connection_id)
1946        .await?;
1947
1948    session
1949        .peer
1950        .forward_send(session.connection_id, leader_id, request)?;
1951
1952    if let Some(project_id) = project_id {
1953        let room = session
1954            .db()
1955            .await
1956            .unfollow(room_id, project_id, leader_id, follower_id)
1957            .await?;
1958        room_updated(&room, &session.peer);
1959    }
1960
1961    Ok(())
1962}
1963
1964async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1965    let room_id = RoomId::from_proto(request.room_id);
1966    let database = session.db.lock().await;
1967
1968    let connection_ids = if let Some(project_id) = request.project_id {
1969        let project_id = ProjectId::from_proto(project_id);
1970        database
1971            .project_connection_ids(project_id, session.connection_id)
1972            .await?
1973    } else {
1974        database
1975            .room_connection_ids(room_id, session.connection_id)
1976            .await?
1977    };
1978
1979    // For now, don't send view update messages back to that view's current leader.
1980    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1981        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1982        _ => None,
1983    });
1984
1985    for follower_peer_id in request.follower_ids.iter().copied() {
1986        let follower_connection_id = follower_peer_id.into();
1987        if Some(follower_peer_id) != connection_id_to_omit
1988            && connection_ids.contains(&follower_connection_id)
1989        {
1990            session.peer.forward_send(
1991                session.connection_id,
1992                follower_connection_id,
1993                request.clone(),
1994            )?;
1995        }
1996    }
1997    Ok(())
1998}
1999
2000async fn get_users(
2001    request: proto::GetUsers,
2002    response: Response<proto::GetUsers>,
2003    session: Session,
2004) -> Result<()> {
2005    let user_ids = request
2006        .user_ids
2007        .into_iter()
2008        .map(UserId::from_proto)
2009        .collect();
2010    let users = session
2011        .db()
2012        .await
2013        .get_users_by_ids(user_ids)
2014        .await?
2015        .into_iter()
2016        .map(|user| proto::User {
2017            id: user.id.to_proto(),
2018            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2019            github_login: user.github_login,
2020        })
2021        .collect();
2022    response.send(proto::UsersResponse { users })?;
2023    Ok(())
2024}
2025
2026async fn fuzzy_search_users(
2027    request: proto::FuzzySearchUsers,
2028    response: Response<proto::FuzzySearchUsers>,
2029    session: Session,
2030) -> Result<()> {
2031    let query = request.query;
2032    let users = match query.len() {
2033        0 => vec![],
2034        1 | 2 => session
2035            .db()
2036            .await
2037            .get_user_by_github_login(&query)
2038            .await?
2039            .into_iter()
2040            .collect(),
2041        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2042    };
2043    let users = users
2044        .into_iter()
2045        .filter(|user| user.id != session.user_id)
2046        .map(|user| proto::User {
2047            id: user.id.to_proto(),
2048            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2049            github_login: user.github_login,
2050        })
2051        .collect();
2052    response.send(proto::UsersResponse { users })?;
2053    Ok(())
2054}
2055
2056async fn request_contact(
2057    request: proto::RequestContact,
2058    response: Response<proto::RequestContact>,
2059    session: Session,
2060) -> Result<()> {
2061    let requester_id = session.user_id;
2062    let responder_id = UserId::from_proto(request.responder_id);
2063    if requester_id == responder_id {
2064        return Err(anyhow!("cannot add yourself as a contact"))?;
2065    }
2066
2067    session
2068        .db()
2069        .await
2070        .send_contact_request(requester_id, responder_id)
2071        .await?;
2072
2073    // Update outgoing contact requests of requester
2074    let mut update = proto::UpdateContacts::default();
2075    update.outgoing_requests.push(responder_id.to_proto());
2076    for connection_id in session
2077        .connection_pool()
2078        .await
2079        .user_connection_ids(requester_id)
2080    {
2081        session.peer.send(connection_id, update.clone())?;
2082    }
2083
2084    // Update incoming contact requests of responder
2085    let mut update = proto::UpdateContacts::default();
2086    update
2087        .incoming_requests
2088        .push(proto::IncomingContactRequest {
2089            requester_id: requester_id.to_proto(),
2090            should_notify: true,
2091        });
2092    for connection_id in session
2093        .connection_pool()
2094        .await
2095        .user_connection_ids(responder_id)
2096    {
2097        session.peer.send(connection_id, update.clone())?;
2098    }
2099
2100    response.send(proto::Ack {})?;
2101    Ok(())
2102}
2103
2104async fn respond_to_contact_request(
2105    request: proto::RespondToContactRequest,
2106    response: Response<proto::RespondToContactRequest>,
2107    session: Session,
2108) -> Result<()> {
2109    let responder_id = session.user_id;
2110    let requester_id = UserId::from_proto(request.requester_id);
2111    let db = session.db().await;
2112    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2113        db.dismiss_contact_notification(responder_id, requester_id)
2114            .await?;
2115    } else {
2116        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2117
2118        db.respond_to_contact_request(responder_id, requester_id, accept)
2119            .await?;
2120        let requester_busy = db.is_user_busy(requester_id).await?;
2121        let responder_busy = db.is_user_busy(responder_id).await?;
2122
2123        let pool = session.connection_pool().await;
2124        // Update responder with new contact
2125        let mut update = proto::UpdateContacts::default();
2126        if accept {
2127            update
2128                .contacts
2129                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2130        }
2131        update
2132            .remove_incoming_requests
2133            .push(requester_id.to_proto());
2134        for connection_id in pool.user_connection_ids(responder_id) {
2135            session.peer.send(connection_id, update.clone())?;
2136        }
2137
2138        // Update requester with new contact
2139        let mut update = proto::UpdateContacts::default();
2140        if accept {
2141            update
2142                .contacts
2143                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2144        }
2145        update
2146            .remove_outgoing_requests
2147            .push(responder_id.to_proto());
2148        for connection_id in pool.user_connection_ids(requester_id) {
2149            session.peer.send(connection_id, update.clone())?;
2150        }
2151    }
2152
2153    response.send(proto::Ack {})?;
2154    Ok(())
2155}
2156
2157async fn remove_contact(
2158    request: proto::RemoveContact,
2159    response: Response<proto::RemoveContact>,
2160    session: Session,
2161) -> Result<()> {
2162    let requester_id = session.user_id;
2163    let responder_id = UserId::from_proto(request.user_id);
2164    let db = session.db().await;
2165    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2166
2167    let pool = session.connection_pool().await;
2168    // Update outgoing contact requests of requester
2169    let mut update = proto::UpdateContacts::default();
2170    if contact_accepted {
2171        update.remove_contacts.push(responder_id.to_proto());
2172    } else {
2173        update
2174            .remove_outgoing_requests
2175            .push(responder_id.to_proto());
2176    }
2177    for connection_id in pool.user_connection_ids(requester_id) {
2178        session.peer.send(connection_id, update.clone())?;
2179    }
2180
2181    // Update incoming contact requests of responder
2182    let mut update = proto::UpdateContacts::default();
2183    if contact_accepted {
2184        update.remove_contacts.push(requester_id.to_proto());
2185    } else {
2186        update
2187            .remove_incoming_requests
2188            .push(requester_id.to_proto());
2189    }
2190    for connection_id in pool.user_connection_ids(responder_id) {
2191        session.peer.send(connection_id, update.clone())?;
2192    }
2193
2194    response.send(proto::Ack {})?;
2195    Ok(())
2196}
2197
2198async fn create_channel(
2199    request: proto::CreateChannel,
2200    response: Response<proto::CreateChannel>,
2201    session: Session,
2202) -> Result<()> {
2203    let db = session.db().await;
2204
2205    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2206    let CreateChannelResult {
2207        channel,
2208        participants_to_update,
2209    } = db
2210        .create_channel(&request.name, parent_id, session.user_id)
2211        .await?;
2212
2213    response.send(proto::CreateChannelResponse {
2214        channel: Some(channel.to_proto()),
2215        parent_id: request.parent_id,
2216    })?;
2217
2218    let connection_pool = session.connection_pool().await;
2219    for (user_id, channels) in participants_to_update {
2220        let update = build_channels_update(channels, vec![]);
2221        for connection_id in connection_pool.user_connection_ids(user_id) {
2222            if user_id == session.user_id {
2223                continue;
2224            }
2225            session.peer.send(connection_id, update.clone())?;
2226        }
2227    }
2228
2229    Ok(())
2230}
2231
2232async fn delete_channel(
2233    request: proto::DeleteChannel,
2234    response: Response<proto::DeleteChannel>,
2235    session: Session,
2236) -> Result<()> {
2237    let db = session.db().await;
2238
2239    let channel_id = request.channel_id;
2240    let (removed_channels, member_ids) = db
2241        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2242        .await?;
2243    response.send(proto::Ack {})?;
2244
2245    // Notify members of removed channels
2246    let mut update = proto::UpdateChannels::default();
2247    update
2248        .delete_channels
2249        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2250
2251    let connection_pool = session.connection_pool().await;
2252    for member_id in member_ids {
2253        for connection_id in connection_pool.user_connection_ids(member_id) {
2254            session.peer.send(connection_id, update.clone())?;
2255        }
2256    }
2257
2258    Ok(())
2259}
2260
2261async fn invite_channel_member(
2262    request: proto::InviteChannelMember,
2263    response: Response<proto::InviteChannelMember>,
2264    session: Session,
2265) -> Result<()> {
2266    let db = session.db().await;
2267    let channel_id = ChannelId::from_proto(request.channel_id);
2268    let invitee_id = UserId::from_proto(request.user_id);
2269    db.invite_channel_member(
2270        channel_id,
2271        invitee_id,
2272        session.user_id,
2273        request.role().into(),
2274    )
2275    .await?;
2276
2277    let channel = db.get_channel(channel_id, session.user_id).await?;
2278
2279    let mut update = proto::UpdateChannels::default();
2280    update.channel_invitations.push(proto::Channel {
2281        id: channel.id.to_proto(),
2282        visibility: channel.visibility.into(),
2283        name: channel.name,
2284        role: request.role().into(),
2285    });
2286    for connection_id in session
2287        .connection_pool()
2288        .await
2289        .user_connection_ids(invitee_id)
2290    {
2291        session.peer.send(connection_id, update.clone())?;
2292    }
2293
2294    response.send(proto::Ack {})?;
2295    Ok(())
2296}
2297
2298async fn remove_channel_member(
2299    request: proto::RemoveChannelMember,
2300    response: Response<proto::RemoveChannelMember>,
2301    session: Session,
2302) -> Result<()> {
2303    let db = session.db().await;
2304    let channel_id = ChannelId::from_proto(request.channel_id);
2305    let member_id = UserId::from_proto(request.user_id);
2306
2307    db.remove_channel_member(channel_id, member_id, session.user_id)
2308        .await?;
2309
2310    let mut update = proto::UpdateChannels::default();
2311    update.delete_channels.push(channel_id.to_proto());
2312
2313    for connection_id in session
2314        .connection_pool()
2315        .await
2316        .user_connection_ids(member_id)
2317    {
2318        session.peer.send(connection_id, update.clone())?;
2319    }
2320
2321    response.send(proto::Ack {})?;
2322    Ok(())
2323}
2324
2325async fn set_channel_visibility(
2326    request: proto::SetChannelVisibility,
2327    response: Response<proto::SetChannelVisibility>,
2328    session: Session,
2329) -> Result<()> {
2330    let db = session.db().await;
2331    let channel_id = ChannelId::from_proto(request.channel_id);
2332    let visibility = request.visibility().into();
2333
2334    let SetChannelVisibilityResult {
2335        participants_to_update,
2336        participants_to_remove,
2337    } = db
2338        .set_channel_visibility(channel_id, visibility, session.user_id)
2339        .await?;
2340
2341    let connection_pool = session.connection_pool().await;
2342    for (user_id, channels) in participants_to_update {
2343        let update = build_channels_update(channels, vec![]);
2344        for connection_id in connection_pool.user_connection_ids(user_id) {
2345            session.peer.send(connection_id, update.clone())?;
2346        }
2347    }
2348    for user_id in participants_to_remove {
2349        let update = proto::UpdateChannels {
2350            delete_channels: vec![channel_id.to_proto()],
2351            ..Default::default()
2352        };
2353        for connection_id in connection_pool.user_connection_ids(user_id) {
2354            session.peer.send(connection_id, update.clone())?;
2355        }
2356    }
2357
2358    response.send(proto::Ack {})?;
2359    Ok(())
2360}
2361
2362async fn set_channel_member_role(
2363    request: proto::SetChannelMemberRole,
2364    response: Response<proto::SetChannelMemberRole>,
2365    session: Session,
2366) -> Result<()> {
2367    let db = session.db().await;
2368    let channel_id = ChannelId::from_proto(request.channel_id);
2369    let member_id = UserId::from_proto(request.user_id);
2370    let channel_member = db
2371        .set_channel_member_role(
2372            channel_id,
2373            session.user_id,
2374            member_id,
2375            request.role().into(),
2376        )
2377        .await?;
2378
2379    let mut update = proto::UpdateChannels::default();
2380    if channel_member.accepted {
2381        let channels = db.get_channel_for_user(channel_id, member_id).await?;
2382        update = build_channels_update(channels, vec![]);
2383    } else {
2384        let channel = db.get_channel(channel_id, session.user_id).await?;
2385        update.channel_invitations.push(proto::Channel {
2386            id: channel_id.to_proto(),
2387            visibility: channel.visibility.into(),
2388            name: channel.name,
2389            role: request.role().into(),
2390        });
2391    }
2392
2393    for connection_id in session
2394        .connection_pool()
2395        .await
2396        .user_connection_ids(member_id)
2397    {
2398        session.peer.send(connection_id, update.clone())?;
2399    }
2400
2401    response.send(proto::Ack {})?;
2402    Ok(())
2403}
2404
2405async fn rename_channel(
2406    request: proto::RenameChannel,
2407    response: Response<proto::RenameChannel>,
2408    session: Session,
2409) -> Result<()> {
2410    let db = session.db().await;
2411    let channel_id = ChannelId::from_proto(request.channel_id);
2412    let RenameChannelResult {
2413        channel,
2414        participants_to_update,
2415    } = db
2416        .rename_channel(channel_id, session.user_id, &request.name)
2417        .await?;
2418
2419    response.send(proto::RenameChannelResponse {
2420        channel: Some(channel.to_proto()),
2421    })?;
2422
2423    let connection_pool = session.connection_pool().await;
2424    for (user_id, channel) in participants_to_update {
2425        for connection_id in connection_pool.user_connection_ids(user_id) {
2426            let update = proto::UpdateChannels {
2427                channels: vec![channel.to_proto()],
2428                ..Default::default()
2429            };
2430
2431            session.peer.send(connection_id, update.clone())?;
2432        }
2433    }
2434
2435    Ok(())
2436}
2437
2438// TODO: Implement in terms of symlinks
2439// Current behavior of this is more like 'Move root channel'
2440async fn link_channel(
2441    request: proto::LinkChannel,
2442    response: Response<proto::LinkChannel>,
2443    session: Session,
2444) -> Result<()> {
2445    let db = session.db().await;
2446    let channel_id = ChannelId::from_proto(request.channel_id);
2447    let to = ChannelId::from_proto(request.to);
2448
2449    let result = db
2450        .move_channel(channel_id, None, to, session.user_id)
2451        .await?;
2452    drop(db);
2453
2454    notify_channel_moved(result, session).await?;
2455
2456    response.send(Ack {})?;
2457
2458    Ok(())
2459}
2460
2461// TODO: Implement in terms of symlinks
2462async fn unlink_channel(
2463    _request: proto::UnlinkChannel,
2464    _response: Response<proto::UnlinkChannel>,
2465    _session: Session,
2466) -> Result<()> {
2467    Err(anyhow!("unimplemented").into())
2468}
2469
2470async fn move_channel(
2471    request: proto::MoveChannel,
2472    response: Response<proto::MoveChannel>,
2473    session: Session,
2474) -> Result<()> {
2475    let db = session.db().await;
2476    let channel_id = ChannelId::from_proto(request.channel_id);
2477    let from_parent = ChannelId::from_proto(request.from);
2478    let to = ChannelId::from_proto(request.to);
2479
2480    let result = db
2481        .move_channel(channel_id, Some(from_parent), to, session.user_id)
2482        .await?;
2483    drop(db);
2484
2485    notify_channel_moved(result, session).await?;
2486
2487    response.send(Ack {})?;
2488    Ok(())
2489}
2490
2491async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2492    let Some(MoveChannelResult {
2493        participants_to_remove,
2494        participants_to_update,
2495        moved_channels,
2496    }) = result
2497    else {
2498        return Ok(());
2499    };
2500    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2501
2502    let connection_pool = session.connection_pool().await;
2503    for (user_id, channels) in participants_to_update {
2504        let mut update = build_channels_update(channels, vec![]);
2505        update.delete_channels = moved_channels.clone();
2506        for connection_id in connection_pool.user_connection_ids(user_id) {
2507            session.peer.send(connection_id, update.clone())?;
2508        }
2509    }
2510
2511    for user_id in participants_to_remove {
2512        let update = proto::UpdateChannels {
2513            delete_channels: moved_channels.clone(),
2514            ..Default::default()
2515        };
2516        for connection_id in connection_pool.user_connection_ids(user_id) {
2517            session.peer.send(connection_id, update.clone())?;
2518        }
2519    }
2520    Ok(())
2521}
2522
2523async fn get_channel_members(
2524    request: proto::GetChannelMembers,
2525    response: Response<proto::GetChannelMembers>,
2526    session: Session,
2527) -> Result<()> {
2528    let db = session.db().await;
2529    let channel_id = ChannelId::from_proto(request.channel_id);
2530    let members = db
2531        .get_channel_participant_details(channel_id, session.user_id)
2532        .await?;
2533    response.send(proto::GetChannelMembersResponse { members })?;
2534    Ok(())
2535}
2536
2537async fn respond_to_channel_invite(
2538    request: proto::RespondToChannelInvite,
2539    response: Response<proto::RespondToChannelInvite>,
2540    session: Session,
2541) -> Result<()> {
2542    let db = session.db().await;
2543    let channel_id = ChannelId::from_proto(request.channel_id);
2544    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2545        .await?;
2546
2547    if request.accept {
2548        channel_membership_updated(db, channel_id, &session).await?;
2549    } else {
2550        let mut update = proto::UpdateChannels::default();
2551        update
2552            .remove_channel_invitations
2553            .push(channel_id.to_proto());
2554        session.peer.send(session.connection_id, update)?;
2555    }
2556    response.send(proto::Ack {})?;
2557
2558    Ok(())
2559}
2560
2561async fn channel_membership_updated(
2562    db: tokio::sync::MutexGuard<'_, DbHandle>,
2563    channel_id: ChannelId,
2564    session: &Session,
2565) -> Result<(), crate::Error> {
2566    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2567    let mut update = build_channels_update(result, vec![]);
2568    update
2569        .remove_channel_invitations
2570        .push(channel_id.to_proto());
2571
2572    session.peer.send(session.connection_id, update)?;
2573    Ok(())
2574}
2575
2576async fn join_channel(
2577    request: proto::JoinChannel,
2578    response: Response<proto::JoinChannel>,
2579    session: Session,
2580) -> Result<()> {
2581    let channel_id = ChannelId::from_proto(request.channel_id);
2582    join_channel_internal(channel_id, Box::new(response), session).await
2583}
2584
2585trait JoinChannelInternalResponse {
2586    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2587}
2588impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2589    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2590        Response::<proto::JoinChannel>::send(self, result)
2591    }
2592}
2593impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2594    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2595        Response::<proto::JoinRoom>::send(self, result)
2596    }
2597}
2598
2599async fn join_channel_internal(
2600    channel_id: ChannelId,
2601    response: Box<impl JoinChannelInternalResponse>,
2602    session: Session,
2603) -> Result<()> {
2604    let joined_room = {
2605        leave_room_for_session(&session).await?;
2606        let db = session.db().await;
2607
2608        let (joined_room, joined_channel) = db
2609            .join_channel(
2610                channel_id,
2611                session.user_id,
2612                session.connection_id,
2613                RELEASE_CHANNEL_NAME.as_str(),
2614            )
2615            .await?;
2616
2617        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2618            let token = live_kit
2619                .room_token(
2620                    &joined_room.room.live_kit_room,
2621                    &session.user_id.to_string(),
2622                )
2623                .trace_err()?;
2624
2625            Some(LiveKitConnectionInfo {
2626                server_url: live_kit.url().into(),
2627                token,
2628            })
2629        });
2630
2631        response.send(proto::JoinRoomResponse {
2632            room: Some(joined_room.room.clone()),
2633            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2634            live_kit_connection_info,
2635        })?;
2636
2637        if let Some(joined_channel) = joined_channel {
2638            channel_membership_updated(db, joined_channel, &session).await?
2639        }
2640
2641        room_updated(&joined_room.room, &session.peer);
2642
2643        joined_room
2644    };
2645
2646    channel_updated(
2647        channel_id,
2648        &joined_room.room,
2649        &joined_room.channel_members,
2650        &session.peer,
2651        &*session.connection_pool().await,
2652    );
2653
2654    update_user_contacts(session.user_id, &session).await?;
2655    Ok(())
2656}
2657
2658async fn join_channel_buffer(
2659    request: proto::JoinChannelBuffer,
2660    response: Response<proto::JoinChannelBuffer>,
2661    session: Session,
2662) -> Result<()> {
2663    let db = session.db().await;
2664    let channel_id = ChannelId::from_proto(request.channel_id);
2665
2666    let open_response = db
2667        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2668        .await?;
2669
2670    let collaborators = open_response.collaborators.clone();
2671    response.send(open_response)?;
2672
2673    let update = UpdateChannelBufferCollaborators {
2674        channel_id: channel_id.to_proto(),
2675        collaborators: collaborators.clone(),
2676    };
2677    channel_buffer_updated(
2678        session.connection_id,
2679        collaborators
2680            .iter()
2681            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2682        &update,
2683        &session.peer,
2684    );
2685
2686    Ok(())
2687}
2688
2689async fn update_channel_buffer(
2690    request: proto::UpdateChannelBuffer,
2691    session: Session,
2692) -> Result<()> {
2693    let db = session.db().await;
2694    let channel_id = ChannelId::from_proto(request.channel_id);
2695
2696    let (collaborators, non_collaborators, epoch, version) = db
2697        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2698        .await?;
2699
2700    channel_buffer_updated(
2701        session.connection_id,
2702        collaborators,
2703        &proto::UpdateChannelBuffer {
2704            channel_id: channel_id.to_proto(),
2705            operations: request.operations,
2706        },
2707        &session.peer,
2708    );
2709
2710    let pool = &*session.connection_pool().await;
2711
2712    broadcast(
2713        None,
2714        non_collaborators
2715            .iter()
2716            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2717        |peer_id| {
2718            session.peer.send(
2719                peer_id.into(),
2720                proto::UpdateChannels {
2721                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2722                        channel_id: channel_id.to_proto(),
2723                        epoch: epoch as u64,
2724                        version: version.clone(),
2725                    }],
2726                    ..Default::default()
2727                },
2728            )
2729        },
2730    );
2731
2732    Ok(())
2733}
2734
2735async fn rejoin_channel_buffers(
2736    request: proto::RejoinChannelBuffers,
2737    response: Response<proto::RejoinChannelBuffers>,
2738    session: Session,
2739) -> Result<()> {
2740    let db = session.db().await;
2741    let buffers = db
2742        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2743        .await?;
2744
2745    for rejoined_buffer in &buffers {
2746        let collaborators_to_notify = rejoined_buffer
2747            .buffer
2748            .collaborators
2749            .iter()
2750            .filter_map(|c| Some(c.peer_id?.into()));
2751        channel_buffer_updated(
2752            session.connection_id,
2753            collaborators_to_notify,
2754            &proto::UpdateChannelBufferCollaborators {
2755                channel_id: rejoined_buffer.buffer.channel_id,
2756                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2757            },
2758            &session.peer,
2759        );
2760    }
2761
2762    response.send(proto::RejoinChannelBuffersResponse {
2763        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2764    })?;
2765
2766    Ok(())
2767}
2768
2769async fn leave_channel_buffer(
2770    request: proto::LeaveChannelBuffer,
2771    response: Response<proto::LeaveChannelBuffer>,
2772    session: Session,
2773) -> Result<()> {
2774    let db = session.db().await;
2775    let channel_id = ChannelId::from_proto(request.channel_id);
2776
2777    let left_buffer = db
2778        .leave_channel_buffer(channel_id, session.connection_id)
2779        .await?;
2780
2781    response.send(Ack {})?;
2782
2783    channel_buffer_updated(
2784        session.connection_id,
2785        left_buffer.connections,
2786        &proto::UpdateChannelBufferCollaborators {
2787            channel_id: channel_id.to_proto(),
2788            collaborators: left_buffer.collaborators,
2789        },
2790        &session.peer,
2791    );
2792
2793    Ok(())
2794}
2795
2796fn channel_buffer_updated<T: EnvelopedMessage>(
2797    sender_id: ConnectionId,
2798    collaborators: impl IntoIterator<Item = ConnectionId>,
2799    message: &T,
2800    peer: &Peer,
2801) {
2802    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2803        peer.send(peer_id.into(), message.clone())
2804    });
2805}
2806
2807async fn send_channel_message(
2808    request: proto::SendChannelMessage,
2809    response: Response<proto::SendChannelMessage>,
2810    session: Session,
2811) -> Result<()> {
2812    // Validate the message body.
2813    let body = request.body.trim().to_string();
2814    if body.len() > MAX_MESSAGE_LEN {
2815        return Err(anyhow!("message is too long"))?;
2816    }
2817    if body.is_empty() {
2818        return Err(anyhow!("message can't be blank"))?;
2819    }
2820
2821    let timestamp = OffsetDateTime::now_utc();
2822    let nonce = request
2823        .nonce
2824        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2825
2826    let channel_id = ChannelId::from_proto(request.channel_id);
2827    let (message_id, connection_ids, non_participants) = session
2828        .db()
2829        .await
2830        .create_channel_message(
2831            channel_id,
2832            session.user_id,
2833            &body,
2834            timestamp,
2835            nonce.clone().into(),
2836        )
2837        .await?;
2838    let message = proto::ChannelMessage {
2839        sender_id: session.user_id.to_proto(),
2840        id: message_id.to_proto(),
2841        body,
2842        timestamp: timestamp.unix_timestamp() as u64,
2843        nonce: Some(nonce),
2844    };
2845    broadcast(Some(session.connection_id), connection_ids, |connection| {
2846        session.peer.send(
2847            connection,
2848            proto::ChannelMessageSent {
2849                channel_id: channel_id.to_proto(),
2850                message: Some(message.clone()),
2851            },
2852        )
2853    });
2854    response.send(proto::SendChannelMessageResponse {
2855        message: Some(message),
2856    })?;
2857
2858    let pool = &*session.connection_pool().await;
2859    broadcast(
2860        None,
2861        non_participants
2862            .iter()
2863            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2864        |peer_id| {
2865            session.peer.send(
2866                peer_id.into(),
2867                proto::UpdateChannels {
2868                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2869                        channel_id: channel_id.to_proto(),
2870                        message_id: message_id.to_proto(),
2871                    }],
2872                    ..Default::default()
2873                },
2874            )
2875        },
2876    );
2877
2878    Ok(())
2879}
2880
2881async fn remove_channel_message(
2882    request: proto::RemoveChannelMessage,
2883    response: Response<proto::RemoveChannelMessage>,
2884    session: Session,
2885) -> Result<()> {
2886    let channel_id = ChannelId::from_proto(request.channel_id);
2887    let message_id = MessageId::from_proto(request.message_id);
2888    let connection_ids = session
2889        .db()
2890        .await
2891        .remove_channel_message(channel_id, message_id, session.user_id)
2892        .await?;
2893    broadcast(Some(session.connection_id), connection_ids, |connection| {
2894        session.peer.send(connection, request.clone())
2895    });
2896    response.send(proto::Ack {})?;
2897    Ok(())
2898}
2899
2900async fn acknowledge_channel_message(
2901    request: proto::AckChannelMessage,
2902    session: Session,
2903) -> Result<()> {
2904    let channel_id = ChannelId::from_proto(request.channel_id);
2905    let message_id = MessageId::from_proto(request.message_id);
2906    session
2907        .db()
2908        .await
2909        .observe_channel_message(channel_id, session.user_id, message_id)
2910        .await?;
2911    Ok(())
2912}
2913
2914async fn acknowledge_buffer_version(
2915    request: proto::AckBufferOperation,
2916    session: Session,
2917) -> Result<()> {
2918    let buffer_id = BufferId::from_proto(request.buffer_id);
2919    session
2920        .db()
2921        .await
2922        .observe_buffer_version(
2923            buffer_id,
2924            session.user_id,
2925            request.epoch as i32,
2926            &request.version,
2927        )
2928        .await?;
2929    Ok(())
2930}
2931
2932async fn join_channel_chat(
2933    request: proto::JoinChannelChat,
2934    response: Response<proto::JoinChannelChat>,
2935    session: Session,
2936) -> Result<()> {
2937    let channel_id = ChannelId::from_proto(request.channel_id);
2938
2939    let db = session.db().await;
2940    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2941        .await?;
2942    let messages = db
2943        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2944        .await?;
2945    response.send(proto::JoinChannelChatResponse {
2946        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2947        messages,
2948    })?;
2949    Ok(())
2950}
2951
2952async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2953    let channel_id = ChannelId::from_proto(request.channel_id);
2954    session
2955        .db()
2956        .await
2957        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2958        .await?;
2959    Ok(())
2960}
2961
2962async fn get_channel_messages(
2963    request: proto::GetChannelMessages,
2964    response: Response<proto::GetChannelMessages>,
2965    session: Session,
2966) -> Result<()> {
2967    let channel_id = ChannelId::from_proto(request.channel_id);
2968    let messages = session
2969        .db()
2970        .await
2971        .get_channel_messages(
2972            channel_id,
2973            session.user_id,
2974            MESSAGE_COUNT_PER_PAGE,
2975            Some(MessageId::from_proto(request.before_message_id)),
2976        )
2977        .await?;
2978    response.send(proto::GetChannelMessagesResponse {
2979        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2980        messages,
2981    })?;
2982    Ok(())
2983}
2984
2985async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2986    let project_id = ProjectId::from_proto(request.project_id);
2987    let project_connection_ids = session
2988        .db()
2989        .await
2990        .project_connection_ids(project_id, session.connection_id)
2991        .await?;
2992    broadcast(
2993        Some(session.connection_id),
2994        project_connection_ids.iter().copied(),
2995        |connection_id| {
2996            session
2997                .peer
2998                .forward_send(session.connection_id, connection_id, request.clone())
2999        },
3000    );
3001    Ok(())
3002}
3003
3004async fn get_private_user_info(
3005    _request: proto::GetPrivateUserInfo,
3006    response: Response<proto::GetPrivateUserInfo>,
3007    session: Session,
3008) -> Result<()> {
3009    let db = session.db().await;
3010
3011    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3012    let user = db
3013        .get_user_by_id(session.user_id)
3014        .await?
3015        .ok_or_else(|| anyhow!("user not found"))?;
3016    let flags = db.get_user_flags(session.user_id).await?;
3017
3018    response.send(proto::GetPrivateUserInfoResponse {
3019        metrics_id,
3020        staff: user.admin,
3021        flags,
3022    })?;
3023    Ok(())
3024}
3025
3026fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3027    match message {
3028        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3029        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3030        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3031        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3032        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3033            code: frame.code.into(),
3034            reason: frame.reason,
3035        })),
3036    }
3037}
3038
3039fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3040    match message {
3041        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3042        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3043        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3044        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3045        AxumMessage::Close(frame) => {
3046            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3047                code: frame.code.into(),
3048                reason: frame.reason,
3049            }))
3050        }
3051    }
3052}
3053
3054fn build_channels_update(
3055    channels: ChannelsForUser,
3056    channel_invites: Vec<db::Channel>,
3057) -> proto::UpdateChannels {
3058    let mut update = proto::UpdateChannels::default();
3059
3060    for channel in channels.channels.channels {
3061        update.channels.push(proto::Channel {
3062            id: channel.id.to_proto(),
3063            name: channel.name,
3064            visibility: channel.visibility.into(),
3065            role: channel.role.into(),
3066        });
3067    }
3068
3069    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3070    update.unseen_channel_messages = channels.channel_messages;
3071    update.insert_edge = channels.channels.edges;
3072
3073    for (channel_id, participants) in channels.channel_participants {
3074        update
3075            .channel_participants
3076            .push(proto::ChannelParticipants {
3077                channel_id: channel_id.to_proto(),
3078                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3079            });
3080    }
3081
3082    for channel in channel_invites {
3083        update.channel_invitations.push(proto::Channel {
3084            id: channel.id.to_proto(),
3085            name: channel.name,
3086            visibility: channel.visibility.into(),
3087            role: channel.role.into(),
3088        });
3089    }
3090
3091    update
3092}
3093
3094fn build_initial_contacts_update(
3095    contacts: Vec<db::Contact>,
3096    pool: &ConnectionPool,
3097) -> proto::UpdateContacts {
3098    let mut update = proto::UpdateContacts::default();
3099
3100    for contact in contacts {
3101        match contact {
3102            db::Contact::Accepted {
3103                user_id,
3104                should_notify,
3105                busy,
3106            } => {
3107                update
3108                    .contacts
3109                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3110            }
3111            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3112            db::Contact::Incoming {
3113                user_id,
3114                should_notify,
3115            } => update
3116                .incoming_requests
3117                .push(proto::IncomingContactRequest {
3118                    requester_id: user_id.to_proto(),
3119                    should_notify,
3120                }),
3121        }
3122    }
3123
3124    update
3125}
3126
3127fn contact_for_user(
3128    user_id: UserId,
3129    should_notify: bool,
3130    busy: bool,
3131    pool: &ConnectionPool,
3132) -> proto::Contact {
3133    proto::Contact {
3134        user_id: user_id.to_proto(),
3135        online: pool.is_user_online(user_id),
3136        busy,
3137        should_notify,
3138    }
3139}
3140
3141fn room_updated(room: &proto::Room, peer: &Peer) {
3142    broadcast(
3143        None,
3144        room.participants
3145            .iter()
3146            .filter_map(|participant| Some(participant.peer_id?.into())),
3147        |peer_id| {
3148            peer.send(
3149                peer_id.into(),
3150                proto::RoomUpdated {
3151                    room: Some(room.clone()),
3152                },
3153            )
3154        },
3155    );
3156}
3157
3158fn channel_updated(
3159    channel_id: ChannelId,
3160    room: &proto::Room,
3161    channel_members: &[UserId],
3162    peer: &Peer,
3163    pool: &ConnectionPool,
3164) {
3165    let participants = room
3166        .participants
3167        .iter()
3168        .map(|p| p.user_id)
3169        .collect::<Vec<_>>();
3170
3171    broadcast(
3172        None,
3173        channel_members
3174            .iter()
3175            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3176        |peer_id| {
3177            peer.send(
3178                peer_id.into(),
3179                proto::UpdateChannels {
3180                    channel_participants: vec![proto::ChannelParticipants {
3181                        channel_id: channel_id.to_proto(),
3182                        participant_user_ids: participants.clone(),
3183                    }],
3184                    ..Default::default()
3185                },
3186            )
3187        },
3188    );
3189}
3190
3191async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3192    let db = session.db().await;
3193
3194    let contacts = db.get_contacts(user_id).await?;
3195    let busy = db.is_user_busy(user_id).await?;
3196
3197    let pool = session.connection_pool().await;
3198    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3199    for contact in contacts {
3200        if let db::Contact::Accepted {
3201            user_id: contact_user_id,
3202            ..
3203        } = contact
3204        {
3205            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3206                session
3207                    .peer
3208                    .send(
3209                        contact_conn_id,
3210                        proto::UpdateContacts {
3211                            contacts: vec![updated_contact.clone()],
3212                            remove_contacts: Default::default(),
3213                            incoming_requests: Default::default(),
3214                            remove_incoming_requests: Default::default(),
3215                            outgoing_requests: Default::default(),
3216                            remove_outgoing_requests: Default::default(),
3217                        },
3218                    )
3219                    .trace_err();
3220            }
3221        }
3222    }
3223    Ok(())
3224}
3225
3226async fn leave_room_for_session(session: &Session) -> Result<()> {
3227    let mut contacts_to_update = HashSet::default();
3228
3229    let room_id;
3230    let canceled_calls_to_user_ids;
3231    let live_kit_room;
3232    let delete_live_kit_room;
3233    let room;
3234    let channel_members;
3235    let channel_id;
3236
3237    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3238        contacts_to_update.insert(session.user_id);
3239
3240        for project in left_room.left_projects.values() {
3241            project_left(project, session);
3242        }
3243
3244        room_id = RoomId::from_proto(left_room.room.id);
3245        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3246        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3247        delete_live_kit_room = left_room.deleted;
3248        room = mem::take(&mut left_room.room);
3249        channel_members = mem::take(&mut left_room.channel_members);
3250        channel_id = left_room.channel_id;
3251
3252        room_updated(&room, &session.peer);
3253    } else {
3254        return Ok(());
3255    }
3256
3257    if let Some(channel_id) = channel_id {
3258        channel_updated(
3259            channel_id,
3260            &room,
3261            &channel_members,
3262            &session.peer,
3263            &*session.connection_pool().await,
3264        );
3265    }
3266
3267    {
3268        let pool = session.connection_pool().await;
3269        for canceled_user_id in canceled_calls_to_user_ids {
3270            for connection_id in pool.user_connection_ids(canceled_user_id) {
3271                session
3272                    .peer
3273                    .send(
3274                        connection_id,
3275                        proto::CallCanceled {
3276                            room_id: room_id.to_proto(),
3277                        },
3278                    )
3279                    .trace_err();
3280            }
3281            contacts_to_update.insert(canceled_user_id);
3282        }
3283    }
3284
3285    for contact_user_id in contacts_to_update {
3286        update_user_contacts(contact_user_id, &session).await?;
3287    }
3288
3289    if let Some(live_kit) = session.live_kit_client.as_ref() {
3290        live_kit
3291            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3292            .await
3293            .trace_err();
3294
3295        if delete_live_kit_room {
3296            live_kit.delete_room(live_kit_room).await.trace_err();
3297        }
3298    }
3299
3300    Ok(())
3301}
3302
3303async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3304    let left_channel_buffers = session
3305        .db()
3306        .await
3307        .leave_channel_buffers(session.connection_id)
3308        .await?;
3309
3310    for left_buffer in left_channel_buffers {
3311        channel_buffer_updated(
3312            session.connection_id,
3313            left_buffer.connections,
3314            &proto::UpdateChannelBufferCollaborators {
3315                channel_id: left_buffer.channel_id.to_proto(),
3316                collaborators: left_buffer.collaborators,
3317            },
3318            &session.peer,
3319        );
3320    }
3321
3322    Ok(())
3323}
3324
3325fn project_left(project: &db::LeftProject, session: &Session) {
3326    for connection_id in &project.connection_ids {
3327        if project.host_user_id == session.user_id {
3328            session
3329                .peer
3330                .send(
3331                    *connection_id,
3332                    proto::UnshareProject {
3333                        project_id: project.id.to_proto(),
3334                    },
3335                )
3336                .trace_err();
3337        } else {
3338            session
3339                .peer
3340                .send(
3341                    *connection_id,
3342                    proto::RemoveProjectCollaborator {
3343                        project_id: project.id.to_proto(),
3344                        peer_id: Some(session.connection_id.into()),
3345                    },
3346                )
3347                .trace_err();
3348        }
3349    }
3350}
3351
3352pub trait ResultExt {
3353    type Ok;
3354
3355    fn trace_err(self) -> Option<Self::Ok>;
3356}
3357
3358impl<T, E> ResultExt for Result<T, E>
3359where
3360    E: std::fmt::Debug,
3361{
3362    type Ok = T;
3363
3364    fn trace_err(self) -> Option<T> {
3365        match self {
3366            Ok(value) => Some(value),
3367            Err(error) => {
3368                tracing::error!("{:?}", error);
3369                None
3370            }
3371        }
3372    }
3373}