rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69use util::channel::RELEASE_CHANNEL_NAME;
  70
  71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  73
  74const MESSAGE_COUNT_PER_PAGE: usize = 100;
  75const MAX_MESSAGE_LEN: usize = 1024;
  76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  77
  78lazy_static! {
  79    static ref METRIC_CONNECTIONS: IntGauge =
  80        register_int_gauge!("connections", "number of connections").unwrap();
  81    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  82        "shared_projects",
  83        "number of open projects with one or more guests"
  84    )
  85    .unwrap();
  86}
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone)]
 106struct Session {
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(call)
 206            .add_request_handler(cancel_call)
 207            .add_message_handler(decline_call)
 208            .add_request_handler(update_participant_location)
 209            .add_request_handler(share_project)
 210            .add_message_handler(unshare_project)
 211            .add_request_handler(join_project)
 212            .add_message_handler(leave_project)
 213            .add_request_handler(update_project)
 214            .add_request_handler(update_worktree)
 215            .add_message_handler(start_language_server)
 216            .add_message_handler(update_language_server)
 217            .add_message_handler(update_diagnostic_summary)
 218            .add_message_handler(update_worktree_settings)
 219            .add_message_handler(refresh_inlay_hints)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 230            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 231            .add_request_handler(forward_mutating_project_request::<proto::OpenBufferByPath>)
 232            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 233            .add_request_handler(
 234                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 235            )
 236            .add_request_handler(
 237                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 238            )
 239            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 240            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 243            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 251            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 252            .add_message_handler(create_buffer_for_peer)
 253            .add_request_handler(update_buffer)
 254            .add_message_handler(update_buffer_file)
 255            .add_message_handler(buffer_reloaded)
 256            .add_message_handler(buffer_saved)
 257            .add_request_handler(get_users)
 258            .add_request_handler(fuzzy_search_users)
 259            .add_request_handler(request_contact)
 260            .add_request_handler(remove_contact)
 261            .add_request_handler(respond_to_contact_request)
 262            .add_request_handler(create_channel)
 263            .add_request_handler(delete_channel)
 264            .add_request_handler(invite_channel_member)
 265            .add_request_handler(remove_channel_member)
 266            .add_request_handler(set_channel_member_role)
 267            .add_request_handler(set_channel_visibility)
 268            .add_request_handler(rename_channel)
 269            .add_request_handler(join_channel_buffer)
 270            .add_request_handler(leave_channel_buffer)
 271            .add_message_handler(update_channel_buffer)
 272            .add_request_handler(rejoin_channel_buffers)
 273            .add_request_handler(get_channel_members)
 274            .add_request_handler(respond_to_channel_invite)
 275            .add_request_handler(join_channel)
 276            .add_request_handler(join_channel_chat)
 277            .add_message_handler(leave_channel_chat)
 278            .add_request_handler(send_channel_message)
 279            .add_request_handler(remove_channel_message)
 280            .add_request_handler(get_channel_messages)
 281            .add_request_handler(get_channel_messages_by_id)
 282            .add_request_handler(get_notifications)
 283            .add_request_handler(mark_notification_as_read)
 284            .add_request_handler(move_channel)
 285            .add_request_handler(follow)
 286            .add_message_handler(unfollow)
 287            .add_message_handler(update_followers)
 288            .add_message_handler(update_diff_base)
 289            .add_request_handler(get_private_user_info)
 290            .add_message_handler(acknowledge_channel_message)
 291            .add_message_handler(acknowledge_buffer_version);
 292
 293        Arc::new(server)
 294    }
 295
 296    pub async fn start(&self) -> Result<()> {
 297        let server_id = *self.id.lock();
 298        let app_state = self.app_state.clone();
 299        let peer = self.peer.clone();
 300        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 301        let pool = self.connection_pool.clone();
 302        let live_kit_client = self.app_state.live_kit_client.clone();
 303
 304        let span = info_span!("start server");
 305        self.executor.spawn_detached(
 306            async move {
 307                tracing::info!("waiting for cleanup timeout");
 308                timeout.await;
 309                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 310                if let Some((room_ids, channel_ids)) = app_state
 311                    .db
 312                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 313                    .await
 314                    .trace_err()
 315                {
 316                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 317                    tracing::info!(
 318                        stale_channel_buffer_count = channel_ids.len(),
 319                        "retrieved stale channel buffers"
 320                    );
 321
 322                    for channel_id in channel_ids {
 323                        if let Some(refreshed_channel_buffer) = app_state
 324                            .db
 325                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 326                            .await
 327                            .trace_err()
 328                        {
 329                            for connection_id in refreshed_channel_buffer.connection_ids {
 330                                peer.send(
 331                                    connection_id,
 332                                    proto::UpdateChannelBufferCollaborators {
 333                                        channel_id: channel_id.to_proto(),
 334                                        collaborators: refreshed_channel_buffer
 335                                            .collaborators
 336                                            .clone(),
 337                                    },
 338                                )
 339                                .trace_err();
 340                            }
 341                        }
 342                    }
 343
 344                    for room_id in room_ids {
 345                        let mut contacts_to_update = HashSet::default();
 346                        let mut canceled_calls_to_user_ids = Vec::new();
 347                        let mut live_kit_room = String::new();
 348                        let mut delete_live_kit_room = false;
 349
 350                        if let Some(mut refreshed_room) = app_state
 351                            .db
 352                            .clear_stale_room_participants(room_id, server_id)
 353                            .await
 354                            .trace_err()
 355                        {
 356                            tracing::info!(
 357                                room_id = room_id.0,
 358                                new_participant_count = refreshed_room.room.participants.len(),
 359                                "refreshed room"
 360                            );
 361                            room_updated(&refreshed_room.room, &peer);
 362                            if let Some(channel_id) = refreshed_room.channel_id {
 363                                channel_updated(
 364                                    channel_id,
 365                                    &refreshed_room.room,
 366                                    &refreshed_room.channel_members,
 367                                    &peer,
 368                                    &*pool.lock(),
 369                                );
 370                            }
 371                            contacts_to_update
 372                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 373                            contacts_to_update
 374                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 375                            canceled_calls_to_user_ids =
 376                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 377                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 378                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 379                        }
 380
 381                        {
 382                            let pool = pool.lock();
 383                            for canceled_user_id in canceled_calls_to_user_ids {
 384                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 385                                    peer.send(
 386                                        connection_id,
 387                                        proto::CallCanceled {
 388                                            room_id: room_id.to_proto(),
 389                                        },
 390                                    )
 391                                    .trace_err();
 392                                }
 393                            }
 394                        }
 395
 396                        for user_id in contacts_to_update {
 397                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 398                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 399                            if let Some((busy, contacts)) = busy.zip(contacts) {
 400                                let pool = pool.lock();
 401                                let updated_contact = contact_for_user(user_id, busy, &pool);
 402                                for contact in contacts {
 403                                    if let db::Contact::Accepted {
 404                                        user_id: contact_user_id,
 405                                        ..
 406                                    } = contact
 407                                    {
 408                                        for contact_conn_id in
 409                                            pool.user_connection_ids(contact_user_id)
 410                                        {
 411                                            peer.send(
 412                                                contact_conn_id,
 413                                                proto::UpdateContacts {
 414                                                    contacts: vec![updated_contact.clone()],
 415                                                    remove_contacts: Default::default(),
 416                                                    incoming_requests: Default::default(),
 417                                                    remove_incoming_requests: Default::default(),
 418                                                    outgoing_requests: Default::default(),
 419                                                    remove_outgoing_requests: Default::default(),
 420                                                },
 421                                            )
 422                                            .trace_err();
 423                                        }
 424                                    }
 425                                }
 426                            }
 427                        }
 428
 429                        if let Some(live_kit) = live_kit_client.as_ref() {
 430                            if delete_live_kit_room {
 431                                live_kit.delete_room(live_kit_room).await.trace_err();
 432                            }
 433                        }
 434                    }
 435                }
 436
 437                app_state
 438                    .db
 439                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 440                    .await
 441                    .trace_err();
 442            }
 443            .instrument(span),
 444        );
 445        Ok(())
 446    }
 447
 448    pub fn teardown(&self) {
 449        self.peer.teardown();
 450        self.connection_pool.lock().reset();
 451        let _ = self.teardown.send(());
 452    }
 453
 454    #[cfg(test)]
 455    pub fn reset(&self, id: ServerId) {
 456        self.teardown();
 457        *self.id.lock() = id;
 458        self.peer.reset(id.0 as u32);
 459    }
 460
 461    #[cfg(test)]
 462    pub fn id(&self) -> ServerId {
 463        *self.id.lock()
 464    }
 465
 466    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 467    where
 468        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 469        Fut: 'static + Send + Future<Output = Result<()>>,
 470        M: EnvelopedMessage,
 471    {
 472        let prev_handler = self.handlers.insert(
 473            TypeId::of::<M>(),
 474            Box::new(move |envelope, session| {
 475                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 476                let span = info_span!(
 477                    "handle message",
 478                    payload_type = envelope.payload_type_name()
 479                );
 480                span.in_scope(|| {
 481                    tracing::info!(
 482                        payload_type = envelope.payload_type_name(),
 483                        "message received"
 484                    );
 485                });
 486                let start_time = Instant::now();
 487                let future = (handler)(*envelope, session);
 488                async move {
 489                    let result = future.await;
 490                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 491                    match result {
 492                        Err(error) => {
 493                            tracing::error!(%error, ?duration_ms, "error handling message")
 494                        }
 495                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 496                    }
 497                }
 498                .instrument(span)
 499                .boxed()
 500            }),
 501        );
 502        if prev_handler.is_some() {
 503            panic!("registered a handler for the same message twice");
 504        }
 505        self
 506    }
 507
 508    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 509    where
 510        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 511        Fut: 'static + Send + Future<Output = Result<()>>,
 512        M: EnvelopedMessage,
 513    {
 514        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 515        self
 516    }
 517
 518    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 519    where
 520        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 521        Fut: Send + Future<Output = Result<()>>,
 522        M: RequestMessage,
 523    {
 524        let handler = Arc::new(handler);
 525        self.add_handler(move |envelope, session| {
 526            let receipt = envelope.receipt();
 527            let handler = handler.clone();
 528            async move {
 529                let peer = session.peer.clone();
 530                let responded = Arc::new(AtomicBool::default());
 531                let response = Response {
 532                    peer: peer.clone(),
 533                    responded: responded.clone(),
 534                    receipt,
 535                };
 536                match (handler)(envelope.payload, response, session).await {
 537                    Ok(()) => {
 538                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 539                            Ok(())
 540                        } else {
 541                            Err(anyhow!("handler did not send a response"))?
 542                        }
 543                    }
 544                    Err(error) => {
 545                        peer.respond_with_error(
 546                            receipt,
 547                            proto::Error {
 548                                message: error.to_string(),
 549                            },
 550                        )?;
 551                        Err(error)
 552                    }
 553                }
 554            }
 555        })
 556    }
 557
 558    pub fn handle_connection(
 559        self: &Arc<Self>,
 560        connection: Connection,
 561        address: String,
 562        user: User,
 563        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 564        executor: Executor,
 565    ) -> impl Future<Output = Result<()>> {
 566        let this = self.clone();
 567        let user_id = user.id;
 568        let login = user.github_login;
 569        let span = info_span!("handle connection", %user_id, %login, %address);
 570        let mut teardown = self.teardown.subscribe();
 571        async move {
 572            let (connection_id, handle_io, mut incoming_rx) = this
 573                .peer
 574                .add_connection(connection, {
 575                    let executor = executor.clone();
 576                    move |duration| executor.sleep(duration)
 577                });
 578
 579            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 580            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 581            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 582
 583            if let Some(send_connection_id) = send_connection_id.take() {
 584                let _ = send_connection_id.send(connection_id);
 585            }
 586
 587            if !user.connected_once {
 588                this.peer.send(connection_id, proto::ShowContacts {})?;
 589                this.app_state.db.set_user_connected_once(user_id, true).await?;
 590            }
 591
 592            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 593                this.app_state.db.get_contacts(user_id),
 594                this.app_state.db.get_channels_for_user(user_id),
 595                this.app_state.db.get_channel_invites_for_user(user_id),
 596            ).await?;
 597
 598            {
 599                let mut pool = this.connection_pool.lock();
 600                pool.add_connection(connection_id, user_id, user.admin);
 601                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 602                this.peer.send(connection_id, build_channels_update(
 603                    channels_for_user,
 604                    channel_invites
 605                ))?;
 606            }
 607
 608            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 609                this.peer.send(connection_id, incoming_call)?;
 610            }
 611
 612            let session = Session {
 613                user_id,
 614                connection_id,
 615                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 616                peer: this.peer.clone(),
 617                connection_pool: this.connection_pool.clone(),
 618                live_kit_client: this.app_state.live_kit_client.clone(),
 619                _executor: executor.clone()
 620            };
 621            update_user_contacts(user_id, &session).await?;
 622
 623            let handle_io = handle_io.fuse();
 624            futures::pin_mut!(handle_io);
 625
 626            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 627            // This prevents deadlocks when e.g., client A performs a request to client B and
 628            // client B performs a request to client A. If both clients stop processing further
 629            // messages until their respective request completes, they won't have a chance to
 630            // respond to the other client's request and cause a deadlock.
 631            //
 632            // This arrangement ensures we will attempt to process earlier messages first, but fall
 633            // back to processing messages arrived later in the spirit of making progress.
 634            let mut foreground_message_handlers = FuturesUnordered::new();
 635            let concurrent_handlers = Arc::new(Semaphore::new(256));
 636            loop {
 637                let next_message = async {
 638                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 639                    let message = incoming_rx.next().await;
 640                    (permit, message)
 641                }.fuse();
 642                futures::pin_mut!(next_message);
 643                futures::select_biased! {
 644                    _ = teardown.changed().fuse() => return Ok(()),
 645                    result = handle_io => {
 646                        if let Err(error) = result {
 647                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 648                        }
 649                        break;
 650                    }
 651                    _ = foreground_message_handlers.next() => {}
 652                    next_message = next_message => {
 653                        let (permit, message) = next_message;
 654                        if let Some(message) = message {
 655                            let type_name = message.payload_type_name();
 656                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 657                            let span_enter = span.enter();
 658                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 659                                let is_background = message.is_background();
 660                                let handle_message = (handler)(message, session.clone());
 661                                drop(span_enter);
 662
 663                                let handle_message = async move {
 664                                    handle_message.await;
 665                                    drop(permit);
 666                                }.instrument(span);
 667                                if is_background {
 668                                    executor.spawn_detached(handle_message);
 669                                } else {
 670                                    foreground_message_handlers.push(handle_message);
 671                                }
 672                            } else {
 673                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 674                            }
 675                        } else {
 676                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 677                            break;
 678                        }
 679                    }
 680                }
 681            }
 682
 683            drop(foreground_message_handlers);
 684            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 685            if let Err(error) = connection_lost(session, teardown, executor).await {
 686                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 687            }
 688
 689            Ok(())
 690        }.instrument(span)
 691    }
 692
 693    pub async fn invite_code_redeemed(
 694        self: &Arc<Self>,
 695        inviter_id: UserId,
 696        invitee_id: UserId,
 697    ) -> Result<()> {
 698        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 699            if let Some(code) = &user.invite_code {
 700                let pool = self.connection_pool.lock();
 701                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 702                for connection_id in pool.user_connection_ids(inviter_id) {
 703                    self.peer.send(
 704                        connection_id,
 705                        proto::UpdateContacts {
 706                            contacts: vec![invitee_contact.clone()],
 707                            ..Default::default()
 708                        },
 709                    )?;
 710                    self.peer.send(
 711                        connection_id,
 712                        proto::UpdateInviteInfo {
 713                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 714                            count: user.invite_count as u32,
 715                        },
 716                    )?;
 717                }
 718            }
 719        }
 720        Ok(())
 721    }
 722
 723    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 724        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 725            if let Some(invite_code) = &user.invite_code {
 726                let pool = self.connection_pool.lock();
 727                for connection_id in pool.user_connection_ids(user_id) {
 728                    self.peer.send(
 729                        connection_id,
 730                        proto::UpdateInviteInfo {
 731                            url: format!(
 732                                "{}{}",
 733                                self.app_state.config.invite_link_prefix, invite_code
 734                            ),
 735                            count: user.invite_count as u32,
 736                        },
 737                    )?;
 738                }
 739            }
 740        }
 741        Ok(())
 742    }
 743
 744    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 745        ServerSnapshot {
 746            connection_pool: ConnectionPoolGuard {
 747                guard: self.connection_pool.lock(),
 748                _not_send: PhantomData,
 749            },
 750            peer: &self.peer,
 751        }
 752    }
 753}
 754
 755impl<'a> Deref for ConnectionPoolGuard<'a> {
 756    type Target = ConnectionPool;
 757
 758    fn deref(&self) -> &Self::Target {
 759        &*self.guard
 760    }
 761}
 762
 763impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 764    fn deref_mut(&mut self) -> &mut Self::Target {
 765        &mut *self.guard
 766    }
 767}
 768
 769impl<'a> Drop for ConnectionPoolGuard<'a> {
 770    fn drop(&mut self) {
 771        #[cfg(test)]
 772        self.check_invariants();
 773    }
 774}
 775
 776fn broadcast<F>(
 777    sender_id: Option<ConnectionId>,
 778    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 779    mut f: F,
 780) where
 781    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 782{
 783    for receiver_id in receiver_ids {
 784        if Some(receiver_id) != sender_id {
 785            if let Err(error) = f(receiver_id) {
 786                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 787            }
 788        }
 789    }
 790}
 791
 792lazy_static! {
 793    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 794}
 795
 796pub struct ProtocolVersion(u32);
 797
 798impl Header for ProtocolVersion {
 799    fn name() -> &'static HeaderName {
 800        &ZED_PROTOCOL_VERSION
 801    }
 802
 803    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 804    where
 805        Self: Sized,
 806        I: Iterator<Item = &'i axum::http::HeaderValue>,
 807    {
 808        let version = values
 809            .next()
 810            .ok_or_else(axum::headers::Error::invalid)?
 811            .to_str()
 812            .map_err(|_| axum::headers::Error::invalid())?
 813            .parse()
 814            .map_err(|_| axum::headers::Error::invalid())?;
 815        Ok(Self(version))
 816    }
 817
 818    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 819        values.extend([self.0.to_string().parse().unwrap()]);
 820    }
 821}
 822
 823pub fn routes(server: Arc<Server>) -> Router<Body> {
 824    Router::new()
 825        .route("/rpc", get(handle_websocket_request))
 826        .layer(
 827            ServiceBuilder::new()
 828                .layer(Extension(server.app_state.clone()))
 829                .layer(middleware::from_fn(auth::validate_header)),
 830        )
 831        .route("/metrics", get(handle_metrics))
 832        .layer(Extension(server))
 833}
 834
 835pub async fn handle_websocket_request(
 836    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 837    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 838    Extension(server): Extension<Arc<Server>>,
 839    Extension(user): Extension<User>,
 840    ws: WebSocketUpgrade,
 841) -> axum::response::Response {
 842    if protocol_version != rpc::PROTOCOL_VERSION {
 843        return (
 844            StatusCode::UPGRADE_REQUIRED,
 845            "client must be upgraded".to_string(),
 846        )
 847            .into_response();
 848    }
 849    let socket_address = socket_address.to_string();
 850    ws.on_upgrade(move |socket| {
 851        use util::ResultExt;
 852        let socket = socket
 853            .map_ok(to_tungstenite_message)
 854            .err_into()
 855            .with(|message| async move { Ok(to_axum_message(message)) });
 856        let connection = Connection::new(Box::pin(socket));
 857        async move {
 858            server
 859                .handle_connection(connection, socket_address, user, None, Executor::Production)
 860                .await
 861                .log_err();
 862        }
 863    })
 864}
 865
 866pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 867    let connections = server
 868        .connection_pool
 869        .lock()
 870        .connections()
 871        .filter(|connection| !connection.admin)
 872        .count();
 873
 874    METRIC_CONNECTIONS.set(connections as _);
 875
 876    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 877    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 878
 879    let encoder = prometheus::TextEncoder::new();
 880    let metric_families = prometheus::gather();
 881    let encoded_metrics = encoder
 882        .encode_to_string(&metric_families)
 883        .map_err(|err| anyhow!("{}", err))?;
 884    Ok(encoded_metrics)
 885}
 886
 887#[instrument(err, skip(executor))]
 888async fn connection_lost(
 889    session: Session,
 890    mut teardown: watch::Receiver<()>,
 891    executor: Executor,
 892) -> Result<()> {
 893    session.peer.disconnect(session.connection_id);
 894    session
 895        .connection_pool()
 896        .await
 897        .remove_connection(session.connection_id)?;
 898
 899    session
 900        .db()
 901        .await
 902        .connection_lost(session.connection_id)
 903        .await
 904        .trace_err();
 905
 906    futures::select_biased! {
 907        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 908            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 909            leave_room_for_session(&session).await.trace_err();
 910            leave_channel_buffers_for_session(&session)
 911                .await
 912                .trace_err();
 913
 914            if !session
 915                .connection_pool()
 916                .await
 917                .is_user_online(session.user_id)
 918            {
 919                let db = session.db().await;
 920                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 921                    room_updated(&room, &session.peer);
 922                }
 923            }
 924
 925            update_user_contacts(session.user_id, &session).await?;
 926        }
 927        _ = teardown.changed().fuse() => {}
 928    }
 929
 930    Ok(())
 931}
 932
 933async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 934    response.send(proto::Ack {})?;
 935    Ok(())
 936}
 937
 938async fn create_room(
 939    _request: proto::CreateRoom,
 940    response: Response<proto::CreateRoom>,
 941    session: Session,
 942) -> Result<()> {
 943    let live_kit_room = nanoid::nanoid!(30);
 944
 945    let live_kit_connection_info = {
 946        let live_kit_room = live_kit_room.clone();
 947        let live_kit = session.live_kit_client.as_ref();
 948
 949        util::async_maybe!({
 950            let live_kit = live_kit?;
 951
 952            let token = live_kit
 953                .room_token(&live_kit_room, &session.user_id.to_string())
 954                .trace_err()?;
 955
 956            Some(proto::LiveKitConnectionInfo {
 957                server_url: live_kit.url().into(),
 958                token,
 959                can_publish: true,
 960            })
 961        })
 962    }
 963    .await;
 964
 965    let room = session
 966        .db()
 967        .await
 968        .create_room(
 969            session.user_id,
 970            session.connection_id,
 971            &live_kit_room,
 972            RELEASE_CHANNEL_NAME.as_str(),
 973        )
 974        .await?;
 975
 976    response.send(proto::CreateRoomResponse {
 977        room: Some(room.clone()),
 978        live_kit_connection_info,
 979    })?;
 980
 981    update_user_contacts(session.user_id, &session).await?;
 982    Ok(())
 983}
 984
 985async fn join_room(
 986    request: proto::JoinRoom,
 987    response: Response<proto::JoinRoom>,
 988    session: Session,
 989) -> Result<()> {
 990    let room_id = RoomId::from_proto(request.id);
 991
 992    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 993
 994    if let Some(channel_id) = channel_id {
 995        return join_channel_internal(channel_id, Box::new(response), session).await;
 996    }
 997
 998    let joined_room = {
 999        let room = session
1000            .db()
1001            .await
1002            .join_room(
1003                room_id,
1004                session.user_id,
1005                session.connection_id,
1006                RELEASE_CHANNEL_NAME.as_str(),
1007            )
1008            .await?;
1009        room_updated(&room.room, &session.peer);
1010        room.into_inner()
1011    };
1012
1013    for connection_id in session
1014        .connection_pool()
1015        .await
1016        .user_connection_ids(session.user_id)
1017    {
1018        session
1019            .peer
1020            .send(
1021                connection_id,
1022                proto::CallCanceled {
1023                    room_id: room_id.to_proto(),
1024                },
1025            )
1026            .trace_err();
1027    }
1028
1029    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1030        if let Some(token) = live_kit
1031            .room_token(
1032                &joined_room.room.live_kit_room,
1033                &session.user_id.to_string(),
1034            )
1035            .trace_err()
1036        {
1037            Some(proto::LiveKitConnectionInfo {
1038                server_url: live_kit.url().into(),
1039                token,
1040                can_publish: true,
1041            })
1042        } else {
1043            None
1044        }
1045    } else {
1046        None
1047    };
1048
1049    response.send(proto::JoinRoomResponse {
1050        room: Some(joined_room.room),
1051        channel_id: None,
1052        live_kit_connection_info,
1053    })?;
1054
1055    update_user_contacts(session.user_id, &session).await?;
1056    Ok(())
1057}
1058
1059async fn rejoin_room(
1060    request: proto::RejoinRoom,
1061    response: Response<proto::RejoinRoom>,
1062    session: Session,
1063) -> Result<()> {
1064    let room;
1065    let channel_id;
1066    let channel_members;
1067    {
1068        let mut rejoined_room = session
1069            .db()
1070            .await
1071            .rejoin_room(request, session.user_id, session.connection_id)
1072            .await?;
1073
1074        response.send(proto::RejoinRoomResponse {
1075            room: Some(rejoined_room.room.clone()),
1076            reshared_projects: rejoined_room
1077                .reshared_projects
1078                .iter()
1079                .map(|project| proto::ResharedProject {
1080                    id: project.id.to_proto(),
1081                    collaborators: project
1082                        .collaborators
1083                        .iter()
1084                        .map(|collaborator| collaborator.to_proto())
1085                        .collect(),
1086                })
1087                .collect(),
1088            rejoined_projects: rejoined_room
1089                .rejoined_projects
1090                .iter()
1091                .map(|rejoined_project| proto::RejoinedProject {
1092                    id: rejoined_project.id.to_proto(),
1093                    worktrees: rejoined_project
1094                        .worktrees
1095                        .iter()
1096                        .map(|worktree| proto::WorktreeMetadata {
1097                            id: worktree.id,
1098                            root_name: worktree.root_name.clone(),
1099                            visible: worktree.visible,
1100                            abs_path: worktree.abs_path.clone(),
1101                        })
1102                        .collect(),
1103                    collaborators: rejoined_project
1104                        .collaborators
1105                        .iter()
1106                        .map(|collaborator| collaborator.to_proto())
1107                        .collect(),
1108                    language_servers: rejoined_project.language_servers.clone(),
1109                })
1110                .collect(),
1111        })?;
1112        room_updated(&rejoined_room.room, &session.peer);
1113
1114        for project in &rejoined_room.reshared_projects {
1115            for collaborator in &project.collaborators {
1116                session
1117                    .peer
1118                    .send(
1119                        collaborator.connection_id,
1120                        proto::UpdateProjectCollaborator {
1121                            project_id: project.id.to_proto(),
1122                            old_peer_id: Some(project.old_connection_id.into()),
1123                            new_peer_id: Some(session.connection_id.into()),
1124                        },
1125                    )
1126                    .trace_err();
1127            }
1128
1129            broadcast(
1130                Some(session.connection_id),
1131                project
1132                    .collaborators
1133                    .iter()
1134                    .map(|collaborator| collaborator.connection_id),
1135                |connection_id| {
1136                    session.peer.forward_send(
1137                        session.connection_id,
1138                        connection_id,
1139                        proto::UpdateProject {
1140                            project_id: project.id.to_proto(),
1141                            worktrees: project.worktrees.clone(),
1142                        },
1143                    )
1144                },
1145            );
1146        }
1147
1148        for project in &rejoined_room.rejoined_projects {
1149            for collaborator in &project.collaborators {
1150                session
1151                    .peer
1152                    .send(
1153                        collaborator.connection_id,
1154                        proto::UpdateProjectCollaborator {
1155                            project_id: project.id.to_proto(),
1156                            old_peer_id: Some(project.old_connection_id.into()),
1157                            new_peer_id: Some(session.connection_id.into()),
1158                        },
1159                    )
1160                    .trace_err();
1161            }
1162        }
1163
1164        for project in &mut rejoined_room.rejoined_projects {
1165            for worktree in mem::take(&mut project.worktrees) {
1166                #[cfg(any(test, feature = "test-support"))]
1167                const MAX_CHUNK_SIZE: usize = 2;
1168                #[cfg(not(any(test, feature = "test-support")))]
1169                const MAX_CHUNK_SIZE: usize = 256;
1170
1171                // Stream this worktree's entries.
1172                let message = proto::UpdateWorktree {
1173                    project_id: project.id.to_proto(),
1174                    worktree_id: worktree.id,
1175                    abs_path: worktree.abs_path.clone(),
1176                    root_name: worktree.root_name,
1177                    updated_entries: worktree.updated_entries,
1178                    removed_entries: worktree.removed_entries,
1179                    scan_id: worktree.scan_id,
1180                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1181                    updated_repositories: worktree.updated_repositories,
1182                    removed_repositories: worktree.removed_repositories,
1183                };
1184                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1185                    session.peer.send(session.connection_id, update.clone())?;
1186                }
1187
1188                // Stream this worktree's diagnostics.
1189                for summary in worktree.diagnostic_summaries {
1190                    session.peer.send(
1191                        session.connection_id,
1192                        proto::UpdateDiagnosticSummary {
1193                            project_id: project.id.to_proto(),
1194                            worktree_id: worktree.id,
1195                            summary: Some(summary),
1196                        },
1197                    )?;
1198                }
1199
1200                for settings_file in worktree.settings_files {
1201                    session.peer.send(
1202                        session.connection_id,
1203                        proto::UpdateWorktreeSettings {
1204                            project_id: project.id.to_proto(),
1205                            worktree_id: worktree.id,
1206                            path: settings_file.path,
1207                            content: Some(settings_file.content),
1208                        },
1209                    )?;
1210                }
1211            }
1212
1213            for language_server in &project.language_servers {
1214                session.peer.send(
1215                    session.connection_id,
1216                    proto::UpdateLanguageServer {
1217                        project_id: project.id.to_proto(),
1218                        language_server_id: language_server.id,
1219                        variant: Some(
1220                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1221                                proto::LspDiskBasedDiagnosticsUpdated {},
1222                            ),
1223                        ),
1224                    },
1225                )?;
1226            }
1227        }
1228
1229        let rejoined_room = rejoined_room.into_inner();
1230
1231        room = rejoined_room.room;
1232        channel_id = rejoined_room.channel_id;
1233        channel_members = rejoined_room.channel_members;
1234    }
1235
1236    if let Some(channel_id) = channel_id {
1237        channel_updated(
1238            channel_id,
1239            &room,
1240            &channel_members,
1241            &session.peer,
1242            &*session.connection_pool().await,
1243        );
1244    }
1245
1246    update_user_contacts(session.user_id, &session).await?;
1247    Ok(())
1248}
1249
1250async fn leave_room(
1251    _: proto::LeaveRoom,
1252    response: Response<proto::LeaveRoom>,
1253    session: Session,
1254) -> Result<()> {
1255    leave_room_for_session(&session).await?;
1256    response.send(proto::Ack {})?;
1257    Ok(())
1258}
1259
1260async fn call(
1261    request: proto::Call,
1262    response: Response<proto::Call>,
1263    session: Session,
1264) -> Result<()> {
1265    let room_id = RoomId::from_proto(request.room_id);
1266    let calling_user_id = session.user_id;
1267    let calling_connection_id = session.connection_id;
1268    let called_user_id = UserId::from_proto(request.called_user_id);
1269    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1270    if !session
1271        .db()
1272        .await
1273        .has_contact(calling_user_id, called_user_id)
1274        .await?
1275    {
1276        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1277    }
1278
1279    let incoming_call = {
1280        let (room, incoming_call) = &mut *session
1281            .db()
1282            .await
1283            .call(
1284                room_id,
1285                calling_user_id,
1286                calling_connection_id,
1287                called_user_id,
1288                initial_project_id,
1289            )
1290            .await?;
1291        room_updated(&room, &session.peer);
1292        mem::take(incoming_call)
1293    };
1294    update_user_contacts(called_user_id, &session).await?;
1295
1296    let mut calls = session
1297        .connection_pool()
1298        .await
1299        .user_connection_ids(called_user_id)
1300        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1301        .collect::<FuturesUnordered<_>>();
1302
1303    while let Some(call_response) = calls.next().await {
1304        match call_response.as_ref() {
1305            Ok(_) => {
1306                response.send(proto::Ack {})?;
1307                return Ok(());
1308            }
1309            Err(_) => {
1310                call_response.trace_err();
1311            }
1312        }
1313    }
1314
1315    {
1316        let room = session
1317            .db()
1318            .await
1319            .call_failed(room_id, called_user_id)
1320            .await?;
1321        room_updated(&room, &session.peer);
1322    }
1323    update_user_contacts(called_user_id, &session).await?;
1324
1325    Err(anyhow!("failed to ring user"))?
1326}
1327
1328async fn cancel_call(
1329    request: proto::CancelCall,
1330    response: Response<proto::CancelCall>,
1331    session: Session,
1332) -> Result<()> {
1333    let called_user_id = UserId::from_proto(request.called_user_id);
1334    let room_id = RoomId::from_proto(request.room_id);
1335    {
1336        let room = session
1337            .db()
1338            .await
1339            .cancel_call(room_id, session.connection_id, called_user_id)
1340            .await?;
1341        room_updated(&room, &session.peer);
1342    }
1343
1344    for connection_id in session
1345        .connection_pool()
1346        .await
1347        .user_connection_ids(called_user_id)
1348    {
1349        session
1350            .peer
1351            .send(
1352                connection_id,
1353                proto::CallCanceled {
1354                    room_id: room_id.to_proto(),
1355                },
1356            )
1357            .trace_err();
1358    }
1359    response.send(proto::Ack {})?;
1360
1361    update_user_contacts(called_user_id, &session).await?;
1362    Ok(())
1363}
1364
1365async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1366    let room_id = RoomId::from_proto(message.room_id);
1367    {
1368        let room = session
1369            .db()
1370            .await
1371            .decline_call(Some(room_id), session.user_id)
1372            .await?
1373            .ok_or_else(|| anyhow!("failed to decline call"))?;
1374        room_updated(&room, &session.peer);
1375    }
1376
1377    for connection_id in session
1378        .connection_pool()
1379        .await
1380        .user_connection_ids(session.user_id)
1381    {
1382        session
1383            .peer
1384            .send(
1385                connection_id,
1386                proto::CallCanceled {
1387                    room_id: room_id.to_proto(),
1388                },
1389            )
1390            .trace_err();
1391    }
1392    update_user_contacts(session.user_id, &session).await?;
1393    Ok(())
1394}
1395
1396async fn update_participant_location(
1397    request: proto::UpdateParticipantLocation,
1398    response: Response<proto::UpdateParticipantLocation>,
1399    session: Session,
1400) -> Result<()> {
1401    let room_id = RoomId::from_proto(request.room_id);
1402    let location = request
1403        .location
1404        .ok_or_else(|| anyhow!("invalid location"))?;
1405
1406    let db = session.db().await;
1407    let room = db
1408        .update_room_participant_location(room_id, session.connection_id, location)
1409        .await?;
1410
1411    room_updated(&room, &session.peer);
1412    response.send(proto::Ack {})?;
1413    Ok(())
1414}
1415
1416async fn share_project(
1417    request: proto::ShareProject,
1418    response: Response<proto::ShareProject>,
1419    session: Session,
1420) -> Result<()> {
1421    let (project_id, room) = &*session
1422        .db()
1423        .await
1424        .share_project(
1425            RoomId::from_proto(request.room_id),
1426            session.connection_id,
1427            &request.worktrees,
1428        )
1429        .await?;
1430    response.send(proto::ShareProjectResponse {
1431        project_id: project_id.to_proto(),
1432    })?;
1433    room_updated(&room, &session.peer);
1434
1435    Ok(())
1436}
1437
1438async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1439    let project_id = ProjectId::from_proto(message.project_id);
1440
1441    let (room, guest_connection_ids) = &*session
1442        .db()
1443        .await
1444        .unshare_project(project_id, session.connection_id)
1445        .await?;
1446
1447    broadcast(
1448        Some(session.connection_id),
1449        guest_connection_ids.iter().copied(),
1450        |conn_id| session.peer.send(conn_id, message.clone()),
1451    );
1452    room_updated(&room, &session.peer);
1453
1454    Ok(())
1455}
1456
1457async fn join_project(
1458    request: proto::JoinProject,
1459    response: Response<proto::JoinProject>,
1460    session: Session,
1461) -> Result<()> {
1462    let project_id = ProjectId::from_proto(request.project_id);
1463    let guest_user_id = session.user_id;
1464
1465    tracing::info!(%project_id, "join project");
1466
1467    let (project, replica_id) = &mut *session
1468        .db()
1469        .await
1470        .join_project(project_id, session.connection_id)
1471        .await?;
1472
1473    let collaborators = project
1474        .collaborators
1475        .iter()
1476        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1477        .map(|collaborator| collaborator.to_proto())
1478        .collect::<Vec<_>>();
1479
1480    let worktrees = project
1481        .worktrees
1482        .iter()
1483        .map(|(id, worktree)| proto::WorktreeMetadata {
1484            id: *id,
1485            root_name: worktree.root_name.clone(),
1486            visible: worktree.visible,
1487            abs_path: worktree.abs_path.clone(),
1488        })
1489        .collect::<Vec<_>>();
1490
1491    for collaborator in &collaborators {
1492        session
1493            .peer
1494            .send(
1495                collaborator.peer_id.unwrap().into(),
1496                proto::AddProjectCollaborator {
1497                    project_id: project_id.to_proto(),
1498                    collaborator: Some(proto::Collaborator {
1499                        peer_id: Some(session.connection_id.into()),
1500                        replica_id: replica_id.0 as u32,
1501                        user_id: guest_user_id.to_proto(),
1502                    }),
1503                },
1504            )
1505            .trace_err();
1506    }
1507
1508    // First, we send the metadata associated with each worktree.
1509    response.send(proto::JoinProjectResponse {
1510        worktrees: worktrees.clone(),
1511        replica_id: replica_id.0 as u32,
1512        collaborators: collaborators.clone(),
1513        language_servers: project.language_servers.clone(),
1514    })?;
1515
1516    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1517        #[cfg(any(test, feature = "test-support"))]
1518        const MAX_CHUNK_SIZE: usize = 2;
1519        #[cfg(not(any(test, feature = "test-support")))]
1520        const MAX_CHUNK_SIZE: usize = 256;
1521
1522        // Stream this worktree's entries.
1523        let message = proto::UpdateWorktree {
1524            project_id: project_id.to_proto(),
1525            worktree_id,
1526            abs_path: worktree.abs_path.clone(),
1527            root_name: worktree.root_name,
1528            updated_entries: worktree.entries,
1529            removed_entries: Default::default(),
1530            scan_id: worktree.scan_id,
1531            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1532            updated_repositories: worktree.repository_entries.into_values().collect(),
1533            removed_repositories: Default::default(),
1534        };
1535        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1536            session.peer.send(session.connection_id, update.clone())?;
1537        }
1538
1539        // Stream this worktree's diagnostics.
1540        for summary in worktree.diagnostic_summaries {
1541            session.peer.send(
1542                session.connection_id,
1543                proto::UpdateDiagnosticSummary {
1544                    project_id: project_id.to_proto(),
1545                    worktree_id: worktree.id,
1546                    summary: Some(summary),
1547                },
1548            )?;
1549        }
1550
1551        for settings_file in worktree.settings_files {
1552            session.peer.send(
1553                session.connection_id,
1554                proto::UpdateWorktreeSettings {
1555                    project_id: project_id.to_proto(),
1556                    worktree_id: worktree.id,
1557                    path: settings_file.path,
1558                    content: Some(settings_file.content),
1559                },
1560            )?;
1561        }
1562    }
1563
1564    for language_server in &project.language_servers {
1565        session.peer.send(
1566            session.connection_id,
1567            proto::UpdateLanguageServer {
1568                project_id: project_id.to_proto(),
1569                language_server_id: language_server.id,
1570                variant: Some(
1571                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1572                        proto::LspDiskBasedDiagnosticsUpdated {},
1573                    ),
1574                ),
1575            },
1576        )?;
1577    }
1578
1579    Ok(())
1580}
1581
1582async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1583    let sender_id = session.connection_id;
1584    let project_id = ProjectId::from_proto(request.project_id);
1585
1586    let (room, project) = &*session
1587        .db()
1588        .await
1589        .leave_project(project_id, sender_id)
1590        .await?;
1591    tracing::info!(
1592        %project_id,
1593        host_user_id = %project.host_user_id,
1594        host_connection_id = %project.host_connection_id,
1595        "leave project"
1596    );
1597
1598    project_left(&project, &session);
1599    room_updated(&room, &session.peer);
1600
1601    Ok(())
1602}
1603
1604async fn update_project(
1605    request: proto::UpdateProject,
1606    response: Response<proto::UpdateProject>,
1607    session: Session,
1608) -> Result<()> {
1609    let project_id = ProjectId::from_proto(request.project_id);
1610    let (room, guest_connection_ids) = &*session
1611        .db()
1612        .await
1613        .update_project(project_id, session.connection_id, &request.worktrees)
1614        .await?;
1615    broadcast(
1616        Some(session.connection_id),
1617        guest_connection_ids.iter().copied(),
1618        |connection_id| {
1619            session
1620                .peer
1621                .forward_send(session.connection_id, connection_id, request.clone())
1622        },
1623    );
1624    room_updated(&room, &session.peer);
1625    response.send(proto::Ack {})?;
1626
1627    Ok(())
1628}
1629
1630async fn update_worktree(
1631    request: proto::UpdateWorktree,
1632    response: Response<proto::UpdateWorktree>,
1633    session: Session,
1634) -> Result<()> {
1635    let guest_connection_ids = session
1636        .db()
1637        .await
1638        .update_worktree(&request, session.connection_id)
1639        .await?;
1640
1641    broadcast(
1642        Some(session.connection_id),
1643        guest_connection_ids.iter().copied(),
1644        |connection_id| {
1645            session
1646                .peer
1647                .forward_send(session.connection_id, connection_id, request.clone())
1648        },
1649    );
1650    response.send(proto::Ack {})?;
1651    Ok(())
1652}
1653
1654async fn update_diagnostic_summary(
1655    message: proto::UpdateDiagnosticSummary,
1656    session: Session,
1657) -> Result<()> {
1658    let guest_connection_ids = session
1659        .db()
1660        .await
1661        .update_diagnostic_summary(&message, session.connection_id)
1662        .await?;
1663
1664    broadcast(
1665        Some(session.connection_id),
1666        guest_connection_ids.iter().copied(),
1667        |connection_id| {
1668            session
1669                .peer
1670                .forward_send(session.connection_id, connection_id, message.clone())
1671        },
1672    );
1673
1674    Ok(())
1675}
1676
1677async fn update_worktree_settings(
1678    message: proto::UpdateWorktreeSettings,
1679    session: Session,
1680) -> Result<()> {
1681    let guest_connection_ids = session
1682        .db()
1683        .await
1684        .update_worktree_settings(&message, session.connection_id)
1685        .await?;
1686
1687    broadcast(
1688        Some(session.connection_id),
1689        guest_connection_ids.iter().copied(),
1690        |connection_id| {
1691            session
1692                .peer
1693                .forward_send(session.connection_id, connection_id, message.clone())
1694        },
1695    );
1696
1697    Ok(())
1698}
1699
1700async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1701    broadcast_project_message(request.project_id, request, session).await
1702}
1703
1704async fn start_language_server(
1705    request: proto::StartLanguageServer,
1706    session: Session,
1707) -> Result<()> {
1708    let guest_connection_ids = session
1709        .db()
1710        .await
1711        .start_language_server(&request, session.connection_id)
1712        .await?;
1713
1714    broadcast(
1715        Some(session.connection_id),
1716        guest_connection_ids.iter().copied(),
1717        |connection_id| {
1718            session
1719                .peer
1720                .forward_send(session.connection_id, connection_id, request.clone())
1721        },
1722    );
1723    Ok(())
1724}
1725
1726async fn update_language_server(
1727    request: proto::UpdateLanguageServer,
1728    session: Session,
1729) -> Result<()> {
1730    let project_id = ProjectId::from_proto(request.project_id);
1731    let project_connection_ids = session
1732        .db()
1733        .await
1734        .project_connection_ids(project_id, session.connection_id)
1735        .await?;
1736    broadcast(
1737        Some(session.connection_id),
1738        project_connection_ids.iter().copied(),
1739        |connection_id| {
1740            session
1741                .peer
1742                .forward_send(session.connection_id, connection_id, request.clone())
1743        },
1744    );
1745    Ok(())
1746}
1747
1748async fn forward_read_only_project_request<T>(
1749    request: T,
1750    response: Response<T>,
1751    session: Session,
1752) -> Result<()>
1753where
1754    T: EntityMessage + RequestMessage,
1755{
1756    let project_id = ProjectId::from_proto(request.remote_entity_id());
1757    let host_connection_id = {
1758        let collaborators = session
1759            .db()
1760            .await
1761            .project_collaborators(project_id, session.connection_id)
1762            .await?;
1763        collaborators
1764            .iter()
1765            .find(|collaborator| collaborator.is_host)
1766            .ok_or_else(|| anyhow!("host not found"))?
1767            .connection_id
1768    };
1769
1770    let payload = session
1771        .peer
1772        .forward_request(session.connection_id, host_connection_id, request)
1773        .await?;
1774
1775    response.send(payload)?;
1776    Ok(())
1777}
1778
1779async fn forward_mutating_project_request<T>(
1780    request: T,
1781    response: Response<T>,
1782    session: Session,
1783) -> Result<()>
1784where
1785    T: EntityMessage + RequestMessage,
1786{
1787    let project_id = ProjectId::from_proto(request.remote_entity_id());
1788    let host_connection_id = session
1789        .db()
1790        .await
1791        .host_for_mutating_project_request(project_id, session.connection_id)
1792        .await?;
1793
1794    let payload = session
1795        .peer
1796        .forward_request(session.connection_id, host_connection_id, request)
1797        .await?;
1798
1799    response.send(payload)?;
1800    Ok(())
1801}
1802
1803async fn create_buffer_for_peer(
1804    request: proto::CreateBufferForPeer,
1805    session: Session,
1806) -> Result<()> {
1807    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1808    session
1809        .peer
1810        .forward_send(session.connection_id, peer_id.into(), request)?;
1811    Ok(())
1812}
1813
1814async fn update_buffer(
1815    request: proto::UpdateBuffer,
1816    response: Response<proto::UpdateBuffer>,
1817    session: Session,
1818) -> Result<()> {
1819    let project_id = ProjectId::from_proto(request.project_id);
1820    let mut guest_connection_ids;
1821    let mut host_connection_id = None;
1822    {
1823        let collaborators = session
1824            .db()
1825            .await
1826            .project_collaborators(project_id, session.connection_id)
1827            .await?;
1828        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1829        for collaborator in collaborators.iter() {
1830            if collaborator.is_host {
1831                host_connection_id = Some(collaborator.connection_id);
1832            } else {
1833                guest_connection_ids.push(collaborator.connection_id);
1834            }
1835        }
1836    }
1837    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1838
1839    broadcast(
1840        Some(session.connection_id),
1841        guest_connection_ids,
1842        |connection_id| {
1843            session
1844                .peer
1845                .forward_send(session.connection_id, connection_id, request.clone())
1846        },
1847    );
1848    if host_connection_id != session.connection_id {
1849        session
1850            .peer
1851            .forward_request(session.connection_id, host_connection_id, request.clone())
1852            .await?;
1853    }
1854
1855    response.send(proto::Ack {})?;
1856    Ok(())
1857}
1858
1859async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1860    let project_id = ProjectId::from_proto(request.project_id);
1861    let project_connection_ids = session
1862        .db()
1863        .await
1864        .project_connection_ids(project_id, session.connection_id)
1865        .await?;
1866
1867    broadcast(
1868        Some(session.connection_id),
1869        project_connection_ids.iter().copied(),
1870        |connection_id| {
1871            session
1872                .peer
1873                .forward_send(session.connection_id, connection_id, request.clone())
1874        },
1875    );
1876    Ok(())
1877}
1878
1879async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1880    let project_id = ProjectId::from_proto(request.project_id);
1881    let project_connection_ids = session
1882        .db()
1883        .await
1884        .project_connection_ids(project_id, session.connection_id)
1885        .await?;
1886    broadcast(
1887        Some(session.connection_id),
1888        project_connection_ids.iter().copied(),
1889        |connection_id| {
1890            session
1891                .peer
1892                .forward_send(session.connection_id, connection_id, request.clone())
1893        },
1894    );
1895    Ok(())
1896}
1897
1898async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1899    broadcast_project_message(request.project_id, request, session).await
1900}
1901
1902async fn broadcast_project_message<T: EnvelopedMessage>(
1903    project_id: u64,
1904    request: T,
1905    session: Session,
1906) -> Result<()> {
1907    let project_id = ProjectId::from_proto(project_id);
1908    let project_connection_ids = session
1909        .db()
1910        .await
1911        .project_connection_ids(project_id, session.connection_id)
1912        .await?;
1913    broadcast(
1914        Some(session.connection_id),
1915        project_connection_ids.iter().copied(),
1916        |connection_id| {
1917            session
1918                .peer
1919                .forward_send(session.connection_id, connection_id, request.clone())
1920        },
1921    );
1922    Ok(())
1923}
1924
1925async fn follow(
1926    request: proto::Follow,
1927    response: Response<proto::Follow>,
1928    session: Session,
1929) -> Result<()> {
1930    let room_id = RoomId::from_proto(request.room_id);
1931    let project_id = request.project_id.map(ProjectId::from_proto);
1932    let leader_id = request
1933        .leader_id
1934        .ok_or_else(|| anyhow!("invalid leader id"))?
1935        .into();
1936    let follower_id = session.connection_id;
1937
1938    session
1939        .db()
1940        .await
1941        .check_room_participants(room_id, leader_id, session.connection_id)
1942        .await?;
1943
1944    let response_payload = session
1945        .peer
1946        .forward_request(session.connection_id, leader_id, request)
1947        .await?;
1948    response.send(response_payload)?;
1949
1950    if let Some(project_id) = project_id {
1951        let room = session
1952            .db()
1953            .await
1954            .follow(room_id, project_id, leader_id, follower_id)
1955            .await?;
1956        room_updated(&room, &session.peer);
1957    }
1958
1959    Ok(())
1960}
1961
1962async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1963    let room_id = RoomId::from_proto(request.room_id);
1964    let project_id = request.project_id.map(ProjectId::from_proto);
1965    let leader_id = request
1966        .leader_id
1967        .ok_or_else(|| anyhow!("invalid leader id"))?
1968        .into();
1969    let follower_id = session.connection_id;
1970
1971    session
1972        .db()
1973        .await
1974        .check_room_participants(room_id, leader_id, session.connection_id)
1975        .await?;
1976
1977    session
1978        .peer
1979        .forward_send(session.connection_id, leader_id, request)?;
1980
1981    if let Some(project_id) = project_id {
1982        let room = session
1983            .db()
1984            .await
1985            .unfollow(room_id, project_id, leader_id, follower_id)
1986            .await?;
1987        room_updated(&room, &session.peer);
1988    }
1989
1990    Ok(())
1991}
1992
1993async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1994    let room_id = RoomId::from_proto(request.room_id);
1995    let database = session.db.lock().await;
1996
1997    let connection_ids = if let Some(project_id) = request.project_id {
1998        let project_id = ProjectId::from_proto(project_id);
1999        database
2000            .project_connection_ids(project_id, session.connection_id)
2001            .await?
2002    } else {
2003        database
2004            .room_connection_ids(room_id, session.connection_id)
2005            .await?
2006    };
2007
2008    // For now, don't send view update messages back to that view's current leader.
2009    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2010        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2011        _ => None,
2012    });
2013
2014    for follower_peer_id in request.follower_ids.iter().copied() {
2015        let follower_connection_id = follower_peer_id.into();
2016        if Some(follower_peer_id) != connection_id_to_omit
2017            && connection_ids.contains(&follower_connection_id)
2018        {
2019            session.peer.forward_send(
2020                session.connection_id,
2021                follower_connection_id,
2022                request.clone(),
2023            )?;
2024        }
2025    }
2026    Ok(())
2027}
2028
2029async fn get_users(
2030    request: proto::GetUsers,
2031    response: Response<proto::GetUsers>,
2032    session: Session,
2033) -> Result<()> {
2034    let user_ids = request
2035        .user_ids
2036        .into_iter()
2037        .map(UserId::from_proto)
2038        .collect();
2039    let users = session
2040        .db()
2041        .await
2042        .get_users_by_ids(user_ids)
2043        .await?
2044        .into_iter()
2045        .map(|user| proto::User {
2046            id: user.id.to_proto(),
2047            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2048            github_login: user.github_login,
2049        })
2050        .collect();
2051    response.send(proto::UsersResponse { users })?;
2052    Ok(())
2053}
2054
2055async fn fuzzy_search_users(
2056    request: proto::FuzzySearchUsers,
2057    response: Response<proto::FuzzySearchUsers>,
2058    session: Session,
2059) -> Result<()> {
2060    let query = request.query;
2061    let users = match query.len() {
2062        0 => vec![],
2063        1 | 2 => session
2064            .db()
2065            .await
2066            .get_user_by_github_login(&query)
2067            .await?
2068            .into_iter()
2069            .collect(),
2070        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2071    };
2072    let users = users
2073        .into_iter()
2074        .filter(|user| user.id != session.user_id)
2075        .map(|user| proto::User {
2076            id: user.id.to_proto(),
2077            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2078            github_login: user.github_login,
2079        })
2080        .collect();
2081    response.send(proto::UsersResponse { users })?;
2082    Ok(())
2083}
2084
2085async fn request_contact(
2086    request: proto::RequestContact,
2087    response: Response<proto::RequestContact>,
2088    session: Session,
2089) -> Result<()> {
2090    let requester_id = session.user_id;
2091    let responder_id = UserId::from_proto(request.responder_id);
2092    if requester_id == responder_id {
2093        return Err(anyhow!("cannot add yourself as a contact"))?;
2094    }
2095
2096    let notifications = session
2097        .db()
2098        .await
2099        .send_contact_request(requester_id, responder_id)
2100        .await?;
2101
2102    // Update outgoing contact requests of requester
2103    let mut update = proto::UpdateContacts::default();
2104    update.outgoing_requests.push(responder_id.to_proto());
2105    for connection_id in session
2106        .connection_pool()
2107        .await
2108        .user_connection_ids(requester_id)
2109    {
2110        session.peer.send(connection_id, update.clone())?;
2111    }
2112
2113    // Update incoming contact requests of responder
2114    let mut update = proto::UpdateContacts::default();
2115    update
2116        .incoming_requests
2117        .push(proto::IncomingContactRequest {
2118            requester_id: requester_id.to_proto(),
2119        });
2120    let connection_pool = session.connection_pool().await;
2121    for connection_id in connection_pool.user_connection_ids(responder_id) {
2122        session.peer.send(connection_id, update.clone())?;
2123    }
2124
2125    send_notifications(&*connection_pool, &session.peer, notifications);
2126
2127    response.send(proto::Ack {})?;
2128    Ok(())
2129}
2130
2131async fn respond_to_contact_request(
2132    request: proto::RespondToContactRequest,
2133    response: Response<proto::RespondToContactRequest>,
2134    session: Session,
2135) -> Result<()> {
2136    let responder_id = session.user_id;
2137    let requester_id = UserId::from_proto(request.requester_id);
2138    let db = session.db().await;
2139    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2140        db.dismiss_contact_notification(responder_id, requester_id)
2141            .await?;
2142    } else {
2143        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2144
2145        let notifications = db
2146            .respond_to_contact_request(responder_id, requester_id, accept)
2147            .await?;
2148        let requester_busy = db.is_user_busy(requester_id).await?;
2149        let responder_busy = db.is_user_busy(responder_id).await?;
2150
2151        let pool = session.connection_pool().await;
2152        // Update responder with new contact
2153        let mut update = proto::UpdateContacts::default();
2154        if accept {
2155            update
2156                .contacts
2157                .push(contact_for_user(requester_id, requester_busy, &pool));
2158        }
2159        update
2160            .remove_incoming_requests
2161            .push(requester_id.to_proto());
2162        for connection_id in pool.user_connection_ids(responder_id) {
2163            session.peer.send(connection_id, update.clone())?;
2164        }
2165
2166        // Update requester with new contact
2167        let mut update = proto::UpdateContacts::default();
2168        if accept {
2169            update
2170                .contacts
2171                .push(contact_for_user(responder_id, responder_busy, &pool));
2172        }
2173        update
2174            .remove_outgoing_requests
2175            .push(responder_id.to_proto());
2176
2177        for connection_id in pool.user_connection_ids(requester_id) {
2178            session.peer.send(connection_id, update.clone())?;
2179        }
2180
2181        send_notifications(&*pool, &session.peer, notifications);
2182    }
2183
2184    response.send(proto::Ack {})?;
2185    Ok(())
2186}
2187
2188async fn remove_contact(
2189    request: proto::RemoveContact,
2190    response: Response<proto::RemoveContact>,
2191    session: Session,
2192) -> Result<()> {
2193    let requester_id = session.user_id;
2194    let responder_id = UserId::from_proto(request.user_id);
2195    let db = session.db().await;
2196    let (contact_accepted, deleted_notification_id) =
2197        db.remove_contact(requester_id, responder_id).await?;
2198
2199    let pool = session.connection_pool().await;
2200    // Update outgoing contact requests of requester
2201    let mut update = proto::UpdateContacts::default();
2202    if contact_accepted {
2203        update.remove_contacts.push(responder_id.to_proto());
2204    } else {
2205        update
2206            .remove_outgoing_requests
2207            .push(responder_id.to_proto());
2208    }
2209    for connection_id in pool.user_connection_ids(requester_id) {
2210        session.peer.send(connection_id, update.clone())?;
2211    }
2212
2213    // Update incoming contact requests of responder
2214    let mut update = proto::UpdateContacts::default();
2215    if contact_accepted {
2216        update.remove_contacts.push(requester_id.to_proto());
2217    } else {
2218        update
2219            .remove_incoming_requests
2220            .push(requester_id.to_proto());
2221    }
2222    for connection_id in pool.user_connection_ids(responder_id) {
2223        session.peer.send(connection_id, update.clone())?;
2224        if let Some(notification_id) = deleted_notification_id {
2225            session.peer.send(
2226                connection_id,
2227                proto::DeleteNotification {
2228                    notification_id: notification_id.to_proto(),
2229                },
2230            )?;
2231        }
2232    }
2233
2234    response.send(proto::Ack {})?;
2235    Ok(())
2236}
2237
2238async fn create_channel(
2239    request: proto::CreateChannel,
2240    response: Response<proto::CreateChannel>,
2241    session: Session,
2242) -> Result<()> {
2243    let db = session.db().await;
2244
2245    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2246    let CreateChannelResult {
2247        channel,
2248        participants_to_update,
2249    } = db
2250        .create_channel(&request.name, parent_id, session.user_id)
2251        .await?;
2252
2253    response.send(proto::CreateChannelResponse {
2254        channel: Some(channel.to_proto()),
2255        parent_id: request.parent_id,
2256    })?;
2257
2258    let connection_pool = session.connection_pool().await;
2259    for (user_id, channels) in participants_to_update {
2260        let update = build_channels_update(channels, vec![]);
2261        for connection_id in connection_pool.user_connection_ids(user_id) {
2262            if user_id == session.user_id {
2263                continue;
2264            }
2265            session.peer.send(connection_id, update.clone())?;
2266        }
2267    }
2268
2269    Ok(())
2270}
2271
2272async fn delete_channel(
2273    request: proto::DeleteChannel,
2274    response: Response<proto::DeleteChannel>,
2275    session: Session,
2276) -> Result<()> {
2277    let db = session.db().await;
2278
2279    let channel_id = request.channel_id;
2280    let (removed_channels, member_ids) = db
2281        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2282        .await?;
2283    response.send(proto::Ack {})?;
2284
2285    // Notify members of removed channels
2286    let mut update = proto::UpdateChannels::default();
2287    update
2288        .delete_channels
2289        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2290
2291    let connection_pool = session.connection_pool().await;
2292    for member_id in member_ids {
2293        for connection_id in connection_pool.user_connection_ids(member_id) {
2294            session.peer.send(connection_id, update.clone())?;
2295        }
2296    }
2297
2298    Ok(())
2299}
2300
2301async fn invite_channel_member(
2302    request: proto::InviteChannelMember,
2303    response: Response<proto::InviteChannelMember>,
2304    session: Session,
2305) -> Result<()> {
2306    let db = session.db().await;
2307    let channel_id = ChannelId::from_proto(request.channel_id);
2308    let invitee_id = UserId::from_proto(request.user_id);
2309    let InviteMemberResult {
2310        channel,
2311        notifications,
2312    } = db
2313        .invite_channel_member(
2314            channel_id,
2315            invitee_id,
2316            session.user_id,
2317            request.role().into(),
2318        )
2319        .await?;
2320
2321    let update = proto::UpdateChannels {
2322        channel_invitations: vec![channel.to_proto()],
2323        ..Default::default()
2324    };
2325
2326    let connection_pool = session.connection_pool().await;
2327    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2328        session.peer.send(connection_id, update.clone())?;
2329    }
2330
2331    send_notifications(&*connection_pool, &session.peer, notifications);
2332
2333    response.send(proto::Ack {})?;
2334    Ok(())
2335}
2336
2337async fn remove_channel_member(
2338    request: proto::RemoveChannelMember,
2339    response: Response<proto::RemoveChannelMember>,
2340    session: Session,
2341) -> Result<()> {
2342    let db = session.db().await;
2343    let channel_id = ChannelId::from_proto(request.channel_id);
2344    let member_id = UserId::from_proto(request.user_id);
2345
2346    let RemoveChannelMemberResult {
2347        membership_update,
2348        notification_id,
2349    } = db
2350        .remove_channel_member(channel_id, member_id, session.user_id)
2351        .await?;
2352
2353    let connection_pool = &session.connection_pool().await;
2354    notify_membership_updated(
2355        &connection_pool,
2356        membership_update,
2357        member_id,
2358        &session.peer,
2359    );
2360    for connection_id in connection_pool.user_connection_ids(member_id) {
2361        if let Some(notification_id) = notification_id {
2362            session
2363                .peer
2364                .send(
2365                    connection_id,
2366                    proto::DeleteNotification {
2367                        notification_id: notification_id.to_proto(),
2368                    },
2369                )
2370                .trace_err();
2371        }
2372    }
2373
2374    response.send(proto::Ack {})?;
2375    Ok(())
2376}
2377
2378async fn set_channel_visibility(
2379    request: proto::SetChannelVisibility,
2380    response: Response<proto::SetChannelVisibility>,
2381    session: Session,
2382) -> Result<()> {
2383    let db = session.db().await;
2384    let channel_id = ChannelId::from_proto(request.channel_id);
2385    let visibility = request.visibility().into();
2386
2387    let SetChannelVisibilityResult {
2388        participants_to_update,
2389        participants_to_remove,
2390        channels_to_remove,
2391    } = db
2392        .set_channel_visibility(channel_id, visibility, session.user_id)
2393        .await?;
2394
2395    let connection_pool = session.connection_pool().await;
2396    for (user_id, channels) in participants_to_update {
2397        let update = build_channels_update(channels, vec![]);
2398        for connection_id in connection_pool.user_connection_ids(user_id) {
2399            session.peer.send(connection_id, update.clone())?;
2400        }
2401    }
2402    for user_id in participants_to_remove {
2403        let update = proto::UpdateChannels {
2404            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2405            ..Default::default()
2406        };
2407        for connection_id in connection_pool.user_connection_ids(user_id) {
2408            session.peer.send(connection_id, update.clone())?;
2409        }
2410    }
2411
2412    response.send(proto::Ack {})?;
2413    Ok(())
2414}
2415
2416async fn set_channel_member_role(
2417    request: proto::SetChannelMemberRole,
2418    response: Response<proto::SetChannelMemberRole>,
2419    session: Session,
2420) -> Result<()> {
2421    let db = session.db().await;
2422    let channel_id = ChannelId::from_proto(request.channel_id);
2423    let member_id = UserId::from_proto(request.user_id);
2424    let result = db
2425        .set_channel_member_role(
2426            channel_id,
2427            session.user_id,
2428            member_id,
2429            request.role().into(),
2430        )
2431        .await?;
2432
2433    match result {
2434        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2435            let connection_pool = session.connection_pool().await;
2436            notify_membership_updated(
2437                &connection_pool,
2438                membership_update,
2439                member_id,
2440                &session.peer,
2441            )
2442        }
2443        db::SetMemberRoleResult::InviteUpdated(channel) => {
2444            let update = proto::UpdateChannels {
2445                channel_invitations: vec![channel.to_proto()],
2446                ..Default::default()
2447            };
2448
2449            for connection_id in session
2450                .connection_pool()
2451                .await
2452                .user_connection_ids(member_id)
2453            {
2454                session.peer.send(connection_id, update.clone())?;
2455            }
2456        }
2457    }
2458
2459    response.send(proto::Ack {})?;
2460    Ok(())
2461}
2462
2463async fn rename_channel(
2464    request: proto::RenameChannel,
2465    response: Response<proto::RenameChannel>,
2466    session: Session,
2467) -> Result<()> {
2468    let db = session.db().await;
2469    let channel_id = ChannelId::from_proto(request.channel_id);
2470    let RenameChannelResult {
2471        channel,
2472        participants_to_update,
2473    } = db
2474        .rename_channel(channel_id, session.user_id, &request.name)
2475        .await?;
2476
2477    response.send(proto::RenameChannelResponse {
2478        channel: Some(channel.to_proto()),
2479    })?;
2480
2481    let connection_pool = session.connection_pool().await;
2482    for (user_id, channel) in participants_to_update {
2483        for connection_id in connection_pool.user_connection_ids(user_id) {
2484            let update = proto::UpdateChannels {
2485                channels: vec![channel.to_proto()],
2486                ..Default::default()
2487            };
2488
2489            session.peer.send(connection_id, update.clone())?;
2490        }
2491    }
2492
2493    Ok(())
2494}
2495
2496async fn move_channel(
2497    request: proto::MoveChannel,
2498    response: Response<proto::MoveChannel>,
2499    session: Session,
2500) -> Result<()> {
2501    let channel_id = ChannelId::from_proto(request.channel_id);
2502    let to = request.to.map(ChannelId::from_proto);
2503
2504    let result = session
2505        .db()
2506        .await
2507        .move_channel(channel_id, to, session.user_id)
2508        .await?;
2509
2510    notify_channel_moved(result, session).await?;
2511
2512    response.send(Ack {})?;
2513    Ok(())
2514}
2515
2516async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2517    let Some(MoveChannelResult {
2518        participants_to_remove,
2519        participants_to_update,
2520        moved_channels,
2521    }) = result
2522    else {
2523        return Ok(());
2524    };
2525    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2526
2527    let connection_pool = session.connection_pool().await;
2528    for (user_id, channels) in participants_to_update {
2529        let mut update = build_channels_update(channels, vec![]);
2530        update.delete_channels = moved_channels.clone();
2531        for connection_id in connection_pool.user_connection_ids(user_id) {
2532            session.peer.send(connection_id, update.clone())?;
2533        }
2534    }
2535
2536    for user_id in participants_to_remove {
2537        let update = proto::UpdateChannels {
2538            delete_channels: moved_channels.clone(),
2539            ..Default::default()
2540        };
2541        for connection_id in connection_pool.user_connection_ids(user_id) {
2542            session.peer.send(connection_id, update.clone())?;
2543        }
2544    }
2545    Ok(())
2546}
2547
2548async fn get_channel_members(
2549    request: proto::GetChannelMembers,
2550    response: Response<proto::GetChannelMembers>,
2551    session: Session,
2552) -> Result<()> {
2553    let db = session.db().await;
2554    let channel_id = ChannelId::from_proto(request.channel_id);
2555    let members = db
2556        .get_channel_participant_details(channel_id, session.user_id)
2557        .await?;
2558    response.send(proto::GetChannelMembersResponse { members })?;
2559    Ok(())
2560}
2561
2562async fn respond_to_channel_invite(
2563    request: proto::RespondToChannelInvite,
2564    response: Response<proto::RespondToChannelInvite>,
2565    session: Session,
2566) -> Result<()> {
2567    let db = session.db().await;
2568    let channel_id = ChannelId::from_proto(request.channel_id);
2569    let RespondToChannelInvite {
2570        membership_update,
2571        notifications,
2572    } = db
2573        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2574        .await?;
2575
2576    let connection_pool = session.connection_pool().await;
2577    if let Some(membership_update) = membership_update {
2578        notify_membership_updated(
2579            &connection_pool,
2580            membership_update,
2581            session.user_id,
2582            &session.peer,
2583        );
2584    } else {
2585        let update = proto::UpdateChannels {
2586            remove_channel_invitations: vec![channel_id.to_proto()],
2587            ..Default::default()
2588        };
2589
2590        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2591            session.peer.send(connection_id, update.clone())?;
2592        }
2593    };
2594
2595    send_notifications(&*connection_pool, &session.peer, notifications);
2596
2597    response.send(proto::Ack {})?;
2598
2599    Ok(())
2600}
2601
2602async fn join_channel(
2603    request: proto::JoinChannel,
2604    response: Response<proto::JoinChannel>,
2605    session: Session,
2606) -> Result<()> {
2607    let channel_id = ChannelId::from_proto(request.channel_id);
2608    join_channel_internal(channel_id, Box::new(response), session).await
2609}
2610
2611trait JoinChannelInternalResponse {
2612    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2613}
2614impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2615    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2616        Response::<proto::JoinChannel>::send(self, result)
2617    }
2618}
2619impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2620    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2621        Response::<proto::JoinRoom>::send(self, result)
2622    }
2623}
2624
2625async fn join_channel_internal(
2626    channel_id: ChannelId,
2627    response: Box<impl JoinChannelInternalResponse>,
2628    session: Session,
2629) -> Result<()> {
2630    let joined_room = {
2631        leave_room_for_session(&session).await?;
2632        let db = session.db().await;
2633
2634        let (joined_room, membership_updated, role) = db
2635            .join_channel(
2636                channel_id,
2637                session.user_id,
2638                session.connection_id,
2639                RELEASE_CHANNEL_NAME.as_str(),
2640            )
2641            .await?;
2642
2643        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2644            let (can_publish, token) = if role == ChannelRole::Guest {
2645                (
2646                    false,
2647                    live_kit
2648                        .guest_token(
2649                            &joined_room.room.live_kit_room,
2650                            &session.user_id.to_string(),
2651                        )
2652                        .trace_err()?,
2653                )
2654            } else {
2655                (
2656                    true,
2657                    live_kit
2658                        .room_token(
2659                            &joined_room.room.live_kit_room,
2660                            &session.user_id.to_string(),
2661                        )
2662                        .trace_err()?,
2663                )
2664            };
2665
2666            Some(LiveKitConnectionInfo {
2667                server_url: live_kit.url().into(),
2668                token,
2669                can_publish,
2670            })
2671        });
2672
2673        response.send(proto::JoinRoomResponse {
2674            room: Some(joined_room.room.clone()),
2675            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2676            live_kit_connection_info,
2677        })?;
2678
2679        let connection_pool = session.connection_pool().await;
2680        if let Some(membership_updated) = membership_updated {
2681            notify_membership_updated(
2682                &connection_pool,
2683                membership_updated,
2684                session.user_id,
2685                &session.peer,
2686            );
2687        }
2688
2689        room_updated(&joined_room.room, &session.peer);
2690
2691        joined_room
2692    };
2693
2694    channel_updated(
2695        channel_id,
2696        &joined_room.room,
2697        &joined_room.channel_members,
2698        &session.peer,
2699        &*session.connection_pool().await,
2700    );
2701
2702    update_user_contacts(session.user_id, &session).await?;
2703    Ok(())
2704}
2705
2706async fn join_channel_buffer(
2707    request: proto::JoinChannelBuffer,
2708    response: Response<proto::JoinChannelBuffer>,
2709    session: Session,
2710) -> Result<()> {
2711    let db = session.db().await;
2712    let channel_id = ChannelId::from_proto(request.channel_id);
2713
2714    let open_response = db
2715        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2716        .await?;
2717
2718    let collaborators = open_response.collaborators.clone();
2719    response.send(open_response)?;
2720
2721    let update = UpdateChannelBufferCollaborators {
2722        channel_id: channel_id.to_proto(),
2723        collaborators: collaborators.clone(),
2724    };
2725    channel_buffer_updated(
2726        session.connection_id,
2727        collaborators
2728            .iter()
2729            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2730        &update,
2731        &session.peer,
2732    );
2733
2734    Ok(())
2735}
2736
2737async fn update_channel_buffer(
2738    request: proto::UpdateChannelBuffer,
2739    session: Session,
2740) -> Result<()> {
2741    let db = session.db().await;
2742    let channel_id = ChannelId::from_proto(request.channel_id);
2743
2744    let (collaborators, non_collaborators, epoch, version) = db
2745        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2746        .await?;
2747
2748    channel_buffer_updated(
2749        session.connection_id,
2750        collaborators,
2751        &proto::UpdateChannelBuffer {
2752            channel_id: channel_id.to_proto(),
2753            operations: request.operations,
2754        },
2755        &session.peer,
2756    );
2757
2758    let pool = &*session.connection_pool().await;
2759
2760    broadcast(
2761        None,
2762        non_collaborators
2763            .iter()
2764            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2765        |peer_id| {
2766            session.peer.send(
2767                peer_id.into(),
2768                proto::UpdateChannels {
2769                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2770                        channel_id: channel_id.to_proto(),
2771                        epoch: epoch as u64,
2772                        version: version.clone(),
2773                    }],
2774                    ..Default::default()
2775                },
2776            )
2777        },
2778    );
2779
2780    Ok(())
2781}
2782
2783async fn rejoin_channel_buffers(
2784    request: proto::RejoinChannelBuffers,
2785    response: Response<proto::RejoinChannelBuffers>,
2786    session: Session,
2787) -> Result<()> {
2788    let db = session.db().await;
2789    let buffers = db
2790        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2791        .await?;
2792
2793    for rejoined_buffer in &buffers {
2794        let collaborators_to_notify = rejoined_buffer
2795            .buffer
2796            .collaborators
2797            .iter()
2798            .filter_map(|c| Some(c.peer_id?.into()));
2799        channel_buffer_updated(
2800            session.connection_id,
2801            collaborators_to_notify,
2802            &proto::UpdateChannelBufferCollaborators {
2803                channel_id: rejoined_buffer.buffer.channel_id,
2804                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2805            },
2806            &session.peer,
2807        );
2808    }
2809
2810    response.send(proto::RejoinChannelBuffersResponse {
2811        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2812    })?;
2813
2814    Ok(())
2815}
2816
2817async fn leave_channel_buffer(
2818    request: proto::LeaveChannelBuffer,
2819    response: Response<proto::LeaveChannelBuffer>,
2820    session: Session,
2821) -> Result<()> {
2822    let db = session.db().await;
2823    let channel_id = ChannelId::from_proto(request.channel_id);
2824
2825    let left_buffer = db
2826        .leave_channel_buffer(channel_id, session.connection_id)
2827        .await?;
2828
2829    response.send(Ack {})?;
2830
2831    channel_buffer_updated(
2832        session.connection_id,
2833        left_buffer.connections,
2834        &proto::UpdateChannelBufferCollaborators {
2835            channel_id: channel_id.to_proto(),
2836            collaborators: left_buffer.collaborators,
2837        },
2838        &session.peer,
2839    );
2840
2841    Ok(())
2842}
2843
2844fn channel_buffer_updated<T: EnvelopedMessage>(
2845    sender_id: ConnectionId,
2846    collaborators: impl IntoIterator<Item = ConnectionId>,
2847    message: &T,
2848    peer: &Peer,
2849) {
2850    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2851        peer.send(peer_id.into(), message.clone())
2852    });
2853}
2854
2855fn send_notifications(
2856    connection_pool: &ConnectionPool,
2857    peer: &Peer,
2858    notifications: db::NotificationBatch,
2859) {
2860    for (user_id, notification) in notifications {
2861        for connection_id in connection_pool.user_connection_ids(user_id) {
2862            if let Err(error) = peer.send(
2863                connection_id,
2864                proto::AddNotification {
2865                    notification: Some(notification.clone()),
2866                },
2867            ) {
2868                tracing::error!(
2869                    "failed to send notification to {:?} {}",
2870                    connection_id,
2871                    error
2872                );
2873            }
2874        }
2875    }
2876}
2877
2878async fn send_channel_message(
2879    request: proto::SendChannelMessage,
2880    response: Response<proto::SendChannelMessage>,
2881    session: Session,
2882) -> Result<()> {
2883    // Validate the message body.
2884    let body = request.body.trim().to_string();
2885    if body.len() > MAX_MESSAGE_LEN {
2886        return Err(anyhow!("message is too long"))?;
2887    }
2888    if body.is_empty() {
2889        return Err(anyhow!("message can't be blank"))?;
2890    }
2891
2892    // TODO: adjust mentions if body is trimmed
2893
2894    let timestamp = OffsetDateTime::now_utc();
2895    let nonce = request
2896        .nonce
2897        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2898
2899    let channel_id = ChannelId::from_proto(request.channel_id);
2900    let CreatedChannelMessage {
2901        message_id,
2902        participant_connection_ids,
2903        channel_members,
2904        notifications,
2905    } = session
2906        .db()
2907        .await
2908        .create_channel_message(
2909            channel_id,
2910            session.user_id,
2911            &body,
2912            &request.mentions,
2913            timestamp,
2914            nonce.clone().into(),
2915        )
2916        .await?;
2917    let message = proto::ChannelMessage {
2918        sender_id: session.user_id.to_proto(),
2919        id: message_id.to_proto(),
2920        body,
2921        mentions: request.mentions,
2922        timestamp: timestamp.unix_timestamp() as u64,
2923        nonce: Some(nonce),
2924    };
2925    broadcast(
2926        Some(session.connection_id),
2927        participant_connection_ids,
2928        |connection| {
2929            session.peer.send(
2930                connection,
2931                proto::ChannelMessageSent {
2932                    channel_id: channel_id.to_proto(),
2933                    message: Some(message.clone()),
2934                },
2935            )
2936        },
2937    );
2938    response.send(proto::SendChannelMessageResponse {
2939        message: Some(message),
2940    })?;
2941
2942    let pool = &*session.connection_pool().await;
2943    broadcast(
2944        None,
2945        channel_members
2946            .iter()
2947            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2948        |peer_id| {
2949            session.peer.send(
2950                peer_id.into(),
2951                proto::UpdateChannels {
2952                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2953                        channel_id: channel_id.to_proto(),
2954                        message_id: message_id.to_proto(),
2955                    }],
2956                    ..Default::default()
2957                },
2958            )
2959        },
2960    );
2961    send_notifications(pool, &session.peer, notifications);
2962
2963    Ok(())
2964}
2965
2966async fn remove_channel_message(
2967    request: proto::RemoveChannelMessage,
2968    response: Response<proto::RemoveChannelMessage>,
2969    session: Session,
2970) -> Result<()> {
2971    let channel_id = ChannelId::from_proto(request.channel_id);
2972    let message_id = MessageId::from_proto(request.message_id);
2973    let connection_ids = session
2974        .db()
2975        .await
2976        .remove_channel_message(channel_id, message_id, session.user_id)
2977        .await?;
2978    broadcast(Some(session.connection_id), connection_ids, |connection| {
2979        session.peer.send(connection, request.clone())
2980    });
2981    response.send(proto::Ack {})?;
2982    Ok(())
2983}
2984
2985async fn acknowledge_channel_message(
2986    request: proto::AckChannelMessage,
2987    session: Session,
2988) -> Result<()> {
2989    let channel_id = ChannelId::from_proto(request.channel_id);
2990    let message_id = MessageId::from_proto(request.message_id);
2991    let notifications = session
2992        .db()
2993        .await
2994        .observe_channel_message(channel_id, session.user_id, message_id)
2995        .await?;
2996    send_notifications(
2997        &*session.connection_pool().await,
2998        &session.peer,
2999        notifications,
3000    );
3001    Ok(())
3002}
3003
3004async fn acknowledge_buffer_version(
3005    request: proto::AckBufferOperation,
3006    session: Session,
3007) -> Result<()> {
3008    let buffer_id = BufferId::from_proto(request.buffer_id);
3009    session
3010        .db()
3011        .await
3012        .observe_buffer_version(
3013            buffer_id,
3014            session.user_id,
3015            request.epoch as i32,
3016            &request.version,
3017        )
3018        .await?;
3019    Ok(())
3020}
3021
3022async fn join_channel_chat(
3023    request: proto::JoinChannelChat,
3024    response: Response<proto::JoinChannelChat>,
3025    session: Session,
3026) -> Result<()> {
3027    let channel_id = ChannelId::from_proto(request.channel_id);
3028
3029    let db = session.db().await;
3030    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3031        .await?;
3032    let messages = db
3033        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3034        .await?;
3035    response.send(proto::JoinChannelChatResponse {
3036        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3037        messages,
3038    })?;
3039    Ok(())
3040}
3041
3042async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3043    let channel_id = ChannelId::from_proto(request.channel_id);
3044    session
3045        .db()
3046        .await
3047        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3048        .await?;
3049    Ok(())
3050}
3051
3052async fn get_channel_messages(
3053    request: proto::GetChannelMessages,
3054    response: Response<proto::GetChannelMessages>,
3055    session: Session,
3056) -> Result<()> {
3057    let channel_id = ChannelId::from_proto(request.channel_id);
3058    let messages = session
3059        .db()
3060        .await
3061        .get_channel_messages(
3062            channel_id,
3063            session.user_id,
3064            MESSAGE_COUNT_PER_PAGE,
3065            Some(MessageId::from_proto(request.before_message_id)),
3066        )
3067        .await?;
3068    response.send(proto::GetChannelMessagesResponse {
3069        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3070        messages,
3071    })?;
3072    Ok(())
3073}
3074
3075async fn get_channel_messages_by_id(
3076    request: proto::GetChannelMessagesById,
3077    response: Response<proto::GetChannelMessagesById>,
3078    session: Session,
3079) -> Result<()> {
3080    let message_ids = request
3081        .message_ids
3082        .iter()
3083        .map(|id| MessageId::from_proto(*id))
3084        .collect::<Vec<_>>();
3085    let messages = session
3086        .db()
3087        .await
3088        .get_channel_messages_by_id(session.user_id, &message_ids)
3089        .await?;
3090    response.send(proto::GetChannelMessagesResponse {
3091        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3092        messages,
3093    })?;
3094    Ok(())
3095}
3096
3097async fn get_notifications(
3098    request: proto::GetNotifications,
3099    response: Response<proto::GetNotifications>,
3100    session: Session,
3101) -> Result<()> {
3102    let notifications = session
3103        .db()
3104        .await
3105        .get_notifications(
3106            session.user_id,
3107            NOTIFICATION_COUNT_PER_PAGE,
3108            request
3109                .before_id
3110                .map(|id| db::NotificationId::from_proto(id)),
3111        )
3112        .await?;
3113    response.send(proto::GetNotificationsResponse {
3114        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3115        notifications,
3116    })?;
3117    Ok(())
3118}
3119
3120async fn mark_notification_as_read(
3121    request: proto::MarkNotificationRead,
3122    response: Response<proto::MarkNotificationRead>,
3123    session: Session,
3124) -> Result<()> {
3125    let database = &session.db().await;
3126    let notifications = database
3127        .mark_notification_as_read_by_id(
3128            session.user_id,
3129            NotificationId::from_proto(request.notification_id),
3130        )
3131        .await?;
3132    send_notifications(
3133        &*session.connection_pool().await,
3134        &session.peer,
3135        notifications,
3136    );
3137    response.send(proto::Ack {})?;
3138    Ok(())
3139}
3140
3141async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3142    let project_id = ProjectId::from_proto(request.project_id);
3143    let project_connection_ids = session
3144        .db()
3145        .await
3146        .project_connection_ids(project_id, session.connection_id)
3147        .await?;
3148    broadcast(
3149        Some(session.connection_id),
3150        project_connection_ids.iter().copied(),
3151        |connection_id| {
3152            session
3153                .peer
3154                .forward_send(session.connection_id, connection_id, request.clone())
3155        },
3156    );
3157    Ok(())
3158}
3159
3160async fn get_private_user_info(
3161    _request: proto::GetPrivateUserInfo,
3162    response: Response<proto::GetPrivateUserInfo>,
3163    session: Session,
3164) -> Result<()> {
3165    let db = session.db().await;
3166
3167    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3168    let user = db
3169        .get_user_by_id(session.user_id)
3170        .await?
3171        .ok_or_else(|| anyhow!("user not found"))?;
3172    let flags = db.get_user_flags(session.user_id).await?;
3173
3174    response.send(proto::GetPrivateUserInfoResponse {
3175        metrics_id,
3176        staff: user.admin,
3177        flags,
3178    })?;
3179    Ok(())
3180}
3181
3182fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3183    match message {
3184        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3185        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3186        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3187        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3188        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3189            code: frame.code.into(),
3190            reason: frame.reason,
3191        })),
3192    }
3193}
3194
3195fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3196    match message {
3197        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3198        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3199        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3200        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3201        AxumMessage::Close(frame) => {
3202            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3203                code: frame.code.into(),
3204                reason: frame.reason,
3205            }))
3206        }
3207    }
3208}
3209
3210fn notify_membership_updated(
3211    connection_pool: &ConnectionPool,
3212    result: MembershipUpdated,
3213    user_id: UserId,
3214    peer: &Peer,
3215) {
3216    let mut update = build_channels_update(result.new_channels, vec![]);
3217    update.delete_channels = result
3218        .removed_channels
3219        .into_iter()
3220        .map(|id| id.to_proto())
3221        .collect();
3222    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3223
3224    for connection_id in connection_pool.user_connection_ids(user_id) {
3225        peer.send(connection_id, update.clone()).trace_err();
3226    }
3227}
3228
3229fn build_channels_update(
3230    channels: ChannelsForUser,
3231    channel_invites: Vec<db::Channel>,
3232) -> proto::UpdateChannels {
3233    let mut update = proto::UpdateChannels::default();
3234
3235    for channel in channels.channels {
3236        update.channels.push(channel.to_proto());
3237    }
3238
3239    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3240    update.unseen_channel_messages = channels.channel_messages;
3241
3242    for (channel_id, participants) in channels.channel_participants {
3243        update
3244            .channel_participants
3245            .push(proto::ChannelParticipants {
3246                channel_id: channel_id.to_proto(),
3247                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3248            });
3249    }
3250
3251    for channel in channel_invites {
3252        update.channel_invitations.push(channel.to_proto());
3253    }
3254
3255    update
3256}
3257
3258fn build_initial_contacts_update(
3259    contacts: Vec<db::Contact>,
3260    pool: &ConnectionPool,
3261) -> proto::UpdateContacts {
3262    let mut update = proto::UpdateContacts::default();
3263
3264    for contact in contacts {
3265        match contact {
3266            db::Contact::Accepted { user_id, busy } => {
3267                update.contacts.push(contact_for_user(user_id, busy, &pool));
3268            }
3269            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3270            db::Contact::Incoming { user_id } => {
3271                update
3272                    .incoming_requests
3273                    .push(proto::IncomingContactRequest {
3274                        requester_id: user_id.to_proto(),
3275                    })
3276            }
3277        }
3278    }
3279
3280    update
3281}
3282
3283fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3284    proto::Contact {
3285        user_id: user_id.to_proto(),
3286        online: pool.is_user_online(user_id),
3287        busy,
3288    }
3289}
3290
3291fn room_updated(room: &proto::Room, peer: &Peer) {
3292    broadcast(
3293        None,
3294        room.participants
3295            .iter()
3296            .filter_map(|participant| Some(participant.peer_id?.into())),
3297        |peer_id| {
3298            peer.send(
3299                peer_id.into(),
3300                proto::RoomUpdated {
3301                    room: Some(room.clone()),
3302                },
3303            )
3304        },
3305    );
3306}
3307
3308fn channel_updated(
3309    channel_id: ChannelId,
3310    room: &proto::Room,
3311    channel_members: &[UserId],
3312    peer: &Peer,
3313    pool: &ConnectionPool,
3314) {
3315    let participants = room
3316        .participants
3317        .iter()
3318        .map(|p| p.user_id)
3319        .collect::<Vec<_>>();
3320
3321    broadcast(
3322        None,
3323        channel_members
3324            .iter()
3325            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3326        |peer_id| {
3327            peer.send(
3328                peer_id.into(),
3329                proto::UpdateChannels {
3330                    channel_participants: vec![proto::ChannelParticipants {
3331                        channel_id: channel_id.to_proto(),
3332                        participant_user_ids: participants.clone(),
3333                    }],
3334                    ..Default::default()
3335                },
3336            )
3337        },
3338    );
3339}
3340
3341async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3342    let db = session.db().await;
3343
3344    let contacts = db.get_contacts(user_id).await?;
3345    let busy = db.is_user_busy(user_id).await?;
3346
3347    let pool = session.connection_pool().await;
3348    let updated_contact = contact_for_user(user_id, busy, &pool);
3349    for contact in contacts {
3350        if let db::Contact::Accepted {
3351            user_id: contact_user_id,
3352            ..
3353        } = contact
3354        {
3355            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3356                session
3357                    .peer
3358                    .send(
3359                        contact_conn_id,
3360                        proto::UpdateContacts {
3361                            contacts: vec![updated_contact.clone()],
3362                            remove_contacts: Default::default(),
3363                            incoming_requests: Default::default(),
3364                            remove_incoming_requests: Default::default(),
3365                            outgoing_requests: Default::default(),
3366                            remove_outgoing_requests: Default::default(),
3367                        },
3368                    )
3369                    .trace_err();
3370            }
3371        }
3372    }
3373    Ok(())
3374}
3375
3376async fn leave_room_for_session(session: &Session) -> Result<()> {
3377    let mut contacts_to_update = HashSet::default();
3378
3379    let room_id;
3380    let canceled_calls_to_user_ids;
3381    let live_kit_room;
3382    let delete_live_kit_room;
3383    let room;
3384    let channel_members;
3385    let channel_id;
3386
3387    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3388        contacts_to_update.insert(session.user_id);
3389
3390        for project in left_room.left_projects.values() {
3391            project_left(project, session);
3392        }
3393
3394        room_id = RoomId::from_proto(left_room.room.id);
3395        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3396        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3397        delete_live_kit_room = left_room.deleted;
3398        room = mem::take(&mut left_room.room);
3399        channel_members = mem::take(&mut left_room.channel_members);
3400        channel_id = left_room.channel_id;
3401
3402        room_updated(&room, &session.peer);
3403    } else {
3404        return Ok(());
3405    }
3406
3407    if let Some(channel_id) = channel_id {
3408        channel_updated(
3409            channel_id,
3410            &room,
3411            &channel_members,
3412            &session.peer,
3413            &*session.connection_pool().await,
3414        );
3415    }
3416
3417    {
3418        let pool = session.connection_pool().await;
3419        for canceled_user_id in canceled_calls_to_user_ids {
3420            for connection_id in pool.user_connection_ids(canceled_user_id) {
3421                session
3422                    .peer
3423                    .send(
3424                        connection_id,
3425                        proto::CallCanceled {
3426                            room_id: room_id.to_proto(),
3427                        },
3428                    )
3429                    .trace_err();
3430            }
3431            contacts_to_update.insert(canceled_user_id);
3432        }
3433    }
3434
3435    for contact_user_id in contacts_to_update {
3436        update_user_contacts(contact_user_id, &session).await?;
3437    }
3438
3439    if let Some(live_kit) = session.live_kit_client.as_ref() {
3440        live_kit
3441            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3442            .await
3443            .trace_err();
3444
3445        if delete_live_kit_room {
3446            live_kit.delete_room(live_kit_room).await.trace_err();
3447        }
3448    }
3449
3450    Ok(())
3451}
3452
3453async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3454    let left_channel_buffers = session
3455        .db()
3456        .await
3457        .leave_channel_buffers(session.connection_id)
3458        .await?;
3459
3460    for left_buffer in left_channel_buffers {
3461        channel_buffer_updated(
3462            session.connection_id,
3463            left_buffer.connections,
3464            &proto::UpdateChannelBufferCollaborators {
3465                channel_id: left_buffer.channel_id.to_proto(),
3466                collaborators: left_buffer.collaborators,
3467            },
3468            &session.peer,
3469        );
3470    }
3471
3472    Ok(())
3473}
3474
3475fn project_left(project: &db::LeftProject, session: &Session) {
3476    for connection_id in &project.connection_ids {
3477        if project.host_user_id == session.user_id {
3478            session
3479                .peer
3480                .send(
3481                    *connection_id,
3482                    proto::UnshareProject {
3483                        project_id: project.id.to_proto(),
3484                    },
3485                )
3486                .trace_err();
3487        } else {
3488            session
3489                .peer
3490                .send(
3491                    *connection_id,
3492                    proto::RemoveProjectCollaborator {
3493                        project_id: project.id.to_proto(),
3494                        peer_id: Some(session.connection_id.into()),
3495                    },
3496                )
3497                .trace_err();
3498        }
3499    }
3500}
3501
3502pub trait ResultExt {
3503    type Ok;
3504
3505    fn trace_err(self) -> Option<Self::Ok>;
3506}
3507
3508impl<T, E> ResultExt for Result<T, E>
3509where
3510    E: std::fmt::Debug,
3511{
3512    type Ok = T;
3513
3514    fn trace_err(self) -> Option<T> {
3515        match self {
3516            Ok(value) => Some(value),
3517            Err(error) => {
3518                tracing::error!("{:?}", error);
3519                None
3520            }
3521        }
3522    }
3523}