rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69use util::channel::RELEASE_CHANNEL_NAME;
  70
  71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  73
  74const MESSAGE_COUNT_PER_PAGE: usize = 100;
  75const MAX_MESSAGE_LEN: usize = 1024;
  76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  77
  78lazy_static! {
  79    static ref METRIC_CONNECTIONS: IntGauge =
  80        register_int_gauge!("connections", "number of connections").unwrap();
  81    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  82        "shared_projects",
  83        "number of open projects with one or more guests"
  84    )
  85    .unwrap();
  86}
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone)]
 106struct Session {
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(call)
 206            .add_request_handler(cancel_call)
 207            .add_message_handler(decline_call)
 208            .add_request_handler(update_participant_location)
 209            .add_request_handler(share_project)
 210            .add_message_handler(unshare_project)
 211            .add_request_handler(join_project)
 212            .add_message_handler(leave_project)
 213            .add_request_handler(update_project)
 214            .add_request_handler(update_worktree)
 215            .add_message_handler(start_language_server)
 216            .add_message_handler(update_language_server)
 217            .add_message_handler(update_diagnostic_summary)
 218            .add_message_handler(update_worktree_settings)
 219            .add_message_handler(refresh_inlay_hints)
 220            .add_request_handler(forward_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 230            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 232            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 233            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 234            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 235            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 236            .add_request_handler(forward_project_request::<proto::PerformRename>)
 237            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 238            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 239            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 240            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 243            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 244            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 245            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 246            .add_request_handler(forward_project_request::<proto::InlayHints>)
 247            .add_message_handler(create_buffer_for_peer)
 248            .add_request_handler(update_buffer)
 249            .add_message_handler(update_buffer_file)
 250            .add_message_handler(buffer_reloaded)
 251            .add_message_handler(buffer_saved)
 252            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 253            .add_request_handler(get_users)
 254            .add_request_handler(fuzzy_search_users)
 255            .add_request_handler(request_contact)
 256            .add_request_handler(remove_contact)
 257            .add_request_handler(respond_to_contact_request)
 258            .add_request_handler(create_channel)
 259            .add_request_handler(delete_channel)
 260            .add_request_handler(invite_channel_member)
 261            .add_request_handler(remove_channel_member)
 262            .add_request_handler(set_channel_member_role)
 263            .add_request_handler(set_channel_visibility)
 264            .add_request_handler(rename_channel)
 265            .add_request_handler(join_channel_buffer)
 266            .add_request_handler(leave_channel_buffer)
 267            .add_message_handler(update_channel_buffer)
 268            .add_request_handler(rejoin_channel_buffers)
 269            .add_request_handler(get_channel_members)
 270            .add_request_handler(respond_to_channel_invite)
 271            .add_request_handler(join_channel)
 272            .add_request_handler(join_channel_chat)
 273            .add_message_handler(leave_channel_chat)
 274            .add_request_handler(send_channel_message)
 275            .add_request_handler(remove_channel_message)
 276            .add_request_handler(get_channel_messages)
 277            .add_request_handler(get_channel_messages_by_id)
 278            .add_request_handler(get_notifications)
 279            .add_request_handler(mark_notification_as_read)
 280            .add_request_handler(link_channel)
 281            .add_request_handler(unlink_channel)
 282            .add_request_handler(move_channel)
 283            .add_request_handler(follow)
 284            .add_message_handler(unfollow)
 285            .add_message_handler(update_followers)
 286            .add_message_handler(update_diff_base)
 287            .add_request_handler(get_private_user_info)
 288            .add_message_handler(acknowledge_channel_message)
 289            .add_message_handler(acknowledge_buffer_version);
 290
 291        Arc::new(server)
 292    }
 293
 294    pub async fn start(&self) -> Result<()> {
 295        let server_id = *self.id.lock();
 296        let app_state = self.app_state.clone();
 297        let peer = self.peer.clone();
 298        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 299        let pool = self.connection_pool.clone();
 300        let live_kit_client = self.app_state.live_kit_client.clone();
 301
 302        let span = info_span!("start server");
 303        self.executor.spawn_detached(
 304            async move {
 305                tracing::info!("waiting for cleanup timeout");
 306                timeout.await;
 307                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 308                if let Some((room_ids, channel_ids)) = app_state
 309                    .db
 310                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 311                    .await
 312                    .trace_err()
 313                {
 314                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 315                    tracing::info!(
 316                        stale_channel_buffer_count = channel_ids.len(),
 317                        "retrieved stale channel buffers"
 318                    );
 319
 320                    for channel_id in channel_ids {
 321                        if let Some(refreshed_channel_buffer) = app_state
 322                            .db
 323                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 324                            .await
 325                            .trace_err()
 326                        {
 327                            for connection_id in refreshed_channel_buffer.connection_ids {
 328                                peer.send(
 329                                    connection_id,
 330                                    proto::UpdateChannelBufferCollaborators {
 331                                        channel_id: channel_id.to_proto(),
 332                                        collaborators: refreshed_channel_buffer
 333                                            .collaborators
 334                                            .clone(),
 335                                    },
 336                                )
 337                                .trace_err();
 338                            }
 339                        }
 340                    }
 341
 342                    for room_id in room_ids {
 343                        let mut contacts_to_update = HashSet::default();
 344                        let mut canceled_calls_to_user_ids = Vec::new();
 345                        let mut live_kit_room = String::new();
 346                        let mut delete_live_kit_room = false;
 347
 348                        if let Some(mut refreshed_room) = app_state
 349                            .db
 350                            .clear_stale_room_participants(room_id, server_id)
 351                            .await
 352                            .trace_err()
 353                        {
 354                            tracing::info!(
 355                                room_id = room_id.0,
 356                                new_participant_count = refreshed_room.room.participants.len(),
 357                                "refreshed room"
 358                            );
 359                            room_updated(&refreshed_room.room, &peer);
 360                            if let Some(channel_id) = refreshed_room.channel_id {
 361                                channel_updated(
 362                                    channel_id,
 363                                    &refreshed_room.room,
 364                                    &refreshed_room.channel_members,
 365                                    &peer,
 366                                    &*pool.lock(),
 367                                );
 368                            }
 369                            contacts_to_update
 370                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 371                            contacts_to_update
 372                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 373                            canceled_calls_to_user_ids =
 374                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 375                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 376                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 377                        }
 378
 379                        {
 380                            let pool = pool.lock();
 381                            for canceled_user_id in canceled_calls_to_user_ids {
 382                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 383                                    peer.send(
 384                                        connection_id,
 385                                        proto::CallCanceled {
 386                                            room_id: room_id.to_proto(),
 387                                        },
 388                                    )
 389                                    .trace_err();
 390                                }
 391                            }
 392                        }
 393
 394                        for user_id in contacts_to_update {
 395                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 396                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 397                            if let Some((busy, contacts)) = busy.zip(contacts) {
 398                                let pool = pool.lock();
 399                                let updated_contact = contact_for_user(user_id, busy, &pool);
 400                                for contact in contacts {
 401                                    if let db::Contact::Accepted {
 402                                        user_id: contact_user_id,
 403                                        ..
 404                                    } = contact
 405                                    {
 406                                        for contact_conn_id in
 407                                            pool.user_connection_ids(contact_user_id)
 408                                        {
 409                                            peer.send(
 410                                                contact_conn_id,
 411                                                proto::UpdateContacts {
 412                                                    contacts: vec![updated_contact.clone()],
 413                                                    remove_contacts: Default::default(),
 414                                                    incoming_requests: Default::default(),
 415                                                    remove_incoming_requests: Default::default(),
 416                                                    outgoing_requests: Default::default(),
 417                                                    remove_outgoing_requests: Default::default(),
 418                                                },
 419                                            )
 420                                            .trace_err();
 421                                        }
 422                                    }
 423                                }
 424                            }
 425                        }
 426
 427                        if let Some(live_kit) = live_kit_client.as_ref() {
 428                            if delete_live_kit_room {
 429                                live_kit.delete_room(live_kit_room).await.trace_err();
 430                            }
 431                        }
 432                    }
 433                }
 434
 435                app_state
 436                    .db
 437                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 438                    .await
 439                    .trace_err();
 440            }
 441            .instrument(span),
 442        );
 443        Ok(())
 444    }
 445
 446    pub fn teardown(&self) {
 447        self.peer.teardown();
 448        self.connection_pool.lock().reset();
 449        let _ = self.teardown.send(());
 450    }
 451
 452    #[cfg(test)]
 453    pub fn reset(&self, id: ServerId) {
 454        self.teardown();
 455        *self.id.lock() = id;
 456        self.peer.reset(id.0 as u32);
 457    }
 458
 459    #[cfg(test)]
 460    pub fn id(&self) -> ServerId {
 461        *self.id.lock()
 462    }
 463
 464    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 465    where
 466        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 467        Fut: 'static + Send + Future<Output = Result<()>>,
 468        M: EnvelopedMessage,
 469    {
 470        let prev_handler = self.handlers.insert(
 471            TypeId::of::<M>(),
 472            Box::new(move |envelope, session| {
 473                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 474                let span = info_span!(
 475                    "handle message",
 476                    payload_type = envelope.payload_type_name()
 477                );
 478                span.in_scope(|| {
 479                    tracing::info!(
 480                        payload_type = envelope.payload_type_name(),
 481                        "message received"
 482                    );
 483                });
 484                let start_time = Instant::now();
 485                let future = (handler)(*envelope, session);
 486                async move {
 487                    let result = future.await;
 488                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 489                    match result {
 490                        Err(error) => {
 491                            tracing::error!(%error, ?duration_ms, "error handling message")
 492                        }
 493                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 494                    }
 495                }
 496                .instrument(span)
 497                .boxed()
 498            }),
 499        );
 500        if prev_handler.is_some() {
 501            panic!("registered a handler for the same message twice");
 502        }
 503        self
 504    }
 505
 506    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 507    where
 508        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 509        Fut: 'static + Send + Future<Output = Result<()>>,
 510        M: EnvelopedMessage,
 511    {
 512        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 513        self
 514    }
 515
 516    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 517    where
 518        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 519        Fut: Send + Future<Output = Result<()>>,
 520        M: RequestMessage,
 521    {
 522        let handler = Arc::new(handler);
 523        self.add_handler(move |envelope, session| {
 524            let receipt = envelope.receipt();
 525            let handler = handler.clone();
 526            async move {
 527                let peer = session.peer.clone();
 528                let responded = Arc::new(AtomicBool::default());
 529                let response = Response {
 530                    peer: peer.clone(),
 531                    responded: responded.clone(),
 532                    receipt,
 533                };
 534                match (handler)(envelope.payload, response, session).await {
 535                    Ok(()) => {
 536                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 537                            Ok(())
 538                        } else {
 539                            Err(anyhow!("handler did not send a response"))?
 540                        }
 541                    }
 542                    Err(error) => {
 543                        peer.respond_with_error(
 544                            receipt,
 545                            proto::Error {
 546                                message: error.to_string(),
 547                            },
 548                        )?;
 549                        Err(error)
 550                    }
 551                }
 552            }
 553        })
 554    }
 555
 556    pub fn handle_connection(
 557        self: &Arc<Self>,
 558        connection: Connection,
 559        address: String,
 560        user: User,
 561        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 562        executor: Executor,
 563    ) -> impl Future<Output = Result<()>> {
 564        let this = self.clone();
 565        let user_id = user.id;
 566        let login = user.github_login;
 567        let span = info_span!("handle connection", %user_id, %login, %address);
 568        let mut teardown = self.teardown.subscribe();
 569        async move {
 570            let (connection_id, handle_io, mut incoming_rx) = this
 571                .peer
 572                .add_connection(connection, {
 573                    let executor = executor.clone();
 574                    move |duration| executor.sleep(duration)
 575                });
 576
 577            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 578            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 579            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 580
 581            if let Some(send_connection_id) = send_connection_id.take() {
 582                let _ = send_connection_id.send(connection_id);
 583            }
 584
 585            if !user.connected_once {
 586                this.peer.send(connection_id, proto::ShowContacts {})?;
 587                this.app_state.db.set_user_connected_once(user_id, true).await?;
 588            }
 589
 590            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 591                this.app_state.db.get_contacts(user_id),
 592                this.app_state.db.get_channels_for_user(user_id),
 593                this.app_state.db.get_channel_invites_for_user(user_id),
 594            ).await?;
 595
 596            {
 597                let mut pool = this.connection_pool.lock();
 598                pool.add_connection(connection_id, user_id, user.admin);
 599                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 600                this.peer.send(connection_id, build_channels_update(
 601                    channels_for_user,
 602                    channel_invites
 603                ))?;
 604            }
 605
 606            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 607                this.peer.send(connection_id, incoming_call)?;
 608            }
 609
 610            let session = Session {
 611                user_id,
 612                connection_id,
 613                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 614                peer: this.peer.clone(),
 615                connection_pool: this.connection_pool.clone(),
 616                live_kit_client: this.app_state.live_kit_client.clone(),
 617                executor: executor.clone(),
 618            };
 619            update_user_contacts(user_id, &session).await?;
 620
 621            let handle_io = handle_io.fuse();
 622            futures::pin_mut!(handle_io);
 623
 624            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 625            // This prevents deadlocks when e.g., client A performs a request to client B and
 626            // client B performs a request to client A. If both clients stop processing further
 627            // messages until their respective request completes, they won't have a chance to
 628            // respond to the other client's request and cause a deadlock.
 629            //
 630            // This arrangement ensures we will attempt to process earlier messages first, but fall
 631            // back to processing messages arrived later in the spirit of making progress.
 632            let mut foreground_message_handlers = FuturesUnordered::new();
 633            let concurrent_handlers = Arc::new(Semaphore::new(256));
 634            loop {
 635                let next_message = async {
 636                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 637                    let message = incoming_rx.next().await;
 638                    (permit, message)
 639                }.fuse();
 640                futures::pin_mut!(next_message);
 641                futures::select_biased! {
 642                    _ = teardown.changed().fuse() => return Ok(()),
 643                    result = handle_io => {
 644                        if let Err(error) = result {
 645                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 646                        }
 647                        break;
 648                    }
 649                    _ = foreground_message_handlers.next() => {}
 650                    next_message = next_message => {
 651                        let (permit, message) = next_message;
 652                        if let Some(message) = message {
 653                            let type_name = message.payload_type_name();
 654                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 655                            let span_enter = span.enter();
 656                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 657                                let is_background = message.is_background();
 658                                let handle_message = (handler)(message, session.clone());
 659                                drop(span_enter);
 660
 661                                let handle_message = async move {
 662                                    handle_message.await;
 663                                    drop(permit);
 664                                }.instrument(span);
 665                                if is_background {
 666                                    executor.spawn_detached(handle_message);
 667                                } else {
 668                                    foreground_message_handlers.push(handle_message);
 669                                }
 670                            } else {
 671                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 672                            }
 673                        } else {
 674                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 675                            break;
 676                        }
 677                    }
 678                }
 679            }
 680
 681            drop(foreground_message_handlers);
 682            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 683            if let Err(error) = connection_lost(session, teardown, executor).await {
 684                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 685            }
 686
 687            Ok(())
 688        }.instrument(span)
 689    }
 690
 691    pub async fn invite_code_redeemed(
 692        self: &Arc<Self>,
 693        inviter_id: UserId,
 694        invitee_id: UserId,
 695    ) -> Result<()> {
 696        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 697            if let Some(code) = &user.invite_code {
 698                let pool = self.connection_pool.lock();
 699                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 700                for connection_id in pool.user_connection_ids(inviter_id) {
 701                    self.peer.send(
 702                        connection_id,
 703                        proto::UpdateContacts {
 704                            contacts: vec![invitee_contact.clone()],
 705                            ..Default::default()
 706                        },
 707                    )?;
 708                    self.peer.send(
 709                        connection_id,
 710                        proto::UpdateInviteInfo {
 711                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 712                            count: user.invite_count as u32,
 713                        },
 714                    )?;
 715                }
 716            }
 717        }
 718        Ok(())
 719    }
 720
 721    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 722        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 723            if let Some(invite_code) = &user.invite_code {
 724                let pool = self.connection_pool.lock();
 725                for connection_id in pool.user_connection_ids(user_id) {
 726                    self.peer.send(
 727                        connection_id,
 728                        proto::UpdateInviteInfo {
 729                            url: format!(
 730                                "{}{}",
 731                                self.app_state.config.invite_link_prefix, invite_code
 732                            ),
 733                            count: user.invite_count as u32,
 734                        },
 735                    )?;
 736                }
 737            }
 738        }
 739        Ok(())
 740    }
 741
 742    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 743        ServerSnapshot {
 744            connection_pool: ConnectionPoolGuard {
 745                guard: self.connection_pool.lock(),
 746                _not_send: PhantomData,
 747            },
 748            peer: &self.peer,
 749        }
 750    }
 751}
 752
 753impl<'a> Deref for ConnectionPoolGuard<'a> {
 754    type Target = ConnectionPool;
 755
 756    fn deref(&self) -> &Self::Target {
 757        &*self.guard
 758    }
 759}
 760
 761impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 762    fn deref_mut(&mut self) -> &mut Self::Target {
 763        &mut *self.guard
 764    }
 765}
 766
 767impl<'a> Drop for ConnectionPoolGuard<'a> {
 768    fn drop(&mut self) {
 769        #[cfg(test)]
 770        self.check_invariants();
 771    }
 772}
 773
 774fn broadcast<F>(
 775    sender_id: Option<ConnectionId>,
 776    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 777    mut f: F,
 778) where
 779    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 780{
 781    for receiver_id in receiver_ids {
 782        if Some(receiver_id) != sender_id {
 783            if let Err(error) = f(receiver_id) {
 784                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 785            }
 786        }
 787    }
 788}
 789
 790lazy_static! {
 791    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 792}
 793
 794pub struct ProtocolVersion(u32);
 795
 796impl Header for ProtocolVersion {
 797    fn name() -> &'static HeaderName {
 798        &ZED_PROTOCOL_VERSION
 799    }
 800
 801    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 802    where
 803        Self: Sized,
 804        I: Iterator<Item = &'i axum::http::HeaderValue>,
 805    {
 806        let version = values
 807            .next()
 808            .ok_or_else(axum::headers::Error::invalid)?
 809            .to_str()
 810            .map_err(|_| axum::headers::Error::invalid())?
 811            .parse()
 812            .map_err(|_| axum::headers::Error::invalid())?;
 813        Ok(Self(version))
 814    }
 815
 816    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 817        values.extend([self.0.to_string().parse().unwrap()]);
 818    }
 819}
 820
 821pub fn routes(server: Arc<Server>) -> Router<Body> {
 822    Router::new()
 823        .route("/rpc", get(handle_websocket_request))
 824        .layer(
 825            ServiceBuilder::new()
 826                .layer(Extension(server.app_state.clone()))
 827                .layer(middleware::from_fn(auth::validate_header)),
 828        )
 829        .route("/metrics", get(handle_metrics))
 830        .layer(Extension(server))
 831}
 832
 833pub async fn handle_websocket_request(
 834    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 835    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 836    Extension(server): Extension<Arc<Server>>,
 837    Extension(user): Extension<User>,
 838    ws: WebSocketUpgrade,
 839) -> axum::response::Response {
 840    if protocol_version != rpc::PROTOCOL_VERSION {
 841        return (
 842            StatusCode::UPGRADE_REQUIRED,
 843            "client must be upgraded".to_string(),
 844        )
 845            .into_response();
 846    }
 847    let socket_address = socket_address.to_string();
 848    ws.on_upgrade(move |socket| {
 849        use util::ResultExt;
 850        let socket = socket
 851            .map_ok(to_tungstenite_message)
 852            .err_into()
 853            .with(|message| async move { Ok(to_axum_message(message)) });
 854        let connection = Connection::new(Box::pin(socket));
 855        async move {
 856            server
 857                .handle_connection(connection, socket_address, user, None, Executor::Production)
 858                .await
 859                .log_err();
 860        }
 861    })
 862}
 863
 864pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 865    let connections = server
 866        .connection_pool
 867        .lock()
 868        .connections()
 869        .filter(|connection| !connection.admin)
 870        .count();
 871
 872    METRIC_CONNECTIONS.set(connections as _);
 873
 874    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 875    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 876
 877    let encoder = prometheus::TextEncoder::new();
 878    let metric_families = prometheus::gather();
 879    let encoded_metrics = encoder
 880        .encode_to_string(&metric_families)
 881        .map_err(|err| anyhow!("{}", err))?;
 882    Ok(encoded_metrics)
 883}
 884
 885#[instrument(err, skip(executor))]
 886async fn connection_lost(
 887    session: Session,
 888    mut teardown: watch::Receiver<()>,
 889    executor: Executor,
 890) -> Result<()> {
 891    session.peer.disconnect(session.connection_id);
 892    session
 893        .connection_pool()
 894        .await
 895        .remove_connection(session.connection_id)?;
 896
 897    session
 898        .db()
 899        .await
 900        .connection_lost(session.connection_id)
 901        .await
 902        .trace_err();
 903
 904    futures::select_biased! {
 905        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 906            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 907            leave_room_for_session(&session).await.trace_err();
 908            leave_channel_buffers_for_session(&session)
 909                .await
 910                .trace_err();
 911
 912            if !session
 913                .connection_pool()
 914                .await
 915                .is_user_online(session.user_id)
 916            {
 917                let db = session.db().await;
 918                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 919                    room_updated(&room, &session.peer);
 920                }
 921            }
 922
 923            update_user_contacts(session.user_id, &session).await?;
 924        }
 925        _ = teardown.changed().fuse() => {}
 926    }
 927
 928    Ok(())
 929}
 930
 931async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 932    response.send(proto::Ack {})?;
 933    Ok(())
 934}
 935
 936async fn create_room(
 937    _request: proto::CreateRoom,
 938    response: Response<proto::CreateRoom>,
 939    session: Session,
 940) -> Result<()> {
 941    let live_kit_room = nanoid::nanoid!(30);
 942
 943    let live_kit_connection_info = {
 944        let live_kit_room = live_kit_room.clone();
 945        let live_kit = session.live_kit_client.as_ref();
 946
 947        util::async_iife!({
 948            let live_kit = live_kit?;
 949
 950            let token = live_kit
 951                .room_token(&live_kit_room, &session.user_id.to_string())
 952                .trace_err()?;
 953
 954            Some(proto::LiveKitConnectionInfo {
 955                server_url: live_kit.url().into(),
 956                token,
 957                can_publish: true,
 958            })
 959        })
 960    }
 961    .await;
 962
 963    let room = session
 964        .db()
 965        .await
 966        .create_room(
 967            session.user_id,
 968            session.connection_id,
 969            &live_kit_room,
 970            RELEASE_CHANNEL_NAME.as_str(),
 971        )
 972        .await?;
 973
 974    response.send(proto::CreateRoomResponse {
 975        room: Some(room.clone()),
 976        live_kit_connection_info,
 977    })?;
 978
 979    update_user_contacts(session.user_id, &session).await?;
 980    Ok(())
 981}
 982
 983async fn join_room(
 984    request: proto::JoinRoom,
 985    response: Response<proto::JoinRoom>,
 986    session: Session,
 987) -> Result<()> {
 988    let room_id = RoomId::from_proto(request.id);
 989
 990    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 991
 992    if let Some(channel_id) = channel_id {
 993        return join_channel_internal(channel_id, Box::new(response), session).await;
 994    }
 995
 996    let joined_room = {
 997        let room = session
 998            .db()
 999            .await
1000            .join_room(
1001                room_id,
1002                session.user_id,
1003                session.connection_id,
1004                RELEASE_CHANNEL_NAME.as_str(),
1005            )
1006            .await?;
1007        room_updated(&room.room, &session.peer);
1008        room.into_inner()
1009    };
1010
1011    for connection_id in session
1012        .connection_pool()
1013        .await
1014        .user_connection_ids(session.user_id)
1015    {
1016        session
1017            .peer
1018            .send(
1019                connection_id,
1020                proto::CallCanceled {
1021                    room_id: room_id.to_proto(),
1022                },
1023            )
1024            .trace_err();
1025    }
1026
1027    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1028        if let Some(token) = live_kit
1029            .room_token(
1030                &joined_room.room.live_kit_room,
1031                &session.user_id.to_string(),
1032            )
1033            .trace_err()
1034        {
1035            Some(proto::LiveKitConnectionInfo {
1036                server_url: live_kit.url().into(),
1037                token,
1038                can_publish: true,
1039            })
1040        } else {
1041            None
1042        }
1043    } else {
1044        None
1045    };
1046
1047    response.send(proto::JoinRoomResponse {
1048        room: Some(joined_room.room),
1049        channel_id: None,
1050        live_kit_connection_info,
1051    })?;
1052
1053    update_user_contacts(session.user_id, &session).await?;
1054    Ok(())
1055}
1056
1057async fn rejoin_room(
1058    request: proto::RejoinRoom,
1059    response: Response<proto::RejoinRoom>,
1060    session: Session,
1061) -> Result<()> {
1062    let room;
1063    let channel_id;
1064    let channel_members;
1065    {
1066        let mut rejoined_room = session
1067            .db()
1068            .await
1069            .rejoin_room(request, session.user_id, session.connection_id)
1070            .await?;
1071
1072        response.send(proto::RejoinRoomResponse {
1073            room: Some(rejoined_room.room.clone()),
1074            reshared_projects: rejoined_room
1075                .reshared_projects
1076                .iter()
1077                .map(|project| proto::ResharedProject {
1078                    id: project.id.to_proto(),
1079                    collaborators: project
1080                        .collaborators
1081                        .iter()
1082                        .map(|collaborator| collaborator.to_proto())
1083                        .collect(),
1084                })
1085                .collect(),
1086            rejoined_projects: rejoined_room
1087                .rejoined_projects
1088                .iter()
1089                .map(|rejoined_project| proto::RejoinedProject {
1090                    id: rejoined_project.id.to_proto(),
1091                    worktrees: rejoined_project
1092                        .worktrees
1093                        .iter()
1094                        .map(|worktree| proto::WorktreeMetadata {
1095                            id: worktree.id,
1096                            root_name: worktree.root_name.clone(),
1097                            visible: worktree.visible,
1098                            abs_path: worktree.abs_path.clone(),
1099                        })
1100                        .collect(),
1101                    collaborators: rejoined_project
1102                        .collaborators
1103                        .iter()
1104                        .map(|collaborator| collaborator.to_proto())
1105                        .collect(),
1106                    language_servers: rejoined_project.language_servers.clone(),
1107                })
1108                .collect(),
1109        })?;
1110        room_updated(&rejoined_room.room, &session.peer);
1111
1112        for project in &rejoined_room.reshared_projects {
1113            for collaborator in &project.collaborators {
1114                session
1115                    .peer
1116                    .send(
1117                        collaborator.connection_id,
1118                        proto::UpdateProjectCollaborator {
1119                            project_id: project.id.to_proto(),
1120                            old_peer_id: Some(project.old_connection_id.into()),
1121                            new_peer_id: Some(session.connection_id.into()),
1122                        },
1123                    )
1124                    .trace_err();
1125            }
1126
1127            broadcast(
1128                Some(session.connection_id),
1129                project
1130                    .collaborators
1131                    .iter()
1132                    .map(|collaborator| collaborator.connection_id),
1133                |connection_id| {
1134                    session.peer.forward_send(
1135                        session.connection_id,
1136                        connection_id,
1137                        proto::UpdateProject {
1138                            project_id: project.id.to_proto(),
1139                            worktrees: project.worktrees.clone(),
1140                        },
1141                    )
1142                },
1143            );
1144        }
1145
1146        for project in &rejoined_room.rejoined_projects {
1147            for collaborator in &project.collaborators {
1148                session
1149                    .peer
1150                    .send(
1151                        collaborator.connection_id,
1152                        proto::UpdateProjectCollaborator {
1153                            project_id: project.id.to_proto(),
1154                            old_peer_id: Some(project.old_connection_id.into()),
1155                            new_peer_id: Some(session.connection_id.into()),
1156                        },
1157                    )
1158                    .trace_err();
1159            }
1160        }
1161
1162        for project in &mut rejoined_room.rejoined_projects {
1163            for worktree in mem::take(&mut project.worktrees) {
1164                #[cfg(any(test, feature = "test-support"))]
1165                const MAX_CHUNK_SIZE: usize = 2;
1166                #[cfg(not(any(test, feature = "test-support")))]
1167                const MAX_CHUNK_SIZE: usize = 256;
1168
1169                // Stream this worktree's entries.
1170                let message = proto::UpdateWorktree {
1171                    project_id: project.id.to_proto(),
1172                    worktree_id: worktree.id,
1173                    abs_path: worktree.abs_path.clone(),
1174                    root_name: worktree.root_name,
1175                    updated_entries: worktree.updated_entries,
1176                    removed_entries: worktree.removed_entries,
1177                    scan_id: worktree.scan_id,
1178                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1179                    updated_repositories: worktree.updated_repositories,
1180                    removed_repositories: worktree.removed_repositories,
1181                };
1182                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1183                    session.peer.send(session.connection_id, update.clone())?;
1184                }
1185
1186                // Stream this worktree's diagnostics.
1187                for summary in worktree.diagnostic_summaries {
1188                    session.peer.send(
1189                        session.connection_id,
1190                        proto::UpdateDiagnosticSummary {
1191                            project_id: project.id.to_proto(),
1192                            worktree_id: worktree.id,
1193                            summary: Some(summary),
1194                        },
1195                    )?;
1196                }
1197
1198                for settings_file in worktree.settings_files {
1199                    session.peer.send(
1200                        session.connection_id,
1201                        proto::UpdateWorktreeSettings {
1202                            project_id: project.id.to_proto(),
1203                            worktree_id: worktree.id,
1204                            path: settings_file.path,
1205                            content: Some(settings_file.content),
1206                        },
1207                    )?;
1208                }
1209            }
1210
1211            for language_server in &project.language_servers {
1212                session.peer.send(
1213                    session.connection_id,
1214                    proto::UpdateLanguageServer {
1215                        project_id: project.id.to_proto(),
1216                        language_server_id: language_server.id,
1217                        variant: Some(
1218                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1219                                proto::LspDiskBasedDiagnosticsUpdated {},
1220                            ),
1221                        ),
1222                    },
1223                )?;
1224            }
1225        }
1226
1227        let rejoined_room = rejoined_room.into_inner();
1228
1229        room = rejoined_room.room;
1230        channel_id = rejoined_room.channel_id;
1231        channel_members = rejoined_room.channel_members;
1232    }
1233
1234    if let Some(channel_id) = channel_id {
1235        channel_updated(
1236            channel_id,
1237            &room,
1238            &channel_members,
1239            &session.peer,
1240            &*session.connection_pool().await,
1241        );
1242    }
1243
1244    update_user_contacts(session.user_id, &session).await?;
1245    Ok(())
1246}
1247
1248async fn leave_room(
1249    _: proto::LeaveRoom,
1250    response: Response<proto::LeaveRoom>,
1251    session: Session,
1252) -> Result<()> {
1253    leave_room_for_session(&session).await?;
1254    response.send(proto::Ack {})?;
1255    Ok(())
1256}
1257
1258async fn call(
1259    request: proto::Call,
1260    response: Response<proto::Call>,
1261    session: Session,
1262) -> Result<()> {
1263    let room_id = RoomId::from_proto(request.room_id);
1264    let calling_user_id = session.user_id;
1265    let calling_connection_id = session.connection_id;
1266    let called_user_id = UserId::from_proto(request.called_user_id);
1267    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1268    if !session
1269        .db()
1270        .await
1271        .has_contact(calling_user_id, called_user_id)
1272        .await?
1273    {
1274        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1275    }
1276
1277    let incoming_call = {
1278        let (room, incoming_call) = &mut *session
1279            .db()
1280            .await
1281            .call(
1282                room_id,
1283                calling_user_id,
1284                calling_connection_id,
1285                called_user_id,
1286                initial_project_id,
1287            )
1288            .await?;
1289        room_updated(&room, &session.peer);
1290        mem::take(incoming_call)
1291    };
1292    update_user_contacts(called_user_id, &session).await?;
1293
1294    let mut calls = session
1295        .connection_pool()
1296        .await
1297        .user_connection_ids(called_user_id)
1298        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1299        .collect::<FuturesUnordered<_>>();
1300
1301    while let Some(call_response) = calls.next().await {
1302        match call_response.as_ref() {
1303            Ok(_) => {
1304                response.send(proto::Ack {})?;
1305                return Ok(());
1306            }
1307            Err(_) => {
1308                call_response.trace_err();
1309            }
1310        }
1311    }
1312
1313    {
1314        let room = session
1315            .db()
1316            .await
1317            .call_failed(room_id, called_user_id)
1318            .await?;
1319        room_updated(&room, &session.peer);
1320    }
1321    update_user_contacts(called_user_id, &session).await?;
1322
1323    Err(anyhow!("failed to ring user"))?
1324}
1325
1326async fn cancel_call(
1327    request: proto::CancelCall,
1328    response: Response<proto::CancelCall>,
1329    session: Session,
1330) -> Result<()> {
1331    let called_user_id = UserId::from_proto(request.called_user_id);
1332    let room_id = RoomId::from_proto(request.room_id);
1333    {
1334        let room = session
1335            .db()
1336            .await
1337            .cancel_call(room_id, session.connection_id, called_user_id)
1338            .await?;
1339        room_updated(&room, &session.peer);
1340    }
1341
1342    for connection_id in session
1343        .connection_pool()
1344        .await
1345        .user_connection_ids(called_user_id)
1346    {
1347        session
1348            .peer
1349            .send(
1350                connection_id,
1351                proto::CallCanceled {
1352                    room_id: room_id.to_proto(),
1353                },
1354            )
1355            .trace_err();
1356    }
1357    response.send(proto::Ack {})?;
1358
1359    update_user_contacts(called_user_id, &session).await?;
1360    Ok(())
1361}
1362
1363async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1364    let room_id = RoomId::from_proto(message.room_id);
1365    {
1366        let room = session
1367            .db()
1368            .await
1369            .decline_call(Some(room_id), session.user_id)
1370            .await?
1371            .ok_or_else(|| anyhow!("failed to decline call"))?;
1372        room_updated(&room, &session.peer);
1373    }
1374
1375    for connection_id in session
1376        .connection_pool()
1377        .await
1378        .user_connection_ids(session.user_id)
1379    {
1380        session
1381            .peer
1382            .send(
1383                connection_id,
1384                proto::CallCanceled {
1385                    room_id: room_id.to_proto(),
1386                },
1387            )
1388            .trace_err();
1389    }
1390    update_user_contacts(session.user_id, &session).await?;
1391    Ok(())
1392}
1393
1394async fn update_participant_location(
1395    request: proto::UpdateParticipantLocation,
1396    response: Response<proto::UpdateParticipantLocation>,
1397    session: Session,
1398) -> Result<()> {
1399    let room_id = RoomId::from_proto(request.room_id);
1400    let location = request
1401        .location
1402        .ok_or_else(|| anyhow!("invalid location"))?;
1403
1404    let db = session.db().await;
1405    let room = db
1406        .update_room_participant_location(room_id, session.connection_id, location)
1407        .await?;
1408
1409    room_updated(&room, &session.peer);
1410    response.send(proto::Ack {})?;
1411    Ok(())
1412}
1413
1414async fn share_project(
1415    request: proto::ShareProject,
1416    response: Response<proto::ShareProject>,
1417    session: Session,
1418) -> Result<()> {
1419    let (project_id, room) = &*session
1420        .db()
1421        .await
1422        .share_project(
1423            RoomId::from_proto(request.room_id),
1424            session.connection_id,
1425            &request.worktrees,
1426        )
1427        .await?;
1428    response.send(proto::ShareProjectResponse {
1429        project_id: project_id.to_proto(),
1430    })?;
1431    room_updated(&room, &session.peer);
1432
1433    Ok(())
1434}
1435
1436async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1437    let project_id = ProjectId::from_proto(message.project_id);
1438
1439    let (room, guest_connection_ids) = &*session
1440        .db()
1441        .await
1442        .unshare_project(project_id, session.connection_id)
1443        .await?;
1444
1445    broadcast(
1446        Some(session.connection_id),
1447        guest_connection_ids.iter().copied(),
1448        |conn_id| session.peer.send(conn_id, message.clone()),
1449    );
1450    room_updated(&room, &session.peer);
1451
1452    Ok(())
1453}
1454
1455async fn join_project(
1456    request: proto::JoinProject,
1457    response: Response<proto::JoinProject>,
1458    session: Session,
1459) -> Result<()> {
1460    let project_id = ProjectId::from_proto(request.project_id);
1461    let guest_user_id = session.user_id;
1462
1463    tracing::info!(%project_id, "join project");
1464
1465    let (project, replica_id) = &mut *session
1466        .db()
1467        .await
1468        .join_project(project_id, session.connection_id)
1469        .await?;
1470
1471    let collaborators = project
1472        .collaborators
1473        .iter()
1474        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1475        .map(|collaborator| collaborator.to_proto())
1476        .collect::<Vec<_>>();
1477
1478    let worktrees = project
1479        .worktrees
1480        .iter()
1481        .map(|(id, worktree)| proto::WorktreeMetadata {
1482            id: *id,
1483            root_name: worktree.root_name.clone(),
1484            visible: worktree.visible,
1485            abs_path: worktree.abs_path.clone(),
1486        })
1487        .collect::<Vec<_>>();
1488
1489    for collaborator in &collaborators {
1490        session
1491            .peer
1492            .send(
1493                collaborator.peer_id.unwrap().into(),
1494                proto::AddProjectCollaborator {
1495                    project_id: project_id.to_proto(),
1496                    collaborator: Some(proto::Collaborator {
1497                        peer_id: Some(session.connection_id.into()),
1498                        replica_id: replica_id.0 as u32,
1499                        user_id: guest_user_id.to_proto(),
1500                    }),
1501                },
1502            )
1503            .trace_err();
1504    }
1505
1506    // First, we send the metadata associated with each worktree.
1507    response.send(proto::JoinProjectResponse {
1508        worktrees: worktrees.clone(),
1509        replica_id: replica_id.0 as u32,
1510        collaborators: collaborators.clone(),
1511        language_servers: project.language_servers.clone(),
1512    })?;
1513
1514    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1515        #[cfg(any(test, feature = "test-support"))]
1516        const MAX_CHUNK_SIZE: usize = 2;
1517        #[cfg(not(any(test, feature = "test-support")))]
1518        const MAX_CHUNK_SIZE: usize = 256;
1519
1520        // Stream this worktree's entries.
1521        let message = proto::UpdateWorktree {
1522            project_id: project_id.to_proto(),
1523            worktree_id,
1524            abs_path: worktree.abs_path.clone(),
1525            root_name: worktree.root_name,
1526            updated_entries: worktree.entries,
1527            removed_entries: Default::default(),
1528            scan_id: worktree.scan_id,
1529            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1530            updated_repositories: worktree.repository_entries.into_values().collect(),
1531            removed_repositories: Default::default(),
1532        };
1533        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1534            session.peer.send(session.connection_id, update.clone())?;
1535        }
1536
1537        // Stream this worktree's diagnostics.
1538        for summary in worktree.diagnostic_summaries {
1539            session.peer.send(
1540                session.connection_id,
1541                proto::UpdateDiagnosticSummary {
1542                    project_id: project_id.to_proto(),
1543                    worktree_id: worktree.id,
1544                    summary: Some(summary),
1545                },
1546            )?;
1547        }
1548
1549        for settings_file in worktree.settings_files {
1550            session.peer.send(
1551                session.connection_id,
1552                proto::UpdateWorktreeSettings {
1553                    project_id: project_id.to_proto(),
1554                    worktree_id: worktree.id,
1555                    path: settings_file.path,
1556                    content: Some(settings_file.content),
1557                },
1558            )?;
1559        }
1560    }
1561
1562    for language_server in &project.language_servers {
1563        session.peer.send(
1564            session.connection_id,
1565            proto::UpdateLanguageServer {
1566                project_id: project_id.to_proto(),
1567                language_server_id: language_server.id,
1568                variant: Some(
1569                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1570                        proto::LspDiskBasedDiagnosticsUpdated {},
1571                    ),
1572                ),
1573            },
1574        )?;
1575    }
1576
1577    Ok(())
1578}
1579
1580async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1581    let sender_id = session.connection_id;
1582    let project_id = ProjectId::from_proto(request.project_id);
1583
1584    let (room, project) = &*session
1585        .db()
1586        .await
1587        .leave_project(project_id, sender_id)
1588        .await?;
1589    tracing::info!(
1590        %project_id,
1591        host_user_id = %project.host_user_id,
1592        host_connection_id = %project.host_connection_id,
1593        "leave project"
1594    );
1595
1596    project_left(&project, &session);
1597    room_updated(&room, &session.peer);
1598
1599    Ok(())
1600}
1601
1602async fn update_project(
1603    request: proto::UpdateProject,
1604    response: Response<proto::UpdateProject>,
1605    session: Session,
1606) -> Result<()> {
1607    let project_id = ProjectId::from_proto(request.project_id);
1608    let (room, guest_connection_ids) = &*session
1609        .db()
1610        .await
1611        .update_project(project_id, session.connection_id, &request.worktrees)
1612        .await?;
1613    broadcast(
1614        Some(session.connection_id),
1615        guest_connection_ids.iter().copied(),
1616        |connection_id| {
1617            session
1618                .peer
1619                .forward_send(session.connection_id, connection_id, request.clone())
1620        },
1621    );
1622    room_updated(&room, &session.peer);
1623    response.send(proto::Ack {})?;
1624
1625    Ok(())
1626}
1627
1628async fn update_worktree(
1629    request: proto::UpdateWorktree,
1630    response: Response<proto::UpdateWorktree>,
1631    session: Session,
1632) -> Result<()> {
1633    let guest_connection_ids = session
1634        .db()
1635        .await
1636        .update_worktree(&request, session.connection_id)
1637        .await?;
1638
1639    broadcast(
1640        Some(session.connection_id),
1641        guest_connection_ids.iter().copied(),
1642        |connection_id| {
1643            session
1644                .peer
1645                .forward_send(session.connection_id, connection_id, request.clone())
1646        },
1647    );
1648    response.send(proto::Ack {})?;
1649    Ok(())
1650}
1651
1652async fn update_diagnostic_summary(
1653    message: proto::UpdateDiagnosticSummary,
1654    session: Session,
1655) -> Result<()> {
1656    let guest_connection_ids = session
1657        .db()
1658        .await
1659        .update_diagnostic_summary(&message, session.connection_id)
1660        .await?;
1661
1662    broadcast(
1663        Some(session.connection_id),
1664        guest_connection_ids.iter().copied(),
1665        |connection_id| {
1666            session
1667                .peer
1668                .forward_send(session.connection_id, connection_id, message.clone())
1669        },
1670    );
1671
1672    Ok(())
1673}
1674
1675async fn update_worktree_settings(
1676    message: proto::UpdateWorktreeSettings,
1677    session: Session,
1678) -> Result<()> {
1679    let guest_connection_ids = session
1680        .db()
1681        .await
1682        .update_worktree_settings(&message, session.connection_id)
1683        .await?;
1684
1685    broadcast(
1686        Some(session.connection_id),
1687        guest_connection_ids.iter().copied(),
1688        |connection_id| {
1689            session
1690                .peer
1691                .forward_send(session.connection_id, connection_id, message.clone())
1692        },
1693    );
1694
1695    Ok(())
1696}
1697
1698async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1699    broadcast_project_message(request.project_id, request, session).await
1700}
1701
1702async fn start_language_server(
1703    request: proto::StartLanguageServer,
1704    session: Session,
1705) -> Result<()> {
1706    let guest_connection_ids = session
1707        .db()
1708        .await
1709        .start_language_server(&request, session.connection_id)
1710        .await?;
1711
1712    broadcast(
1713        Some(session.connection_id),
1714        guest_connection_ids.iter().copied(),
1715        |connection_id| {
1716            session
1717                .peer
1718                .forward_send(session.connection_id, connection_id, request.clone())
1719        },
1720    );
1721    Ok(())
1722}
1723
1724async fn update_language_server(
1725    request: proto::UpdateLanguageServer,
1726    session: Session,
1727) -> Result<()> {
1728    session.executor.record_backtrace();
1729    let project_id = ProjectId::from_proto(request.project_id);
1730    let project_connection_ids = session
1731        .db()
1732        .await
1733        .project_connection_ids(project_id, session.connection_id)
1734        .await?;
1735    broadcast(
1736        Some(session.connection_id),
1737        project_connection_ids.iter().copied(),
1738        |connection_id| {
1739            session
1740                .peer
1741                .forward_send(session.connection_id, connection_id, request.clone())
1742        },
1743    );
1744    Ok(())
1745}
1746
1747async fn forward_project_request<T>(
1748    request: T,
1749    response: Response<T>,
1750    session: Session,
1751) -> Result<()>
1752where
1753    T: EntityMessage + RequestMessage,
1754{
1755    session.executor.record_backtrace();
1756    let project_id = ProjectId::from_proto(request.remote_entity_id());
1757    let host_connection_id = {
1758        let collaborators = session
1759            .db()
1760            .await
1761            .project_collaborators(project_id, session.connection_id)
1762            .await?;
1763        collaborators
1764            .iter()
1765            .find(|collaborator| collaborator.is_host)
1766            .ok_or_else(|| anyhow!("host not found"))?
1767            .connection_id
1768    };
1769
1770    let payload = session
1771        .peer
1772        .forward_request(session.connection_id, host_connection_id, request)
1773        .await?;
1774
1775    response.send(payload)?;
1776    Ok(())
1777}
1778
1779async fn create_buffer_for_peer(
1780    request: proto::CreateBufferForPeer,
1781    session: Session,
1782) -> Result<()> {
1783    session.executor.record_backtrace();
1784    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1785    session
1786        .peer
1787        .forward_send(session.connection_id, peer_id.into(), request)?;
1788    Ok(())
1789}
1790
1791async fn update_buffer(
1792    request: proto::UpdateBuffer,
1793    response: Response<proto::UpdateBuffer>,
1794    session: Session,
1795) -> Result<()> {
1796    session.executor.record_backtrace();
1797    let project_id = ProjectId::from_proto(request.project_id);
1798    let mut guest_connection_ids;
1799    let mut host_connection_id = None;
1800    {
1801        let collaborators = session
1802            .db()
1803            .await
1804            .project_collaborators(project_id, session.connection_id)
1805            .await?;
1806        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1807        for collaborator in collaborators.iter() {
1808            if collaborator.is_host {
1809                host_connection_id = Some(collaborator.connection_id);
1810            } else {
1811                guest_connection_ids.push(collaborator.connection_id);
1812            }
1813        }
1814    }
1815    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1816
1817    session.executor.record_backtrace();
1818    broadcast(
1819        Some(session.connection_id),
1820        guest_connection_ids,
1821        |connection_id| {
1822            session
1823                .peer
1824                .forward_send(session.connection_id, connection_id, request.clone())
1825        },
1826    );
1827    if host_connection_id != session.connection_id {
1828        session
1829            .peer
1830            .forward_request(session.connection_id, host_connection_id, request.clone())
1831            .await?;
1832    }
1833
1834    response.send(proto::Ack {})?;
1835    Ok(())
1836}
1837
1838async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1839    let project_id = ProjectId::from_proto(request.project_id);
1840    let project_connection_ids = session
1841        .db()
1842        .await
1843        .project_connection_ids(project_id, session.connection_id)
1844        .await?;
1845
1846    broadcast(
1847        Some(session.connection_id),
1848        project_connection_ids.iter().copied(),
1849        |connection_id| {
1850            session
1851                .peer
1852                .forward_send(session.connection_id, connection_id, request.clone())
1853        },
1854    );
1855    Ok(())
1856}
1857
1858async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1859    let project_id = ProjectId::from_proto(request.project_id);
1860    let project_connection_ids = session
1861        .db()
1862        .await
1863        .project_connection_ids(project_id, session.connection_id)
1864        .await?;
1865    broadcast(
1866        Some(session.connection_id),
1867        project_connection_ids.iter().copied(),
1868        |connection_id| {
1869            session
1870                .peer
1871                .forward_send(session.connection_id, connection_id, request.clone())
1872        },
1873    );
1874    Ok(())
1875}
1876
1877async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1878    broadcast_project_message(request.project_id, request, session).await
1879}
1880
1881async fn broadcast_project_message<T: EnvelopedMessage>(
1882    project_id: u64,
1883    request: T,
1884    session: Session,
1885) -> Result<()> {
1886    let project_id = ProjectId::from_proto(project_id);
1887    let project_connection_ids = session
1888        .db()
1889        .await
1890        .project_connection_ids(project_id, session.connection_id)
1891        .await?;
1892    broadcast(
1893        Some(session.connection_id),
1894        project_connection_ids.iter().copied(),
1895        |connection_id| {
1896            session
1897                .peer
1898                .forward_send(session.connection_id, connection_id, request.clone())
1899        },
1900    );
1901    Ok(())
1902}
1903
1904async fn follow(
1905    request: proto::Follow,
1906    response: Response<proto::Follow>,
1907    session: Session,
1908) -> Result<()> {
1909    let room_id = RoomId::from_proto(request.room_id);
1910    let project_id = request.project_id.map(ProjectId::from_proto);
1911    let leader_id = request
1912        .leader_id
1913        .ok_or_else(|| anyhow!("invalid leader id"))?
1914        .into();
1915    let follower_id = session.connection_id;
1916
1917    session
1918        .db()
1919        .await
1920        .check_room_participants(room_id, leader_id, session.connection_id)
1921        .await?;
1922
1923    let response_payload = session
1924        .peer
1925        .forward_request(session.connection_id, leader_id, request)
1926        .await?;
1927    response.send(response_payload)?;
1928
1929    if let Some(project_id) = project_id {
1930        let room = session
1931            .db()
1932            .await
1933            .follow(room_id, project_id, leader_id, follower_id)
1934            .await?;
1935        room_updated(&room, &session.peer);
1936    }
1937
1938    Ok(())
1939}
1940
1941async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1942    let room_id = RoomId::from_proto(request.room_id);
1943    let project_id = request.project_id.map(ProjectId::from_proto);
1944    let leader_id = request
1945        .leader_id
1946        .ok_or_else(|| anyhow!("invalid leader id"))?
1947        .into();
1948    let follower_id = session.connection_id;
1949
1950    session
1951        .db()
1952        .await
1953        .check_room_participants(room_id, leader_id, session.connection_id)
1954        .await?;
1955
1956    session
1957        .peer
1958        .forward_send(session.connection_id, leader_id, request)?;
1959
1960    if let Some(project_id) = project_id {
1961        let room = session
1962            .db()
1963            .await
1964            .unfollow(room_id, project_id, leader_id, follower_id)
1965            .await?;
1966        room_updated(&room, &session.peer);
1967    }
1968
1969    Ok(())
1970}
1971
1972async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1973    let room_id = RoomId::from_proto(request.room_id);
1974    let database = session.db.lock().await;
1975
1976    let connection_ids = if let Some(project_id) = request.project_id {
1977        let project_id = ProjectId::from_proto(project_id);
1978        database
1979            .project_connection_ids(project_id, session.connection_id)
1980            .await?
1981    } else {
1982        database
1983            .room_connection_ids(room_id, session.connection_id)
1984            .await?
1985    };
1986
1987    // For now, don't send view update messages back to that view's current leader.
1988    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1989        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1990        _ => None,
1991    });
1992
1993    for follower_peer_id in request.follower_ids.iter().copied() {
1994        let follower_connection_id = follower_peer_id.into();
1995        if Some(follower_peer_id) != connection_id_to_omit
1996            && connection_ids.contains(&follower_connection_id)
1997        {
1998            session.peer.forward_send(
1999                session.connection_id,
2000                follower_connection_id,
2001                request.clone(),
2002            )?;
2003        }
2004    }
2005    Ok(())
2006}
2007
2008async fn get_users(
2009    request: proto::GetUsers,
2010    response: Response<proto::GetUsers>,
2011    session: Session,
2012) -> Result<()> {
2013    let user_ids = request
2014        .user_ids
2015        .into_iter()
2016        .map(UserId::from_proto)
2017        .collect();
2018    let users = session
2019        .db()
2020        .await
2021        .get_users_by_ids(user_ids)
2022        .await?
2023        .into_iter()
2024        .map(|user| proto::User {
2025            id: user.id.to_proto(),
2026            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2027            github_login: user.github_login,
2028        })
2029        .collect();
2030    response.send(proto::UsersResponse { users })?;
2031    Ok(())
2032}
2033
2034async fn fuzzy_search_users(
2035    request: proto::FuzzySearchUsers,
2036    response: Response<proto::FuzzySearchUsers>,
2037    session: Session,
2038) -> Result<()> {
2039    let query = request.query;
2040    let users = match query.len() {
2041        0 => vec![],
2042        1 | 2 => session
2043            .db()
2044            .await
2045            .get_user_by_github_login(&query)
2046            .await?
2047            .into_iter()
2048            .collect(),
2049        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2050    };
2051    let users = users
2052        .into_iter()
2053        .filter(|user| user.id != session.user_id)
2054        .map(|user| proto::User {
2055            id: user.id.to_proto(),
2056            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2057            github_login: user.github_login,
2058        })
2059        .collect();
2060    response.send(proto::UsersResponse { users })?;
2061    Ok(())
2062}
2063
2064async fn request_contact(
2065    request: proto::RequestContact,
2066    response: Response<proto::RequestContact>,
2067    session: Session,
2068) -> Result<()> {
2069    let requester_id = session.user_id;
2070    let responder_id = UserId::from_proto(request.responder_id);
2071    if requester_id == responder_id {
2072        return Err(anyhow!("cannot add yourself as a contact"))?;
2073    }
2074
2075    let notifications = session
2076        .db()
2077        .await
2078        .send_contact_request(requester_id, responder_id)
2079        .await?;
2080
2081    // Update outgoing contact requests of requester
2082    let mut update = proto::UpdateContacts::default();
2083    update.outgoing_requests.push(responder_id.to_proto());
2084    for connection_id in session
2085        .connection_pool()
2086        .await
2087        .user_connection_ids(requester_id)
2088    {
2089        session.peer.send(connection_id, update.clone())?;
2090    }
2091
2092    // Update incoming contact requests of responder
2093    let mut update = proto::UpdateContacts::default();
2094    update
2095        .incoming_requests
2096        .push(proto::IncomingContactRequest {
2097            requester_id: requester_id.to_proto(),
2098        });
2099    let connection_pool = session.connection_pool().await;
2100    for connection_id in connection_pool.user_connection_ids(responder_id) {
2101        session.peer.send(connection_id, update.clone())?;
2102    }
2103
2104    send_notifications(&*connection_pool, &session.peer, notifications);
2105
2106    response.send(proto::Ack {})?;
2107    Ok(())
2108}
2109
2110async fn respond_to_contact_request(
2111    request: proto::RespondToContactRequest,
2112    response: Response<proto::RespondToContactRequest>,
2113    session: Session,
2114) -> Result<()> {
2115    let responder_id = session.user_id;
2116    let requester_id = UserId::from_proto(request.requester_id);
2117    let db = session.db().await;
2118    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2119        db.dismiss_contact_notification(responder_id, requester_id)
2120            .await?;
2121    } else {
2122        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2123
2124        let notifications = db
2125            .respond_to_contact_request(responder_id, requester_id, accept)
2126            .await?;
2127        let requester_busy = db.is_user_busy(requester_id).await?;
2128        let responder_busy = db.is_user_busy(responder_id).await?;
2129
2130        let pool = session.connection_pool().await;
2131        // Update responder with new contact
2132        let mut update = proto::UpdateContacts::default();
2133        if accept {
2134            update
2135                .contacts
2136                .push(contact_for_user(requester_id, requester_busy, &pool));
2137        }
2138        update
2139            .remove_incoming_requests
2140            .push(requester_id.to_proto());
2141        for connection_id in pool.user_connection_ids(responder_id) {
2142            session.peer.send(connection_id, update.clone())?;
2143        }
2144
2145        // Update requester with new contact
2146        let mut update = proto::UpdateContacts::default();
2147        if accept {
2148            update
2149                .contacts
2150                .push(contact_for_user(responder_id, responder_busy, &pool));
2151        }
2152        update
2153            .remove_outgoing_requests
2154            .push(responder_id.to_proto());
2155
2156        for connection_id in pool.user_connection_ids(requester_id) {
2157            session.peer.send(connection_id, update.clone())?;
2158        }
2159
2160        send_notifications(&*pool, &session.peer, notifications);
2161    }
2162
2163    response.send(proto::Ack {})?;
2164    Ok(())
2165}
2166
2167async fn remove_contact(
2168    request: proto::RemoveContact,
2169    response: Response<proto::RemoveContact>,
2170    session: Session,
2171) -> Result<()> {
2172    let requester_id = session.user_id;
2173    let responder_id = UserId::from_proto(request.user_id);
2174    let db = session.db().await;
2175    let (contact_accepted, deleted_notification_id) =
2176        db.remove_contact(requester_id, responder_id).await?;
2177
2178    let pool = session.connection_pool().await;
2179    // Update outgoing contact requests of requester
2180    let mut update = proto::UpdateContacts::default();
2181    if contact_accepted {
2182        update.remove_contacts.push(responder_id.to_proto());
2183    } else {
2184        update
2185            .remove_outgoing_requests
2186            .push(responder_id.to_proto());
2187    }
2188    for connection_id in pool.user_connection_ids(requester_id) {
2189        session.peer.send(connection_id, update.clone())?;
2190    }
2191
2192    // Update incoming contact requests of responder
2193    let mut update = proto::UpdateContacts::default();
2194    if contact_accepted {
2195        update.remove_contacts.push(requester_id.to_proto());
2196    } else {
2197        update
2198            .remove_incoming_requests
2199            .push(requester_id.to_proto());
2200    }
2201    for connection_id in pool.user_connection_ids(responder_id) {
2202        session.peer.send(connection_id, update.clone())?;
2203        if let Some(notification_id) = deleted_notification_id {
2204            session.peer.send(
2205                connection_id,
2206                proto::DeleteNotification {
2207                    notification_id: notification_id.to_proto(),
2208                },
2209            )?;
2210        }
2211    }
2212
2213    response.send(proto::Ack {})?;
2214    Ok(())
2215}
2216
2217async fn create_channel(
2218    request: proto::CreateChannel,
2219    response: Response<proto::CreateChannel>,
2220    session: Session,
2221) -> Result<()> {
2222    let db = session.db().await;
2223
2224    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2225    let CreateChannelResult {
2226        channel,
2227        participants_to_update,
2228    } = db
2229        .create_channel(&request.name, parent_id, session.user_id)
2230        .await?;
2231
2232    response.send(proto::CreateChannelResponse {
2233        channel: Some(channel.to_proto()),
2234        parent_id: request.parent_id,
2235    })?;
2236
2237    let connection_pool = session.connection_pool().await;
2238    for (user_id, channels) in participants_to_update {
2239        let update = build_channels_update(channels, vec![]);
2240        for connection_id in connection_pool.user_connection_ids(user_id) {
2241            if user_id == session.user_id {
2242                continue;
2243            }
2244            session.peer.send(connection_id, update.clone())?;
2245        }
2246    }
2247
2248    Ok(())
2249}
2250
2251async fn delete_channel(
2252    request: proto::DeleteChannel,
2253    response: Response<proto::DeleteChannel>,
2254    session: Session,
2255) -> Result<()> {
2256    let db = session.db().await;
2257
2258    let channel_id = request.channel_id;
2259    let (removed_channels, member_ids) = db
2260        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2261        .await?;
2262    response.send(proto::Ack {})?;
2263
2264    // Notify members of removed channels
2265    let mut update = proto::UpdateChannels::default();
2266    update
2267        .delete_channels
2268        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2269
2270    let connection_pool = session.connection_pool().await;
2271    for member_id in member_ids {
2272        for connection_id in connection_pool.user_connection_ids(member_id) {
2273            session.peer.send(connection_id, update.clone())?;
2274        }
2275    }
2276
2277    Ok(())
2278}
2279
2280async fn invite_channel_member(
2281    request: proto::InviteChannelMember,
2282    response: Response<proto::InviteChannelMember>,
2283    session: Session,
2284) -> Result<()> {
2285    let db = session.db().await;
2286    let channel_id = ChannelId::from_proto(request.channel_id);
2287    let invitee_id = UserId::from_proto(request.user_id);
2288    let InviteMemberResult {
2289        channel,
2290        notifications,
2291    } = db
2292        .invite_channel_member(
2293            channel_id,
2294            invitee_id,
2295            session.user_id,
2296            request.role().into(),
2297        )
2298        .await?;
2299
2300    let update = proto::UpdateChannels {
2301        channel_invitations: vec![channel.to_proto()],
2302        ..Default::default()
2303    };
2304
2305    let connection_pool = session.connection_pool().await;
2306    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2307        session.peer.send(connection_id, update.clone())?;
2308    }
2309
2310    send_notifications(&*connection_pool, &session.peer, notifications);
2311
2312    response.send(proto::Ack {})?;
2313    Ok(())
2314}
2315
2316async fn remove_channel_member(
2317    request: proto::RemoveChannelMember,
2318    response: Response<proto::RemoveChannelMember>,
2319    session: Session,
2320) -> Result<()> {
2321    let db = session.db().await;
2322    let channel_id = ChannelId::from_proto(request.channel_id);
2323    let member_id = UserId::from_proto(request.user_id);
2324
2325    let RemoveChannelMemberResult {
2326        membership_update,
2327        notification_id,
2328    } = db
2329        .remove_channel_member(channel_id, member_id, session.user_id)
2330        .await?;
2331
2332    let connection_pool = &session.connection_pool().await;
2333    notify_membership_updated(
2334        &connection_pool,
2335        membership_update,
2336        member_id,
2337        &session.peer,
2338    );
2339    for connection_id in connection_pool.user_connection_ids(member_id) {
2340        if let Some(notification_id) = notification_id {
2341            session
2342                .peer
2343                .send(
2344                    connection_id,
2345                    proto::DeleteNotification {
2346                        notification_id: notification_id.to_proto(),
2347                    },
2348                )
2349                .trace_err();
2350        }
2351    }
2352
2353    response.send(proto::Ack {})?;
2354    Ok(())
2355}
2356
2357async fn set_channel_visibility(
2358    request: proto::SetChannelVisibility,
2359    response: Response<proto::SetChannelVisibility>,
2360    session: Session,
2361) -> Result<()> {
2362    let db = session.db().await;
2363    let channel_id = ChannelId::from_proto(request.channel_id);
2364    let visibility = request.visibility().into();
2365
2366    let SetChannelVisibilityResult {
2367        participants_to_update,
2368        participants_to_remove,
2369    } = db
2370        .set_channel_visibility(channel_id, visibility, session.user_id)
2371        .await?;
2372
2373    let connection_pool = session.connection_pool().await;
2374    for (user_id, channels) in participants_to_update {
2375        let update = build_channels_update(channels, vec![]);
2376        for connection_id in connection_pool.user_connection_ids(user_id) {
2377            session.peer.send(connection_id, update.clone())?;
2378        }
2379    }
2380    for user_id in participants_to_remove {
2381        let update = proto::UpdateChannels {
2382            // for public participants  we only need to remove the current channel
2383            // (not descendants)
2384            // because they can still see any public descendants
2385            delete_channels: vec![channel_id.to_proto()],
2386            ..Default::default()
2387        };
2388        for connection_id in connection_pool.user_connection_ids(user_id) {
2389            session.peer.send(connection_id, update.clone())?;
2390        }
2391    }
2392
2393    response.send(proto::Ack {})?;
2394    Ok(())
2395}
2396
2397async fn set_channel_member_role(
2398    request: proto::SetChannelMemberRole,
2399    response: Response<proto::SetChannelMemberRole>,
2400    session: Session,
2401) -> Result<()> {
2402    let db = session.db().await;
2403    let channel_id = ChannelId::from_proto(request.channel_id);
2404    let member_id = UserId::from_proto(request.user_id);
2405    let result = db
2406        .set_channel_member_role(
2407            channel_id,
2408            session.user_id,
2409            member_id,
2410            request.role().into(),
2411        )
2412        .await?;
2413
2414    match result {
2415        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2416            let connection_pool = session.connection_pool().await;
2417            notify_membership_updated(
2418                &connection_pool,
2419                membership_update,
2420                member_id,
2421                &session.peer,
2422            )
2423        }
2424        db::SetMemberRoleResult::InviteUpdated(channel) => {
2425            let update = proto::UpdateChannels {
2426                channel_invitations: vec![channel.to_proto()],
2427                ..Default::default()
2428            };
2429
2430            for connection_id in session
2431                .connection_pool()
2432                .await
2433                .user_connection_ids(member_id)
2434            {
2435                session.peer.send(connection_id, update.clone())?;
2436            }
2437        }
2438    }
2439
2440    response.send(proto::Ack {})?;
2441    Ok(())
2442}
2443
2444async fn rename_channel(
2445    request: proto::RenameChannel,
2446    response: Response<proto::RenameChannel>,
2447    session: Session,
2448) -> Result<()> {
2449    let db = session.db().await;
2450    let channel_id = ChannelId::from_proto(request.channel_id);
2451    let RenameChannelResult {
2452        channel,
2453        participants_to_update,
2454    } = db
2455        .rename_channel(channel_id, session.user_id, &request.name)
2456        .await?;
2457
2458    response.send(proto::RenameChannelResponse {
2459        channel: Some(channel.to_proto()),
2460    })?;
2461
2462    let connection_pool = session.connection_pool().await;
2463    for (user_id, channel) in participants_to_update {
2464        for connection_id in connection_pool.user_connection_ids(user_id) {
2465            let update = proto::UpdateChannels {
2466                channels: vec![channel.to_proto()],
2467                ..Default::default()
2468            };
2469
2470            session.peer.send(connection_id, update.clone())?;
2471        }
2472    }
2473
2474    Ok(())
2475}
2476
2477// TODO: Implement in terms of symlinks
2478// Current behavior of this is more like 'Move root channel'
2479async fn link_channel(
2480    request: proto::LinkChannel,
2481    response: Response<proto::LinkChannel>,
2482    session: Session,
2483) -> Result<()> {
2484    let db = session.db().await;
2485    let channel_id = ChannelId::from_proto(request.channel_id);
2486    let to = ChannelId::from_proto(request.to);
2487
2488    let result = db
2489        .move_channel(channel_id, None, to, session.user_id)
2490        .await?;
2491    drop(db);
2492
2493    notify_channel_moved(result, session).await?;
2494
2495    response.send(Ack {})?;
2496
2497    Ok(())
2498}
2499
2500// TODO: Implement in terms of symlinks
2501async fn unlink_channel(
2502    _request: proto::UnlinkChannel,
2503    _response: Response<proto::UnlinkChannel>,
2504    _session: Session,
2505) -> Result<()> {
2506    Err(anyhow!("unimplemented").into())
2507}
2508
2509async fn move_channel(
2510    request: proto::MoveChannel,
2511    response: Response<proto::MoveChannel>,
2512    session: Session,
2513) -> Result<()> {
2514    let db = session.db().await;
2515    let channel_id = ChannelId::from_proto(request.channel_id);
2516    let from_parent = ChannelId::from_proto(request.from);
2517    let to = ChannelId::from_proto(request.to);
2518
2519    let result = db
2520        .move_channel(channel_id, Some(from_parent), to, session.user_id)
2521        .await?;
2522    drop(db);
2523
2524    notify_channel_moved(result, session).await?;
2525
2526    response.send(Ack {})?;
2527    Ok(())
2528}
2529
2530async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2531    let Some(MoveChannelResult {
2532        participants_to_remove,
2533        participants_to_update,
2534        moved_channels,
2535    }) = result
2536    else {
2537        return Ok(());
2538    };
2539    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2540
2541    let connection_pool = session.connection_pool().await;
2542    for (user_id, channels) in participants_to_update {
2543        let mut update = build_channels_update(channels, vec![]);
2544        update.delete_channels = moved_channels.clone();
2545        for connection_id in connection_pool.user_connection_ids(user_id) {
2546            session.peer.send(connection_id, update.clone())?;
2547        }
2548    }
2549
2550    for user_id in participants_to_remove {
2551        let update = proto::UpdateChannels {
2552            delete_channels: moved_channels.clone(),
2553            ..Default::default()
2554        };
2555        for connection_id in connection_pool.user_connection_ids(user_id) {
2556            session.peer.send(connection_id, update.clone())?;
2557        }
2558    }
2559    Ok(())
2560}
2561
2562async fn get_channel_members(
2563    request: proto::GetChannelMembers,
2564    response: Response<proto::GetChannelMembers>,
2565    session: Session,
2566) -> Result<()> {
2567    let db = session.db().await;
2568    let channel_id = ChannelId::from_proto(request.channel_id);
2569    let members = db
2570        .get_channel_participant_details(channel_id, session.user_id)
2571        .await?;
2572    response.send(proto::GetChannelMembersResponse { members })?;
2573    Ok(())
2574}
2575
2576async fn respond_to_channel_invite(
2577    request: proto::RespondToChannelInvite,
2578    response: Response<proto::RespondToChannelInvite>,
2579    session: Session,
2580) -> Result<()> {
2581    let db = session.db().await;
2582    let channel_id = ChannelId::from_proto(request.channel_id);
2583    let RespondToChannelInvite {
2584        membership_update,
2585        notifications,
2586    } = db
2587        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2588        .await?;
2589
2590    let connection_pool = session.connection_pool().await;
2591    if let Some(membership_update) = membership_update {
2592        notify_membership_updated(
2593            &connection_pool,
2594            membership_update,
2595            session.user_id,
2596            &session.peer,
2597        );
2598    } else {
2599        let update = proto::UpdateChannels {
2600            remove_channel_invitations: vec![channel_id.to_proto()],
2601            ..Default::default()
2602        };
2603
2604        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2605            session.peer.send(connection_id, update.clone())?;
2606        }
2607    };
2608
2609    send_notifications(&*connection_pool, &session.peer, notifications);
2610
2611    response.send(proto::Ack {})?;
2612
2613    Ok(())
2614}
2615
2616async fn join_channel(
2617    request: proto::JoinChannel,
2618    response: Response<proto::JoinChannel>,
2619    session: Session,
2620) -> Result<()> {
2621    let channel_id = ChannelId::from_proto(request.channel_id);
2622    join_channel_internal(channel_id, Box::new(response), session).await
2623}
2624
2625trait JoinChannelInternalResponse {
2626    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2627}
2628impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2629    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2630        Response::<proto::JoinChannel>::send(self, result)
2631    }
2632}
2633impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2634    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2635        Response::<proto::JoinRoom>::send(self, result)
2636    }
2637}
2638
2639async fn join_channel_internal(
2640    channel_id: ChannelId,
2641    response: Box<impl JoinChannelInternalResponse>,
2642    session: Session,
2643) -> Result<()> {
2644    let joined_room = {
2645        leave_room_for_session(&session).await?;
2646        let db = session.db().await;
2647
2648        let (joined_room, accept_invite_result, role) = db
2649            .join_channel(
2650                channel_id,
2651                session.user_id,
2652                session.connection_id,
2653                RELEASE_CHANNEL_NAME.as_str(),
2654            )
2655            .await?;
2656
2657        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2658            let (can_publish, token) = if role == ChannelRole::Guest {
2659                (
2660                    false,
2661                    live_kit
2662                        .guest_token(
2663                            &joined_room.room.live_kit_room,
2664                            &session.user_id.to_string(),
2665                        )
2666                        .trace_err()?,
2667                )
2668            } else {
2669                (
2670                    true,
2671                    live_kit
2672                        .room_token(
2673                            &joined_room.room.live_kit_room,
2674                            &session.user_id.to_string(),
2675                        )
2676                        .trace_err()?,
2677                )
2678            };
2679
2680            Some(LiveKitConnectionInfo {
2681                server_url: live_kit.url().into(),
2682                token,
2683                can_publish,
2684            })
2685        });
2686
2687        response.send(proto::JoinRoomResponse {
2688            room: Some(joined_room.room.clone()),
2689            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2690            live_kit_connection_info,
2691        })?;
2692
2693        let connection_pool = session.connection_pool().await;
2694        if let Some(accept_invite_result) = accept_invite_result {
2695            notify_membership_updated(
2696                &connection_pool,
2697                accept_invite_result,
2698                session.user_id,
2699                &session.peer,
2700            );
2701        }
2702
2703        room_updated(&joined_room.room, &session.peer);
2704
2705        joined_room
2706    };
2707
2708    channel_updated(
2709        channel_id,
2710        &joined_room.room,
2711        &joined_room.channel_members,
2712        &session.peer,
2713        &*session.connection_pool().await,
2714    );
2715
2716    update_user_contacts(session.user_id, &session).await?;
2717    Ok(())
2718}
2719
2720async fn join_channel_buffer(
2721    request: proto::JoinChannelBuffer,
2722    response: Response<proto::JoinChannelBuffer>,
2723    session: Session,
2724) -> Result<()> {
2725    let db = session.db().await;
2726    let channel_id = ChannelId::from_proto(request.channel_id);
2727
2728    let open_response = db
2729        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2730        .await?;
2731
2732    let collaborators = open_response.collaborators.clone();
2733    response.send(open_response)?;
2734
2735    let update = UpdateChannelBufferCollaborators {
2736        channel_id: channel_id.to_proto(),
2737        collaborators: collaborators.clone(),
2738    };
2739    channel_buffer_updated(
2740        session.connection_id,
2741        collaborators
2742            .iter()
2743            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2744        &update,
2745        &session.peer,
2746    );
2747
2748    Ok(())
2749}
2750
2751async fn update_channel_buffer(
2752    request: proto::UpdateChannelBuffer,
2753    session: Session,
2754) -> Result<()> {
2755    let db = session.db().await;
2756    let channel_id = ChannelId::from_proto(request.channel_id);
2757
2758    let (collaborators, non_collaborators, epoch, version) = db
2759        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2760        .await?;
2761
2762    channel_buffer_updated(
2763        session.connection_id,
2764        collaborators,
2765        &proto::UpdateChannelBuffer {
2766            channel_id: channel_id.to_proto(),
2767            operations: request.operations,
2768        },
2769        &session.peer,
2770    );
2771
2772    let pool = &*session.connection_pool().await;
2773
2774    broadcast(
2775        None,
2776        non_collaborators
2777            .iter()
2778            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2779        |peer_id| {
2780            session.peer.send(
2781                peer_id.into(),
2782                proto::UpdateChannels {
2783                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2784                        channel_id: channel_id.to_proto(),
2785                        epoch: epoch as u64,
2786                        version: version.clone(),
2787                    }],
2788                    ..Default::default()
2789                },
2790            )
2791        },
2792    );
2793
2794    Ok(())
2795}
2796
2797async fn rejoin_channel_buffers(
2798    request: proto::RejoinChannelBuffers,
2799    response: Response<proto::RejoinChannelBuffers>,
2800    session: Session,
2801) -> Result<()> {
2802    let db = session.db().await;
2803    let buffers = db
2804        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2805        .await?;
2806
2807    for rejoined_buffer in &buffers {
2808        let collaborators_to_notify = rejoined_buffer
2809            .buffer
2810            .collaborators
2811            .iter()
2812            .filter_map(|c| Some(c.peer_id?.into()));
2813        channel_buffer_updated(
2814            session.connection_id,
2815            collaborators_to_notify,
2816            &proto::UpdateChannelBufferCollaborators {
2817                channel_id: rejoined_buffer.buffer.channel_id,
2818                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2819            },
2820            &session.peer,
2821        );
2822    }
2823
2824    response.send(proto::RejoinChannelBuffersResponse {
2825        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2826    })?;
2827
2828    Ok(())
2829}
2830
2831async fn leave_channel_buffer(
2832    request: proto::LeaveChannelBuffer,
2833    response: Response<proto::LeaveChannelBuffer>,
2834    session: Session,
2835) -> Result<()> {
2836    let db = session.db().await;
2837    let channel_id = ChannelId::from_proto(request.channel_id);
2838
2839    let left_buffer = db
2840        .leave_channel_buffer(channel_id, session.connection_id)
2841        .await?;
2842
2843    response.send(Ack {})?;
2844
2845    channel_buffer_updated(
2846        session.connection_id,
2847        left_buffer.connections,
2848        &proto::UpdateChannelBufferCollaborators {
2849            channel_id: channel_id.to_proto(),
2850            collaborators: left_buffer.collaborators,
2851        },
2852        &session.peer,
2853    );
2854
2855    Ok(())
2856}
2857
2858fn channel_buffer_updated<T: EnvelopedMessage>(
2859    sender_id: ConnectionId,
2860    collaborators: impl IntoIterator<Item = ConnectionId>,
2861    message: &T,
2862    peer: &Peer,
2863) {
2864    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2865        peer.send(peer_id.into(), message.clone())
2866    });
2867}
2868
2869fn send_notifications(
2870    connection_pool: &ConnectionPool,
2871    peer: &Peer,
2872    notifications: db::NotificationBatch,
2873) {
2874    for (user_id, notification) in notifications {
2875        for connection_id in connection_pool.user_connection_ids(user_id) {
2876            if let Err(error) = peer.send(
2877                connection_id,
2878                proto::AddNotification {
2879                    notification: Some(notification.clone()),
2880                },
2881            ) {
2882                tracing::error!(
2883                    "failed to send notification to {:?} {}",
2884                    connection_id,
2885                    error
2886                );
2887            }
2888        }
2889    }
2890}
2891
2892async fn send_channel_message(
2893    request: proto::SendChannelMessage,
2894    response: Response<proto::SendChannelMessage>,
2895    session: Session,
2896) -> Result<()> {
2897    // Validate the message body.
2898    let body = request.body.trim().to_string();
2899    if body.len() > MAX_MESSAGE_LEN {
2900        return Err(anyhow!("message is too long"))?;
2901    }
2902    if body.is_empty() {
2903        return Err(anyhow!("message can't be blank"))?;
2904    }
2905
2906    // TODO: adjust mentions if body is trimmed
2907
2908    let timestamp = OffsetDateTime::now_utc();
2909    let nonce = request
2910        .nonce
2911        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2912
2913    let channel_id = ChannelId::from_proto(request.channel_id);
2914    let CreatedChannelMessage {
2915        message_id,
2916        participant_connection_ids,
2917        channel_members,
2918        notifications,
2919    } = session
2920        .db()
2921        .await
2922        .create_channel_message(
2923            channel_id,
2924            session.user_id,
2925            &body,
2926            &request.mentions,
2927            timestamp,
2928            nonce.clone().into(),
2929        )
2930        .await?;
2931    let message = proto::ChannelMessage {
2932        sender_id: session.user_id.to_proto(),
2933        id: message_id.to_proto(),
2934        body,
2935        mentions: request.mentions,
2936        timestamp: timestamp.unix_timestamp() as u64,
2937        nonce: Some(nonce),
2938    };
2939    broadcast(
2940        Some(session.connection_id),
2941        participant_connection_ids,
2942        |connection| {
2943            session.peer.send(
2944                connection,
2945                proto::ChannelMessageSent {
2946                    channel_id: channel_id.to_proto(),
2947                    message: Some(message.clone()),
2948                },
2949            )
2950        },
2951    );
2952    response.send(proto::SendChannelMessageResponse {
2953        message: Some(message),
2954    })?;
2955
2956    let pool = &*session.connection_pool().await;
2957    broadcast(
2958        None,
2959        channel_members
2960            .iter()
2961            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2962        |peer_id| {
2963            session.peer.send(
2964                peer_id.into(),
2965                proto::UpdateChannels {
2966                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2967                        channel_id: channel_id.to_proto(),
2968                        message_id: message_id.to_proto(),
2969                    }],
2970                    ..Default::default()
2971                },
2972            )
2973        },
2974    );
2975    send_notifications(pool, &session.peer, notifications);
2976
2977    Ok(())
2978}
2979
2980async fn remove_channel_message(
2981    request: proto::RemoveChannelMessage,
2982    response: Response<proto::RemoveChannelMessage>,
2983    session: Session,
2984) -> Result<()> {
2985    let channel_id = ChannelId::from_proto(request.channel_id);
2986    let message_id = MessageId::from_proto(request.message_id);
2987    let connection_ids = session
2988        .db()
2989        .await
2990        .remove_channel_message(channel_id, message_id, session.user_id)
2991        .await?;
2992    broadcast(Some(session.connection_id), connection_ids, |connection| {
2993        session.peer.send(connection, request.clone())
2994    });
2995    response.send(proto::Ack {})?;
2996    Ok(())
2997}
2998
2999async fn acknowledge_channel_message(
3000    request: proto::AckChannelMessage,
3001    session: Session,
3002) -> Result<()> {
3003    let channel_id = ChannelId::from_proto(request.channel_id);
3004    let message_id = MessageId::from_proto(request.message_id);
3005    let notifications = session
3006        .db()
3007        .await
3008        .observe_channel_message(channel_id, session.user_id, message_id)
3009        .await?;
3010    send_notifications(
3011        &*session.connection_pool().await,
3012        &session.peer,
3013        notifications,
3014    );
3015    Ok(())
3016}
3017
3018async fn acknowledge_buffer_version(
3019    request: proto::AckBufferOperation,
3020    session: Session,
3021) -> Result<()> {
3022    let buffer_id = BufferId::from_proto(request.buffer_id);
3023    session
3024        .db()
3025        .await
3026        .observe_buffer_version(
3027            buffer_id,
3028            session.user_id,
3029            request.epoch as i32,
3030            &request.version,
3031        )
3032        .await?;
3033    Ok(())
3034}
3035
3036async fn join_channel_chat(
3037    request: proto::JoinChannelChat,
3038    response: Response<proto::JoinChannelChat>,
3039    session: Session,
3040) -> Result<()> {
3041    let channel_id = ChannelId::from_proto(request.channel_id);
3042
3043    let db = session.db().await;
3044    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3045        .await?;
3046    let messages = db
3047        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3048        .await?;
3049    response.send(proto::JoinChannelChatResponse {
3050        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3051        messages,
3052    })?;
3053    Ok(())
3054}
3055
3056async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3057    let channel_id = ChannelId::from_proto(request.channel_id);
3058    session
3059        .db()
3060        .await
3061        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3062        .await?;
3063    Ok(())
3064}
3065
3066async fn get_channel_messages(
3067    request: proto::GetChannelMessages,
3068    response: Response<proto::GetChannelMessages>,
3069    session: Session,
3070) -> Result<()> {
3071    let channel_id = ChannelId::from_proto(request.channel_id);
3072    let messages = session
3073        .db()
3074        .await
3075        .get_channel_messages(
3076            channel_id,
3077            session.user_id,
3078            MESSAGE_COUNT_PER_PAGE,
3079            Some(MessageId::from_proto(request.before_message_id)),
3080        )
3081        .await?;
3082    response.send(proto::GetChannelMessagesResponse {
3083        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3084        messages,
3085    })?;
3086    Ok(())
3087}
3088
3089async fn get_channel_messages_by_id(
3090    request: proto::GetChannelMessagesById,
3091    response: Response<proto::GetChannelMessagesById>,
3092    session: Session,
3093) -> Result<()> {
3094    let message_ids = request
3095        .message_ids
3096        .iter()
3097        .map(|id| MessageId::from_proto(*id))
3098        .collect::<Vec<_>>();
3099    let messages = session
3100        .db()
3101        .await
3102        .get_channel_messages_by_id(session.user_id, &message_ids)
3103        .await?;
3104    response.send(proto::GetChannelMessagesResponse {
3105        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3106        messages,
3107    })?;
3108    Ok(())
3109}
3110
3111async fn get_notifications(
3112    request: proto::GetNotifications,
3113    response: Response<proto::GetNotifications>,
3114    session: Session,
3115) -> Result<()> {
3116    let notifications = session
3117        .db()
3118        .await
3119        .get_notifications(
3120            session.user_id,
3121            NOTIFICATION_COUNT_PER_PAGE,
3122            request
3123                .before_id
3124                .map(|id| db::NotificationId::from_proto(id)),
3125        )
3126        .await?;
3127    response.send(proto::GetNotificationsResponse {
3128        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3129        notifications,
3130    })?;
3131    Ok(())
3132}
3133
3134async fn mark_notification_as_read(
3135    request: proto::MarkNotificationRead,
3136    response: Response<proto::MarkNotificationRead>,
3137    session: Session,
3138) -> Result<()> {
3139    let database = &session.db().await;
3140    let notifications = database
3141        .mark_notification_as_read_by_id(
3142            session.user_id,
3143            NotificationId::from_proto(request.notification_id),
3144        )
3145        .await?;
3146    send_notifications(
3147        &*session.connection_pool().await,
3148        &session.peer,
3149        notifications,
3150    );
3151    response.send(proto::Ack {})?;
3152    Ok(())
3153}
3154
3155async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3156    let project_id = ProjectId::from_proto(request.project_id);
3157    let project_connection_ids = session
3158        .db()
3159        .await
3160        .project_connection_ids(project_id, session.connection_id)
3161        .await?;
3162    broadcast(
3163        Some(session.connection_id),
3164        project_connection_ids.iter().copied(),
3165        |connection_id| {
3166            session
3167                .peer
3168                .forward_send(session.connection_id, connection_id, request.clone())
3169        },
3170    );
3171    Ok(())
3172}
3173
3174async fn get_private_user_info(
3175    _request: proto::GetPrivateUserInfo,
3176    response: Response<proto::GetPrivateUserInfo>,
3177    session: Session,
3178) -> Result<()> {
3179    let db = session.db().await;
3180
3181    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3182    let user = db
3183        .get_user_by_id(session.user_id)
3184        .await?
3185        .ok_or_else(|| anyhow!("user not found"))?;
3186    let flags = db.get_user_flags(session.user_id).await?;
3187
3188    response.send(proto::GetPrivateUserInfoResponse {
3189        metrics_id,
3190        staff: user.admin,
3191        flags,
3192    })?;
3193    Ok(())
3194}
3195
3196fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3197    match message {
3198        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3199        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3200        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3201        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3202        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3203            code: frame.code.into(),
3204            reason: frame.reason,
3205        })),
3206    }
3207}
3208
3209fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3210    match message {
3211        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3212        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3213        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3214        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3215        AxumMessage::Close(frame) => {
3216            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3217                code: frame.code.into(),
3218                reason: frame.reason,
3219            }))
3220        }
3221    }
3222}
3223
3224fn notify_membership_updated(
3225    connection_pool: &ConnectionPool,
3226    result: MembershipUpdated,
3227    user_id: UserId,
3228    peer: &Peer,
3229) {
3230    let mut update = build_channels_update(result.new_channels, vec![]);
3231    update.delete_channels = result
3232        .removed_channels
3233        .into_iter()
3234        .map(|id| id.to_proto())
3235        .collect();
3236    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3237
3238    for connection_id in connection_pool.user_connection_ids(user_id) {
3239        peer.send(connection_id, update.clone()).trace_err();
3240    }
3241}
3242
3243fn build_channels_update(
3244    channels: ChannelsForUser,
3245    channel_invites: Vec<db::Channel>,
3246) -> proto::UpdateChannels {
3247    let mut update = proto::UpdateChannels::default();
3248
3249    for channel in channels.channels.channels {
3250        update.channels.push(proto::Channel {
3251            id: channel.id.to_proto(),
3252            name: channel.name,
3253            visibility: channel.visibility.into(),
3254            role: channel.role.into(),
3255        });
3256    }
3257
3258    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3259    update.unseen_channel_messages = channels.channel_messages;
3260    update.insert_edge = channels.channels.edges;
3261
3262    for (channel_id, participants) in channels.channel_participants {
3263        update
3264            .channel_participants
3265            .push(proto::ChannelParticipants {
3266                channel_id: channel_id.to_proto(),
3267                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3268            });
3269    }
3270
3271    for channel in channel_invites {
3272        update.channel_invitations.push(proto::Channel {
3273            id: channel.id.to_proto(),
3274            name: channel.name,
3275            visibility: channel.visibility.into(),
3276            role: channel.role.into(),
3277        });
3278    }
3279
3280    update
3281}
3282
3283fn build_initial_contacts_update(
3284    contacts: Vec<db::Contact>,
3285    pool: &ConnectionPool,
3286) -> proto::UpdateContacts {
3287    let mut update = proto::UpdateContacts::default();
3288
3289    for contact in contacts {
3290        match contact {
3291            db::Contact::Accepted { user_id, busy } => {
3292                update.contacts.push(contact_for_user(user_id, busy, &pool));
3293            }
3294            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3295            db::Contact::Incoming { user_id } => {
3296                update
3297                    .incoming_requests
3298                    .push(proto::IncomingContactRequest {
3299                        requester_id: user_id.to_proto(),
3300                    })
3301            }
3302        }
3303    }
3304
3305    update
3306}
3307
3308fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3309    proto::Contact {
3310        user_id: user_id.to_proto(),
3311        online: pool.is_user_online(user_id),
3312        busy,
3313    }
3314}
3315
3316fn room_updated(room: &proto::Room, peer: &Peer) {
3317    broadcast(
3318        None,
3319        room.participants
3320            .iter()
3321            .filter_map(|participant| Some(participant.peer_id?.into())),
3322        |peer_id| {
3323            peer.send(
3324                peer_id.into(),
3325                proto::RoomUpdated {
3326                    room: Some(room.clone()),
3327                },
3328            )
3329        },
3330    );
3331}
3332
3333fn channel_updated(
3334    channel_id: ChannelId,
3335    room: &proto::Room,
3336    channel_members: &[UserId],
3337    peer: &Peer,
3338    pool: &ConnectionPool,
3339) {
3340    let participants = room
3341        .participants
3342        .iter()
3343        .map(|p| p.user_id)
3344        .collect::<Vec<_>>();
3345
3346    broadcast(
3347        None,
3348        channel_members
3349            .iter()
3350            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3351        |peer_id| {
3352            peer.send(
3353                peer_id.into(),
3354                proto::UpdateChannels {
3355                    channel_participants: vec![proto::ChannelParticipants {
3356                        channel_id: channel_id.to_proto(),
3357                        participant_user_ids: participants.clone(),
3358                    }],
3359                    ..Default::default()
3360                },
3361            )
3362        },
3363    );
3364}
3365
3366async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3367    let db = session.db().await;
3368
3369    let contacts = db.get_contacts(user_id).await?;
3370    let busy = db.is_user_busy(user_id).await?;
3371
3372    let pool = session.connection_pool().await;
3373    let updated_contact = contact_for_user(user_id, busy, &pool);
3374    for contact in contacts {
3375        if let db::Contact::Accepted {
3376            user_id: contact_user_id,
3377            ..
3378        } = contact
3379        {
3380            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3381                session
3382                    .peer
3383                    .send(
3384                        contact_conn_id,
3385                        proto::UpdateContacts {
3386                            contacts: vec![updated_contact.clone()],
3387                            remove_contacts: Default::default(),
3388                            incoming_requests: Default::default(),
3389                            remove_incoming_requests: Default::default(),
3390                            outgoing_requests: Default::default(),
3391                            remove_outgoing_requests: Default::default(),
3392                        },
3393                    )
3394                    .trace_err();
3395            }
3396        }
3397    }
3398    Ok(())
3399}
3400
3401async fn leave_room_for_session(session: &Session) -> Result<()> {
3402    let mut contacts_to_update = HashSet::default();
3403
3404    let room_id;
3405    let canceled_calls_to_user_ids;
3406    let live_kit_room;
3407    let delete_live_kit_room;
3408    let room;
3409    let channel_members;
3410    let channel_id;
3411
3412    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3413        contacts_to_update.insert(session.user_id);
3414
3415        for project in left_room.left_projects.values() {
3416            project_left(project, session);
3417        }
3418
3419        room_id = RoomId::from_proto(left_room.room.id);
3420        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3421        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3422        delete_live_kit_room = left_room.deleted;
3423        room = mem::take(&mut left_room.room);
3424        channel_members = mem::take(&mut left_room.channel_members);
3425        channel_id = left_room.channel_id;
3426
3427        room_updated(&room, &session.peer);
3428    } else {
3429        return Ok(());
3430    }
3431
3432    if let Some(channel_id) = channel_id {
3433        channel_updated(
3434            channel_id,
3435            &room,
3436            &channel_members,
3437            &session.peer,
3438            &*session.connection_pool().await,
3439        );
3440    }
3441
3442    {
3443        let pool = session.connection_pool().await;
3444        for canceled_user_id in canceled_calls_to_user_ids {
3445            for connection_id in pool.user_connection_ids(canceled_user_id) {
3446                session
3447                    .peer
3448                    .send(
3449                        connection_id,
3450                        proto::CallCanceled {
3451                            room_id: room_id.to_proto(),
3452                        },
3453                    )
3454                    .trace_err();
3455            }
3456            contacts_to_update.insert(canceled_user_id);
3457        }
3458    }
3459
3460    for contact_user_id in contacts_to_update {
3461        update_user_contacts(contact_user_id, &session).await?;
3462    }
3463
3464    if let Some(live_kit) = session.live_kit_client.as_ref() {
3465        live_kit
3466            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3467            .await
3468            .trace_err();
3469
3470        if delete_live_kit_room {
3471            live_kit.delete_room(live_kit_room).await.trace_err();
3472        }
3473    }
3474
3475    Ok(())
3476}
3477
3478async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3479    let left_channel_buffers = session
3480        .db()
3481        .await
3482        .leave_channel_buffers(session.connection_id)
3483        .await?;
3484
3485    for left_buffer in left_channel_buffers {
3486        channel_buffer_updated(
3487            session.connection_id,
3488            left_buffer.connections,
3489            &proto::UpdateChannelBufferCollaborators {
3490                channel_id: left_buffer.channel_id.to_proto(),
3491                collaborators: left_buffer.collaborators,
3492            },
3493            &session.peer,
3494        );
3495    }
3496
3497    Ok(())
3498}
3499
3500fn project_left(project: &db::LeftProject, session: &Session) {
3501    for connection_id in &project.connection_ids {
3502        if project.host_user_id == session.user_id {
3503            session
3504                .peer
3505                .send(
3506                    *connection_id,
3507                    proto::UnshareProject {
3508                        project_id: project.id.to_proto(),
3509                    },
3510                )
3511                .trace_err();
3512        } else {
3513            session
3514                .peer
3515                .send(
3516                    *connection_id,
3517                    proto::RemoveProjectCollaborator {
3518                        project_id: project.id.to_proto(),
3519                        peer_id: Some(session.connection_id.into()),
3520                    },
3521                )
3522                .trace_err();
3523        }
3524    }
3525}
3526
3527pub trait ResultExt {
3528    type Ok;
3529
3530    fn trace_err(self) -> Option<Self::Ok>;
3531}
3532
3533impl<T, E> ResultExt for Result<T, E>
3534where
3535    E: std::fmt::Debug,
3536{
3537    type Ok = T;
3538
3539    fn trace_err(self) -> Option<T> {
3540        match self {
3541            Ok(value) => Some(value),
3542            Err(error) => {
3543                tracing::error!("{:?}", error);
3544                None
3545            }
3546        }
3547    }
3548}