rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69use util::channel::RELEASE_CHANNEL_NAME;
  70
  71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  73
  74const MESSAGE_COUNT_PER_PAGE: usize = 100;
  75const MAX_MESSAGE_LEN: usize = 1024;
  76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  77
  78lazy_static! {
  79    static ref METRIC_CONNECTIONS: IntGauge =
  80        register_int_gauge!("connections", "number of connections").unwrap();
  81    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  82        "shared_projects",
  83        "number of open projects with one or more guests"
  84    )
  85    .unwrap();
  86}
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone)]
 106struct Session {
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(call)
 206            .add_request_handler(cancel_call)
 207            .add_message_handler(decline_call)
 208            .add_request_handler(update_participant_location)
 209            .add_request_handler(share_project)
 210            .add_message_handler(unshare_project)
 211            .add_request_handler(join_project)
 212            .add_message_handler(leave_project)
 213            .add_request_handler(update_project)
 214            .add_request_handler(update_worktree)
 215            .add_message_handler(start_language_server)
 216            .add_message_handler(update_language_server)
 217            .add_message_handler(update_diagnostic_summary)
 218            .add_message_handler(update_worktree_settings)
 219            .add_message_handler(refresh_inlay_hints)
 220            .add_request_handler(forward_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 230            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 232            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 233            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 234            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 235            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 236            .add_request_handler(forward_project_request::<proto::PerformRename>)
 237            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 238            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 239            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 240            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 243            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 244            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 245            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 246            .add_request_handler(forward_project_request::<proto::InlayHints>)
 247            .add_message_handler(create_buffer_for_peer)
 248            .add_request_handler(update_buffer)
 249            .add_message_handler(update_buffer_file)
 250            .add_message_handler(buffer_reloaded)
 251            .add_message_handler(buffer_saved)
 252            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 253            .add_request_handler(get_users)
 254            .add_request_handler(fuzzy_search_users)
 255            .add_request_handler(request_contact)
 256            .add_request_handler(remove_contact)
 257            .add_request_handler(respond_to_contact_request)
 258            .add_request_handler(create_channel)
 259            .add_request_handler(delete_channel)
 260            .add_request_handler(invite_channel_member)
 261            .add_request_handler(remove_channel_member)
 262            .add_request_handler(set_channel_member_role)
 263            .add_request_handler(set_channel_visibility)
 264            .add_request_handler(rename_channel)
 265            .add_request_handler(join_channel_buffer)
 266            .add_request_handler(leave_channel_buffer)
 267            .add_message_handler(update_channel_buffer)
 268            .add_request_handler(rejoin_channel_buffers)
 269            .add_request_handler(get_channel_members)
 270            .add_request_handler(respond_to_channel_invite)
 271            .add_request_handler(join_channel)
 272            .add_request_handler(join_channel_chat)
 273            .add_message_handler(leave_channel_chat)
 274            .add_request_handler(send_channel_message)
 275            .add_request_handler(remove_channel_message)
 276            .add_request_handler(get_channel_messages)
 277            .add_request_handler(get_channel_messages_by_id)
 278            .add_request_handler(get_notifications)
 279            .add_request_handler(mark_notification_as_read)
 280            .add_request_handler(move_channel)
 281            .add_request_handler(follow)
 282            .add_message_handler(unfollow)
 283            .add_message_handler(update_followers)
 284            .add_message_handler(update_diff_base)
 285            .add_request_handler(get_private_user_info)
 286            .add_message_handler(acknowledge_channel_message)
 287            .add_message_handler(acknowledge_buffer_version);
 288
 289        Arc::new(server)
 290    }
 291
 292    pub async fn start(&self) -> Result<()> {
 293        let server_id = *self.id.lock();
 294        let app_state = self.app_state.clone();
 295        let peer = self.peer.clone();
 296        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 297        let pool = self.connection_pool.clone();
 298        let live_kit_client = self.app_state.live_kit_client.clone();
 299
 300        let span = info_span!("start server");
 301        self.executor.spawn_detached(
 302            async move {
 303                tracing::info!("waiting for cleanup timeout");
 304                timeout.await;
 305                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 306                if let Some((room_ids, channel_ids)) = app_state
 307                    .db
 308                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 309                    .await
 310                    .trace_err()
 311                {
 312                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 313                    tracing::info!(
 314                        stale_channel_buffer_count = channel_ids.len(),
 315                        "retrieved stale channel buffers"
 316                    );
 317
 318                    for channel_id in channel_ids {
 319                        if let Some(refreshed_channel_buffer) = app_state
 320                            .db
 321                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 322                            .await
 323                            .trace_err()
 324                        {
 325                            for connection_id in refreshed_channel_buffer.connection_ids {
 326                                peer.send(
 327                                    connection_id,
 328                                    proto::UpdateChannelBufferCollaborators {
 329                                        channel_id: channel_id.to_proto(),
 330                                        collaborators: refreshed_channel_buffer
 331                                            .collaborators
 332                                            .clone(),
 333                                    },
 334                                )
 335                                .trace_err();
 336                            }
 337                        }
 338                    }
 339
 340                    for room_id in room_ids {
 341                        let mut contacts_to_update = HashSet::default();
 342                        let mut canceled_calls_to_user_ids = Vec::new();
 343                        let mut live_kit_room = String::new();
 344                        let mut delete_live_kit_room = false;
 345
 346                        if let Some(mut refreshed_room) = app_state
 347                            .db
 348                            .clear_stale_room_participants(room_id, server_id)
 349                            .await
 350                            .trace_err()
 351                        {
 352                            tracing::info!(
 353                                room_id = room_id.0,
 354                                new_participant_count = refreshed_room.room.participants.len(),
 355                                "refreshed room"
 356                            );
 357                            room_updated(&refreshed_room.room, &peer);
 358                            if let Some(channel_id) = refreshed_room.channel_id {
 359                                channel_updated(
 360                                    channel_id,
 361                                    &refreshed_room.room,
 362                                    &refreshed_room.channel_members,
 363                                    &peer,
 364                                    &*pool.lock(),
 365                                );
 366                            }
 367                            contacts_to_update
 368                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 369                            contacts_to_update
 370                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 371                            canceled_calls_to_user_ids =
 372                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 373                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 374                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 375                        }
 376
 377                        {
 378                            let pool = pool.lock();
 379                            for canceled_user_id in canceled_calls_to_user_ids {
 380                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 381                                    peer.send(
 382                                        connection_id,
 383                                        proto::CallCanceled {
 384                                            room_id: room_id.to_proto(),
 385                                        },
 386                                    )
 387                                    .trace_err();
 388                                }
 389                            }
 390                        }
 391
 392                        for user_id in contacts_to_update {
 393                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 394                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 395                            if let Some((busy, contacts)) = busy.zip(contacts) {
 396                                let pool = pool.lock();
 397                                let updated_contact = contact_for_user(user_id, busy, &pool);
 398                                for contact in contacts {
 399                                    if let db::Contact::Accepted {
 400                                        user_id: contact_user_id,
 401                                        ..
 402                                    } = contact
 403                                    {
 404                                        for contact_conn_id in
 405                                            pool.user_connection_ids(contact_user_id)
 406                                        {
 407                                            peer.send(
 408                                                contact_conn_id,
 409                                                proto::UpdateContacts {
 410                                                    contacts: vec![updated_contact.clone()],
 411                                                    remove_contacts: Default::default(),
 412                                                    incoming_requests: Default::default(),
 413                                                    remove_incoming_requests: Default::default(),
 414                                                    outgoing_requests: Default::default(),
 415                                                    remove_outgoing_requests: Default::default(),
 416                                                },
 417                                            )
 418                                            .trace_err();
 419                                        }
 420                                    }
 421                                }
 422                            }
 423                        }
 424
 425                        if let Some(live_kit) = live_kit_client.as_ref() {
 426                            if delete_live_kit_room {
 427                                live_kit.delete_room(live_kit_room).await.trace_err();
 428                            }
 429                        }
 430                    }
 431                }
 432
 433                app_state
 434                    .db
 435                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 436                    .await
 437                    .trace_err();
 438            }
 439            .instrument(span),
 440        );
 441        Ok(())
 442    }
 443
 444    pub fn teardown(&self) {
 445        self.peer.teardown();
 446        self.connection_pool.lock().reset();
 447        let _ = self.teardown.send(());
 448    }
 449
 450    #[cfg(test)]
 451    pub fn reset(&self, id: ServerId) {
 452        self.teardown();
 453        *self.id.lock() = id;
 454        self.peer.reset(id.0 as u32);
 455    }
 456
 457    #[cfg(test)]
 458    pub fn id(&self) -> ServerId {
 459        *self.id.lock()
 460    }
 461
 462    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 463    where
 464        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 465        Fut: 'static + Send + Future<Output = Result<()>>,
 466        M: EnvelopedMessage,
 467    {
 468        let prev_handler = self.handlers.insert(
 469            TypeId::of::<M>(),
 470            Box::new(move |envelope, session| {
 471                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 472                let span = info_span!(
 473                    "handle message",
 474                    payload_type = envelope.payload_type_name()
 475                );
 476                span.in_scope(|| {
 477                    tracing::info!(
 478                        payload_type = envelope.payload_type_name(),
 479                        "message received"
 480                    );
 481                });
 482                let start_time = Instant::now();
 483                let future = (handler)(*envelope, session);
 484                async move {
 485                    let result = future.await;
 486                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 487                    match result {
 488                        Err(error) => {
 489                            tracing::error!(%error, ?duration_ms, "error handling message")
 490                        }
 491                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 492                    }
 493                }
 494                .instrument(span)
 495                .boxed()
 496            }),
 497        );
 498        if prev_handler.is_some() {
 499            panic!("registered a handler for the same message twice");
 500        }
 501        self
 502    }
 503
 504    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 505    where
 506        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 507        Fut: 'static + Send + Future<Output = Result<()>>,
 508        M: EnvelopedMessage,
 509    {
 510        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 511        self
 512    }
 513
 514    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 515    where
 516        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 517        Fut: Send + Future<Output = Result<()>>,
 518        M: RequestMessage,
 519    {
 520        let handler = Arc::new(handler);
 521        self.add_handler(move |envelope, session| {
 522            let receipt = envelope.receipt();
 523            let handler = handler.clone();
 524            async move {
 525                let peer = session.peer.clone();
 526                let responded = Arc::new(AtomicBool::default());
 527                let response = Response {
 528                    peer: peer.clone(),
 529                    responded: responded.clone(),
 530                    receipt,
 531                };
 532                match (handler)(envelope.payload, response, session).await {
 533                    Ok(()) => {
 534                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 535                            Ok(())
 536                        } else {
 537                            Err(anyhow!("handler did not send a response"))?
 538                        }
 539                    }
 540                    Err(error) => {
 541                        peer.respond_with_error(
 542                            receipt,
 543                            proto::Error {
 544                                message: error.to_string(),
 545                            },
 546                        )?;
 547                        Err(error)
 548                    }
 549                }
 550            }
 551        })
 552    }
 553
 554    pub fn handle_connection(
 555        self: &Arc<Self>,
 556        connection: Connection,
 557        address: String,
 558        user: User,
 559        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 560        executor: Executor,
 561    ) -> impl Future<Output = Result<()>> {
 562        let this = self.clone();
 563        let user_id = user.id;
 564        let login = user.github_login;
 565        let span = info_span!("handle connection", %user_id, %login, %address);
 566        let mut teardown = self.teardown.subscribe();
 567        async move {
 568            let (connection_id, handle_io, mut incoming_rx) = this
 569                .peer
 570                .add_connection(connection, {
 571                    let executor = executor.clone();
 572                    move |duration| executor.sleep(duration)
 573                });
 574
 575            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 576            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 577            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 578
 579            if let Some(send_connection_id) = send_connection_id.take() {
 580                let _ = send_connection_id.send(connection_id);
 581            }
 582
 583            if !user.connected_once {
 584                this.peer.send(connection_id, proto::ShowContacts {})?;
 585                this.app_state.db.set_user_connected_once(user_id, true).await?;
 586            }
 587
 588            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 589                this.app_state.db.get_contacts(user_id),
 590                this.app_state.db.get_channels_for_user(user_id),
 591                this.app_state.db.get_channel_invites_for_user(user_id),
 592            ).await?;
 593
 594            {
 595                let mut pool = this.connection_pool.lock();
 596                pool.add_connection(connection_id, user_id, user.admin);
 597                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 598                this.peer.send(connection_id, build_channels_update(
 599                    channels_for_user,
 600                    channel_invites
 601                ))?;
 602            }
 603
 604            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 605                this.peer.send(connection_id, incoming_call)?;
 606            }
 607
 608            let session = Session {
 609                user_id,
 610                connection_id,
 611                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 612                peer: this.peer.clone(),
 613                connection_pool: this.connection_pool.clone(),
 614                live_kit_client: this.app_state.live_kit_client.clone(),
 615                executor: executor.clone(),
 616            };
 617            update_user_contacts(user_id, &session).await?;
 618
 619            let handle_io = handle_io.fuse();
 620            futures::pin_mut!(handle_io);
 621
 622            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 623            // This prevents deadlocks when e.g., client A performs a request to client B and
 624            // client B performs a request to client A. If both clients stop processing further
 625            // messages until their respective request completes, they won't have a chance to
 626            // respond to the other client's request and cause a deadlock.
 627            //
 628            // This arrangement ensures we will attempt to process earlier messages first, but fall
 629            // back to processing messages arrived later in the spirit of making progress.
 630            let mut foreground_message_handlers = FuturesUnordered::new();
 631            let concurrent_handlers = Arc::new(Semaphore::new(256));
 632            loop {
 633                let next_message = async {
 634                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 635                    let message = incoming_rx.next().await;
 636                    (permit, message)
 637                }.fuse();
 638                futures::pin_mut!(next_message);
 639                futures::select_biased! {
 640                    _ = teardown.changed().fuse() => return Ok(()),
 641                    result = handle_io => {
 642                        if let Err(error) = result {
 643                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 644                        }
 645                        break;
 646                    }
 647                    _ = foreground_message_handlers.next() => {}
 648                    next_message = next_message => {
 649                        let (permit, message) = next_message;
 650                        if let Some(message) = message {
 651                            let type_name = message.payload_type_name();
 652                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 653                            let span_enter = span.enter();
 654                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 655                                let is_background = message.is_background();
 656                                let handle_message = (handler)(message, session.clone());
 657                                drop(span_enter);
 658
 659                                let handle_message = async move {
 660                                    handle_message.await;
 661                                    drop(permit);
 662                                }.instrument(span);
 663                                if is_background {
 664                                    executor.spawn_detached(handle_message);
 665                                } else {
 666                                    foreground_message_handlers.push(handle_message);
 667                                }
 668                            } else {
 669                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 670                            }
 671                        } else {
 672                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 673                            break;
 674                        }
 675                    }
 676                }
 677            }
 678
 679            drop(foreground_message_handlers);
 680            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 681            if let Err(error) = connection_lost(session, teardown, executor).await {
 682                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 683            }
 684
 685            Ok(())
 686        }.instrument(span)
 687    }
 688
 689    pub async fn invite_code_redeemed(
 690        self: &Arc<Self>,
 691        inviter_id: UserId,
 692        invitee_id: UserId,
 693    ) -> Result<()> {
 694        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 695            if let Some(code) = &user.invite_code {
 696                let pool = self.connection_pool.lock();
 697                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 698                for connection_id in pool.user_connection_ids(inviter_id) {
 699                    self.peer.send(
 700                        connection_id,
 701                        proto::UpdateContacts {
 702                            contacts: vec![invitee_contact.clone()],
 703                            ..Default::default()
 704                        },
 705                    )?;
 706                    self.peer.send(
 707                        connection_id,
 708                        proto::UpdateInviteInfo {
 709                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 710                            count: user.invite_count as u32,
 711                        },
 712                    )?;
 713                }
 714            }
 715        }
 716        Ok(())
 717    }
 718
 719    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 720        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 721            if let Some(invite_code) = &user.invite_code {
 722                let pool = self.connection_pool.lock();
 723                for connection_id in pool.user_connection_ids(user_id) {
 724                    self.peer.send(
 725                        connection_id,
 726                        proto::UpdateInviteInfo {
 727                            url: format!(
 728                                "{}{}",
 729                                self.app_state.config.invite_link_prefix, invite_code
 730                            ),
 731                            count: user.invite_count as u32,
 732                        },
 733                    )?;
 734                }
 735            }
 736        }
 737        Ok(())
 738    }
 739
 740    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 741        ServerSnapshot {
 742            connection_pool: ConnectionPoolGuard {
 743                guard: self.connection_pool.lock(),
 744                _not_send: PhantomData,
 745            },
 746            peer: &self.peer,
 747        }
 748    }
 749}
 750
 751impl<'a> Deref for ConnectionPoolGuard<'a> {
 752    type Target = ConnectionPool;
 753
 754    fn deref(&self) -> &Self::Target {
 755        &*self.guard
 756    }
 757}
 758
 759impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 760    fn deref_mut(&mut self) -> &mut Self::Target {
 761        &mut *self.guard
 762    }
 763}
 764
 765impl<'a> Drop for ConnectionPoolGuard<'a> {
 766    fn drop(&mut self) {
 767        #[cfg(test)]
 768        self.check_invariants();
 769    }
 770}
 771
 772fn broadcast<F>(
 773    sender_id: Option<ConnectionId>,
 774    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 775    mut f: F,
 776) where
 777    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 778{
 779    for receiver_id in receiver_ids {
 780        if Some(receiver_id) != sender_id {
 781            if let Err(error) = f(receiver_id) {
 782                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 783            }
 784        }
 785    }
 786}
 787
 788lazy_static! {
 789    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 790}
 791
 792pub struct ProtocolVersion(u32);
 793
 794impl Header for ProtocolVersion {
 795    fn name() -> &'static HeaderName {
 796        &ZED_PROTOCOL_VERSION
 797    }
 798
 799    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 800    where
 801        Self: Sized,
 802        I: Iterator<Item = &'i axum::http::HeaderValue>,
 803    {
 804        let version = values
 805            .next()
 806            .ok_or_else(axum::headers::Error::invalid)?
 807            .to_str()
 808            .map_err(|_| axum::headers::Error::invalid())?
 809            .parse()
 810            .map_err(|_| axum::headers::Error::invalid())?;
 811        Ok(Self(version))
 812    }
 813
 814    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 815        values.extend([self.0.to_string().parse().unwrap()]);
 816    }
 817}
 818
 819pub fn routes(server: Arc<Server>) -> Router<Body> {
 820    Router::new()
 821        .route("/rpc", get(handle_websocket_request))
 822        .layer(
 823            ServiceBuilder::new()
 824                .layer(Extension(server.app_state.clone()))
 825                .layer(middleware::from_fn(auth::validate_header)),
 826        )
 827        .route("/metrics", get(handle_metrics))
 828        .layer(Extension(server))
 829}
 830
 831pub async fn handle_websocket_request(
 832    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 833    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 834    Extension(server): Extension<Arc<Server>>,
 835    Extension(user): Extension<User>,
 836    ws: WebSocketUpgrade,
 837) -> axum::response::Response {
 838    if protocol_version != rpc::PROTOCOL_VERSION {
 839        return (
 840            StatusCode::UPGRADE_REQUIRED,
 841            "client must be upgraded".to_string(),
 842        )
 843            .into_response();
 844    }
 845    let socket_address = socket_address.to_string();
 846    ws.on_upgrade(move |socket| {
 847        use util::ResultExt;
 848        let socket = socket
 849            .map_ok(to_tungstenite_message)
 850            .err_into()
 851            .with(|message| async move { Ok(to_axum_message(message)) });
 852        let connection = Connection::new(Box::pin(socket));
 853        async move {
 854            server
 855                .handle_connection(connection, socket_address, user, None, Executor::Production)
 856                .await
 857                .log_err();
 858        }
 859    })
 860}
 861
 862pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 863    let connections = server
 864        .connection_pool
 865        .lock()
 866        .connections()
 867        .filter(|connection| !connection.admin)
 868        .count();
 869
 870    METRIC_CONNECTIONS.set(connections as _);
 871
 872    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 873    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 874
 875    let encoder = prometheus::TextEncoder::new();
 876    let metric_families = prometheus::gather();
 877    let encoded_metrics = encoder
 878        .encode_to_string(&metric_families)
 879        .map_err(|err| anyhow!("{}", err))?;
 880    Ok(encoded_metrics)
 881}
 882
 883#[instrument(err, skip(executor))]
 884async fn connection_lost(
 885    session: Session,
 886    mut teardown: watch::Receiver<()>,
 887    executor: Executor,
 888) -> Result<()> {
 889    session.peer.disconnect(session.connection_id);
 890    session
 891        .connection_pool()
 892        .await
 893        .remove_connection(session.connection_id)?;
 894
 895    session
 896        .db()
 897        .await
 898        .connection_lost(session.connection_id)
 899        .await
 900        .trace_err();
 901
 902    futures::select_biased! {
 903        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 904            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 905            leave_room_for_session(&session).await.trace_err();
 906            leave_channel_buffers_for_session(&session)
 907                .await
 908                .trace_err();
 909
 910            if !session
 911                .connection_pool()
 912                .await
 913                .is_user_online(session.user_id)
 914            {
 915                let db = session.db().await;
 916                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 917                    room_updated(&room, &session.peer);
 918                }
 919            }
 920
 921            update_user_contacts(session.user_id, &session).await?;
 922        }
 923        _ = teardown.changed().fuse() => {}
 924    }
 925
 926    Ok(())
 927}
 928
 929async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 930    response.send(proto::Ack {})?;
 931    Ok(())
 932}
 933
 934async fn create_room(
 935    _request: proto::CreateRoom,
 936    response: Response<proto::CreateRoom>,
 937    session: Session,
 938) -> Result<()> {
 939    let live_kit_room = nanoid::nanoid!(30);
 940
 941    let live_kit_connection_info = {
 942        let live_kit_room = live_kit_room.clone();
 943        let live_kit = session.live_kit_client.as_ref();
 944
 945        util::async_iife!({
 946            let live_kit = live_kit?;
 947
 948            let token = live_kit
 949                .room_token(&live_kit_room, &session.user_id.to_string())
 950                .trace_err()?;
 951
 952            Some(proto::LiveKitConnectionInfo {
 953                server_url: live_kit.url().into(),
 954                token,
 955                can_publish: true,
 956            })
 957        })
 958    }
 959    .await;
 960
 961    let room = session
 962        .db()
 963        .await
 964        .create_room(
 965            session.user_id,
 966            session.connection_id,
 967            &live_kit_room,
 968            RELEASE_CHANNEL_NAME.as_str(),
 969        )
 970        .await?;
 971
 972    response.send(proto::CreateRoomResponse {
 973        room: Some(room.clone()),
 974        live_kit_connection_info,
 975    })?;
 976
 977    update_user_contacts(session.user_id, &session).await?;
 978    Ok(())
 979}
 980
 981async fn join_room(
 982    request: proto::JoinRoom,
 983    response: Response<proto::JoinRoom>,
 984    session: Session,
 985) -> Result<()> {
 986    let room_id = RoomId::from_proto(request.id);
 987
 988    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 989
 990    if let Some(channel_id) = channel_id {
 991        return join_channel_internal(channel_id, Box::new(response), session).await;
 992    }
 993
 994    let joined_room = {
 995        let room = session
 996            .db()
 997            .await
 998            .join_room(
 999                room_id,
1000                session.user_id,
1001                session.connection_id,
1002                RELEASE_CHANNEL_NAME.as_str(),
1003            )
1004            .await?;
1005        room_updated(&room.room, &session.peer);
1006        room.into_inner()
1007    };
1008
1009    for connection_id in session
1010        .connection_pool()
1011        .await
1012        .user_connection_ids(session.user_id)
1013    {
1014        session
1015            .peer
1016            .send(
1017                connection_id,
1018                proto::CallCanceled {
1019                    room_id: room_id.to_proto(),
1020                },
1021            )
1022            .trace_err();
1023    }
1024
1025    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1026        if let Some(token) = live_kit
1027            .room_token(
1028                &joined_room.room.live_kit_room,
1029                &session.user_id.to_string(),
1030            )
1031            .trace_err()
1032        {
1033            Some(proto::LiveKitConnectionInfo {
1034                server_url: live_kit.url().into(),
1035                token,
1036                can_publish: true,
1037            })
1038        } else {
1039            None
1040        }
1041    } else {
1042        None
1043    };
1044
1045    response.send(proto::JoinRoomResponse {
1046        room: Some(joined_room.room),
1047        channel_id: None,
1048        live_kit_connection_info,
1049    })?;
1050
1051    update_user_contacts(session.user_id, &session).await?;
1052    Ok(())
1053}
1054
1055async fn rejoin_room(
1056    request: proto::RejoinRoom,
1057    response: Response<proto::RejoinRoom>,
1058    session: Session,
1059) -> Result<()> {
1060    let room;
1061    let channel_id;
1062    let channel_members;
1063    {
1064        let mut rejoined_room = session
1065            .db()
1066            .await
1067            .rejoin_room(request, session.user_id, session.connection_id)
1068            .await?;
1069
1070        response.send(proto::RejoinRoomResponse {
1071            room: Some(rejoined_room.room.clone()),
1072            reshared_projects: rejoined_room
1073                .reshared_projects
1074                .iter()
1075                .map(|project| proto::ResharedProject {
1076                    id: project.id.to_proto(),
1077                    collaborators: project
1078                        .collaborators
1079                        .iter()
1080                        .map(|collaborator| collaborator.to_proto())
1081                        .collect(),
1082                })
1083                .collect(),
1084            rejoined_projects: rejoined_room
1085                .rejoined_projects
1086                .iter()
1087                .map(|rejoined_project| proto::RejoinedProject {
1088                    id: rejoined_project.id.to_proto(),
1089                    worktrees: rejoined_project
1090                        .worktrees
1091                        .iter()
1092                        .map(|worktree| proto::WorktreeMetadata {
1093                            id: worktree.id,
1094                            root_name: worktree.root_name.clone(),
1095                            visible: worktree.visible,
1096                            abs_path: worktree.abs_path.clone(),
1097                        })
1098                        .collect(),
1099                    collaborators: rejoined_project
1100                        .collaborators
1101                        .iter()
1102                        .map(|collaborator| collaborator.to_proto())
1103                        .collect(),
1104                    language_servers: rejoined_project.language_servers.clone(),
1105                })
1106                .collect(),
1107        })?;
1108        room_updated(&rejoined_room.room, &session.peer);
1109
1110        for project in &rejoined_room.reshared_projects {
1111            for collaborator in &project.collaborators {
1112                session
1113                    .peer
1114                    .send(
1115                        collaborator.connection_id,
1116                        proto::UpdateProjectCollaborator {
1117                            project_id: project.id.to_proto(),
1118                            old_peer_id: Some(project.old_connection_id.into()),
1119                            new_peer_id: Some(session.connection_id.into()),
1120                        },
1121                    )
1122                    .trace_err();
1123            }
1124
1125            broadcast(
1126                Some(session.connection_id),
1127                project
1128                    .collaborators
1129                    .iter()
1130                    .map(|collaborator| collaborator.connection_id),
1131                |connection_id| {
1132                    session.peer.forward_send(
1133                        session.connection_id,
1134                        connection_id,
1135                        proto::UpdateProject {
1136                            project_id: project.id.to_proto(),
1137                            worktrees: project.worktrees.clone(),
1138                        },
1139                    )
1140                },
1141            );
1142        }
1143
1144        for project in &rejoined_room.rejoined_projects {
1145            for collaborator in &project.collaborators {
1146                session
1147                    .peer
1148                    .send(
1149                        collaborator.connection_id,
1150                        proto::UpdateProjectCollaborator {
1151                            project_id: project.id.to_proto(),
1152                            old_peer_id: Some(project.old_connection_id.into()),
1153                            new_peer_id: Some(session.connection_id.into()),
1154                        },
1155                    )
1156                    .trace_err();
1157            }
1158        }
1159
1160        for project in &mut rejoined_room.rejoined_projects {
1161            for worktree in mem::take(&mut project.worktrees) {
1162                #[cfg(any(test, feature = "test-support"))]
1163                const MAX_CHUNK_SIZE: usize = 2;
1164                #[cfg(not(any(test, feature = "test-support")))]
1165                const MAX_CHUNK_SIZE: usize = 256;
1166
1167                // Stream this worktree's entries.
1168                let message = proto::UpdateWorktree {
1169                    project_id: project.id.to_proto(),
1170                    worktree_id: worktree.id,
1171                    abs_path: worktree.abs_path.clone(),
1172                    root_name: worktree.root_name,
1173                    updated_entries: worktree.updated_entries,
1174                    removed_entries: worktree.removed_entries,
1175                    scan_id: worktree.scan_id,
1176                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1177                    updated_repositories: worktree.updated_repositories,
1178                    removed_repositories: worktree.removed_repositories,
1179                };
1180                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1181                    session.peer.send(session.connection_id, update.clone())?;
1182                }
1183
1184                // Stream this worktree's diagnostics.
1185                for summary in worktree.diagnostic_summaries {
1186                    session.peer.send(
1187                        session.connection_id,
1188                        proto::UpdateDiagnosticSummary {
1189                            project_id: project.id.to_proto(),
1190                            worktree_id: worktree.id,
1191                            summary: Some(summary),
1192                        },
1193                    )?;
1194                }
1195
1196                for settings_file in worktree.settings_files {
1197                    session.peer.send(
1198                        session.connection_id,
1199                        proto::UpdateWorktreeSettings {
1200                            project_id: project.id.to_proto(),
1201                            worktree_id: worktree.id,
1202                            path: settings_file.path,
1203                            content: Some(settings_file.content),
1204                        },
1205                    )?;
1206                }
1207            }
1208
1209            for language_server in &project.language_servers {
1210                session.peer.send(
1211                    session.connection_id,
1212                    proto::UpdateLanguageServer {
1213                        project_id: project.id.to_proto(),
1214                        language_server_id: language_server.id,
1215                        variant: Some(
1216                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1217                                proto::LspDiskBasedDiagnosticsUpdated {},
1218                            ),
1219                        ),
1220                    },
1221                )?;
1222            }
1223        }
1224
1225        let rejoined_room = rejoined_room.into_inner();
1226
1227        room = rejoined_room.room;
1228        channel_id = rejoined_room.channel_id;
1229        channel_members = rejoined_room.channel_members;
1230    }
1231
1232    if let Some(channel_id) = channel_id {
1233        channel_updated(
1234            channel_id,
1235            &room,
1236            &channel_members,
1237            &session.peer,
1238            &*session.connection_pool().await,
1239        );
1240    }
1241
1242    update_user_contacts(session.user_id, &session).await?;
1243    Ok(())
1244}
1245
1246async fn leave_room(
1247    _: proto::LeaveRoom,
1248    response: Response<proto::LeaveRoom>,
1249    session: Session,
1250) -> Result<()> {
1251    leave_room_for_session(&session).await?;
1252    response.send(proto::Ack {})?;
1253    Ok(())
1254}
1255
1256async fn call(
1257    request: proto::Call,
1258    response: Response<proto::Call>,
1259    session: Session,
1260) -> Result<()> {
1261    let room_id = RoomId::from_proto(request.room_id);
1262    let calling_user_id = session.user_id;
1263    let calling_connection_id = session.connection_id;
1264    let called_user_id = UserId::from_proto(request.called_user_id);
1265    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1266    if !session
1267        .db()
1268        .await
1269        .has_contact(calling_user_id, called_user_id)
1270        .await?
1271    {
1272        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1273    }
1274
1275    let incoming_call = {
1276        let (room, incoming_call) = &mut *session
1277            .db()
1278            .await
1279            .call(
1280                room_id,
1281                calling_user_id,
1282                calling_connection_id,
1283                called_user_id,
1284                initial_project_id,
1285            )
1286            .await?;
1287        room_updated(&room, &session.peer);
1288        mem::take(incoming_call)
1289    };
1290    update_user_contacts(called_user_id, &session).await?;
1291
1292    let mut calls = session
1293        .connection_pool()
1294        .await
1295        .user_connection_ids(called_user_id)
1296        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1297        .collect::<FuturesUnordered<_>>();
1298
1299    while let Some(call_response) = calls.next().await {
1300        match call_response.as_ref() {
1301            Ok(_) => {
1302                response.send(proto::Ack {})?;
1303                return Ok(());
1304            }
1305            Err(_) => {
1306                call_response.trace_err();
1307            }
1308        }
1309    }
1310
1311    {
1312        let room = session
1313            .db()
1314            .await
1315            .call_failed(room_id, called_user_id)
1316            .await?;
1317        room_updated(&room, &session.peer);
1318    }
1319    update_user_contacts(called_user_id, &session).await?;
1320
1321    Err(anyhow!("failed to ring user"))?
1322}
1323
1324async fn cancel_call(
1325    request: proto::CancelCall,
1326    response: Response<proto::CancelCall>,
1327    session: Session,
1328) -> Result<()> {
1329    let called_user_id = UserId::from_proto(request.called_user_id);
1330    let room_id = RoomId::from_proto(request.room_id);
1331    {
1332        let room = session
1333            .db()
1334            .await
1335            .cancel_call(room_id, session.connection_id, called_user_id)
1336            .await?;
1337        room_updated(&room, &session.peer);
1338    }
1339
1340    for connection_id in session
1341        .connection_pool()
1342        .await
1343        .user_connection_ids(called_user_id)
1344    {
1345        session
1346            .peer
1347            .send(
1348                connection_id,
1349                proto::CallCanceled {
1350                    room_id: room_id.to_proto(),
1351                },
1352            )
1353            .trace_err();
1354    }
1355    response.send(proto::Ack {})?;
1356
1357    update_user_contacts(called_user_id, &session).await?;
1358    Ok(())
1359}
1360
1361async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1362    let room_id = RoomId::from_proto(message.room_id);
1363    {
1364        let room = session
1365            .db()
1366            .await
1367            .decline_call(Some(room_id), session.user_id)
1368            .await?
1369            .ok_or_else(|| anyhow!("failed to decline call"))?;
1370        room_updated(&room, &session.peer);
1371    }
1372
1373    for connection_id in session
1374        .connection_pool()
1375        .await
1376        .user_connection_ids(session.user_id)
1377    {
1378        session
1379            .peer
1380            .send(
1381                connection_id,
1382                proto::CallCanceled {
1383                    room_id: room_id.to_proto(),
1384                },
1385            )
1386            .trace_err();
1387    }
1388    update_user_contacts(session.user_id, &session).await?;
1389    Ok(())
1390}
1391
1392async fn update_participant_location(
1393    request: proto::UpdateParticipantLocation,
1394    response: Response<proto::UpdateParticipantLocation>,
1395    session: Session,
1396) -> Result<()> {
1397    let room_id = RoomId::from_proto(request.room_id);
1398    let location = request
1399        .location
1400        .ok_or_else(|| anyhow!("invalid location"))?;
1401
1402    let db = session.db().await;
1403    let room = db
1404        .update_room_participant_location(room_id, session.connection_id, location)
1405        .await?;
1406
1407    room_updated(&room, &session.peer);
1408    response.send(proto::Ack {})?;
1409    Ok(())
1410}
1411
1412async fn share_project(
1413    request: proto::ShareProject,
1414    response: Response<proto::ShareProject>,
1415    session: Session,
1416) -> Result<()> {
1417    let (project_id, room) = &*session
1418        .db()
1419        .await
1420        .share_project(
1421            RoomId::from_proto(request.room_id),
1422            session.connection_id,
1423            &request.worktrees,
1424        )
1425        .await?;
1426    response.send(proto::ShareProjectResponse {
1427        project_id: project_id.to_proto(),
1428    })?;
1429    room_updated(&room, &session.peer);
1430
1431    Ok(())
1432}
1433
1434async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1435    let project_id = ProjectId::from_proto(message.project_id);
1436
1437    let (room, guest_connection_ids) = &*session
1438        .db()
1439        .await
1440        .unshare_project(project_id, session.connection_id)
1441        .await?;
1442
1443    broadcast(
1444        Some(session.connection_id),
1445        guest_connection_ids.iter().copied(),
1446        |conn_id| session.peer.send(conn_id, message.clone()),
1447    );
1448    room_updated(&room, &session.peer);
1449
1450    Ok(())
1451}
1452
1453async fn join_project(
1454    request: proto::JoinProject,
1455    response: Response<proto::JoinProject>,
1456    session: Session,
1457) -> Result<()> {
1458    let project_id = ProjectId::from_proto(request.project_id);
1459    let guest_user_id = session.user_id;
1460
1461    tracing::info!(%project_id, "join project");
1462
1463    let (project, replica_id) = &mut *session
1464        .db()
1465        .await
1466        .join_project(project_id, session.connection_id)
1467        .await?;
1468
1469    let collaborators = project
1470        .collaborators
1471        .iter()
1472        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1473        .map(|collaborator| collaborator.to_proto())
1474        .collect::<Vec<_>>();
1475
1476    let worktrees = project
1477        .worktrees
1478        .iter()
1479        .map(|(id, worktree)| proto::WorktreeMetadata {
1480            id: *id,
1481            root_name: worktree.root_name.clone(),
1482            visible: worktree.visible,
1483            abs_path: worktree.abs_path.clone(),
1484        })
1485        .collect::<Vec<_>>();
1486
1487    for collaborator in &collaborators {
1488        session
1489            .peer
1490            .send(
1491                collaborator.peer_id.unwrap().into(),
1492                proto::AddProjectCollaborator {
1493                    project_id: project_id.to_proto(),
1494                    collaborator: Some(proto::Collaborator {
1495                        peer_id: Some(session.connection_id.into()),
1496                        replica_id: replica_id.0 as u32,
1497                        user_id: guest_user_id.to_proto(),
1498                    }),
1499                },
1500            )
1501            .trace_err();
1502    }
1503
1504    // First, we send the metadata associated with each worktree.
1505    response.send(proto::JoinProjectResponse {
1506        worktrees: worktrees.clone(),
1507        replica_id: replica_id.0 as u32,
1508        collaborators: collaborators.clone(),
1509        language_servers: project.language_servers.clone(),
1510    })?;
1511
1512    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1513        #[cfg(any(test, feature = "test-support"))]
1514        const MAX_CHUNK_SIZE: usize = 2;
1515        #[cfg(not(any(test, feature = "test-support")))]
1516        const MAX_CHUNK_SIZE: usize = 256;
1517
1518        // Stream this worktree's entries.
1519        let message = proto::UpdateWorktree {
1520            project_id: project_id.to_proto(),
1521            worktree_id,
1522            abs_path: worktree.abs_path.clone(),
1523            root_name: worktree.root_name,
1524            updated_entries: worktree.entries,
1525            removed_entries: Default::default(),
1526            scan_id: worktree.scan_id,
1527            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1528            updated_repositories: worktree.repository_entries.into_values().collect(),
1529            removed_repositories: Default::default(),
1530        };
1531        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1532            session.peer.send(session.connection_id, update.clone())?;
1533        }
1534
1535        // Stream this worktree's diagnostics.
1536        for summary in worktree.diagnostic_summaries {
1537            session.peer.send(
1538                session.connection_id,
1539                proto::UpdateDiagnosticSummary {
1540                    project_id: project_id.to_proto(),
1541                    worktree_id: worktree.id,
1542                    summary: Some(summary),
1543                },
1544            )?;
1545        }
1546
1547        for settings_file in worktree.settings_files {
1548            session.peer.send(
1549                session.connection_id,
1550                proto::UpdateWorktreeSettings {
1551                    project_id: project_id.to_proto(),
1552                    worktree_id: worktree.id,
1553                    path: settings_file.path,
1554                    content: Some(settings_file.content),
1555                },
1556            )?;
1557        }
1558    }
1559
1560    for language_server in &project.language_servers {
1561        session.peer.send(
1562            session.connection_id,
1563            proto::UpdateLanguageServer {
1564                project_id: project_id.to_proto(),
1565                language_server_id: language_server.id,
1566                variant: Some(
1567                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1568                        proto::LspDiskBasedDiagnosticsUpdated {},
1569                    ),
1570                ),
1571            },
1572        )?;
1573    }
1574
1575    Ok(())
1576}
1577
1578async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1579    let sender_id = session.connection_id;
1580    let project_id = ProjectId::from_proto(request.project_id);
1581
1582    let (room, project) = &*session
1583        .db()
1584        .await
1585        .leave_project(project_id, sender_id)
1586        .await?;
1587    tracing::info!(
1588        %project_id,
1589        host_user_id = %project.host_user_id,
1590        host_connection_id = %project.host_connection_id,
1591        "leave project"
1592    );
1593
1594    project_left(&project, &session);
1595    room_updated(&room, &session.peer);
1596
1597    Ok(())
1598}
1599
1600async fn update_project(
1601    request: proto::UpdateProject,
1602    response: Response<proto::UpdateProject>,
1603    session: Session,
1604) -> Result<()> {
1605    let project_id = ProjectId::from_proto(request.project_id);
1606    let (room, guest_connection_ids) = &*session
1607        .db()
1608        .await
1609        .update_project(project_id, session.connection_id, &request.worktrees)
1610        .await?;
1611    broadcast(
1612        Some(session.connection_id),
1613        guest_connection_ids.iter().copied(),
1614        |connection_id| {
1615            session
1616                .peer
1617                .forward_send(session.connection_id, connection_id, request.clone())
1618        },
1619    );
1620    room_updated(&room, &session.peer);
1621    response.send(proto::Ack {})?;
1622
1623    Ok(())
1624}
1625
1626async fn update_worktree(
1627    request: proto::UpdateWorktree,
1628    response: Response<proto::UpdateWorktree>,
1629    session: Session,
1630) -> Result<()> {
1631    let guest_connection_ids = session
1632        .db()
1633        .await
1634        .update_worktree(&request, session.connection_id)
1635        .await?;
1636
1637    broadcast(
1638        Some(session.connection_id),
1639        guest_connection_ids.iter().copied(),
1640        |connection_id| {
1641            session
1642                .peer
1643                .forward_send(session.connection_id, connection_id, request.clone())
1644        },
1645    );
1646    response.send(proto::Ack {})?;
1647    Ok(())
1648}
1649
1650async fn update_diagnostic_summary(
1651    message: proto::UpdateDiagnosticSummary,
1652    session: Session,
1653) -> Result<()> {
1654    let guest_connection_ids = session
1655        .db()
1656        .await
1657        .update_diagnostic_summary(&message, session.connection_id)
1658        .await?;
1659
1660    broadcast(
1661        Some(session.connection_id),
1662        guest_connection_ids.iter().copied(),
1663        |connection_id| {
1664            session
1665                .peer
1666                .forward_send(session.connection_id, connection_id, message.clone())
1667        },
1668    );
1669
1670    Ok(())
1671}
1672
1673async fn update_worktree_settings(
1674    message: proto::UpdateWorktreeSettings,
1675    session: Session,
1676) -> Result<()> {
1677    let guest_connection_ids = session
1678        .db()
1679        .await
1680        .update_worktree_settings(&message, session.connection_id)
1681        .await?;
1682
1683    broadcast(
1684        Some(session.connection_id),
1685        guest_connection_ids.iter().copied(),
1686        |connection_id| {
1687            session
1688                .peer
1689                .forward_send(session.connection_id, connection_id, message.clone())
1690        },
1691    );
1692
1693    Ok(())
1694}
1695
1696async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1697    broadcast_project_message(request.project_id, request, session).await
1698}
1699
1700async fn start_language_server(
1701    request: proto::StartLanguageServer,
1702    session: Session,
1703) -> Result<()> {
1704    let guest_connection_ids = session
1705        .db()
1706        .await
1707        .start_language_server(&request, session.connection_id)
1708        .await?;
1709
1710    broadcast(
1711        Some(session.connection_id),
1712        guest_connection_ids.iter().copied(),
1713        |connection_id| {
1714            session
1715                .peer
1716                .forward_send(session.connection_id, connection_id, request.clone())
1717        },
1718    );
1719    Ok(())
1720}
1721
1722async fn update_language_server(
1723    request: proto::UpdateLanguageServer,
1724    session: Session,
1725) -> Result<()> {
1726    session.executor.record_backtrace();
1727    let project_id = ProjectId::from_proto(request.project_id);
1728    let project_connection_ids = session
1729        .db()
1730        .await
1731        .project_connection_ids(project_id, session.connection_id)
1732        .await?;
1733    broadcast(
1734        Some(session.connection_id),
1735        project_connection_ids.iter().copied(),
1736        |connection_id| {
1737            session
1738                .peer
1739                .forward_send(session.connection_id, connection_id, request.clone())
1740        },
1741    );
1742    Ok(())
1743}
1744
1745async fn forward_project_request<T>(
1746    request: T,
1747    response: Response<T>,
1748    session: Session,
1749) -> Result<()>
1750where
1751    T: EntityMessage + RequestMessage,
1752{
1753    session.executor.record_backtrace();
1754    let project_id = ProjectId::from_proto(request.remote_entity_id());
1755    let host_connection_id = {
1756        let collaborators = session
1757            .db()
1758            .await
1759            .project_collaborators(project_id, session.connection_id)
1760            .await?;
1761        collaborators
1762            .iter()
1763            .find(|collaborator| collaborator.is_host)
1764            .ok_or_else(|| anyhow!("host not found"))?
1765            .connection_id
1766    };
1767
1768    let payload = session
1769        .peer
1770        .forward_request(session.connection_id, host_connection_id, request)
1771        .await?;
1772
1773    response.send(payload)?;
1774    Ok(())
1775}
1776
1777async fn create_buffer_for_peer(
1778    request: proto::CreateBufferForPeer,
1779    session: Session,
1780) -> Result<()> {
1781    session.executor.record_backtrace();
1782    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1783    session
1784        .peer
1785        .forward_send(session.connection_id, peer_id.into(), request)?;
1786    Ok(())
1787}
1788
1789async fn update_buffer(
1790    request: proto::UpdateBuffer,
1791    response: Response<proto::UpdateBuffer>,
1792    session: Session,
1793) -> Result<()> {
1794    session.executor.record_backtrace();
1795    let project_id = ProjectId::from_proto(request.project_id);
1796    let mut guest_connection_ids;
1797    let mut host_connection_id = None;
1798    {
1799        let collaborators = session
1800            .db()
1801            .await
1802            .project_collaborators(project_id, session.connection_id)
1803            .await?;
1804        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1805        for collaborator in collaborators.iter() {
1806            if collaborator.is_host {
1807                host_connection_id = Some(collaborator.connection_id);
1808            } else {
1809                guest_connection_ids.push(collaborator.connection_id);
1810            }
1811        }
1812    }
1813    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1814
1815    session.executor.record_backtrace();
1816    broadcast(
1817        Some(session.connection_id),
1818        guest_connection_ids,
1819        |connection_id| {
1820            session
1821                .peer
1822                .forward_send(session.connection_id, connection_id, request.clone())
1823        },
1824    );
1825    if host_connection_id != session.connection_id {
1826        session
1827            .peer
1828            .forward_request(session.connection_id, host_connection_id, request.clone())
1829            .await?;
1830    }
1831
1832    response.send(proto::Ack {})?;
1833    Ok(())
1834}
1835
1836async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1837    let project_id = ProjectId::from_proto(request.project_id);
1838    let project_connection_ids = session
1839        .db()
1840        .await
1841        .project_connection_ids(project_id, session.connection_id)
1842        .await?;
1843
1844    broadcast(
1845        Some(session.connection_id),
1846        project_connection_ids.iter().copied(),
1847        |connection_id| {
1848            session
1849                .peer
1850                .forward_send(session.connection_id, connection_id, request.clone())
1851        },
1852    );
1853    Ok(())
1854}
1855
1856async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1857    let project_id = ProjectId::from_proto(request.project_id);
1858    let project_connection_ids = session
1859        .db()
1860        .await
1861        .project_connection_ids(project_id, session.connection_id)
1862        .await?;
1863    broadcast(
1864        Some(session.connection_id),
1865        project_connection_ids.iter().copied(),
1866        |connection_id| {
1867            session
1868                .peer
1869                .forward_send(session.connection_id, connection_id, request.clone())
1870        },
1871    );
1872    Ok(())
1873}
1874
1875async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1876    broadcast_project_message(request.project_id, request, session).await
1877}
1878
1879async fn broadcast_project_message<T: EnvelopedMessage>(
1880    project_id: u64,
1881    request: T,
1882    session: Session,
1883) -> Result<()> {
1884    let project_id = ProjectId::from_proto(project_id);
1885    let project_connection_ids = session
1886        .db()
1887        .await
1888        .project_connection_ids(project_id, session.connection_id)
1889        .await?;
1890    broadcast(
1891        Some(session.connection_id),
1892        project_connection_ids.iter().copied(),
1893        |connection_id| {
1894            session
1895                .peer
1896                .forward_send(session.connection_id, connection_id, request.clone())
1897        },
1898    );
1899    Ok(())
1900}
1901
1902async fn follow(
1903    request: proto::Follow,
1904    response: Response<proto::Follow>,
1905    session: Session,
1906) -> Result<()> {
1907    let room_id = RoomId::from_proto(request.room_id);
1908    let project_id = request.project_id.map(ProjectId::from_proto);
1909    let leader_id = request
1910        .leader_id
1911        .ok_or_else(|| anyhow!("invalid leader id"))?
1912        .into();
1913    let follower_id = session.connection_id;
1914
1915    session
1916        .db()
1917        .await
1918        .check_room_participants(room_id, leader_id, session.connection_id)
1919        .await?;
1920
1921    let response_payload = session
1922        .peer
1923        .forward_request(session.connection_id, leader_id, request)
1924        .await?;
1925    response.send(response_payload)?;
1926
1927    if let Some(project_id) = project_id {
1928        let room = session
1929            .db()
1930            .await
1931            .follow(room_id, project_id, leader_id, follower_id)
1932            .await?;
1933        room_updated(&room, &session.peer);
1934    }
1935
1936    Ok(())
1937}
1938
1939async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1940    let room_id = RoomId::from_proto(request.room_id);
1941    let project_id = request.project_id.map(ProjectId::from_proto);
1942    let leader_id = request
1943        .leader_id
1944        .ok_or_else(|| anyhow!("invalid leader id"))?
1945        .into();
1946    let follower_id = session.connection_id;
1947
1948    session
1949        .db()
1950        .await
1951        .check_room_participants(room_id, leader_id, session.connection_id)
1952        .await?;
1953
1954    session
1955        .peer
1956        .forward_send(session.connection_id, leader_id, request)?;
1957
1958    if let Some(project_id) = project_id {
1959        let room = session
1960            .db()
1961            .await
1962            .unfollow(room_id, project_id, leader_id, follower_id)
1963            .await?;
1964        room_updated(&room, &session.peer);
1965    }
1966
1967    Ok(())
1968}
1969
1970async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1971    let room_id = RoomId::from_proto(request.room_id);
1972    let database = session.db.lock().await;
1973
1974    let connection_ids = if let Some(project_id) = request.project_id {
1975        let project_id = ProjectId::from_proto(project_id);
1976        database
1977            .project_connection_ids(project_id, session.connection_id)
1978            .await?
1979    } else {
1980        database
1981            .room_connection_ids(room_id, session.connection_id)
1982            .await?
1983    };
1984
1985    // For now, don't send view update messages back to that view's current leader.
1986    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1987        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1988        _ => None,
1989    });
1990
1991    for follower_peer_id in request.follower_ids.iter().copied() {
1992        let follower_connection_id = follower_peer_id.into();
1993        if Some(follower_peer_id) != connection_id_to_omit
1994            && connection_ids.contains(&follower_connection_id)
1995        {
1996            session.peer.forward_send(
1997                session.connection_id,
1998                follower_connection_id,
1999                request.clone(),
2000            )?;
2001        }
2002    }
2003    Ok(())
2004}
2005
2006async fn get_users(
2007    request: proto::GetUsers,
2008    response: Response<proto::GetUsers>,
2009    session: Session,
2010) -> Result<()> {
2011    let user_ids = request
2012        .user_ids
2013        .into_iter()
2014        .map(UserId::from_proto)
2015        .collect();
2016    let users = session
2017        .db()
2018        .await
2019        .get_users_by_ids(user_ids)
2020        .await?
2021        .into_iter()
2022        .map(|user| proto::User {
2023            id: user.id.to_proto(),
2024            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2025            github_login: user.github_login,
2026        })
2027        .collect();
2028    response.send(proto::UsersResponse { users })?;
2029    Ok(())
2030}
2031
2032async fn fuzzy_search_users(
2033    request: proto::FuzzySearchUsers,
2034    response: Response<proto::FuzzySearchUsers>,
2035    session: Session,
2036) -> Result<()> {
2037    let query = request.query;
2038    let users = match query.len() {
2039        0 => vec![],
2040        1 | 2 => session
2041            .db()
2042            .await
2043            .get_user_by_github_login(&query)
2044            .await?
2045            .into_iter()
2046            .collect(),
2047        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2048    };
2049    let users = users
2050        .into_iter()
2051        .filter(|user| user.id != session.user_id)
2052        .map(|user| proto::User {
2053            id: user.id.to_proto(),
2054            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2055            github_login: user.github_login,
2056        })
2057        .collect();
2058    response.send(proto::UsersResponse { users })?;
2059    Ok(())
2060}
2061
2062async fn request_contact(
2063    request: proto::RequestContact,
2064    response: Response<proto::RequestContact>,
2065    session: Session,
2066) -> Result<()> {
2067    let requester_id = session.user_id;
2068    let responder_id = UserId::from_proto(request.responder_id);
2069    if requester_id == responder_id {
2070        return Err(anyhow!("cannot add yourself as a contact"))?;
2071    }
2072
2073    let notifications = session
2074        .db()
2075        .await
2076        .send_contact_request(requester_id, responder_id)
2077        .await?;
2078
2079    // Update outgoing contact requests of requester
2080    let mut update = proto::UpdateContacts::default();
2081    update.outgoing_requests.push(responder_id.to_proto());
2082    for connection_id in session
2083        .connection_pool()
2084        .await
2085        .user_connection_ids(requester_id)
2086    {
2087        session.peer.send(connection_id, update.clone())?;
2088    }
2089
2090    // Update incoming contact requests of responder
2091    let mut update = proto::UpdateContacts::default();
2092    update
2093        .incoming_requests
2094        .push(proto::IncomingContactRequest {
2095            requester_id: requester_id.to_proto(),
2096        });
2097    let connection_pool = session.connection_pool().await;
2098    for connection_id in connection_pool.user_connection_ids(responder_id) {
2099        session.peer.send(connection_id, update.clone())?;
2100    }
2101
2102    send_notifications(&*connection_pool, &session.peer, notifications);
2103
2104    response.send(proto::Ack {})?;
2105    Ok(())
2106}
2107
2108async fn respond_to_contact_request(
2109    request: proto::RespondToContactRequest,
2110    response: Response<proto::RespondToContactRequest>,
2111    session: Session,
2112) -> Result<()> {
2113    let responder_id = session.user_id;
2114    let requester_id = UserId::from_proto(request.requester_id);
2115    let db = session.db().await;
2116    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2117        db.dismiss_contact_notification(responder_id, requester_id)
2118            .await?;
2119    } else {
2120        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2121
2122        let notifications = db
2123            .respond_to_contact_request(responder_id, requester_id, accept)
2124            .await?;
2125        let requester_busy = db.is_user_busy(requester_id).await?;
2126        let responder_busy = db.is_user_busy(responder_id).await?;
2127
2128        let pool = session.connection_pool().await;
2129        // Update responder with new contact
2130        let mut update = proto::UpdateContacts::default();
2131        if accept {
2132            update
2133                .contacts
2134                .push(contact_for_user(requester_id, requester_busy, &pool));
2135        }
2136        update
2137            .remove_incoming_requests
2138            .push(requester_id.to_proto());
2139        for connection_id in pool.user_connection_ids(responder_id) {
2140            session.peer.send(connection_id, update.clone())?;
2141        }
2142
2143        // Update requester with new contact
2144        let mut update = proto::UpdateContacts::default();
2145        if accept {
2146            update
2147                .contacts
2148                .push(contact_for_user(responder_id, responder_busy, &pool));
2149        }
2150        update
2151            .remove_outgoing_requests
2152            .push(responder_id.to_proto());
2153
2154        for connection_id in pool.user_connection_ids(requester_id) {
2155            session.peer.send(connection_id, update.clone())?;
2156        }
2157
2158        send_notifications(&*pool, &session.peer, notifications);
2159    }
2160
2161    response.send(proto::Ack {})?;
2162    Ok(())
2163}
2164
2165async fn remove_contact(
2166    request: proto::RemoveContact,
2167    response: Response<proto::RemoveContact>,
2168    session: Session,
2169) -> Result<()> {
2170    let requester_id = session.user_id;
2171    let responder_id = UserId::from_proto(request.user_id);
2172    let db = session.db().await;
2173    let (contact_accepted, deleted_notification_id) =
2174        db.remove_contact(requester_id, responder_id).await?;
2175
2176    let pool = session.connection_pool().await;
2177    // Update outgoing contact requests of requester
2178    let mut update = proto::UpdateContacts::default();
2179    if contact_accepted {
2180        update.remove_contacts.push(responder_id.to_proto());
2181    } else {
2182        update
2183            .remove_outgoing_requests
2184            .push(responder_id.to_proto());
2185    }
2186    for connection_id in pool.user_connection_ids(requester_id) {
2187        session.peer.send(connection_id, update.clone())?;
2188    }
2189
2190    // Update incoming contact requests of responder
2191    let mut update = proto::UpdateContacts::default();
2192    if contact_accepted {
2193        update.remove_contacts.push(requester_id.to_proto());
2194    } else {
2195        update
2196            .remove_incoming_requests
2197            .push(requester_id.to_proto());
2198    }
2199    for connection_id in pool.user_connection_ids(responder_id) {
2200        session.peer.send(connection_id, update.clone())?;
2201        if let Some(notification_id) = deleted_notification_id {
2202            session.peer.send(
2203                connection_id,
2204                proto::DeleteNotification {
2205                    notification_id: notification_id.to_proto(),
2206                },
2207            )?;
2208        }
2209    }
2210
2211    response.send(proto::Ack {})?;
2212    Ok(())
2213}
2214
2215async fn create_channel(
2216    request: proto::CreateChannel,
2217    response: Response<proto::CreateChannel>,
2218    session: Session,
2219) -> Result<()> {
2220    let db = session.db().await;
2221
2222    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2223    let CreateChannelResult {
2224        channel,
2225        participants_to_update,
2226    } = db
2227        .create_channel(&request.name, parent_id, session.user_id)
2228        .await?;
2229
2230    response.send(proto::CreateChannelResponse {
2231        channel: Some(channel.to_proto()),
2232        parent_id: request.parent_id,
2233    })?;
2234
2235    let connection_pool = session.connection_pool().await;
2236    for (user_id, channels) in participants_to_update {
2237        let update = build_channels_update(channels, vec![]);
2238        for connection_id in connection_pool.user_connection_ids(user_id) {
2239            if user_id == session.user_id {
2240                continue;
2241            }
2242            session.peer.send(connection_id, update.clone())?;
2243        }
2244    }
2245
2246    Ok(())
2247}
2248
2249async fn delete_channel(
2250    request: proto::DeleteChannel,
2251    response: Response<proto::DeleteChannel>,
2252    session: Session,
2253) -> Result<()> {
2254    let db = session.db().await;
2255
2256    let channel_id = request.channel_id;
2257    let (removed_channels, member_ids) = db
2258        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2259        .await?;
2260    response.send(proto::Ack {})?;
2261
2262    // Notify members of removed channels
2263    let mut update = proto::UpdateChannels::default();
2264    update
2265        .delete_channels
2266        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2267
2268    let connection_pool = session.connection_pool().await;
2269    for member_id in member_ids {
2270        for connection_id in connection_pool.user_connection_ids(member_id) {
2271            session.peer.send(connection_id, update.clone())?;
2272        }
2273    }
2274
2275    Ok(())
2276}
2277
2278async fn invite_channel_member(
2279    request: proto::InviteChannelMember,
2280    response: Response<proto::InviteChannelMember>,
2281    session: Session,
2282) -> Result<()> {
2283    let db = session.db().await;
2284    let channel_id = ChannelId::from_proto(request.channel_id);
2285    let invitee_id = UserId::from_proto(request.user_id);
2286    let InviteMemberResult {
2287        channel,
2288        notifications,
2289    } = db
2290        .invite_channel_member(
2291            channel_id,
2292            invitee_id,
2293            session.user_id,
2294            request.role().into(),
2295        )
2296        .await?;
2297
2298    let update = proto::UpdateChannels {
2299        channel_invitations: vec![channel.to_proto()],
2300        ..Default::default()
2301    };
2302
2303    let connection_pool = session.connection_pool().await;
2304    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2305        session.peer.send(connection_id, update.clone())?;
2306    }
2307
2308    send_notifications(&*connection_pool, &session.peer, notifications);
2309
2310    response.send(proto::Ack {})?;
2311    Ok(())
2312}
2313
2314async fn remove_channel_member(
2315    request: proto::RemoveChannelMember,
2316    response: Response<proto::RemoveChannelMember>,
2317    session: Session,
2318) -> Result<()> {
2319    let db = session.db().await;
2320    let channel_id = ChannelId::from_proto(request.channel_id);
2321    let member_id = UserId::from_proto(request.user_id);
2322
2323    let RemoveChannelMemberResult {
2324        membership_update,
2325        notification_id,
2326    } = db
2327        .remove_channel_member(channel_id, member_id, session.user_id)
2328        .await?;
2329
2330    let connection_pool = &session.connection_pool().await;
2331    notify_membership_updated(
2332        &connection_pool,
2333        membership_update,
2334        member_id,
2335        &session.peer,
2336    );
2337    for connection_id in connection_pool.user_connection_ids(member_id) {
2338        if let Some(notification_id) = notification_id {
2339            session
2340                .peer
2341                .send(
2342                    connection_id,
2343                    proto::DeleteNotification {
2344                        notification_id: notification_id.to_proto(),
2345                    },
2346                )
2347                .trace_err();
2348        }
2349    }
2350
2351    response.send(proto::Ack {})?;
2352    Ok(())
2353}
2354
2355async fn set_channel_visibility(
2356    request: proto::SetChannelVisibility,
2357    response: Response<proto::SetChannelVisibility>,
2358    session: Session,
2359) -> Result<()> {
2360    let db = session.db().await;
2361    let channel_id = ChannelId::from_proto(request.channel_id);
2362    let visibility = request.visibility().into();
2363
2364    let SetChannelVisibilityResult {
2365        participants_to_update,
2366        participants_to_remove,
2367        channels_to_remove,
2368    } = db
2369        .set_channel_visibility(channel_id, visibility, session.user_id)
2370        .await?;
2371
2372    let connection_pool = session.connection_pool().await;
2373    for (user_id, channels) in participants_to_update {
2374        let update = build_channels_update(channels, vec![]);
2375        for connection_id in connection_pool.user_connection_ids(user_id) {
2376            session.peer.send(connection_id, update.clone())?;
2377        }
2378    }
2379    for user_id in participants_to_remove {
2380        let update = proto::UpdateChannels {
2381            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2382            ..Default::default()
2383        };
2384        for connection_id in connection_pool.user_connection_ids(user_id) {
2385            session.peer.send(connection_id, update.clone())?;
2386        }
2387    }
2388
2389    response.send(proto::Ack {})?;
2390    Ok(())
2391}
2392
2393async fn set_channel_member_role(
2394    request: proto::SetChannelMemberRole,
2395    response: Response<proto::SetChannelMemberRole>,
2396    session: Session,
2397) -> Result<()> {
2398    let db = session.db().await;
2399    let channel_id = ChannelId::from_proto(request.channel_id);
2400    let member_id = UserId::from_proto(request.user_id);
2401    let result = db
2402        .set_channel_member_role(
2403            channel_id,
2404            session.user_id,
2405            member_id,
2406            request.role().into(),
2407        )
2408        .await?;
2409
2410    match result {
2411        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2412            let connection_pool = session.connection_pool().await;
2413            notify_membership_updated(
2414                &connection_pool,
2415                membership_update,
2416                member_id,
2417                &session.peer,
2418            )
2419        }
2420        db::SetMemberRoleResult::InviteUpdated(channel) => {
2421            let update = proto::UpdateChannels {
2422                channel_invitations: vec![channel.to_proto()],
2423                ..Default::default()
2424            };
2425
2426            for connection_id in session
2427                .connection_pool()
2428                .await
2429                .user_connection_ids(member_id)
2430            {
2431                session.peer.send(connection_id, update.clone())?;
2432            }
2433        }
2434    }
2435
2436    response.send(proto::Ack {})?;
2437    Ok(())
2438}
2439
2440async fn rename_channel(
2441    request: proto::RenameChannel,
2442    response: Response<proto::RenameChannel>,
2443    session: Session,
2444) -> Result<()> {
2445    let db = session.db().await;
2446    let channel_id = ChannelId::from_proto(request.channel_id);
2447    let RenameChannelResult {
2448        channel,
2449        participants_to_update,
2450    } = db
2451        .rename_channel(channel_id, session.user_id, &request.name)
2452        .await?;
2453
2454    response.send(proto::RenameChannelResponse {
2455        channel: Some(channel.to_proto()),
2456    })?;
2457
2458    let connection_pool = session.connection_pool().await;
2459    for (user_id, channel) in participants_to_update {
2460        for connection_id in connection_pool.user_connection_ids(user_id) {
2461            let update = proto::UpdateChannels {
2462                channels: vec![channel.to_proto()],
2463                ..Default::default()
2464            };
2465
2466            session.peer.send(connection_id, update.clone())?;
2467        }
2468    }
2469
2470    Ok(())
2471}
2472
2473async fn move_channel(
2474    request: proto::MoveChannel,
2475    response: Response<proto::MoveChannel>,
2476    session: Session,
2477) -> Result<()> {
2478    let channel_id = ChannelId::from_proto(request.channel_id);
2479    let to = request.to.map(ChannelId::from_proto);
2480
2481    let result = session
2482        .db()
2483        .await
2484        .move_channel(channel_id, to, session.user_id)
2485        .await?;
2486
2487    notify_channel_moved(result, session).await?;
2488
2489    response.send(Ack {})?;
2490    Ok(())
2491}
2492
2493async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2494    let Some(MoveChannelResult {
2495        participants_to_remove,
2496        participants_to_update,
2497        moved_channels,
2498    }) = result
2499    else {
2500        return Ok(());
2501    };
2502    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2503
2504    let connection_pool = session.connection_pool().await;
2505    for (user_id, channels) in participants_to_update {
2506        let mut update = build_channels_update(channels, vec![]);
2507        update.delete_channels = moved_channels.clone();
2508        for connection_id in connection_pool.user_connection_ids(user_id) {
2509            session.peer.send(connection_id, update.clone())?;
2510        }
2511    }
2512
2513    for user_id in participants_to_remove {
2514        let update = proto::UpdateChannels {
2515            delete_channels: moved_channels.clone(),
2516            ..Default::default()
2517        };
2518        for connection_id in connection_pool.user_connection_ids(user_id) {
2519            session.peer.send(connection_id, update.clone())?;
2520        }
2521    }
2522    Ok(())
2523}
2524
2525async fn get_channel_members(
2526    request: proto::GetChannelMembers,
2527    response: Response<proto::GetChannelMembers>,
2528    session: Session,
2529) -> Result<()> {
2530    let db = session.db().await;
2531    let channel_id = ChannelId::from_proto(request.channel_id);
2532    let members = db
2533        .get_channel_participant_details(channel_id, session.user_id)
2534        .await?;
2535    response.send(proto::GetChannelMembersResponse { members })?;
2536    Ok(())
2537}
2538
2539async fn respond_to_channel_invite(
2540    request: proto::RespondToChannelInvite,
2541    response: Response<proto::RespondToChannelInvite>,
2542    session: Session,
2543) -> Result<()> {
2544    let db = session.db().await;
2545    let channel_id = ChannelId::from_proto(request.channel_id);
2546    let RespondToChannelInvite {
2547        membership_update,
2548        notifications,
2549    } = db
2550        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2551        .await?;
2552
2553    let connection_pool = session.connection_pool().await;
2554    if let Some(membership_update) = membership_update {
2555        notify_membership_updated(
2556            &connection_pool,
2557            membership_update,
2558            session.user_id,
2559            &session.peer,
2560        );
2561    } else {
2562        let update = proto::UpdateChannels {
2563            remove_channel_invitations: vec![channel_id.to_proto()],
2564            ..Default::default()
2565        };
2566
2567        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2568            session.peer.send(connection_id, update.clone())?;
2569        }
2570    };
2571
2572    send_notifications(&*connection_pool, &session.peer, notifications);
2573
2574    response.send(proto::Ack {})?;
2575
2576    Ok(())
2577}
2578
2579async fn join_channel(
2580    request: proto::JoinChannel,
2581    response: Response<proto::JoinChannel>,
2582    session: Session,
2583) -> Result<()> {
2584    let channel_id = ChannelId::from_proto(request.channel_id);
2585    join_channel_internal(channel_id, Box::new(response), session).await
2586}
2587
2588trait JoinChannelInternalResponse {
2589    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2590}
2591impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2592    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2593        Response::<proto::JoinChannel>::send(self, result)
2594    }
2595}
2596impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2597    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2598        Response::<proto::JoinRoom>::send(self, result)
2599    }
2600}
2601
2602async fn join_channel_internal(
2603    channel_id: ChannelId,
2604    response: Box<impl JoinChannelInternalResponse>,
2605    session: Session,
2606) -> Result<()> {
2607    let joined_room = {
2608        leave_room_for_session(&session).await?;
2609        let db = session.db().await;
2610
2611        let (joined_room, membership_updated, role) = db
2612            .join_channel(
2613                channel_id,
2614                session.user_id,
2615                session.connection_id,
2616                RELEASE_CHANNEL_NAME.as_str(),
2617            )
2618            .await?;
2619
2620        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2621            let (can_publish, token) = if role == ChannelRole::Guest {
2622                (
2623                    false,
2624                    live_kit
2625                        .guest_token(
2626                            &joined_room.room.live_kit_room,
2627                            &session.user_id.to_string(),
2628                        )
2629                        .trace_err()?,
2630                )
2631            } else {
2632                (
2633                    true,
2634                    live_kit
2635                        .room_token(
2636                            &joined_room.room.live_kit_room,
2637                            &session.user_id.to_string(),
2638                        )
2639                        .trace_err()?,
2640                )
2641            };
2642
2643            Some(LiveKitConnectionInfo {
2644                server_url: live_kit.url().into(),
2645                token,
2646                can_publish,
2647            })
2648        });
2649
2650        response.send(proto::JoinRoomResponse {
2651            room: Some(joined_room.room.clone()),
2652            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2653            live_kit_connection_info,
2654        })?;
2655
2656        let connection_pool = session.connection_pool().await;
2657        if let Some(membership_updated) = membership_updated {
2658            notify_membership_updated(
2659                &connection_pool,
2660                membership_updated,
2661                session.user_id,
2662                &session.peer,
2663            );
2664        }
2665
2666        room_updated(&joined_room.room, &session.peer);
2667
2668        joined_room
2669    };
2670
2671    channel_updated(
2672        channel_id,
2673        &joined_room.room,
2674        &joined_room.channel_members,
2675        &session.peer,
2676        &*session.connection_pool().await,
2677    );
2678
2679    update_user_contacts(session.user_id, &session).await?;
2680    Ok(())
2681}
2682
2683async fn join_channel_buffer(
2684    request: proto::JoinChannelBuffer,
2685    response: Response<proto::JoinChannelBuffer>,
2686    session: Session,
2687) -> Result<()> {
2688    let db = session.db().await;
2689    let channel_id = ChannelId::from_proto(request.channel_id);
2690
2691    let open_response = db
2692        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2693        .await?;
2694
2695    let collaborators = open_response.collaborators.clone();
2696    response.send(open_response)?;
2697
2698    let update = UpdateChannelBufferCollaborators {
2699        channel_id: channel_id.to_proto(),
2700        collaborators: collaborators.clone(),
2701    };
2702    channel_buffer_updated(
2703        session.connection_id,
2704        collaborators
2705            .iter()
2706            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2707        &update,
2708        &session.peer,
2709    );
2710
2711    Ok(())
2712}
2713
2714async fn update_channel_buffer(
2715    request: proto::UpdateChannelBuffer,
2716    session: Session,
2717) -> Result<()> {
2718    let db = session.db().await;
2719    let channel_id = ChannelId::from_proto(request.channel_id);
2720
2721    let (collaborators, non_collaborators, epoch, version) = db
2722        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2723        .await?;
2724
2725    channel_buffer_updated(
2726        session.connection_id,
2727        collaborators,
2728        &proto::UpdateChannelBuffer {
2729            channel_id: channel_id.to_proto(),
2730            operations: request.operations,
2731        },
2732        &session.peer,
2733    );
2734
2735    let pool = &*session.connection_pool().await;
2736
2737    broadcast(
2738        None,
2739        non_collaborators
2740            .iter()
2741            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2742        |peer_id| {
2743            session.peer.send(
2744                peer_id.into(),
2745                proto::UpdateChannels {
2746                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2747                        channel_id: channel_id.to_proto(),
2748                        epoch: epoch as u64,
2749                        version: version.clone(),
2750                    }],
2751                    ..Default::default()
2752                },
2753            )
2754        },
2755    );
2756
2757    Ok(())
2758}
2759
2760async fn rejoin_channel_buffers(
2761    request: proto::RejoinChannelBuffers,
2762    response: Response<proto::RejoinChannelBuffers>,
2763    session: Session,
2764) -> Result<()> {
2765    let db = session.db().await;
2766    let buffers = db
2767        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2768        .await?;
2769
2770    for rejoined_buffer in &buffers {
2771        let collaborators_to_notify = rejoined_buffer
2772            .buffer
2773            .collaborators
2774            .iter()
2775            .filter_map(|c| Some(c.peer_id?.into()));
2776        channel_buffer_updated(
2777            session.connection_id,
2778            collaborators_to_notify,
2779            &proto::UpdateChannelBufferCollaborators {
2780                channel_id: rejoined_buffer.buffer.channel_id,
2781                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2782            },
2783            &session.peer,
2784        );
2785    }
2786
2787    response.send(proto::RejoinChannelBuffersResponse {
2788        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2789    })?;
2790
2791    Ok(())
2792}
2793
2794async fn leave_channel_buffer(
2795    request: proto::LeaveChannelBuffer,
2796    response: Response<proto::LeaveChannelBuffer>,
2797    session: Session,
2798) -> Result<()> {
2799    let db = session.db().await;
2800    let channel_id = ChannelId::from_proto(request.channel_id);
2801
2802    let left_buffer = db
2803        .leave_channel_buffer(channel_id, session.connection_id)
2804        .await?;
2805
2806    response.send(Ack {})?;
2807
2808    channel_buffer_updated(
2809        session.connection_id,
2810        left_buffer.connections,
2811        &proto::UpdateChannelBufferCollaborators {
2812            channel_id: channel_id.to_proto(),
2813            collaborators: left_buffer.collaborators,
2814        },
2815        &session.peer,
2816    );
2817
2818    Ok(())
2819}
2820
2821fn channel_buffer_updated<T: EnvelopedMessage>(
2822    sender_id: ConnectionId,
2823    collaborators: impl IntoIterator<Item = ConnectionId>,
2824    message: &T,
2825    peer: &Peer,
2826) {
2827    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2828        peer.send(peer_id.into(), message.clone())
2829    });
2830}
2831
2832fn send_notifications(
2833    connection_pool: &ConnectionPool,
2834    peer: &Peer,
2835    notifications: db::NotificationBatch,
2836) {
2837    for (user_id, notification) in notifications {
2838        for connection_id in connection_pool.user_connection_ids(user_id) {
2839            if let Err(error) = peer.send(
2840                connection_id,
2841                proto::AddNotification {
2842                    notification: Some(notification.clone()),
2843                },
2844            ) {
2845                tracing::error!(
2846                    "failed to send notification to {:?} {}",
2847                    connection_id,
2848                    error
2849                );
2850            }
2851        }
2852    }
2853}
2854
2855async fn send_channel_message(
2856    request: proto::SendChannelMessage,
2857    response: Response<proto::SendChannelMessage>,
2858    session: Session,
2859) -> Result<()> {
2860    // Validate the message body.
2861    let body = request.body.trim().to_string();
2862    if body.len() > MAX_MESSAGE_LEN {
2863        return Err(anyhow!("message is too long"))?;
2864    }
2865    if body.is_empty() {
2866        return Err(anyhow!("message can't be blank"))?;
2867    }
2868
2869    // TODO: adjust mentions if body is trimmed
2870
2871    let timestamp = OffsetDateTime::now_utc();
2872    let nonce = request
2873        .nonce
2874        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2875
2876    let channel_id = ChannelId::from_proto(request.channel_id);
2877    let CreatedChannelMessage {
2878        message_id,
2879        participant_connection_ids,
2880        channel_members,
2881        notifications,
2882    } = session
2883        .db()
2884        .await
2885        .create_channel_message(
2886            channel_id,
2887            session.user_id,
2888            &body,
2889            &request.mentions,
2890            timestamp,
2891            nonce.clone().into(),
2892        )
2893        .await?;
2894    let message = proto::ChannelMessage {
2895        sender_id: session.user_id.to_proto(),
2896        id: message_id.to_proto(),
2897        body,
2898        mentions: request.mentions,
2899        timestamp: timestamp.unix_timestamp() as u64,
2900        nonce: Some(nonce),
2901    };
2902    broadcast(
2903        Some(session.connection_id),
2904        participant_connection_ids,
2905        |connection| {
2906            session.peer.send(
2907                connection,
2908                proto::ChannelMessageSent {
2909                    channel_id: channel_id.to_proto(),
2910                    message: Some(message.clone()),
2911                },
2912            )
2913        },
2914    );
2915    response.send(proto::SendChannelMessageResponse {
2916        message: Some(message),
2917    })?;
2918
2919    let pool = &*session.connection_pool().await;
2920    broadcast(
2921        None,
2922        channel_members
2923            .iter()
2924            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2925        |peer_id| {
2926            session.peer.send(
2927                peer_id.into(),
2928                proto::UpdateChannels {
2929                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2930                        channel_id: channel_id.to_proto(),
2931                        message_id: message_id.to_proto(),
2932                    }],
2933                    ..Default::default()
2934                },
2935            )
2936        },
2937    );
2938    send_notifications(pool, &session.peer, notifications);
2939
2940    Ok(())
2941}
2942
2943async fn remove_channel_message(
2944    request: proto::RemoveChannelMessage,
2945    response: Response<proto::RemoveChannelMessage>,
2946    session: Session,
2947) -> Result<()> {
2948    let channel_id = ChannelId::from_proto(request.channel_id);
2949    let message_id = MessageId::from_proto(request.message_id);
2950    let connection_ids = session
2951        .db()
2952        .await
2953        .remove_channel_message(channel_id, message_id, session.user_id)
2954        .await?;
2955    broadcast(Some(session.connection_id), connection_ids, |connection| {
2956        session.peer.send(connection, request.clone())
2957    });
2958    response.send(proto::Ack {})?;
2959    Ok(())
2960}
2961
2962async fn acknowledge_channel_message(
2963    request: proto::AckChannelMessage,
2964    session: Session,
2965) -> Result<()> {
2966    let channel_id = ChannelId::from_proto(request.channel_id);
2967    let message_id = MessageId::from_proto(request.message_id);
2968    let notifications = session
2969        .db()
2970        .await
2971        .observe_channel_message(channel_id, session.user_id, message_id)
2972        .await?;
2973    send_notifications(
2974        &*session.connection_pool().await,
2975        &session.peer,
2976        notifications,
2977    );
2978    Ok(())
2979}
2980
2981async fn acknowledge_buffer_version(
2982    request: proto::AckBufferOperation,
2983    session: Session,
2984) -> Result<()> {
2985    let buffer_id = BufferId::from_proto(request.buffer_id);
2986    session
2987        .db()
2988        .await
2989        .observe_buffer_version(
2990            buffer_id,
2991            session.user_id,
2992            request.epoch as i32,
2993            &request.version,
2994        )
2995        .await?;
2996    Ok(())
2997}
2998
2999async fn join_channel_chat(
3000    request: proto::JoinChannelChat,
3001    response: Response<proto::JoinChannelChat>,
3002    session: Session,
3003) -> Result<()> {
3004    let channel_id = ChannelId::from_proto(request.channel_id);
3005
3006    let db = session.db().await;
3007    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3008        .await?;
3009    let messages = db
3010        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3011        .await?;
3012    response.send(proto::JoinChannelChatResponse {
3013        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3014        messages,
3015    })?;
3016    Ok(())
3017}
3018
3019async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3020    let channel_id = ChannelId::from_proto(request.channel_id);
3021    session
3022        .db()
3023        .await
3024        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3025        .await?;
3026    Ok(())
3027}
3028
3029async fn get_channel_messages(
3030    request: proto::GetChannelMessages,
3031    response: Response<proto::GetChannelMessages>,
3032    session: Session,
3033) -> Result<()> {
3034    let channel_id = ChannelId::from_proto(request.channel_id);
3035    let messages = session
3036        .db()
3037        .await
3038        .get_channel_messages(
3039            channel_id,
3040            session.user_id,
3041            MESSAGE_COUNT_PER_PAGE,
3042            Some(MessageId::from_proto(request.before_message_id)),
3043        )
3044        .await?;
3045    response.send(proto::GetChannelMessagesResponse {
3046        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3047        messages,
3048    })?;
3049    Ok(())
3050}
3051
3052async fn get_channel_messages_by_id(
3053    request: proto::GetChannelMessagesById,
3054    response: Response<proto::GetChannelMessagesById>,
3055    session: Session,
3056) -> Result<()> {
3057    let message_ids = request
3058        .message_ids
3059        .iter()
3060        .map(|id| MessageId::from_proto(*id))
3061        .collect::<Vec<_>>();
3062    let messages = session
3063        .db()
3064        .await
3065        .get_channel_messages_by_id(session.user_id, &message_ids)
3066        .await?;
3067    response.send(proto::GetChannelMessagesResponse {
3068        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3069        messages,
3070    })?;
3071    Ok(())
3072}
3073
3074async fn get_notifications(
3075    request: proto::GetNotifications,
3076    response: Response<proto::GetNotifications>,
3077    session: Session,
3078) -> Result<()> {
3079    let notifications = session
3080        .db()
3081        .await
3082        .get_notifications(
3083            session.user_id,
3084            NOTIFICATION_COUNT_PER_PAGE,
3085            request
3086                .before_id
3087                .map(|id| db::NotificationId::from_proto(id)),
3088        )
3089        .await?;
3090    response.send(proto::GetNotificationsResponse {
3091        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3092        notifications,
3093    })?;
3094    Ok(())
3095}
3096
3097async fn mark_notification_as_read(
3098    request: proto::MarkNotificationRead,
3099    response: Response<proto::MarkNotificationRead>,
3100    session: Session,
3101) -> Result<()> {
3102    let database = &session.db().await;
3103    let notifications = database
3104        .mark_notification_as_read_by_id(
3105            session.user_id,
3106            NotificationId::from_proto(request.notification_id),
3107        )
3108        .await?;
3109    send_notifications(
3110        &*session.connection_pool().await,
3111        &session.peer,
3112        notifications,
3113    );
3114    response.send(proto::Ack {})?;
3115    Ok(())
3116}
3117
3118async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3119    let project_id = ProjectId::from_proto(request.project_id);
3120    let project_connection_ids = session
3121        .db()
3122        .await
3123        .project_connection_ids(project_id, session.connection_id)
3124        .await?;
3125    broadcast(
3126        Some(session.connection_id),
3127        project_connection_ids.iter().copied(),
3128        |connection_id| {
3129            session
3130                .peer
3131                .forward_send(session.connection_id, connection_id, request.clone())
3132        },
3133    );
3134    Ok(())
3135}
3136
3137async fn get_private_user_info(
3138    _request: proto::GetPrivateUserInfo,
3139    response: Response<proto::GetPrivateUserInfo>,
3140    session: Session,
3141) -> Result<()> {
3142    let db = session.db().await;
3143
3144    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3145    let user = db
3146        .get_user_by_id(session.user_id)
3147        .await?
3148        .ok_or_else(|| anyhow!("user not found"))?;
3149    let flags = db.get_user_flags(session.user_id).await?;
3150
3151    response.send(proto::GetPrivateUserInfoResponse {
3152        metrics_id,
3153        staff: user.admin,
3154        flags,
3155    })?;
3156    Ok(())
3157}
3158
3159fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3160    match message {
3161        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3162        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3163        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3164        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3165        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3166            code: frame.code.into(),
3167            reason: frame.reason,
3168        })),
3169    }
3170}
3171
3172fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3173    match message {
3174        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3175        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3176        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3177        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3178        AxumMessage::Close(frame) => {
3179            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3180                code: frame.code.into(),
3181                reason: frame.reason,
3182            }))
3183        }
3184    }
3185}
3186
3187fn notify_membership_updated(
3188    connection_pool: &ConnectionPool,
3189    result: MembershipUpdated,
3190    user_id: UserId,
3191    peer: &Peer,
3192) {
3193    let mut update = build_channels_update(result.new_channels, vec![]);
3194    update.delete_channels = result
3195        .removed_channels
3196        .into_iter()
3197        .map(|id| id.to_proto())
3198        .collect();
3199    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3200
3201    for connection_id in connection_pool.user_connection_ids(user_id) {
3202        peer.send(connection_id, update.clone()).trace_err();
3203    }
3204}
3205
3206fn build_channels_update(
3207    channels: ChannelsForUser,
3208    channel_invites: Vec<db::Channel>,
3209) -> proto::UpdateChannels {
3210    let mut update = proto::UpdateChannels::default();
3211
3212    for channel in channels.channels {
3213        update.channels.push(channel.to_proto());
3214    }
3215
3216    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3217    update.unseen_channel_messages = channels.channel_messages;
3218
3219    for (channel_id, participants) in channels.channel_participants {
3220        update
3221            .channel_participants
3222            .push(proto::ChannelParticipants {
3223                channel_id: channel_id.to_proto(),
3224                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3225            });
3226    }
3227
3228    for channel in channel_invites {
3229        update.channel_invitations.push(channel.to_proto());
3230    }
3231
3232    update
3233}
3234
3235fn build_initial_contacts_update(
3236    contacts: Vec<db::Contact>,
3237    pool: &ConnectionPool,
3238) -> proto::UpdateContacts {
3239    let mut update = proto::UpdateContacts::default();
3240
3241    for contact in contacts {
3242        match contact {
3243            db::Contact::Accepted { user_id, busy } => {
3244                update.contacts.push(contact_for_user(user_id, busy, &pool));
3245            }
3246            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3247            db::Contact::Incoming { user_id } => {
3248                update
3249                    .incoming_requests
3250                    .push(proto::IncomingContactRequest {
3251                        requester_id: user_id.to_proto(),
3252                    })
3253            }
3254        }
3255    }
3256
3257    update
3258}
3259
3260fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3261    proto::Contact {
3262        user_id: user_id.to_proto(),
3263        online: pool.is_user_online(user_id),
3264        busy,
3265    }
3266}
3267
3268fn room_updated(room: &proto::Room, peer: &Peer) {
3269    broadcast(
3270        None,
3271        room.participants
3272            .iter()
3273            .filter_map(|participant| Some(participant.peer_id?.into())),
3274        |peer_id| {
3275            peer.send(
3276                peer_id.into(),
3277                proto::RoomUpdated {
3278                    room: Some(room.clone()),
3279                },
3280            )
3281        },
3282    );
3283}
3284
3285fn channel_updated(
3286    channel_id: ChannelId,
3287    room: &proto::Room,
3288    channel_members: &[UserId],
3289    peer: &Peer,
3290    pool: &ConnectionPool,
3291) {
3292    let participants = room
3293        .participants
3294        .iter()
3295        .map(|p| p.user_id)
3296        .collect::<Vec<_>>();
3297
3298    broadcast(
3299        None,
3300        channel_members
3301            .iter()
3302            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3303        |peer_id| {
3304            peer.send(
3305                peer_id.into(),
3306                proto::UpdateChannels {
3307                    channel_participants: vec![proto::ChannelParticipants {
3308                        channel_id: channel_id.to_proto(),
3309                        participant_user_ids: participants.clone(),
3310                    }],
3311                    ..Default::default()
3312                },
3313            )
3314        },
3315    );
3316}
3317
3318async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3319    let db = session.db().await;
3320
3321    let contacts = db.get_contacts(user_id).await?;
3322    let busy = db.is_user_busy(user_id).await?;
3323
3324    let pool = session.connection_pool().await;
3325    let updated_contact = contact_for_user(user_id, busy, &pool);
3326    for contact in contacts {
3327        if let db::Contact::Accepted {
3328            user_id: contact_user_id,
3329            ..
3330        } = contact
3331        {
3332            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3333                session
3334                    .peer
3335                    .send(
3336                        contact_conn_id,
3337                        proto::UpdateContacts {
3338                            contacts: vec![updated_contact.clone()],
3339                            remove_contacts: Default::default(),
3340                            incoming_requests: Default::default(),
3341                            remove_incoming_requests: Default::default(),
3342                            outgoing_requests: Default::default(),
3343                            remove_outgoing_requests: Default::default(),
3344                        },
3345                    )
3346                    .trace_err();
3347            }
3348        }
3349    }
3350    Ok(())
3351}
3352
3353async fn leave_room_for_session(session: &Session) -> Result<()> {
3354    let mut contacts_to_update = HashSet::default();
3355
3356    let room_id;
3357    let canceled_calls_to_user_ids;
3358    let live_kit_room;
3359    let delete_live_kit_room;
3360    let room;
3361    let channel_members;
3362    let channel_id;
3363
3364    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3365        contacts_to_update.insert(session.user_id);
3366
3367        for project in left_room.left_projects.values() {
3368            project_left(project, session);
3369        }
3370
3371        room_id = RoomId::from_proto(left_room.room.id);
3372        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3373        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3374        delete_live_kit_room = left_room.deleted;
3375        room = mem::take(&mut left_room.room);
3376        channel_members = mem::take(&mut left_room.channel_members);
3377        channel_id = left_room.channel_id;
3378
3379        room_updated(&room, &session.peer);
3380    } else {
3381        return Ok(());
3382    }
3383
3384    if let Some(channel_id) = channel_id {
3385        channel_updated(
3386            channel_id,
3387            &room,
3388            &channel_members,
3389            &session.peer,
3390            &*session.connection_pool().await,
3391        );
3392    }
3393
3394    {
3395        let pool = session.connection_pool().await;
3396        for canceled_user_id in canceled_calls_to_user_ids {
3397            for connection_id in pool.user_connection_ids(canceled_user_id) {
3398                session
3399                    .peer
3400                    .send(
3401                        connection_id,
3402                        proto::CallCanceled {
3403                            room_id: room_id.to_proto(),
3404                        },
3405                    )
3406                    .trace_err();
3407            }
3408            contacts_to_update.insert(canceled_user_id);
3409        }
3410    }
3411
3412    for contact_user_id in contacts_to_update {
3413        update_user_contacts(contact_user_id, &session).await?;
3414    }
3415
3416    if let Some(live_kit) = session.live_kit_client.as_ref() {
3417        live_kit
3418            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3419            .await
3420            .trace_err();
3421
3422        if delete_live_kit_room {
3423            live_kit.delete_room(live_kit_room).await.trace_err();
3424        }
3425    }
3426
3427    Ok(())
3428}
3429
3430async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3431    let left_channel_buffers = session
3432        .db()
3433        .await
3434        .leave_channel_buffers(session.connection_id)
3435        .await?;
3436
3437    for left_buffer in left_channel_buffers {
3438        channel_buffer_updated(
3439            session.connection_id,
3440            left_buffer.connections,
3441            &proto::UpdateChannelBufferCollaborators {
3442                channel_id: left_buffer.channel_id.to_proto(),
3443                collaborators: left_buffer.collaborators,
3444            },
3445            &session.peer,
3446        );
3447    }
3448
3449    Ok(())
3450}
3451
3452fn project_left(project: &db::LeftProject, session: &Session) {
3453    for connection_id in &project.connection_ids {
3454        if project.host_user_id == session.user_id {
3455            session
3456                .peer
3457                .send(
3458                    *connection_id,
3459                    proto::UnshareProject {
3460                        project_id: project.id.to_proto(),
3461                    },
3462                )
3463                .trace_err();
3464        } else {
3465            session
3466                .peer
3467                .send(
3468                    *connection_id,
3469                    proto::RemoveProjectCollaborator {
3470                        project_id: project.id.to_proto(),
3471                        peer_id: Some(session.connection_id.into()),
3472                    },
3473                )
3474                .trace_err();
3475        }
3476    }
3477}
3478
3479pub trait ResultExt {
3480    type Ok;
3481
3482    fn trace_err(self) -> Option<Self::Ok>;
3483}
3484
3485impl<T, E> ResultExt for Result<T, E>
3486where
3487    E: std::fmt::Debug,
3488{
3489    type Ok = T;
3490
3491    fn trace_err(self) -> Option<T> {
3492        match self {
3493            Ok(value) => Some(value),
3494            Err(error) => {
3495                tracing::error!("{:?}", error);
3496                None
3497            }
3498        }
3499    }
3500}