rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
   7        ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  74
  75lazy_static! {
  76    static ref METRIC_CONNECTIONS: IntGauge =
  77        register_int_gauge!("connections", "number of connections").unwrap();
  78    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  79        "shared_projects",
  80        "number of open projects with one or more guests"
  81    )
  82    .unwrap();
  83}
  84
  85type MessageHandler =
  86    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  87
  88struct Response<R> {
  89    peer: Arc<Peer>,
  90    receipt: Receipt<R>,
  91    responded: Arc<AtomicBool>,
  92}
  93
  94impl<R: RequestMessage> Response<R> {
  95    fn send(self, payload: R::Response) -> Result<()> {
  96        self.responded.store(true, SeqCst);
  97        self.peer.respond(self.receipt, payload)?;
  98        Ok(())
  99    }
 100}
 101
 102#[derive(Clone)]
 103struct Session {
 104    user_id: UserId,
 105    connection_id: ConnectionId,
 106    db: Arc<tokio::sync::Mutex<DbHandle>>,
 107    peer: Arc<Peer>,
 108    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 109    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 110    executor: Executor,
 111}
 112
 113impl Session {
 114    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.db.lock().await;
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        guard
 121    }
 122
 123    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 124        #[cfg(test)]
 125        tokio::task::yield_now().await;
 126        let guard = self.connection_pool.lock();
 127        ConnectionPoolGuard {
 128            guard,
 129            _not_send: PhantomData,
 130        }
 131    }
 132}
 133
 134impl fmt::Debug for Session {
 135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 136        f.debug_struct("Session")
 137            .field("user_id", &self.user_id)
 138            .field("connection_id", &self.connection_id)
 139            .finish()
 140    }
 141}
 142
 143struct DbHandle(Arc<Database>);
 144
 145impl Deref for DbHandle {
 146    type Target = Database;
 147
 148    fn deref(&self) -> &Self::Target {
 149        self.0.as_ref()
 150    }
 151}
 152
 153pub struct Server {
 154    id: parking_lot::Mutex<ServerId>,
 155    peer: Arc<Peer>,
 156    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 157    app_state: Arc<AppState>,
 158    executor: Executor,
 159    handlers: HashMap<TypeId, MessageHandler>,
 160    teardown: watch::Sender<()>,
 161}
 162
 163pub(crate) struct ConnectionPoolGuard<'a> {
 164    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 165    _not_send: PhantomData<Rc<()>>,
 166}
 167
 168#[derive(Serialize)]
 169pub struct ServerSnapshot<'a> {
 170    peer: &'a Peer,
 171    #[serde(serialize_with = "serialize_deref")]
 172    connection_pool: ConnectionPoolGuard<'a>,
 173}
 174
 175pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 176where
 177    S: Serializer,
 178    T: Deref<Target = U>,
 179    U: Serialize,
 180{
 181    Serialize::serialize(value.deref(), serializer)
 182}
 183
 184impl Server {
 185    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 186        let mut server = Self {
 187            id: parking_lot::Mutex::new(id),
 188            peer: Peer::new(id.0 as u32),
 189            app_state,
 190            executor,
 191            connection_pool: Default::default(),
 192            handlers: Default::default(),
 193            teardown: watch::channel(()).0,
 194        };
 195
 196        server
 197            .add_request_handler(ping)
 198            .add_request_handler(create_room)
 199            .add_request_handler(join_room)
 200            .add_request_handler(rejoin_room)
 201            .add_request_handler(leave_room)
 202            .add_request_handler(call)
 203            .add_request_handler(cancel_call)
 204            .add_message_handler(decline_call)
 205            .add_request_handler(update_participant_location)
 206            .add_request_handler(share_project)
 207            .add_message_handler(unshare_project)
 208            .add_request_handler(join_project)
 209            .add_message_handler(leave_project)
 210            .add_request_handler(update_project)
 211            .add_request_handler(update_worktree)
 212            .add_message_handler(start_language_server)
 213            .add_message_handler(update_language_server)
 214            .add_message_handler(update_diagnostic_summary)
 215            .add_message_handler(update_worktree_settings)
 216            .add_message_handler(refresh_inlay_hints)
 217            .add_request_handler(forward_project_request::<proto::GetHover>)
 218            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 220            .add_request_handler(forward_project_request::<proto::GetReferences>)
 221            .add_request_handler(forward_project_request::<proto::SearchProject>)
 222            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 223            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 226            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 227            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 229            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 230            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 232            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 233            .add_request_handler(forward_project_request::<proto::PerformRename>)
 234            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 235            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 236            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 237            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 243            .add_request_handler(forward_project_request::<proto::InlayHints>)
 244            .add_message_handler(create_buffer_for_peer)
 245            .add_request_handler(update_buffer)
 246            .add_message_handler(update_buffer_file)
 247            .add_message_handler(buffer_reloaded)
 248            .add_message_handler(buffer_saved)
 249            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 250            .add_request_handler(get_users)
 251            .add_request_handler(fuzzy_search_users)
 252            .add_request_handler(request_contact)
 253            .add_request_handler(remove_contact)
 254            .add_request_handler(respond_to_contact_request)
 255            .add_request_handler(create_channel)
 256            .add_request_handler(delete_channel)
 257            .add_request_handler(invite_channel_member)
 258            .add_request_handler(remove_channel_member)
 259            .add_request_handler(set_channel_member_admin)
 260            .add_request_handler(rename_channel)
 261            .add_request_handler(join_channel_buffer)
 262            .add_request_handler(leave_channel_buffer)
 263            .add_message_handler(update_channel_buffer)
 264            .add_request_handler(rejoin_channel_buffers)
 265            .add_request_handler(get_channel_members)
 266            .add_request_handler(respond_to_channel_invite)
 267            .add_request_handler(join_channel)
 268            .add_request_handler(join_channel_chat)
 269            .add_message_handler(leave_channel_chat)
 270            .add_request_handler(send_channel_message)
 271            .add_request_handler(remove_channel_message)
 272            .add_request_handler(get_channel_messages)
 273            .add_request_handler(get_notifications)
 274            .add_request_handler(link_channel)
 275            .add_request_handler(unlink_channel)
 276            .add_request_handler(move_channel)
 277            .add_request_handler(follow)
 278            .add_message_handler(unfollow)
 279            .add_message_handler(update_followers)
 280            .add_message_handler(update_diff_base)
 281            .add_request_handler(get_private_user_info)
 282            .add_message_handler(acknowledge_channel_message)
 283            .add_message_handler(acknowledge_buffer_version);
 284
 285        Arc::new(server)
 286    }
 287
 288    pub async fn start(&self) -> Result<()> {
 289        let server_id = *self.id.lock();
 290        let app_state = self.app_state.clone();
 291        let peer = self.peer.clone();
 292        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 293        let pool = self.connection_pool.clone();
 294        let live_kit_client = self.app_state.live_kit_client.clone();
 295
 296        let span = info_span!("start server");
 297        self.executor.spawn_detached(
 298            async move {
 299                tracing::info!("waiting for cleanup timeout");
 300                timeout.await;
 301                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 302                if let Some((room_ids, channel_ids)) = app_state
 303                    .db
 304                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 305                    .await
 306                    .trace_err()
 307                {
 308                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 309                    tracing::info!(
 310                        stale_channel_buffer_count = channel_ids.len(),
 311                        "retrieved stale channel buffers"
 312                    );
 313
 314                    for channel_id in channel_ids {
 315                        if let Some(refreshed_channel_buffer) = app_state
 316                            .db
 317                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 318                            .await
 319                            .trace_err()
 320                        {
 321                            for connection_id in refreshed_channel_buffer.connection_ids {
 322                                peer.send(
 323                                    connection_id,
 324                                    proto::UpdateChannelBufferCollaborators {
 325                                        channel_id: channel_id.to_proto(),
 326                                        collaborators: refreshed_channel_buffer
 327                                            .collaborators
 328                                            .clone(),
 329                                    },
 330                                )
 331                                .trace_err();
 332                            }
 333                        }
 334                    }
 335
 336                    for room_id in room_ids {
 337                        let mut contacts_to_update = HashSet::default();
 338                        let mut canceled_calls_to_user_ids = Vec::new();
 339                        let mut live_kit_room = String::new();
 340                        let mut delete_live_kit_room = false;
 341
 342                        if let Some(mut refreshed_room) = app_state
 343                            .db
 344                            .clear_stale_room_participants(room_id, server_id)
 345                            .await
 346                            .trace_err()
 347                        {
 348                            tracing::info!(
 349                                room_id = room_id.0,
 350                                new_participant_count = refreshed_room.room.participants.len(),
 351                                "refreshed room"
 352                            );
 353                            room_updated(&refreshed_room.room, &peer);
 354                            if let Some(channel_id) = refreshed_room.channel_id {
 355                                channel_updated(
 356                                    channel_id,
 357                                    &refreshed_room.room,
 358                                    &refreshed_room.channel_members,
 359                                    &peer,
 360                                    &*pool.lock(),
 361                                );
 362                            }
 363                            contacts_to_update
 364                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 365                            contacts_to_update
 366                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 367                            canceled_calls_to_user_ids =
 368                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 369                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 370                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 371                        }
 372
 373                        {
 374                            let pool = pool.lock();
 375                            for canceled_user_id in canceled_calls_to_user_ids {
 376                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 377                                    peer.send(
 378                                        connection_id,
 379                                        proto::CallCanceled {
 380                                            room_id: room_id.to_proto(),
 381                                        },
 382                                    )
 383                                    .trace_err();
 384                                }
 385                            }
 386                        }
 387
 388                        for user_id in contacts_to_update {
 389                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 390                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 391                            if let Some((busy, contacts)) = busy.zip(contacts) {
 392                                let pool = pool.lock();
 393                                let updated_contact = contact_for_user(user_id, busy, &pool);
 394                                for contact in contacts {
 395                                    if let db::Contact::Accepted {
 396                                        user_id: contact_user_id,
 397                                        ..
 398                                    } = contact
 399                                    {
 400                                        for contact_conn_id in
 401                                            pool.user_connection_ids(contact_user_id)
 402                                        {
 403                                            peer.send(
 404                                                contact_conn_id,
 405                                                proto::UpdateContacts {
 406                                                    contacts: vec![updated_contact.clone()],
 407                                                    remove_contacts: Default::default(),
 408                                                    incoming_requests: Default::default(),
 409                                                    remove_incoming_requests: Default::default(),
 410                                                    outgoing_requests: Default::default(),
 411                                                    remove_outgoing_requests: Default::default(),
 412                                                },
 413                                            )
 414                                            .trace_err();
 415                                        }
 416                                    }
 417                                }
 418                            }
 419                        }
 420
 421                        if let Some(live_kit) = live_kit_client.as_ref() {
 422                            if delete_live_kit_room {
 423                                live_kit.delete_room(live_kit_room).await.trace_err();
 424                            }
 425                        }
 426                    }
 427                }
 428
 429                app_state
 430                    .db
 431                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 432                    .await
 433                    .trace_err();
 434            }
 435            .instrument(span),
 436        );
 437        Ok(())
 438    }
 439
 440    pub fn teardown(&self) {
 441        self.peer.teardown();
 442        self.connection_pool.lock().reset();
 443        let _ = self.teardown.send(());
 444    }
 445
 446    #[cfg(test)]
 447    pub fn reset(&self, id: ServerId) {
 448        self.teardown();
 449        *self.id.lock() = id;
 450        self.peer.reset(id.0 as u32);
 451    }
 452
 453    #[cfg(test)]
 454    pub fn id(&self) -> ServerId {
 455        *self.id.lock()
 456    }
 457
 458    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 459    where
 460        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 461        Fut: 'static + Send + Future<Output = Result<()>>,
 462        M: EnvelopedMessage,
 463    {
 464        let prev_handler = self.handlers.insert(
 465            TypeId::of::<M>(),
 466            Box::new(move |envelope, session| {
 467                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 468                let span = info_span!(
 469                    "handle message",
 470                    payload_type = envelope.payload_type_name()
 471                );
 472                span.in_scope(|| {
 473                    tracing::info!(
 474                        payload_type = envelope.payload_type_name(),
 475                        "message received"
 476                    );
 477                });
 478                let start_time = Instant::now();
 479                let future = (handler)(*envelope, session);
 480                async move {
 481                    let result = future.await;
 482                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 483                    match result {
 484                        Err(error) => {
 485                            tracing::error!(%error, ?duration_ms, "error handling message")
 486                        }
 487                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 488                    }
 489                }
 490                .instrument(span)
 491                .boxed()
 492            }),
 493        );
 494        if prev_handler.is_some() {
 495            panic!("registered a handler for the same message twice");
 496        }
 497        self
 498    }
 499
 500    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 501    where
 502        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 503        Fut: 'static + Send + Future<Output = Result<()>>,
 504        M: EnvelopedMessage,
 505    {
 506        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 507        self
 508    }
 509
 510    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 511    where
 512        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 513        Fut: Send + Future<Output = Result<()>>,
 514        M: RequestMessage,
 515    {
 516        let handler = Arc::new(handler);
 517        self.add_handler(move |envelope, session| {
 518            let receipt = envelope.receipt();
 519            let handler = handler.clone();
 520            async move {
 521                let peer = session.peer.clone();
 522                let responded = Arc::new(AtomicBool::default());
 523                let response = Response {
 524                    peer: peer.clone(),
 525                    responded: responded.clone(),
 526                    receipt,
 527                };
 528                match (handler)(envelope.payload, response, session).await {
 529                    Ok(()) => {
 530                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 531                            Ok(())
 532                        } else {
 533                            Err(anyhow!("handler did not send a response"))?
 534                        }
 535                    }
 536                    Err(error) => {
 537                        peer.respond_with_error(
 538                            receipt,
 539                            proto::Error {
 540                                message: error.to_string(),
 541                            },
 542                        )?;
 543                        Err(error)
 544                    }
 545                }
 546            }
 547        })
 548    }
 549
 550    pub fn handle_connection(
 551        self: &Arc<Self>,
 552        connection: Connection,
 553        address: String,
 554        user: User,
 555        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 556        executor: Executor,
 557    ) -> impl Future<Output = Result<()>> {
 558        let this = self.clone();
 559        let user_id = user.id;
 560        let login = user.github_login;
 561        let span = info_span!("handle connection", %user_id, %login, %address);
 562        let mut teardown = self.teardown.subscribe();
 563        async move {
 564            let (connection_id, handle_io, mut incoming_rx) = this
 565                .peer
 566                .add_connection(connection, {
 567                    let executor = executor.clone();
 568                    move |duration| executor.sleep(duration)
 569                });
 570
 571            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 572            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 573            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 574
 575            if let Some(send_connection_id) = send_connection_id.take() {
 576                let _ = send_connection_id.send(connection_id);
 577            }
 578
 579            if !user.connected_once {
 580                this.peer.send(connection_id, proto::ShowContacts {})?;
 581                this.app_state.db.set_user_connected_once(user_id, true).await?;
 582            }
 583
 584            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 585                this.app_state.db.get_contacts(user_id),
 586                this.app_state.db.get_channels_for_user(user_id),
 587                this.app_state.db.get_channel_invites_for_user(user_id),
 588            ).await?;
 589
 590            {
 591                let mut pool = this.connection_pool.lock();
 592                pool.add_connection(connection_id, user_id, user.admin);
 593                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 594                this.peer.send(connection_id, build_initial_channels_update(
 595                    channels_for_user,
 596                    channel_invites
 597                ))?;
 598            }
 599
 600            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 601                this.peer.send(connection_id, incoming_call)?;
 602            }
 603
 604            let session = Session {
 605                user_id,
 606                connection_id,
 607                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 608                peer: this.peer.clone(),
 609                connection_pool: this.connection_pool.clone(),
 610                live_kit_client: this.app_state.live_kit_client.clone(),
 611                executor: executor.clone(),
 612            };
 613            update_user_contacts(user_id, &session).await?;
 614
 615            let handle_io = handle_io.fuse();
 616            futures::pin_mut!(handle_io);
 617
 618            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 619            // This prevents deadlocks when e.g., client A performs a request to client B and
 620            // client B performs a request to client A. If both clients stop processing further
 621            // messages until their respective request completes, they won't have a chance to
 622            // respond to the other client's request and cause a deadlock.
 623            //
 624            // This arrangement ensures we will attempt to process earlier messages first, but fall
 625            // back to processing messages arrived later in the spirit of making progress.
 626            let mut foreground_message_handlers = FuturesUnordered::new();
 627            let concurrent_handlers = Arc::new(Semaphore::new(256));
 628            loop {
 629                let next_message = async {
 630                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 631                    let message = incoming_rx.next().await;
 632                    (permit, message)
 633                }.fuse();
 634                futures::pin_mut!(next_message);
 635                futures::select_biased! {
 636                    _ = teardown.changed().fuse() => return Ok(()),
 637                    result = handle_io => {
 638                        if let Err(error) = result {
 639                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 640                        }
 641                        break;
 642                    }
 643                    _ = foreground_message_handlers.next() => {}
 644                    next_message = next_message => {
 645                        let (permit, message) = next_message;
 646                        if let Some(message) = message {
 647                            let type_name = message.payload_type_name();
 648                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 649                            let span_enter = span.enter();
 650                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 651                                let is_background = message.is_background();
 652                                let handle_message = (handler)(message, session.clone());
 653                                drop(span_enter);
 654
 655                                let handle_message = async move {
 656                                    handle_message.await;
 657                                    drop(permit);
 658                                }.instrument(span);
 659                                if is_background {
 660                                    executor.spawn_detached(handle_message);
 661                                } else {
 662                                    foreground_message_handlers.push(handle_message);
 663                                }
 664                            } else {
 665                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 666                            }
 667                        } else {
 668                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 669                            break;
 670                        }
 671                    }
 672                }
 673            }
 674
 675            drop(foreground_message_handlers);
 676            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 677            if let Err(error) = connection_lost(session, teardown, executor).await {
 678                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 679            }
 680
 681            Ok(())
 682        }.instrument(span)
 683    }
 684
 685    pub async fn invite_code_redeemed(
 686        self: &Arc<Self>,
 687        inviter_id: UserId,
 688        invitee_id: UserId,
 689    ) -> Result<()> {
 690        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 691            if let Some(code) = &user.invite_code {
 692                let pool = self.connection_pool.lock();
 693                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 694                for connection_id in pool.user_connection_ids(inviter_id) {
 695                    self.peer.send(
 696                        connection_id,
 697                        proto::UpdateContacts {
 698                            contacts: vec![invitee_contact.clone()],
 699                            ..Default::default()
 700                        },
 701                    )?;
 702                    self.peer.send(
 703                        connection_id,
 704                        proto::UpdateInviteInfo {
 705                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 706                            count: user.invite_count as u32,
 707                        },
 708                    )?;
 709                }
 710            }
 711        }
 712        Ok(())
 713    }
 714
 715    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 716        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 717            if let Some(invite_code) = &user.invite_code {
 718                let pool = self.connection_pool.lock();
 719                for connection_id in pool.user_connection_ids(user_id) {
 720                    self.peer.send(
 721                        connection_id,
 722                        proto::UpdateInviteInfo {
 723                            url: format!(
 724                                "{}{}",
 725                                self.app_state.config.invite_link_prefix, invite_code
 726                            ),
 727                            count: user.invite_count as u32,
 728                        },
 729                    )?;
 730                }
 731            }
 732        }
 733        Ok(())
 734    }
 735
 736    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 737        ServerSnapshot {
 738            connection_pool: ConnectionPoolGuard {
 739                guard: self.connection_pool.lock(),
 740                _not_send: PhantomData,
 741            },
 742            peer: &self.peer,
 743        }
 744    }
 745}
 746
 747impl<'a> Deref for ConnectionPoolGuard<'a> {
 748    type Target = ConnectionPool;
 749
 750    fn deref(&self) -> &Self::Target {
 751        &*self.guard
 752    }
 753}
 754
 755impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 756    fn deref_mut(&mut self) -> &mut Self::Target {
 757        &mut *self.guard
 758    }
 759}
 760
 761impl<'a> Drop for ConnectionPoolGuard<'a> {
 762    fn drop(&mut self) {
 763        #[cfg(test)]
 764        self.check_invariants();
 765    }
 766}
 767
 768fn broadcast<F>(
 769    sender_id: Option<ConnectionId>,
 770    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 771    mut f: F,
 772) where
 773    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 774{
 775    for receiver_id in receiver_ids {
 776        if Some(receiver_id) != sender_id {
 777            if let Err(error) = f(receiver_id) {
 778                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 779            }
 780        }
 781    }
 782}
 783
 784lazy_static! {
 785    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 786}
 787
 788pub struct ProtocolVersion(u32);
 789
 790impl Header for ProtocolVersion {
 791    fn name() -> &'static HeaderName {
 792        &ZED_PROTOCOL_VERSION
 793    }
 794
 795    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 796    where
 797        Self: Sized,
 798        I: Iterator<Item = &'i axum::http::HeaderValue>,
 799    {
 800        let version = values
 801            .next()
 802            .ok_or_else(axum::headers::Error::invalid)?
 803            .to_str()
 804            .map_err(|_| axum::headers::Error::invalid())?
 805            .parse()
 806            .map_err(|_| axum::headers::Error::invalid())?;
 807        Ok(Self(version))
 808    }
 809
 810    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 811        values.extend([self.0.to_string().parse().unwrap()]);
 812    }
 813}
 814
 815pub fn routes(server: Arc<Server>) -> Router<Body> {
 816    Router::new()
 817        .route("/rpc", get(handle_websocket_request))
 818        .layer(
 819            ServiceBuilder::new()
 820                .layer(Extension(server.app_state.clone()))
 821                .layer(middleware::from_fn(auth::validate_header)),
 822        )
 823        .route("/metrics", get(handle_metrics))
 824        .layer(Extension(server))
 825}
 826
 827pub async fn handle_websocket_request(
 828    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 829    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 830    Extension(server): Extension<Arc<Server>>,
 831    Extension(user): Extension<User>,
 832    ws: WebSocketUpgrade,
 833) -> axum::response::Response {
 834    if protocol_version != rpc::PROTOCOL_VERSION {
 835        return (
 836            StatusCode::UPGRADE_REQUIRED,
 837            "client must be upgraded".to_string(),
 838        )
 839            .into_response();
 840    }
 841    let socket_address = socket_address.to_string();
 842    ws.on_upgrade(move |socket| {
 843        use util::ResultExt;
 844        let socket = socket
 845            .map_ok(to_tungstenite_message)
 846            .err_into()
 847            .with(|message| async move { Ok(to_axum_message(message)) });
 848        let connection = Connection::new(Box::pin(socket));
 849        async move {
 850            server
 851                .handle_connection(connection, socket_address, user, None, Executor::Production)
 852                .await
 853                .log_err();
 854        }
 855    })
 856}
 857
 858pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 859    let connections = server
 860        .connection_pool
 861        .lock()
 862        .connections()
 863        .filter(|connection| !connection.admin)
 864        .count();
 865
 866    METRIC_CONNECTIONS.set(connections as _);
 867
 868    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 869    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 870
 871    let encoder = prometheus::TextEncoder::new();
 872    let metric_families = prometheus::gather();
 873    let encoded_metrics = encoder
 874        .encode_to_string(&metric_families)
 875        .map_err(|err| anyhow!("{}", err))?;
 876    Ok(encoded_metrics)
 877}
 878
 879#[instrument(err, skip(executor))]
 880async fn connection_lost(
 881    session: Session,
 882    mut teardown: watch::Receiver<()>,
 883    executor: Executor,
 884) -> Result<()> {
 885    session.peer.disconnect(session.connection_id);
 886    session
 887        .connection_pool()
 888        .await
 889        .remove_connection(session.connection_id)?;
 890
 891    session
 892        .db()
 893        .await
 894        .connection_lost(session.connection_id)
 895        .await
 896        .trace_err();
 897
 898    futures::select_biased! {
 899        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 900            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 901            leave_room_for_session(&session).await.trace_err();
 902            leave_channel_buffers_for_session(&session)
 903                .await
 904                .trace_err();
 905
 906            if !session
 907                .connection_pool()
 908                .await
 909                .is_user_online(session.user_id)
 910            {
 911                let db = session.db().await;
 912                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 913                    room_updated(&room, &session.peer);
 914                }
 915            }
 916
 917            update_user_contacts(session.user_id, &session).await?;
 918        }
 919        _ = teardown.changed().fuse() => {}
 920    }
 921
 922    Ok(())
 923}
 924
 925async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 926    response.send(proto::Ack {})?;
 927    Ok(())
 928}
 929
 930async fn create_room(
 931    _request: proto::CreateRoom,
 932    response: Response<proto::CreateRoom>,
 933    session: Session,
 934) -> Result<()> {
 935    let live_kit_room = nanoid::nanoid!(30);
 936
 937    let live_kit_connection_info = {
 938        let live_kit_room = live_kit_room.clone();
 939        let live_kit = session.live_kit_client.as_ref();
 940
 941        util::async_iife!({
 942            let live_kit = live_kit?;
 943
 944            let token = live_kit
 945                .room_token(&live_kit_room, &session.user_id.to_string())
 946                .trace_err()?;
 947
 948            Some(proto::LiveKitConnectionInfo {
 949                server_url: live_kit.url().into(),
 950                token,
 951            })
 952        })
 953    }
 954    .await;
 955
 956    let room = session
 957        .db()
 958        .await
 959        .create_room(
 960            session.user_id,
 961            session.connection_id,
 962            &live_kit_room,
 963            RELEASE_CHANNEL_NAME.as_str(),
 964        )
 965        .await?;
 966
 967    response.send(proto::CreateRoomResponse {
 968        room: Some(room.clone()),
 969        live_kit_connection_info,
 970    })?;
 971
 972    update_user_contacts(session.user_id, &session).await?;
 973    Ok(())
 974}
 975
 976async fn join_room(
 977    request: proto::JoinRoom,
 978    response: Response<proto::JoinRoom>,
 979    session: Session,
 980) -> Result<()> {
 981    let room_id = RoomId::from_proto(request.id);
 982    let joined_room = {
 983        let room = session
 984            .db()
 985            .await
 986            .join_room(
 987                room_id,
 988                session.user_id,
 989                session.connection_id,
 990                RELEASE_CHANNEL_NAME.as_str(),
 991            )
 992            .await?;
 993        room_updated(&room.room, &session.peer);
 994        room.into_inner()
 995    };
 996
 997    if let Some(channel_id) = joined_room.channel_id {
 998        channel_updated(
 999            channel_id,
1000            &joined_room.room,
1001            &joined_room.channel_members,
1002            &session.peer,
1003            &*session.connection_pool().await,
1004        )
1005    }
1006
1007    for connection_id in session
1008        .connection_pool()
1009        .await
1010        .user_connection_ids(session.user_id)
1011    {
1012        session
1013            .peer
1014            .send(
1015                connection_id,
1016                proto::CallCanceled {
1017                    room_id: room_id.to_proto(),
1018                },
1019            )
1020            .trace_err();
1021    }
1022
1023    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1024        if let Some(token) = live_kit
1025            .room_token(
1026                &joined_room.room.live_kit_room,
1027                &session.user_id.to_string(),
1028            )
1029            .trace_err()
1030        {
1031            Some(proto::LiveKitConnectionInfo {
1032                server_url: live_kit.url().into(),
1033                token,
1034            })
1035        } else {
1036            None
1037        }
1038    } else {
1039        None
1040    };
1041
1042    response.send(proto::JoinRoomResponse {
1043        room: Some(joined_room.room),
1044        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1045        live_kit_connection_info,
1046    })?;
1047
1048    update_user_contacts(session.user_id, &session).await?;
1049    Ok(())
1050}
1051
1052async fn rejoin_room(
1053    request: proto::RejoinRoom,
1054    response: Response<proto::RejoinRoom>,
1055    session: Session,
1056) -> Result<()> {
1057    let room;
1058    let channel_id;
1059    let channel_members;
1060    {
1061        let mut rejoined_room = session
1062            .db()
1063            .await
1064            .rejoin_room(request, session.user_id, session.connection_id)
1065            .await?;
1066
1067        response.send(proto::RejoinRoomResponse {
1068            room: Some(rejoined_room.room.clone()),
1069            reshared_projects: rejoined_room
1070                .reshared_projects
1071                .iter()
1072                .map(|project| proto::ResharedProject {
1073                    id: project.id.to_proto(),
1074                    collaborators: project
1075                        .collaborators
1076                        .iter()
1077                        .map(|collaborator| collaborator.to_proto())
1078                        .collect(),
1079                })
1080                .collect(),
1081            rejoined_projects: rejoined_room
1082                .rejoined_projects
1083                .iter()
1084                .map(|rejoined_project| proto::RejoinedProject {
1085                    id: rejoined_project.id.to_proto(),
1086                    worktrees: rejoined_project
1087                        .worktrees
1088                        .iter()
1089                        .map(|worktree| proto::WorktreeMetadata {
1090                            id: worktree.id,
1091                            root_name: worktree.root_name.clone(),
1092                            visible: worktree.visible,
1093                            abs_path: worktree.abs_path.clone(),
1094                        })
1095                        .collect(),
1096                    collaborators: rejoined_project
1097                        .collaborators
1098                        .iter()
1099                        .map(|collaborator| collaborator.to_proto())
1100                        .collect(),
1101                    language_servers: rejoined_project.language_servers.clone(),
1102                })
1103                .collect(),
1104        })?;
1105        room_updated(&rejoined_room.room, &session.peer);
1106
1107        for project in &rejoined_room.reshared_projects {
1108            for collaborator in &project.collaborators {
1109                session
1110                    .peer
1111                    .send(
1112                        collaborator.connection_id,
1113                        proto::UpdateProjectCollaborator {
1114                            project_id: project.id.to_proto(),
1115                            old_peer_id: Some(project.old_connection_id.into()),
1116                            new_peer_id: Some(session.connection_id.into()),
1117                        },
1118                    )
1119                    .trace_err();
1120            }
1121
1122            broadcast(
1123                Some(session.connection_id),
1124                project
1125                    .collaborators
1126                    .iter()
1127                    .map(|collaborator| collaborator.connection_id),
1128                |connection_id| {
1129                    session.peer.forward_send(
1130                        session.connection_id,
1131                        connection_id,
1132                        proto::UpdateProject {
1133                            project_id: project.id.to_proto(),
1134                            worktrees: project.worktrees.clone(),
1135                        },
1136                    )
1137                },
1138            );
1139        }
1140
1141        for project in &rejoined_room.rejoined_projects {
1142            for collaborator in &project.collaborators {
1143                session
1144                    .peer
1145                    .send(
1146                        collaborator.connection_id,
1147                        proto::UpdateProjectCollaborator {
1148                            project_id: project.id.to_proto(),
1149                            old_peer_id: Some(project.old_connection_id.into()),
1150                            new_peer_id: Some(session.connection_id.into()),
1151                        },
1152                    )
1153                    .trace_err();
1154            }
1155        }
1156
1157        for project in &mut rejoined_room.rejoined_projects {
1158            for worktree in mem::take(&mut project.worktrees) {
1159                #[cfg(any(test, feature = "test-support"))]
1160                const MAX_CHUNK_SIZE: usize = 2;
1161                #[cfg(not(any(test, feature = "test-support")))]
1162                const MAX_CHUNK_SIZE: usize = 256;
1163
1164                // Stream this worktree's entries.
1165                let message = proto::UpdateWorktree {
1166                    project_id: project.id.to_proto(),
1167                    worktree_id: worktree.id,
1168                    abs_path: worktree.abs_path.clone(),
1169                    root_name: worktree.root_name,
1170                    updated_entries: worktree.updated_entries,
1171                    removed_entries: worktree.removed_entries,
1172                    scan_id: worktree.scan_id,
1173                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1174                    updated_repositories: worktree.updated_repositories,
1175                    removed_repositories: worktree.removed_repositories,
1176                };
1177                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1178                    session.peer.send(session.connection_id, update.clone())?;
1179                }
1180
1181                // Stream this worktree's diagnostics.
1182                for summary in worktree.diagnostic_summaries {
1183                    session.peer.send(
1184                        session.connection_id,
1185                        proto::UpdateDiagnosticSummary {
1186                            project_id: project.id.to_proto(),
1187                            worktree_id: worktree.id,
1188                            summary: Some(summary),
1189                        },
1190                    )?;
1191                }
1192
1193                for settings_file in worktree.settings_files {
1194                    session.peer.send(
1195                        session.connection_id,
1196                        proto::UpdateWorktreeSettings {
1197                            project_id: project.id.to_proto(),
1198                            worktree_id: worktree.id,
1199                            path: settings_file.path,
1200                            content: Some(settings_file.content),
1201                        },
1202                    )?;
1203                }
1204            }
1205
1206            for language_server in &project.language_servers {
1207                session.peer.send(
1208                    session.connection_id,
1209                    proto::UpdateLanguageServer {
1210                        project_id: project.id.to_proto(),
1211                        language_server_id: language_server.id,
1212                        variant: Some(
1213                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1214                                proto::LspDiskBasedDiagnosticsUpdated {},
1215                            ),
1216                        ),
1217                    },
1218                )?;
1219            }
1220        }
1221
1222        let rejoined_room = rejoined_room.into_inner();
1223
1224        room = rejoined_room.room;
1225        channel_id = rejoined_room.channel_id;
1226        channel_members = rejoined_room.channel_members;
1227    }
1228
1229    if let Some(channel_id) = channel_id {
1230        channel_updated(
1231            channel_id,
1232            &room,
1233            &channel_members,
1234            &session.peer,
1235            &*session.connection_pool().await,
1236        );
1237    }
1238
1239    update_user_contacts(session.user_id, &session).await?;
1240    Ok(())
1241}
1242
1243async fn leave_room(
1244    _: proto::LeaveRoom,
1245    response: Response<proto::LeaveRoom>,
1246    session: Session,
1247) -> Result<()> {
1248    leave_room_for_session(&session).await?;
1249    response.send(proto::Ack {})?;
1250    Ok(())
1251}
1252
1253async fn call(
1254    request: proto::Call,
1255    response: Response<proto::Call>,
1256    session: Session,
1257) -> Result<()> {
1258    let room_id = RoomId::from_proto(request.room_id);
1259    let calling_user_id = session.user_id;
1260    let calling_connection_id = session.connection_id;
1261    let called_user_id = UserId::from_proto(request.called_user_id);
1262    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1263    if !session
1264        .db()
1265        .await
1266        .has_contact(calling_user_id, called_user_id)
1267        .await?
1268    {
1269        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1270    }
1271
1272    let incoming_call = {
1273        let (room, incoming_call) = &mut *session
1274            .db()
1275            .await
1276            .call(
1277                room_id,
1278                calling_user_id,
1279                calling_connection_id,
1280                called_user_id,
1281                initial_project_id,
1282            )
1283            .await?;
1284        room_updated(&room, &session.peer);
1285        mem::take(incoming_call)
1286    };
1287    update_user_contacts(called_user_id, &session).await?;
1288
1289    let mut calls = session
1290        .connection_pool()
1291        .await
1292        .user_connection_ids(called_user_id)
1293        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1294        .collect::<FuturesUnordered<_>>();
1295
1296    while let Some(call_response) = calls.next().await {
1297        match call_response.as_ref() {
1298            Ok(_) => {
1299                response.send(proto::Ack {})?;
1300                return Ok(());
1301            }
1302            Err(_) => {
1303                call_response.trace_err();
1304            }
1305        }
1306    }
1307
1308    {
1309        let room = session
1310            .db()
1311            .await
1312            .call_failed(room_id, called_user_id)
1313            .await?;
1314        room_updated(&room, &session.peer);
1315    }
1316    update_user_contacts(called_user_id, &session).await?;
1317
1318    Err(anyhow!("failed to ring user"))?
1319}
1320
1321async fn cancel_call(
1322    request: proto::CancelCall,
1323    response: Response<proto::CancelCall>,
1324    session: Session,
1325) -> Result<()> {
1326    let called_user_id = UserId::from_proto(request.called_user_id);
1327    let room_id = RoomId::from_proto(request.room_id);
1328    {
1329        let room = session
1330            .db()
1331            .await
1332            .cancel_call(room_id, session.connection_id, called_user_id)
1333            .await?;
1334        room_updated(&room, &session.peer);
1335    }
1336
1337    for connection_id in session
1338        .connection_pool()
1339        .await
1340        .user_connection_ids(called_user_id)
1341    {
1342        session
1343            .peer
1344            .send(
1345                connection_id,
1346                proto::CallCanceled {
1347                    room_id: room_id.to_proto(),
1348                },
1349            )
1350            .trace_err();
1351    }
1352    response.send(proto::Ack {})?;
1353
1354    update_user_contacts(called_user_id, &session).await?;
1355    Ok(())
1356}
1357
1358async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1359    let room_id = RoomId::from_proto(message.room_id);
1360    {
1361        let room = session
1362            .db()
1363            .await
1364            .decline_call(Some(room_id), session.user_id)
1365            .await?
1366            .ok_or_else(|| anyhow!("failed to decline call"))?;
1367        room_updated(&room, &session.peer);
1368    }
1369
1370    for connection_id in session
1371        .connection_pool()
1372        .await
1373        .user_connection_ids(session.user_id)
1374    {
1375        session
1376            .peer
1377            .send(
1378                connection_id,
1379                proto::CallCanceled {
1380                    room_id: room_id.to_proto(),
1381                },
1382            )
1383            .trace_err();
1384    }
1385    update_user_contacts(session.user_id, &session).await?;
1386    Ok(())
1387}
1388
1389async fn update_participant_location(
1390    request: proto::UpdateParticipantLocation,
1391    response: Response<proto::UpdateParticipantLocation>,
1392    session: Session,
1393) -> Result<()> {
1394    let room_id = RoomId::from_proto(request.room_id);
1395    let location = request
1396        .location
1397        .ok_or_else(|| anyhow!("invalid location"))?;
1398
1399    let db = session.db().await;
1400    let room = db
1401        .update_room_participant_location(room_id, session.connection_id, location)
1402        .await?;
1403
1404    room_updated(&room, &session.peer);
1405    response.send(proto::Ack {})?;
1406    Ok(())
1407}
1408
1409async fn share_project(
1410    request: proto::ShareProject,
1411    response: Response<proto::ShareProject>,
1412    session: Session,
1413) -> Result<()> {
1414    let (project_id, room) = &*session
1415        .db()
1416        .await
1417        .share_project(
1418            RoomId::from_proto(request.room_id),
1419            session.connection_id,
1420            &request.worktrees,
1421        )
1422        .await?;
1423    response.send(proto::ShareProjectResponse {
1424        project_id: project_id.to_proto(),
1425    })?;
1426    room_updated(&room, &session.peer);
1427
1428    Ok(())
1429}
1430
1431async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1432    let project_id = ProjectId::from_proto(message.project_id);
1433
1434    let (room, guest_connection_ids) = &*session
1435        .db()
1436        .await
1437        .unshare_project(project_id, session.connection_id)
1438        .await?;
1439
1440    broadcast(
1441        Some(session.connection_id),
1442        guest_connection_ids.iter().copied(),
1443        |conn_id| session.peer.send(conn_id, message.clone()),
1444    );
1445    room_updated(&room, &session.peer);
1446
1447    Ok(())
1448}
1449
1450async fn join_project(
1451    request: proto::JoinProject,
1452    response: Response<proto::JoinProject>,
1453    session: Session,
1454) -> Result<()> {
1455    let project_id = ProjectId::from_proto(request.project_id);
1456    let guest_user_id = session.user_id;
1457
1458    tracing::info!(%project_id, "join project");
1459
1460    let (project, replica_id) = &mut *session
1461        .db()
1462        .await
1463        .join_project(project_id, session.connection_id)
1464        .await?;
1465
1466    let collaborators = project
1467        .collaborators
1468        .iter()
1469        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1470        .map(|collaborator| collaborator.to_proto())
1471        .collect::<Vec<_>>();
1472
1473    let worktrees = project
1474        .worktrees
1475        .iter()
1476        .map(|(id, worktree)| proto::WorktreeMetadata {
1477            id: *id,
1478            root_name: worktree.root_name.clone(),
1479            visible: worktree.visible,
1480            abs_path: worktree.abs_path.clone(),
1481        })
1482        .collect::<Vec<_>>();
1483
1484    for collaborator in &collaborators {
1485        session
1486            .peer
1487            .send(
1488                collaborator.peer_id.unwrap().into(),
1489                proto::AddProjectCollaborator {
1490                    project_id: project_id.to_proto(),
1491                    collaborator: Some(proto::Collaborator {
1492                        peer_id: Some(session.connection_id.into()),
1493                        replica_id: replica_id.0 as u32,
1494                        user_id: guest_user_id.to_proto(),
1495                    }),
1496                },
1497            )
1498            .trace_err();
1499    }
1500
1501    // First, we send the metadata associated with each worktree.
1502    response.send(proto::JoinProjectResponse {
1503        worktrees: worktrees.clone(),
1504        replica_id: replica_id.0 as u32,
1505        collaborators: collaborators.clone(),
1506        language_servers: project.language_servers.clone(),
1507    })?;
1508
1509    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1510        #[cfg(any(test, feature = "test-support"))]
1511        const MAX_CHUNK_SIZE: usize = 2;
1512        #[cfg(not(any(test, feature = "test-support")))]
1513        const MAX_CHUNK_SIZE: usize = 256;
1514
1515        // Stream this worktree's entries.
1516        let message = proto::UpdateWorktree {
1517            project_id: project_id.to_proto(),
1518            worktree_id,
1519            abs_path: worktree.abs_path.clone(),
1520            root_name: worktree.root_name,
1521            updated_entries: worktree.entries,
1522            removed_entries: Default::default(),
1523            scan_id: worktree.scan_id,
1524            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1525            updated_repositories: worktree.repository_entries.into_values().collect(),
1526            removed_repositories: Default::default(),
1527        };
1528        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1529            session.peer.send(session.connection_id, update.clone())?;
1530        }
1531
1532        // Stream this worktree's diagnostics.
1533        for summary in worktree.diagnostic_summaries {
1534            session.peer.send(
1535                session.connection_id,
1536                proto::UpdateDiagnosticSummary {
1537                    project_id: project_id.to_proto(),
1538                    worktree_id: worktree.id,
1539                    summary: Some(summary),
1540                },
1541            )?;
1542        }
1543
1544        for settings_file in worktree.settings_files {
1545            session.peer.send(
1546                session.connection_id,
1547                proto::UpdateWorktreeSettings {
1548                    project_id: project_id.to_proto(),
1549                    worktree_id: worktree.id,
1550                    path: settings_file.path,
1551                    content: Some(settings_file.content),
1552                },
1553            )?;
1554        }
1555    }
1556
1557    for language_server in &project.language_servers {
1558        session.peer.send(
1559            session.connection_id,
1560            proto::UpdateLanguageServer {
1561                project_id: project_id.to_proto(),
1562                language_server_id: language_server.id,
1563                variant: Some(
1564                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1565                        proto::LspDiskBasedDiagnosticsUpdated {},
1566                    ),
1567                ),
1568            },
1569        )?;
1570    }
1571
1572    Ok(())
1573}
1574
1575async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1576    let sender_id = session.connection_id;
1577    let project_id = ProjectId::from_proto(request.project_id);
1578
1579    let (room, project) = &*session
1580        .db()
1581        .await
1582        .leave_project(project_id, sender_id)
1583        .await?;
1584    tracing::info!(
1585        %project_id,
1586        host_user_id = %project.host_user_id,
1587        host_connection_id = %project.host_connection_id,
1588        "leave project"
1589    );
1590
1591    project_left(&project, &session);
1592    room_updated(&room, &session.peer);
1593
1594    Ok(())
1595}
1596
1597async fn update_project(
1598    request: proto::UpdateProject,
1599    response: Response<proto::UpdateProject>,
1600    session: Session,
1601) -> Result<()> {
1602    let project_id = ProjectId::from_proto(request.project_id);
1603    let (room, guest_connection_ids) = &*session
1604        .db()
1605        .await
1606        .update_project(project_id, session.connection_id, &request.worktrees)
1607        .await?;
1608    broadcast(
1609        Some(session.connection_id),
1610        guest_connection_ids.iter().copied(),
1611        |connection_id| {
1612            session
1613                .peer
1614                .forward_send(session.connection_id, connection_id, request.clone())
1615        },
1616    );
1617    room_updated(&room, &session.peer);
1618    response.send(proto::Ack {})?;
1619
1620    Ok(())
1621}
1622
1623async fn update_worktree(
1624    request: proto::UpdateWorktree,
1625    response: Response<proto::UpdateWorktree>,
1626    session: Session,
1627) -> Result<()> {
1628    let guest_connection_ids = session
1629        .db()
1630        .await
1631        .update_worktree(&request, session.connection_id)
1632        .await?;
1633
1634    broadcast(
1635        Some(session.connection_id),
1636        guest_connection_ids.iter().copied(),
1637        |connection_id| {
1638            session
1639                .peer
1640                .forward_send(session.connection_id, connection_id, request.clone())
1641        },
1642    );
1643    response.send(proto::Ack {})?;
1644    Ok(())
1645}
1646
1647async fn update_diagnostic_summary(
1648    message: proto::UpdateDiagnosticSummary,
1649    session: Session,
1650) -> Result<()> {
1651    let guest_connection_ids = session
1652        .db()
1653        .await
1654        .update_diagnostic_summary(&message, session.connection_id)
1655        .await?;
1656
1657    broadcast(
1658        Some(session.connection_id),
1659        guest_connection_ids.iter().copied(),
1660        |connection_id| {
1661            session
1662                .peer
1663                .forward_send(session.connection_id, connection_id, message.clone())
1664        },
1665    );
1666
1667    Ok(())
1668}
1669
1670async fn update_worktree_settings(
1671    message: proto::UpdateWorktreeSettings,
1672    session: Session,
1673) -> Result<()> {
1674    let guest_connection_ids = session
1675        .db()
1676        .await
1677        .update_worktree_settings(&message, session.connection_id)
1678        .await?;
1679
1680    broadcast(
1681        Some(session.connection_id),
1682        guest_connection_ids.iter().copied(),
1683        |connection_id| {
1684            session
1685                .peer
1686                .forward_send(session.connection_id, connection_id, message.clone())
1687        },
1688    );
1689
1690    Ok(())
1691}
1692
1693async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1694    broadcast_project_message(request.project_id, request, session).await
1695}
1696
1697async fn start_language_server(
1698    request: proto::StartLanguageServer,
1699    session: Session,
1700) -> Result<()> {
1701    let guest_connection_ids = session
1702        .db()
1703        .await
1704        .start_language_server(&request, session.connection_id)
1705        .await?;
1706
1707    broadcast(
1708        Some(session.connection_id),
1709        guest_connection_ids.iter().copied(),
1710        |connection_id| {
1711            session
1712                .peer
1713                .forward_send(session.connection_id, connection_id, request.clone())
1714        },
1715    );
1716    Ok(())
1717}
1718
1719async fn update_language_server(
1720    request: proto::UpdateLanguageServer,
1721    session: Session,
1722) -> Result<()> {
1723    session.executor.record_backtrace();
1724    let project_id = ProjectId::from_proto(request.project_id);
1725    let project_connection_ids = session
1726        .db()
1727        .await
1728        .project_connection_ids(project_id, session.connection_id)
1729        .await?;
1730    broadcast(
1731        Some(session.connection_id),
1732        project_connection_ids.iter().copied(),
1733        |connection_id| {
1734            session
1735                .peer
1736                .forward_send(session.connection_id, connection_id, request.clone())
1737        },
1738    );
1739    Ok(())
1740}
1741
1742async fn forward_project_request<T>(
1743    request: T,
1744    response: Response<T>,
1745    session: Session,
1746) -> Result<()>
1747where
1748    T: EntityMessage + RequestMessage,
1749{
1750    session.executor.record_backtrace();
1751    let project_id = ProjectId::from_proto(request.remote_entity_id());
1752    let host_connection_id = {
1753        let collaborators = session
1754            .db()
1755            .await
1756            .project_collaborators(project_id, session.connection_id)
1757            .await?;
1758        collaborators
1759            .iter()
1760            .find(|collaborator| collaborator.is_host)
1761            .ok_or_else(|| anyhow!("host not found"))?
1762            .connection_id
1763    };
1764
1765    let payload = session
1766        .peer
1767        .forward_request(session.connection_id, host_connection_id, request)
1768        .await?;
1769
1770    response.send(payload)?;
1771    Ok(())
1772}
1773
1774async fn create_buffer_for_peer(
1775    request: proto::CreateBufferForPeer,
1776    session: Session,
1777) -> Result<()> {
1778    session.executor.record_backtrace();
1779    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1780    session
1781        .peer
1782        .forward_send(session.connection_id, peer_id.into(), request)?;
1783    Ok(())
1784}
1785
1786async fn update_buffer(
1787    request: proto::UpdateBuffer,
1788    response: Response<proto::UpdateBuffer>,
1789    session: Session,
1790) -> Result<()> {
1791    session.executor.record_backtrace();
1792    let project_id = ProjectId::from_proto(request.project_id);
1793    let mut guest_connection_ids;
1794    let mut host_connection_id = None;
1795    {
1796        let collaborators = session
1797            .db()
1798            .await
1799            .project_collaborators(project_id, session.connection_id)
1800            .await?;
1801        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1802        for collaborator in collaborators.iter() {
1803            if collaborator.is_host {
1804                host_connection_id = Some(collaborator.connection_id);
1805            } else {
1806                guest_connection_ids.push(collaborator.connection_id);
1807            }
1808        }
1809    }
1810    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1811
1812    session.executor.record_backtrace();
1813    broadcast(
1814        Some(session.connection_id),
1815        guest_connection_ids,
1816        |connection_id| {
1817            session
1818                .peer
1819                .forward_send(session.connection_id, connection_id, request.clone())
1820        },
1821    );
1822    if host_connection_id != session.connection_id {
1823        session
1824            .peer
1825            .forward_request(session.connection_id, host_connection_id, request.clone())
1826            .await?;
1827    }
1828
1829    response.send(proto::Ack {})?;
1830    Ok(())
1831}
1832
1833async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1834    let project_id = ProjectId::from_proto(request.project_id);
1835    let project_connection_ids = session
1836        .db()
1837        .await
1838        .project_connection_ids(project_id, session.connection_id)
1839        .await?;
1840
1841    broadcast(
1842        Some(session.connection_id),
1843        project_connection_ids.iter().copied(),
1844        |connection_id| {
1845            session
1846                .peer
1847                .forward_send(session.connection_id, connection_id, request.clone())
1848        },
1849    );
1850    Ok(())
1851}
1852
1853async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1854    let project_id = ProjectId::from_proto(request.project_id);
1855    let project_connection_ids = session
1856        .db()
1857        .await
1858        .project_connection_ids(project_id, session.connection_id)
1859        .await?;
1860    broadcast(
1861        Some(session.connection_id),
1862        project_connection_ids.iter().copied(),
1863        |connection_id| {
1864            session
1865                .peer
1866                .forward_send(session.connection_id, connection_id, request.clone())
1867        },
1868    );
1869    Ok(())
1870}
1871
1872async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1873    broadcast_project_message(request.project_id, request, session).await
1874}
1875
1876async fn broadcast_project_message<T: EnvelopedMessage>(
1877    project_id: u64,
1878    request: T,
1879    session: Session,
1880) -> Result<()> {
1881    let project_id = ProjectId::from_proto(project_id);
1882    let project_connection_ids = session
1883        .db()
1884        .await
1885        .project_connection_ids(project_id, session.connection_id)
1886        .await?;
1887    broadcast(
1888        Some(session.connection_id),
1889        project_connection_ids.iter().copied(),
1890        |connection_id| {
1891            session
1892                .peer
1893                .forward_send(session.connection_id, connection_id, request.clone())
1894        },
1895    );
1896    Ok(())
1897}
1898
1899async fn follow(
1900    request: proto::Follow,
1901    response: Response<proto::Follow>,
1902    session: Session,
1903) -> Result<()> {
1904    let room_id = RoomId::from_proto(request.room_id);
1905    let project_id = request.project_id.map(ProjectId::from_proto);
1906    let leader_id = request
1907        .leader_id
1908        .ok_or_else(|| anyhow!("invalid leader id"))?
1909        .into();
1910    let follower_id = session.connection_id;
1911
1912    session
1913        .db()
1914        .await
1915        .check_room_participants(room_id, leader_id, session.connection_id)
1916        .await?;
1917
1918    let response_payload = session
1919        .peer
1920        .forward_request(session.connection_id, leader_id, request)
1921        .await?;
1922    response.send(response_payload)?;
1923
1924    if let Some(project_id) = project_id {
1925        let room = session
1926            .db()
1927            .await
1928            .follow(room_id, project_id, leader_id, follower_id)
1929            .await?;
1930        room_updated(&room, &session.peer);
1931    }
1932
1933    Ok(())
1934}
1935
1936async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1937    let room_id = RoomId::from_proto(request.room_id);
1938    let project_id = request.project_id.map(ProjectId::from_proto);
1939    let leader_id = request
1940        .leader_id
1941        .ok_or_else(|| anyhow!("invalid leader id"))?
1942        .into();
1943    let follower_id = session.connection_id;
1944
1945    session
1946        .db()
1947        .await
1948        .check_room_participants(room_id, leader_id, session.connection_id)
1949        .await?;
1950
1951    session
1952        .peer
1953        .forward_send(session.connection_id, leader_id, request)?;
1954
1955    if let Some(project_id) = project_id {
1956        let room = session
1957            .db()
1958            .await
1959            .unfollow(room_id, project_id, leader_id, follower_id)
1960            .await?;
1961        room_updated(&room, &session.peer);
1962    }
1963
1964    Ok(())
1965}
1966
1967async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1968    let room_id = RoomId::from_proto(request.room_id);
1969    let database = session.db.lock().await;
1970
1971    let connection_ids = if let Some(project_id) = request.project_id {
1972        let project_id = ProjectId::from_proto(project_id);
1973        database
1974            .project_connection_ids(project_id, session.connection_id)
1975            .await?
1976    } else {
1977        database
1978            .room_connection_ids(room_id, session.connection_id)
1979            .await?
1980    };
1981
1982    // For now, don't send view update messages back to that view's current leader.
1983    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1984        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1985        _ => None,
1986    });
1987
1988    for follower_peer_id in request.follower_ids.iter().copied() {
1989        let follower_connection_id = follower_peer_id.into();
1990        if Some(follower_peer_id) != connection_id_to_omit
1991            && connection_ids.contains(&follower_connection_id)
1992        {
1993            session.peer.forward_send(
1994                session.connection_id,
1995                follower_connection_id,
1996                request.clone(),
1997            )?;
1998        }
1999    }
2000    Ok(())
2001}
2002
2003async fn get_users(
2004    request: proto::GetUsers,
2005    response: Response<proto::GetUsers>,
2006    session: Session,
2007) -> Result<()> {
2008    let user_ids = request
2009        .user_ids
2010        .into_iter()
2011        .map(UserId::from_proto)
2012        .collect();
2013    let users = session
2014        .db()
2015        .await
2016        .get_users_by_ids(user_ids)
2017        .await?
2018        .into_iter()
2019        .map(|user| proto::User {
2020            id: user.id.to_proto(),
2021            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2022            github_login: user.github_login,
2023        })
2024        .collect();
2025    response.send(proto::UsersResponse { users })?;
2026    Ok(())
2027}
2028
2029async fn fuzzy_search_users(
2030    request: proto::FuzzySearchUsers,
2031    response: Response<proto::FuzzySearchUsers>,
2032    session: Session,
2033) -> Result<()> {
2034    let query = request.query;
2035    let users = match query.len() {
2036        0 => vec![],
2037        1 | 2 => session
2038            .db()
2039            .await
2040            .get_user_by_github_login(&query)
2041            .await?
2042            .into_iter()
2043            .collect(),
2044        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2045    };
2046    let users = users
2047        .into_iter()
2048        .filter(|user| user.id != session.user_id)
2049        .map(|user| proto::User {
2050            id: user.id.to_proto(),
2051            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2052            github_login: user.github_login,
2053        })
2054        .collect();
2055    response.send(proto::UsersResponse { users })?;
2056    Ok(())
2057}
2058
2059async fn request_contact(
2060    request: proto::RequestContact,
2061    response: Response<proto::RequestContact>,
2062    session: Session,
2063) -> Result<()> {
2064    let requester_id = session.user_id;
2065    let responder_id = UserId::from_proto(request.responder_id);
2066    if requester_id == responder_id {
2067        return Err(anyhow!("cannot add yourself as a contact"))?;
2068    }
2069
2070    let notification = session
2071        .db()
2072        .await
2073        .send_contact_request(requester_id, responder_id)
2074        .await?;
2075
2076    // Update outgoing contact requests of requester
2077    let mut update = proto::UpdateContacts::default();
2078    update.outgoing_requests.push(responder_id.to_proto());
2079    for connection_id in session
2080        .connection_pool()
2081        .await
2082        .user_connection_ids(requester_id)
2083    {
2084        session.peer.send(connection_id, update.clone())?;
2085    }
2086
2087    // Update incoming contact requests of responder
2088    let mut update = proto::UpdateContacts::default();
2089    update
2090        .incoming_requests
2091        .push(proto::IncomingContactRequest {
2092            requester_id: requester_id.to_proto(),
2093        });
2094    for connection_id in session
2095        .connection_pool()
2096        .await
2097        .user_connection_ids(responder_id)
2098    {
2099        session.peer.send(connection_id, update.clone())?;
2100        if let Some(notification) = &notification {
2101            session.peer.send(
2102                connection_id,
2103                proto::NewNotification {
2104                    notification: Some(notification.clone()),
2105                },
2106            )?;
2107        }
2108    }
2109
2110    response.send(proto::Ack {})?;
2111    Ok(())
2112}
2113
2114async fn respond_to_contact_request(
2115    request: proto::RespondToContactRequest,
2116    response: Response<proto::RespondToContactRequest>,
2117    session: Session,
2118) -> Result<()> {
2119    let responder_id = session.user_id;
2120    let requester_id = UserId::from_proto(request.requester_id);
2121    let db = session.db().await;
2122    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2123        db.dismiss_contact_notification(responder_id, requester_id)
2124            .await?;
2125    } else {
2126        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2127
2128        let notification = db
2129            .respond_to_contact_request(responder_id, requester_id, accept)
2130            .await?;
2131        let requester_busy = db.is_user_busy(requester_id).await?;
2132        let responder_busy = db.is_user_busy(responder_id).await?;
2133
2134        let pool = session.connection_pool().await;
2135        // Update responder with new contact
2136        let mut update = proto::UpdateContacts::default();
2137        if accept {
2138            update
2139                .contacts
2140                .push(contact_for_user(requester_id, requester_busy, &pool));
2141        }
2142        update
2143            .remove_incoming_requests
2144            .push(requester_id.to_proto());
2145        for connection_id in pool.user_connection_ids(responder_id) {
2146            session.peer.send(connection_id, update.clone())?;
2147        }
2148
2149        // Update requester with new contact
2150        let mut update = proto::UpdateContacts::default();
2151        if accept {
2152            update
2153                .contacts
2154                .push(contact_for_user(responder_id, responder_busy, &pool));
2155        }
2156        update
2157            .remove_outgoing_requests
2158            .push(responder_id.to_proto());
2159        for connection_id in pool.user_connection_ids(requester_id) {
2160            session.peer.send(connection_id, update.clone())?;
2161            if let Some(notification) = &notification {
2162                session.peer.send(
2163                    connection_id,
2164                    proto::NewNotification {
2165                        notification: Some(notification.clone()),
2166                    },
2167                )?;
2168            }
2169        }
2170    }
2171
2172    response.send(proto::Ack {})?;
2173    Ok(())
2174}
2175
2176async fn remove_contact(
2177    request: proto::RemoveContact,
2178    response: Response<proto::RemoveContact>,
2179    session: Session,
2180) -> Result<()> {
2181    let requester_id = session.user_id;
2182    let responder_id = UserId::from_proto(request.user_id);
2183    let db = session.db().await;
2184    let (contact_accepted, deleted_notification_id) =
2185        db.remove_contact(requester_id, responder_id).await?;
2186
2187    let pool = session.connection_pool().await;
2188    // Update outgoing contact requests of requester
2189    let mut update = proto::UpdateContacts::default();
2190    if contact_accepted {
2191        update.remove_contacts.push(responder_id.to_proto());
2192    } else {
2193        update
2194            .remove_outgoing_requests
2195            .push(responder_id.to_proto());
2196    }
2197    for connection_id in pool.user_connection_ids(requester_id) {
2198        session.peer.send(connection_id, update.clone())?;
2199    }
2200
2201    // Update incoming contact requests of responder
2202    let mut update = proto::UpdateContacts::default();
2203    if contact_accepted {
2204        update.remove_contacts.push(requester_id.to_proto());
2205    } else {
2206        update
2207            .remove_incoming_requests
2208            .push(requester_id.to_proto());
2209    }
2210    for connection_id in pool.user_connection_ids(responder_id) {
2211        session.peer.send(connection_id, update.clone())?;
2212        if let Some(notification_id) = deleted_notification_id {
2213            session.peer.send(
2214                connection_id,
2215                proto::DeleteNotification {
2216                    notification_id: notification_id.to_proto(),
2217                },
2218            )?;
2219        }
2220    }
2221
2222    response.send(proto::Ack {})?;
2223    Ok(())
2224}
2225
2226async fn create_channel(
2227    request: proto::CreateChannel,
2228    response: Response<proto::CreateChannel>,
2229    session: Session,
2230) -> Result<()> {
2231    let db = session.db().await;
2232
2233    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2234    let id = db
2235        .create_channel(&request.name, parent_id, session.user_id)
2236        .await?;
2237
2238    let channel = proto::Channel {
2239        id: id.to_proto(),
2240        name: request.name,
2241    };
2242
2243    response.send(proto::CreateChannelResponse {
2244        channel: Some(channel.clone()),
2245        parent_id: request.parent_id,
2246    })?;
2247
2248    let Some(parent_id) = parent_id else {
2249        return Ok(());
2250    };
2251
2252    let update = proto::UpdateChannels {
2253        channels: vec![channel],
2254        insert_edge: vec![ChannelEdge {
2255            parent_id: parent_id.to_proto(),
2256            channel_id: id.to_proto(),
2257        }],
2258        ..Default::default()
2259    };
2260
2261    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2262
2263    let connection_pool = session.connection_pool().await;
2264    for user_id in user_ids_to_notify {
2265        for connection_id in connection_pool.user_connection_ids(user_id) {
2266            if user_id == session.user_id {
2267                continue;
2268            }
2269            session.peer.send(connection_id, update.clone())?;
2270        }
2271    }
2272
2273    Ok(())
2274}
2275
2276async fn delete_channel(
2277    request: proto::DeleteChannel,
2278    response: Response<proto::DeleteChannel>,
2279    session: Session,
2280) -> Result<()> {
2281    let db = session.db().await;
2282
2283    let channel_id = request.channel_id;
2284    let (removed_channels, member_ids) = db
2285        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2286        .await?;
2287    response.send(proto::Ack {})?;
2288
2289    // Notify members of removed channels
2290    let mut update = proto::UpdateChannels::default();
2291    update
2292        .delete_channels
2293        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2294
2295    let connection_pool = session.connection_pool().await;
2296    for member_id in member_ids {
2297        for connection_id in connection_pool.user_connection_ids(member_id) {
2298            session.peer.send(connection_id, update.clone())?;
2299        }
2300    }
2301
2302    Ok(())
2303}
2304
2305async fn invite_channel_member(
2306    request: proto::InviteChannelMember,
2307    response: Response<proto::InviteChannelMember>,
2308    session: Session,
2309) -> Result<()> {
2310    let db = session.db().await;
2311    let channel_id = ChannelId::from_proto(request.channel_id);
2312    let invitee_id = UserId::from_proto(request.user_id);
2313    let notification = db
2314        .invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2315        .await?;
2316
2317    let (channel, _) = db
2318        .get_channel(channel_id, session.user_id)
2319        .await?
2320        .ok_or_else(|| anyhow!("channel not found"))?;
2321
2322    let mut update = proto::UpdateChannels::default();
2323    update.channel_invitations.push(proto::Channel {
2324        id: channel.id.to_proto(),
2325        name: channel.name,
2326    });
2327
2328    for connection_id in session
2329        .connection_pool()
2330        .await
2331        .user_connection_ids(invitee_id)
2332    {
2333        session.peer.send(connection_id, update.clone())?;
2334        if let Some(notification) = &notification {
2335            session.peer.send(
2336                connection_id,
2337                proto::NewNotification {
2338                    notification: Some(notification.clone()),
2339                },
2340            )?;
2341        }
2342    }
2343
2344    response.send(proto::Ack {})?;
2345    Ok(())
2346}
2347
2348async fn remove_channel_member(
2349    request: proto::RemoveChannelMember,
2350    response: Response<proto::RemoveChannelMember>,
2351    session: Session,
2352) -> Result<()> {
2353    let db = session.db().await;
2354    let channel_id = ChannelId::from_proto(request.channel_id);
2355    let member_id = UserId::from_proto(request.user_id);
2356
2357    db.remove_channel_member(channel_id, member_id, session.user_id)
2358        .await?;
2359
2360    let mut update = proto::UpdateChannels::default();
2361    update.delete_channels.push(channel_id.to_proto());
2362
2363    for connection_id in session
2364        .connection_pool()
2365        .await
2366        .user_connection_ids(member_id)
2367    {
2368        session.peer.send(connection_id, update.clone())?;
2369    }
2370
2371    response.send(proto::Ack {})?;
2372    Ok(())
2373}
2374
2375async fn set_channel_member_admin(
2376    request: proto::SetChannelMemberAdmin,
2377    response: Response<proto::SetChannelMemberAdmin>,
2378    session: Session,
2379) -> Result<()> {
2380    let db = session.db().await;
2381    let channel_id = ChannelId::from_proto(request.channel_id);
2382    let member_id = UserId::from_proto(request.user_id);
2383    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2384        .await?;
2385
2386    let (channel, has_accepted) = db
2387        .get_channel(channel_id, member_id)
2388        .await?
2389        .ok_or_else(|| anyhow!("channel not found"))?;
2390
2391    let mut update = proto::UpdateChannels::default();
2392    if has_accepted {
2393        update.channel_permissions.push(proto::ChannelPermission {
2394            channel_id: channel.id.to_proto(),
2395            is_admin: request.admin,
2396        });
2397    }
2398
2399    for connection_id in session
2400        .connection_pool()
2401        .await
2402        .user_connection_ids(member_id)
2403    {
2404        session.peer.send(connection_id, update.clone())?;
2405    }
2406
2407    response.send(proto::Ack {})?;
2408    Ok(())
2409}
2410
2411async fn rename_channel(
2412    request: proto::RenameChannel,
2413    response: Response<proto::RenameChannel>,
2414    session: Session,
2415) -> Result<()> {
2416    let db = session.db().await;
2417    let channel_id = ChannelId::from_proto(request.channel_id);
2418    let new_name = db
2419        .rename_channel(channel_id, session.user_id, &request.name)
2420        .await?;
2421
2422    let channel = proto::Channel {
2423        id: request.channel_id,
2424        name: new_name,
2425    };
2426    response.send(proto::RenameChannelResponse {
2427        channel: Some(channel.clone()),
2428    })?;
2429    let mut update = proto::UpdateChannels::default();
2430    update.channels.push(channel);
2431
2432    let member_ids = db.get_channel_members(channel_id).await?;
2433
2434    let connection_pool = session.connection_pool().await;
2435    for member_id in member_ids {
2436        for connection_id in connection_pool.user_connection_ids(member_id) {
2437            session.peer.send(connection_id, update.clone())?;
2438        }
2439    }
2440
2441    Ok(())
2442}
2443
2444async fn link_channel(
2445    request: proto::LinkChannel,
2446    response: Response<proto::LinkChannel>,
2447    session: Session,
2448) -> Result<()> {
2449    let db = session.db().await;
2450    let channel_id = ChannelId::from_proto(request.channel_id);
2451    let to = ChannelId::from_proto(request.to);
2452    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2453
2454    let members = db.get_channel_members(to).await?;
2455    let connection_pool = session.connection_pool().await;
2456    let update = proto::UpdateChannels {
2457        channels: channels_to_send
2458            .channels
2459            .into_iter()
2460            .map(|channel| proto::Channel {
2461                id: channel.id.to_proto(),
2462                name: channel.name,
2463            })
2464            .collect(),
2465        insert_edge: channels_to_send.edges,
2466        ..Default::default()
2467    };
2468    for member_id in members {
2469        for connection_id in connection_pool.user_connection_ids(member_id) {
2470            session.peer.send(connection_id, update.clone())?;
2471        }
2472    }
2473
2474    response.send(Ack {})?;
2475
2476    Ok(())
2477}
2478
2479async fn unlink_channel(
2480    request: proto::UnlinkChannel,
2481    response: Response<proto::UnlinkChannel>,
2482    session: Session,
2483) -> Result<()> {
2484    let db = session.db().await;
2485    let channel_id = ChannelId::from_proto(request.channel_id);
2486    let from = ChannelId::from_proto(request.from);
2487
2488    db.unlink_channel(session.user_id, channel_id, from).await?;
2489
2490    let members = db.get_channel_members(from).await?;
2491
2492    let update = proto::UpdateChannels {
2493        delete_edge: vec![proto::ChannelEdge {
2494            channel_id: channel_id.to_proto(),
2495            parent_id: from.to_proto(),
2496        }],
2497        ..Default::default()
2498    };
2499    let connection_pool = session.connection_pool().await;
2500    for member_id in members {
2501        for connection_id in connection_pool.user_connection_ids(member_id) {
2502            session.peer.send(connection_id, update.clone())?;
2503        }
2504    }
2505
2506    response.send(Ack {})?;
2507
2508    Ok(())
2509}
2510
2511async fn move_channel(
2512    request: proto::MoveChannel,
2513    response: Response<proto::MoveChannel>,
2514    session: Session,
2515) -> Result<()> {
2516    let db = session.db().await;
2517    let channel_id = ChannelId::from_proto(request.channel_id);
2518    let from_parent = ChannelId::from_proto(request.from);
2519    let to = ChannelId::from_proto(request.to);
2520
2521    let channels_to_send = db
2522        .move_channel(session.user_id, channel_id, from_parent, to)
2523        .await?;
2524
2525    if channels_to_send.is_empty() {
2526        response.send(Ack {})?;
2527        return Ok(());
2528    }
2529
2530    let members_from = db.get_channel_members(from_parent).await?;
2531    let members_to = db.get_channel_members(to).await?;
2532
2533    let update = proto::UpdateChannels {
2534        delete_edge: vec![proto::ChannelEdge {
2535            channel_id: channel_id.to_proto(),
2536            parent_id: from_parent.to_proto(),
2537        }],
2538        ..Default::default()
2539    };
2540    let connection_pool = session.connection_pool().await;
2541    for member_id in members_from {
2542        for connection_id in connection_pool.user_connection_ids(member_id) {
2543            session.peer.send(connection_id, update.clone())?;
2544        }
2545    }
2546
2547    let update = proto::UpdateChannels {
2548        channels: channels_to_send
2549            .channels
2550            .into_iter()
2551            .map(|channel| proto::Channel {
2552                id: channel.id.to_proto(),
2553                name: channel.name,
2554            })
2555            .collect(),
2556        insert_edge: channels_to_send.edges,
2557        ..Default::default()
2558    };
2559    for member_id in members_to {
2560        for connection_id in connection_pool.user_connection_ids(member_id) {
2561            session.peer.send(connection_id, update.clone())?;
2562        }
2563    }
2564
2565    response.send(Ack {})?;
2566
2567    Ok(())
2568}
2569
2570async fn get_channel_members(
2571    request: proto::GetChannelMembers,
2572    response: Response<proto::GetChannelMembers>,
2573    session: Session,
2574) -> Result<()> {
2575    let db = session.db().await;
2576    let channel_id = ChannelId::from_proto(request.channel_id);
2577    let members = db
2578        .get_channel_member_details(channel_id, session.user_id)
2579        .await?;
2580    response.send(proto::GetChannelMembersResponse { members })?;
2581    Ok(())
2582}
2583
2584async fn respond_to_channel_invite(
2585    request: proto::RespondToChannelInvite,
2586    response: Response<proto::RespondToChannelInvite>,
2587    session: Session,
2588) -> Result<()> {
2589    let db = session.db().await;
2590    let channel_id = ChannelId::from_proto(request.channel_id);
2591    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2592        .await?;
2593
2594    let mut update = proto::UpdateChannels::default();
2595    update
2596        .remove_channel_invitations
2597        .push(channel_id.to_proto());
2598    if request.accept {
2599        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2600        update
2601            .channels
2602            .extend(
2603                result
2604                    .channels
2605                    .channels
2606                    .into_iter()
2607                    .map(|channel| proto::Channel {
2608                        id: channel.id.to_proto(),
2609                        name: channel.name,
2610                    }),
2611            );
2612        update.unseen_channel_messages = result.channel_messages;
2613        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2614        update.insert_edge = result.channels.edges;
2615        update
2616            .channel_participants
2617            .extend(
2618                result
2619                    .channel_participants
2620                    .into_iter()
2621                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2622                        channel_id: channel_id.to_proto(),
2623                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2624                    }),
2625            );
2626        update
2627            .channel_permissions
2628            .extend(
2629                result
2630                    .channels_with_admin_privileges
2631                    .into_iter()
2632                    .map(|channel_id| proto::ChannelPermission {
2633                        channel_id: channel_id.to_proto(),
2634                        is_admin: true,
2635                    }),
2636            );
2637    }
2638    session.peer.send(session.connection_id, update)?;
2639    response.send(proto::Ack {})?;
2640
2641    Ok(())
2642}
2643
2644async fn join_channel(
2645    request: proto::JoinChannel,
2646    response: Response<proto::JoinChannel>,
2647    session: Session,
2648) -> Result<()> {
2649    let channel_id = ChannelId::from_proto(request.channel_id);
2650    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2651
2652    let joined_room = {
2653        leave_room_for_session(&session).await?;
2654        let db = session.db().await;
2655
2656        let room_id = db
2657            .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME)
2658            .await?;
2659
2660        let joined_room = db
2661            .join_room(
2662                room_id,
2663                session.user_id,
2664                session.connection_id,
2665                RELEASE_CHANNEL_NAME.as_str(),
2666            )
2667            .await?;
2668
2669        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2670            let token = live_kit
2671                .room_token(
2672                    &joined_room.room.live_kit_room,
2673                    &session.user_id.to_string(),
2674                )
2675                .trace_err()?;
2676
2677            Some(LiveKitConnectionInfo {
2678                server_url: live_kit.url().into(),
2679                token,
2680            })
2681        });
2682
2683        response.send(proto::JoinRoomResponse {
2684            room: Some(joined_room.room.clone()),
2685            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2686            live_kit_connection_info,
2687        })?;
2688
2689        room_updated(&joined_room.room, &session.peer);
2690
2691        joined_room.into_inner()
2692    };
2693
2694    channel_updated(
2695        channel_id,
2696        &joined_room.room,
2697        &joined_room.channel_members,
2698        &session.peer,
2699        &*session.connection_pool().await,
2700    );
2701
2702    update_user_contacts(session.user_id, &session).await?;
2703
2704    Ok(())
2705}
2706
2707async fn join_channel_buffer(
2708    request: proto::JoinChannelBuffer,
2709    response: Response<proto::JoinChannelBuffer>,
2710    session: Session,
2711) -> Result<()> {
2712    let db = session.db().await;
2713    let channel_id = ChannelId::from_proto(request.channel_id);
2714
2715    let open_response = db
2716        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2717        .await?;
2718
2719    let collaborators = open_response.collaborators.clone();
2720    response.send(open_response)?;
2721
2722    let update = UpdateChannelBufferCollaborators {
2723        channel_id: channel_id.to_proto(),
2724        collaborators: collaborators.clone(),
2725    };
2726    channel_buffer_updated(
2727        session.connection_id,
2728        collaborators
2729            .iter()
2730            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2731        &update,
2732        &session.peer,
2733    );
2734
2735    Ok(())
2736}
2737
2738async fn update_channel_buffer(
2739    request: proto::UpdateChannelBuffer,
2740    session: Session,
2741) -> Result<()> {
2742    let db = session.db().await;
2743    let channel_id = ChannelId::from_proto(request.channel_id);
2744
2745    let (collaborators, non_collaborators, epoch, version) = db
2746        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2747        .await?;
2748
2749    channel_buffer_updated(
2750        session.connection_id,
2751        collaborators,
2752        &proto::UpdateChannelBuffer {
2753            channel_id: channel_id.to_proto(),
2754            operations: request.operations,
2755        },
2756        &session.peer,
2757    );
2758
2759    let pool = &*session.connection_pool().await;
2760
2761    broadcast(
2762        None,
2763        non_collaborators
2764            .iter()
2765            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2766        |peer_id| {
2767            session.peer.send(
2768                peer_id.into(),
2769                proto::UpdateChannels {
2770                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2771                        channel_id: channel_id.to_proto(),
2772                        epoch: epoch as u64,
2773                        version: version.clone(),
2774                    }],
2775                    ..Default::default()
2776                },
2777            )
2778        },
2779    );
2780
2781    Ok(())
2782}
2783
2784async fn rejoin_channel_buffers(
2785    request: proto::RejoinChannelBuffers,
2786    response: Response<proto::RejoinChannelBuffers>,
2787    session: Session,
2788) -> Result<()> {
2789    let db = session.db().await;
2790    let buffers = db
2791        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2792        .await?;
2793
2794    for rejoined_buffer in &buffers {
2795        let collaborators_to_notify = rejoined_buffer
2796            .buffer
2797            .collaborators
2798            .iter()
2799            .filter_map(|c| Some(c.peer_id?.into()));
2800        channel_buffer_updated(
2801            session.connection_id,
2802            collaborators_to_notify,
2803            &proto::UpdateChannelBufferCollaborators {
2804                channel_id: rejoined_buffer.buffer.channel_id,
2805                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2806            },
2807            &session.peer,
2808        );
2809    }
2810
2811    response.send(proto::RejoinChannelBuffersResponse {
2812        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2813    })?;
2814
2815    Ok(())
2816}
2817
2818async fn leave_channel_buffer(
2819    request: proto::LeaveChannelBuffer,
2820    response: Response<proto::LeaveChannelBuffer>,
2821    session: Session,
2822) -> Result<()> {
2823    let db = session.db().await;
2824    let channel_id = ChannelId::from_proto(request.channel_id);
2825
2826    let left_buffer = db
2827        .leave_channel_buffer(channel_id, session.connection_id)
2828        .await?;
2829
2830    response.send(Ack {})?;
2831
2832    channel_buffer_updated(
2833        session.connection_id,
2834        left_buffer.connections,
2835        &proto::UpdateChannelBufferCollaborators {
2836            channel_id: channel_id.to_proto(),
2837            collaborators: left_buffer.collaborators,
2838        },
2839        &session.peer,
2840    );
2841
2842    Ok(())
2843}
2844
2845fn channel_buffer_updated<T: EnvelopedMessage>(
2846    sender_id: ConnectionId,
2847    collaborators: impl IntoIterator<Item = ConnectionId>,
2848    message: &T,
2849    peer: &Peer,
2850) {
2851    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2852        peer.send(peer_id.into(), message.clone())
2853    });
2854}
2855
2856async fn send_channel_message(
2857    request: proto::SendChannelMessage,
2858    response: Response<proto::SendChannelMessage>,
2859    session: Session,
2860) -> Result<()> {
2861    // Validate the message body.
2862    let body = request.body.trim().to_string();
2863    if body.len() > MAX_MESSAGE_LEN {
2864        return Err(anyhow!("message is too long"))?;
2865    }
2866    if body.is_empty() {
2867        return Err(anyhow!("message can't be blank"))?;
2868    }
2869
2870    let timestamp = OffsetDateTime::now_utc();
2871    let nonce = request
2872        .nonce
2873        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2874
2875    let channel_id = ChannelId::from_proto(request.channel_id);
2876    let (message_id, connection_ids, non_participants) = session
2877        .db()
2878        .await
2879        .create_channel_message(
2880            channel_id,
2881            session.user_id,
2882            &body,
2883            timestamp,
2884            nonce.clone().into(),
2885        )
2886        .await?;
2887    let message = proto::ChannelMessage {
2888        sender_id: session.user_id.to_proto(),
2889        id: message_id.to_proto(),
2890        body,
2891        timestamp: timestamp.unix_timestamp() as u64,
2892        nonce: Some(nonce),
2893    };
2894    broadcast(Some(session.connection_id), connection_ids, |connection| {
2895        session.peer.send(
2896            connection,
2897            proto::ChannelMessageSent {
2898                channel_id: channel_id.to_proto(),
2899                message: Some(message.clone()),
2900            },
2901        )
2902    });
2903    response.send(proto::SendChannelMessageResponse {
2904        message: Some(message),
2905    })?;
2906
2907    let pool = &*session.connection_pool().await;
2908    broadcast(
2909        None,
2910        non_participants
2911            .iter()
2912            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2913        |peer_id| {
2914            session.peer.send(
2915                peer_id.into(),
2916                proto::UpdateChannels {
2917                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2918                        channel_id: channel_id.to_proto(),
2919                        message_id: message_id.to_proto(),
2920                    }],
2921                    ..Default::default()
2922                },
2923            )
2924        },
2925    );
2926
2927    Ok(())
2928}
2929
2930async fn remove_channel_message(
2931    request: proto::RemoveChannelMessage,
2932    response: Response<proto::RemoveChannelMessage>,
2933    session: Session,
2934) -> Result<()> {
2935    let channel_id = ChannelId::from_proto(request.channel_id);
2936    let message_id = MessageId::from_proto(request.message_id);
2937    let connection_ids = session
2938        .db()
2939        .await
2940        .remove_channel_message(channel_id, message_id, session.user_id)
2941        .await?;
2942    broadcast(Some(session.connection_id), connection_ids, |connection| {
2943        session.peer.send(connection, request.clone())
2944    });
2945    response.send(proto::Ack {})?;
2946    Ok(())
2947}
2948
2949async fn acknowledge_channel_message(
2950    request: proto::AckChannelMessage,
2951    session: Session,
2952) -> Result<()> {
2953    let channel_id = ChannelId::from_proto(request.channel_id);
2954    let message_id = MessageId::from_proto(request.message_id);
2955    session
2956        .db()
2957        .await
2958        .observe_channel_message(channel_id, session.user_id, message_id)
2959        .await?;
2960    Ok(())
2961}
2962
2963async fn acknowledge_buffer_version(
2964    request: proto::AckBufferOperation,
2965    session: Session,
2966) -> Result<()> {
2967    let buffer_id = BufferId::from_proto(request.buffer_id);
2968    session
2969        .db()
2970        .await
2971        .observe_buffer_version(
2972            buffer_id,
2973            session.user_id,
2974            request.epoch as i32,
2975            &request.version,
2976        )
2977        .await?;
2978    Ok(())
2979}
2980
2981async fn join_channel_chat(
2982    request: proto::JoinChannelChat,
2983    response: Response<proto::JoinChannelChat>,
2984    session: Session,
2985) -> Result<()> {
2986    let channel_id = ChannelId::from_proto(request.channel_id);
2987
2988    let db = session.db().await;
2989    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2990        .await?;
2991    let messages = db
2992        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2993        .await?;
2994    response.send(proto::JoinChannelChatResponse {
2995        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2996        messages,
2997    })?;
2998    Ok(())
2999}
3000
3001async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3002    let channel_id = ChannelId::from_proto(request.channel_id);
3003    session
3004        .db()
3005        .await
3006        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3007        .await?;
3008    Ok(())
3009}
3010
3011async fn get_channel_messages(
3012    request: proto::GetChannelMessages,
3013    response: Response<proto::GetChannelMessages>,
3014    session: Session,
3015) -> Result<()> {
3016    let channel_id = ChannelId::from_proto(request.channel_id);
3017    let messages = session
3018        .db()
3019        .await
3020        .get_channel_messages(
3021            channel_id,
3022            session.user_id,
3023            MESSAGE_COUNT_PER_PAGE,
3024            Some(MessageId::from_proto(request.before_message_id)),
3025        )
3026        .await?;
3027    response.send(proto::GetChannelMessagesResponse {
3028        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3029        messages,
3030    })?;
3031    Ok(())
3032}
3033
3034async fn get_notifications(
3035    request: proto::GetNotifications,
3036    response: Response<proto::GetNotifications>,
3037    session: Session,
3038) -> Result<()> {
3039    let notifications = session
3040        .db()
3041        .await
3042        .get_notifications(
3043            session.user_id,
3044            NOTIFICATION_COUNT_PER_PAGE,
3045            request
3046                .before_id
3047                .map(|id| db::NotificationId::from_proto(id)),
3048        )
3049        .await?;
3050    response.send(proto::GetNotificationsResponse { notifications })?;
3051    Ok(())
3052}
3053
3054async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3055    let project_id = ProjectId::from_proto(request.project_id);
3056    let project_connection_ids = session
3057        .db()
3058        .await
3059        .project_connection_ids(project_id, session.connection_id)
3060        .await?;
3061    broadcast(
3062        Some(session.connection_id),
3063        project_connection_ids.iter().copied(),
3064        |connection_id| {
3065            session
3066                .peer
3067                .forward_send(session.connection_id, connection_id, request.clone())
3068        },
3069    );
3070    Ok(())
3071}
3072
3073async fn get_private_user_info(
3074    _request: proto::GetPrivateUserInfo,
3075    response: Response<proto::GetPrivateUserInfo>,
3076    session: Session,
3077) -> Result<()> {
3078    let db = session.db().await;
3079
3080    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3081    let user = db
3082        .get_user_by_id(session.user_id)
3083        .await?
3084        .ok_or_else(|| anyhow!("user not found"))?;
3085    let flags = db.get_user_flags(session.user_id).await?;
3086
3087    response.send(proto::GetPrivateUserInfoResponse {
3088        metrics_id,
3089        staff: user.admin,
3090        flags,
3091    })?;
3092    Ok(())
3093}
3094
3095fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3096    match message {
3097        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3098        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3099        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3100        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3101        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3102            code: frame.code.into(),
3103            reason: frame.reason,
3104        })),
3105    }
3106}
3107
3108fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3109    match message {
3110        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3111        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3112        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3113        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3114        AxumMessage::Close(frame) => {
3115            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3116                code: frame.code.into(),
3117                reason: frame.reason,
3118            }))
3119        }
3120    }
3121}
3122
3123fn build_initial_channels_update(
3124    channels: ChannelsForUser,
3125    channel_invites: Vec<db::Channel>,
3126) -> proto::UpdateChannels {
3127    let mut update = proto::UpdateChannels::default();
3128
3129    for channel in channels.channels.channels {
3130        update.channels.push(proto::Channel {
3131            id: channel.id.to_proto(),
3132            name: channel.name,
3133        });
3134    }
3135
3136    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3137    update.unseen_channel_messages = channels.channel_messages;
3138    update.insert_edge = channels.channels.edges;
3139
3140    for (channel_id, participants) in channels.channel_participants {
3141        update
3142            .channel_participants
3143            .push(proto::ChannelParticipants {
3144                channel_id: channel_id.to_proto(),
3145                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3146            });
3147    }
3148
3149    update
3150        .channel_permissions
3151        .extend(
3152            channels
3153                .channels_with_admin_privileges
3154                .into_iter()
3155                .map(|id| proto::ChannelPermission {
3156                    channel_id: id.to_proto(),
3157                    is_admin: true,
3158                }),
3159        );
3160
3161    for channel in channel_invites {
3162        update.channel_invitations.push(proto::Channel {
3163            id: channel.id.to_proto(),
3164            name: channel.name,
3165        });
3166    }
3167
3168    update
3169}
3170
3171fn build_initial_contacts_update(
3172    contacts: Vec<db::Contact>,
3173    pool: &ConnectionPool,
3174) -> proto::UpdateContacts {
3175    let mut update = proto::UpdateContacts::default();
3176
3177    for contact in contacts {
3178        match contact {
3179            db::Contact::Accepted { user_id, busy } => {
3180                update.contacts.push(contact_for_user(user_id, busy, &pool));
3181            }
3182            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3183            db::Contact::Incoming { user_id } => {
3184                update
3185                    .incoming_requests
3186                    .push(proto::IncomingContactRequest {
3187                        requester_id: user_id.to_proto(),
3188                    })
3189            }
3190        }
3191    }
3192
3193    update
3194}
3195
3196fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3197    proto::Contact {
3198        user_id: user_id.to_proto(),
3199        online: pool.is_user_online(user_id),
3200        busy,
3201    }
3202}
3203
3204fn room_updated(room: &proto::Room, peer: &Peer) {
3205    broadcast(
3206        None,
3207        room.participants
3208            .iter()
3209            .filter_map(|participant| Some(participant.peer_id?.into())),
3210        |peer_id| {
3211            peer.send(
3212                peer_id.into(),
3213                proto::RoomUpdated {
3214                    room: Some(room.clone()),
3215                },
3216            )
3217        },
3218    );
3219}
3220
3221fn channel_updated(
3222    channel_id: ChannelId,
3223    room: &proto::Room,
3224    channel_members: &[UserId],
3225    peer: &Peer,
3226    pool: &ConnectionPool,
3227) {
3228    let participants = room
3229        .participants
3230        .iter()
3231        .map(|p| p.user_id)
3232        .collect::<Vec<_>>();
3233
3234    broadcast(
3235        None,
3236        channel_members
3237            .iter()
3238            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3239        |peer_id| {
3240            peer.send(
3241                peer_id.into(),
3242                proto::UpdateChannels {
3243                    channel_participants: vec![proto::ChannelParticipants {
3244                        channel_id: channel_id.to_proto(),
3245                        participant_user_ids: participants.clone(),
3246                    }],
3247                    ..Default::default()
3248                },
3249            )
3250        },
3251    );
3252}
3253
3254async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3255    let db = session.db().await;
3256
3257    let contacts = db.get_contacts(user_id).await?;
3258    let busy = db.is_user_busy(user_id).await?;
3259
3260    let pool = session.connection_pool().await;
3261    let updated_contact = contact_for_user(user_id, busy, &pool);
3262    for contact in contacts {
3263        if let db::Contact::Accepted {
3264            user_id: contact_user_id,
3265            ..
3266        } = contact
3267        {
3268            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3269                session
3270                    .peer
3271                    .send(
3272                        contact_conn_id,
3273                        proto::UpdateContacts {
3274                            contacts: vec![updated_contact.clone()],
3275                            remove_contacts: Default::default(),
3276                            incoming_requests: Default::default(),
3277                            remove_incoming_requests: Default::default(),
3278                            outgoing_requests: Default::default(),
3279                            remove_outgoing_requests: Default::default(),
3280                        },
3281                    )
3282                    .trace_err();
3283            }
3284        }
3285    }
3286    Ok(())
3287}
3288
3289async fn leave_room_for_session(session: &Session) -> Result<()> {
3290    let mut contacts_to_update = HashSet::default();
3291
3292    let room_id;
3293    let canceled_calls_to_user_ids;
3294    let live_kit_room;
3295    let delete_live_kit_room;
3296    let room;
3297    let channel_members;
3298    let channel_id;
3299
3300    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3301        contacts_to_update.insert(session.user_id);
3302
3303        for project in left_room.left_projects.values() {
3304            project_left(project, session);
3305        }
3306
3307        room_id = RoomId::from_proto(left_room.room.id);
3308        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3309        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3310        delete_live_kit_room = left_room.deleted;
3311        room = mem::take(&mut left_room.room);
3312        channel_members = mem::take(&mut left_room.channel_members);
3313        channel_id = left_room.channel_id;
3314
3315        room_updated(&room, &session.peer);
3316    } else {
3317        return Ok(());
3318    }
3319
3320    if let Some(channel_id) = channel_id {
3321        channel_updated(
3322            channel_id,
3323            &room,
3324            &channel_members,
3325            &session.peer,
3326            &*session.connection_pool().await,
3327        );
3328    }
3329
3330    {
3331        let pool = session.connection_pool().await;
3332        for canceled_user_id in canceled_calls_to_user_ids {
3333            for connection_id in pool.user_connection_ids(canceled_user_id) {
3334                session
3335                    .peer
3336                    .send(
3337                        connection_id,
3338                        proto::CallCanceled {
3339                            room_id: room_id.to_proto(),
3340                        },
3341                    )
3342                    .trace_err();
3343            }
3344            contacts_to_update.insert(canceled_user_id);
3345        }
3346    }
3347
3348    for contact_user_id in contacts_to_update {
3349        update_user_contacts(contact_user_id, &session).await?;
3350    }
3351
3352    if let Some(live_kit) = session.live_kit_client.as_ref() {
3353        live_kit
3354            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3355            .await
3356            .trace_err();
3357
3358        if delete_live_kit_room {
3359            live_kit.delete_room(live_kit_room).await.trace_err();
3360        }
3361    }
3362
3363    Ok(())
3364}
3365
3366async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3367    let left_channel_buffers = session
3368        .db()
3369        .await
3370        .leave_channel_buffers(session.connection_id)
3371        .await?;
3372
3373    for left_buffer in left_channel_buffers {
3374        channel_buffer_updated(
3375            session.connection_id,
3376            left_buffer.connections,
3377            &proto::UpdateChannelBufferCollaborators {
3378                channel_id: left_buffer.channel_id.to_proto(),
3379                collaborators: left_buffer.collaborators,
3380            },
3381            &session.peer,
3382        );
3383    }
3384
3385    Ok(())
3386}
3387
3388fn project_left(project: &db::LeftProject, session: &Session) {
3389    for connection_id in &project.connection_ids {
3390        if project.host_user_id == session.user_id {
3391            session
3392                .peer
3393                .send(
3394                    *connection_id,
3395                    proto::UnshareProject {
3396                        project_id: project.id.to_proto(),
3397                    },
3398                )
3399                .trace_err();
3400        } else {
3401            session
3402                .peer
3403                .send(
3404                    *connection_id,
3405                    proto::RemoveProjectCollaborator {
3406                        project_id: project.id.to_proto(),
3407                        peer_id: Some(session.connection_id.into()),
3408                    },
3409                )
3410                .trace_err();
3411        }
3412    }
3413}
3414
3415pub trait ResultExt {
3416    type Ok;
3417
3418    fn trace_err(self) -> Option<Self::Ok>;
3419}
3420
3421impl<T, E> ResultExt for Result<T, E>
3422where
3423    E: std::fmt::Debug,
3424{
3425    type Ok = T;
3426
3427    fn trace_err(self) -> Option<T> {
3428        match self {
3429            Ok(value) => Some(value),
3430            Err(error) => {
3431                tracing::error!("{:?}", error);
3432                None
3433            }
3434        }
3435    }
3436}