rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelVisibility, ChannelsForUser, Database,
   7        MessageId, ProjectId, RoomId, ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  42        RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73
  74lazy_static! {
  75    static ref METRIC_CONNECTIONS: IntGauge =
  76        register_int_gauge!("connections", "number of connections").unwrap();
  77    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  78        "shared_projects",
  79        "number of open projects with one or more guests"
  80    )
  81    .unwrap();
  82}
  83
  84type MessageHandler =
  85    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  86
  87struct Response<R> {
  88    peer: Arc<Peer>,
  89    receipt: Receipt<R>,
  90    responded: Arc<AtomicBool>,
  91}
  92
  93impl<R: RequestMessage> Response<R> {
  94    fn send(self, payload: R::Response) -> Result<()> {
  95        self.responded.store(true, SeqCst);
  96        self.peer.respond(self.receipt, payload)?;
  97        Ok(())
  98    }
  99}
 100
 101#[derive(Clone)]
 102struct Session {
 103    user_id: UserId,
 104    connection_id: ConnectionId,
 105    db: Arc<tokio::sync::Mutex<DbHandle>>,
 106    peer: Arc<Peer>,
 107    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 108    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 109    executor: Executor,
 110}
 111
 112impl Session {
 113    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.db.lock().await;
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        guard
 120    }
 121
 122    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.connection_pool.lock();
 126        ConnectionPoolGuard {
 127            guard,
 128            _not_send: PhantomData,
 129        }
 130    }
 131}
 132
 133impl fmt::Debug for Session {
 134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 135        f.debug_struct("Session")
 136            .field("user_id", &self.user_id)
 137            .field("connection_id", &self.connection_id)
 138            .finish()
 139    }
 140}
 141
 142struct DbHandle(Arc<Database>);
 143
 144impl Deref for DbHandle {
 145    type Target = Database;
 146
 147    fn deref(&self) -> &Self::Target {
 148        self.0.as_ref()
 149    }
 150}
 151
 152pub struct Server {
 153    id: parking_lot::Mutex<ServerId>,
 154    peer: Arc<Peer>,
 155    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 156    app_state: Arc<AppState>,
 157    executor: Executor,
 158    handlers: HashMap<TypeId, MessageHandler>,
 159    teardown: watch::Sender<()>,
 160}
 161
 162pub(crate) struct ConnectionPoolGuard<'a> {
 163    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 164    _not_send: PhantomData<Rc<()>>,
 165}
 166
 167#[derive(Serialize)]
 168pub struct ServerSnapshot<'a> {
 169    peer: &'a Peer,
 170    #[serde(serialize_with = "serialize_deref")]
 171    connection_pool: ConnectionPoolGuard<'a>,
 172}
 173
 174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 175where
 176    S: Serializer,
 177    T: Deref<Target = U>,
 178    U: Serialize,
 179{
 180    Serialize::serialize(value.deref(), serializer)
 181}
 182
 183impl Server {
 184    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 185        let mut server = Self {
 186            id: parking_lot::Mutex::new(id),
 187            peer: Peer::new(id.0 as u32),
 188            app_state,
 189            executor,
 190            connection_pool: Default::default(),
 191            handlers: Default::default(),
 192            teardown: watch::channel(()).0,
 193        };
 194
 195        server
 196            .add_request_handler(ping)
 197            .add_request_handler(create_room)
 198            .add_request_handler(join_room)
 199            .add_request_handler(rejoin_room)
 200            .add_request_handler(leave_room)
 201            .add_request_handler(call)
 202            .add_request_handler(cancel_call)
 203            .add_message_handler(decline_call)
 204            .add_request_handler(update_participant_location)
 205            .add_request_handler(share_project)
 206            .add_message_handler(unshare_project)
 207            .add_request_handler(join_project)
 208            .add_message_handler(leave_project)
 209            .add_request_handler(update_project)
 210            .add_request_handler(update_worktree)
 211            .add_message_handler(start_language_server)
 212            .add_message_handler(update_language_server)
 213            .add_message_handler(update_diagnostic_summary)
 214            .add_message_handler(update_worktree_settings)
 215            .add_message_handler(refresh_inlay_hints)
 216            .add_request_handler(forward_project_request::<proto::GetHover>)
 217            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetReferences>)
 220            .add_request_handler(forward_project_request::<proto::SearchProject>)
 221            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 222            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 226            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 227            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 228            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 229            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 230            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 231            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 232            .add_request_handler(forward_project_request::<proto::PerformRename>)
 233            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 234            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 235            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 236            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 242            .add_request_handler(forward_project_request::<proto::InlayHints>)
 243            .add_message_handler(create_buffer_for_peer)
 244            .add_request_handler(update_buffer)
 245            .add_message_handler(update_buffer_file)
 246            .add_message_handler(buffer_reloaded)
 247            .add_message_handler(buffer_saved)
 248            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 249            .add_request_handler(get_users)
 250            .add_request_handler(fuzzy_search_users)
 251            .add_request_handler(request_contact)
 252            .add_request_handler(remove_contact)
 253            .add_request_handler(respond_to_contact_request)
 254            .add_request_handler(create_channel)
 255            .add_request_handler(delete_channel)
 256            .add_request_handler(invite_channel_member)
 257            .add_request_handler(remove_channel_member)
 258            .add_request_handler(set_channel_member_role)
 259            .add_request_handler(set_channel_visibility)
 260            .add_request_handler(rename_channel)
 261            .add_request_handler(join_channel_buffer)
 262            .add_request_handler(leave_channel_buffer)
 263            .add_message_handler(update_channel_buffer)
 264            .add_request_handler(rejoin_channel_buffers)
 265            .add_request_handler(get_channel_members)
 266            .add_request_handler(respond_to_channel_invite)
 267            .add_request_handler(join_channel)
 268            .add_request_handler(join_channel_chat)
 269            .add_message_handler(leave_channel_chat)
 270            .add_request_handler(send_channel_message)
 271            .add_request_handler(remove_channel_message)
 272            .add_request_handler(get_channel_messages)
 273            .add_request_handler(link_channel)
 274            .add_request_handler(unlink_channel)
 275            .add_request_handler(move_channel)
 276            .add_request_handler(follow)
 277            .add_message_handler(unfollow)
 278            .add_message_handler(update_followers)
 279            .add_message_handler(update_diff_base)
 280            .add_request_handler(get_private_user_info)
 281            .add_message_handler(acknowledge_channel_message)
 282            .add_message_handler(acknowledge_buffer_version);
 283
 284        Arc::new(server)
 285    }
 286
 287    pub async fn start(&self) -> Result<()> {
 288        let server_id = *self.id.lock();
 289        let app_state = self.app_state.clone();
 290        let peer = self.peer.clone();
 291        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 292        let pool = self.connection_pool.clone();
 293        let live_kit_client = self.app_state.live_kit_client.clone();
 294
 295        let span = info_span!("start server");
 296        self.executor.spawn_detached(
 297            async move {
 298                tracing::info!("waiting for cleanup timeout");
 299                timeout.await;
 300                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 301                if let Some((room_ids, channel_ids)) = app_state
 302                    .db
 303                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 304                    .await
 305                    .trace_err()
 306                {
 307                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 308                    tracing::info!(
 309                        stale_channel_buffer_count = channel_ids.len(),
 310                        "retrieved stale channel buffers"
 311                    );
 312
 313                    for channel_id in channel_ids {
 314                        if let Some(refreshed_channel_buffer) = app_state
 315                            .db
 316                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 317                            .await
 318                            .trace_err()
 319                        {
 320                            for connection_id in refreshed_channel_buffer.connection_ids {
 321                                peer.send(
 322                                    connection_id,
 323                                    proto::UpdateChannelBufferCollaborators {
 324                                        channel_id: channel_id.to_proto(),
 325                                        collaborators: refreshed_channel_buffer
 326                                            .collaborators
 327                                            .clone(),
 328                                    },
 329                                )
 330                                .trace_err();
 331                            }
 332                        }
 333                    }
 334
 335                    for room_id in room_ids {
 336                        let mut contacts_to_update = HashSet::default();
 337                        let mut canceled_calls_to_user_ids = Vec::new();
 338                        let mut live_kit_room = String::new();
 339                        let mut delete_live_kit_room = false;
 340
 341                        if let Some(mut refreshed_room) = app_state
 342                            .db
 343                            .clear_stale_room_participants(room_id, server_id)
 344                            .await
 345                            .trace_err()
 346                        {
 347                            tracing::info!(
 348                                room_id = room_id.0,
 349                                new_participant_count = refreshed_room.room.participants.len(),
 350                                "refreshed room"
 351                            );
 352                            room_updated(&refreshed_room.room, &peer);
 353                            if let Some(channel_id) = refreshed_room.channel_id {
 354                                channel_updated(
 355                                    channel_id,
 356                                    &refreshed_room.room,
 357                                    &refreshed_room.channel_members,
 358                                    &peer,
 359                                    &*pool.lock(),
 360                                );
 361                            }
 362                            contacts_to_update
 363                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 364                            contacts_to_update
 365                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 366                            canceled_calls_to_user_ids =
 367                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 368                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 369                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 370                        }
 371
 372                        {
 373                            let pool = pool.lock();
 374                            for canceled_user_id in canceled_calls_to_user_ids {
 375                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 376                                    peer.send(
 377                                        connection_id,
 378                                        proto::CallCanceled {
 379                                            room_id: room_id.to_proto(),
 380                                        },
 381                                    )
 382                                    .trace_err();
 383                                }
 384                            }
 385                        }
 386
 387                        for user_id in contacts_to_update {
 388                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 389                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 390                            if let Some((busy, contacts)) = busy.zip(contacts) {
 391                                let pool = pool.lock();
 392                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 393                                for contact in contacts {
 394                                    if let db::Contact::Accepted {
 395                                        user_id: contact_user_id,
 396                                        ..
 397                                    } = contact
 398                                    {
 399                                        for contact_conn_id in
 400                                            pool.user_connection_ids(contact_user_id)
 401                                        {
 402                                            peer.send(
 403                                                contact_conn_id,
 404                                                proto::UpdateContacts {
 405                                                    contacts: vec![updated_contact.clone()],
 406                                                    remove_contacts: Default::default(),
 407                                                    incoming_requests: Default::default(),
 408                                                    remove_incoming_requests: Default::default(),
 409                                                    outgoing_requests: Default::default(),
 410                                                    remove_outgoing_requests: Default::default(),
 411                                                },
 412                                            )
 413                                            .trace_err();
 414                                        }
 415                                    }
 416                                }
 417                            }
 418                        }
 419
 420                        if let Some(live_kit) = live_kit_client.as_ref() {
 421                            if delete_live_kit_room {
 422                                live_kit.delete_room(live_kit_room).await.trace_err();
 423                            }
 424                        }
 425                    }
 426                }
 427
 428                app_state
 429                    .db
 430                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 431                    .await
 432                    .trace_err();
 433            }
 434            .instrument(span),
 435        );
 436        Ok(())
 437    }
 438
 439    pub fn teardown(&self) {
 440        self.peer.teardown();
 441        self.connection_pool.lock().reset();
 442        let _ = self.teardown.send(());
 443    }
 444
 445    #[cfg(test)]
 446    pub fn reset(&self, id: ServerId) {
 447        self.teardown();
 448        *self.id.lock() = id;
 449        self.peer.reset(id.0 as u32);
 450    }
 451
 452    #[cfg(test)]
 453    pub fn id(&self) -> ServerId {
 454        *self.id.lock()
 455    }
 456
 457    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 458    where
 459        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 460        Fut: 'static + Send + Future<Output = Result<()>>,
 461        M: EnvelopedMessage,
 462    {
 463        let prev_handler = self.handlers.insert(
 464            TypeId::of::<M>(),
 465            Box::new(move |envelope, session| {
 466                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 467                let span = info_span!(
 468                    "handle message",
 469                    payload_type = envelope.payload_type_name()
 470                );
 471                span.in_scope(|| {
 472                    tracing::info!(
 473                        payload_type = envelope.payload_type_name(),
 474                        "message received"
 475                    );
 476                });
 477                let start_time = Instant::now();
 478                let future = (handler)(*envelope, session);
 479                async move {
 480                    let result = future.await;
 481                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 482                    match result {
 483                        Err(error) => {
 484                            tracing::error!(%error, ?duration_ms, "error handling message")
 485                        }
 486                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 487                    }
 488                }
 489                .instrument(span)
 490                .boxed()
 491            }),
 492        );
 493        if prev_handler.is_some() {
 494            panic!("registered a handler for the same message twice");
 495        }
 496        self
 497    }
 498
 499    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 500    where
 501        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 502        Fut: 'static + Send + Future<Output = Result<()>>,
 503        M: EnvelopedMessage,
 504    {
 505        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 506        self
 507    }
 508
 509    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 512        Fut: Send + Future<Output = Result<()>>,
 513        M: RequestMessage,
 514    {
 515        let handler = Arc::new(handler);
 516        self.add_handler(move |envelope, session| {
 517            let receipt = envelope.receipt();
 518            let handler = handler.clone();
 519            async move {
 520                let peer = session.peer.clone();
 521                let responded = Arc::new(AtomicBool::default());
 522                let response = Response {
 523                    peer: peer.clone(),
 524                    responded: responded.clone(),
 525                    receipt,
 526                };
 527                match (handler)(envelope.payload, response, session).await {
 528                    Ok(()) => {
 529                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 530                            Ok(())
 531                        } else {
 532                            Err(anyhow!("handler did not send a response"))?
 533                        }
 534                    }
 535                    Err(error) => {
 536                        peer.respond_with_error(
 537                            receipt,
 538                            proto::Error {
 539                                message: error.to_string(),
 540                            },
 541                        )?;
 542                        Err(error)
 543                    }
 544                }
 545            }
 546        })
 547    }
 548
 549    pub fn handle_connection(
 550        self: &Arc<Self>,
 551        connection: Connection,
 552        address: String,
 553        user: User,
 554        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 555        executor: Executor,
 556    ) -> impl Future<Output = Result<()>> {
 557        let this = self.clone();
 558        let user_id = user.id;
 559        let login = user.github_login;
 560        let span = info_span!("handle connection", %user_id, %login, %address);
 561        let mut teardown = self.teardown.subscribe();
 562        async move {
 563            let (connection_id, handle_io, mut incoming_rx) = this
 564                .peer
 565                .add_connection(connection, {
 566                    let executor = executor.clone();
 567                    move |duration| executor.sleep(duration)
 568                });
 569
 570            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 571            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 572            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 573
 574            if let Some(send_connection_id) = send_connection_id.take() {
 575                let _ = send_connection_id.send(connection_id);
 576            }
 577
 578            if !user.connected_once {
 579                this.peer.send(connection_id, proto::ShowContacts {})?;
 580                this.app_state.db.set_user_connected_once(user_id, true).await?;
 581            }
 582
 583            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 584                this.app_state.db.get_contacts(user_id),
 585                this.app_state.db.get_channels_for_user(user_id),
 586                this.app_state.db.get_channel_invites_for_user(user_id)
 587            ).await?;
 588
 589            {
 590                let mut pool = this.connection_pool.lock();
 591                pool.add_connection(connection_id, user_id, user.admin);
 592                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 593                this.peer.send(connection_id, build_initial_channels_update(
 594                    channels_for_user,
 595                    channel_invites
 596                ))?;
 597            }
 598
 599            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 600                this.peer.send(connection_id, incoming_call)?;
 601            }
 602
 603            let session = Session {
 604                user_id,
 605                connection_id,
 606                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 607                peer: this.peer.clone(),
 608                connection_pool: this.connection_pool.clone(),
 609                live_kit_client: this.app_state.live_kit_client.clone(),
 610                executor: executor.clone(),
 611            };
 612            update_user_contacts(user_id, &session).await?;
 613
 614            let handle_io = handle_io.fuse();
 615            futures::pin_mut!(handle_io);
 616
 617            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 618            // This prevents deadlocks when e.g., client A performs a request to client B and
 619            // client B performs a request to client A. If both clients stop processing further
 620            // messages until their respective request completes, they won't have a chance to
 621            // respond to the other client's request and cause a deadlock.
 622            //
 623            // This arrangement ensures we will attempt to process earlier messages first, but fall
 624            // back to processing messages arrived later in the spirit of making progress.
 625            let mut foreground_message_handlers = FuturesUnordered::new();
 626            let concurrent_handlers = Arc::new(Semaphore::new(256));
 627            loop {
 628                let next_message = async {
 629                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 630                    let message = incoming_rx.next().await;
 631                    (permit, message)
 632                }.fuse();
 633                futures::pin_mut!(next_message);
 634                futures::select_biased! {
 635                    _ = teardown.changed().fuse() => return Ok(()),
 636                    result = handle_io => {
 637                        if let Err(error) = result {
 638                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 639                        }
 640                        break;
 641                    }
 642                    _ = foreground_message_handlers.next() => {}
 643                    next_message = next_message => {
 644                        let (permit, message) = next_message;
 645                        if let Some(message) = message {
 646                            let type_name = message.payload_type_name();
 647                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 648                            let span_enter = span.enter();
 649                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 650                                let is_background = message.is_background();
 651                                let handle_message = (handler)(message, session.clone());
 652                                drop(span_enter);
 653
 654                                let handle_message = async move {
 655                                    handle_message.await;
 656                                    drop(permit);
 657                                }.instrument(span);
 658                                if is_background {
 659                                    executor.spawn_detached(handle_message);
 660                                } else {
 661                                    foreground_message_handlers.push(handle_message);
 662                                }
 663                            } else {
 664                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 665                            }
 666                        } else {
 667                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 668                            break;
 669                        }
 670                    }
 671                }
 672            }
 673
 674            drop(foreground_message_handlers);
 675            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 676            if let Err(error) = connection_lost(session, teardown, executor).await {
 677                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 678            }
 679
 680            Ok(())
 681        }.instrument(span)
 682    }
 683
 684    pub async fn invite_code_redeemed(
 685        self: &Arc<Self>,
 686        inviter_id: UserId,
 687        invitee_id: UserId,
 688    ) -> Result<()> {
 689        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 690            if let Some(code) = &user.invite_code {
 691                let pool = self.connection_pool.lock();
 692                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 693                for connection_id in pool.user_connection_ids(inviter_id) {
 694                    self.peer.send(
 695                        connection_id,
 696                        proto::UpdateContacts {
 697                            contacts: vec![invitee_contact.clone()],
 698                            ..Default::default()
 699                        },
 700                    )?;
 701                    self.peer.send(
 702                        connection_id,
 703                        proto::UpdateInviteInfo {
 704                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 705                            count: user.invite_count as u32,
 706                        },
 707                    )?;
 708                }
 709            }
 710        }
 711        Ok(())
 712    }
 713
 714    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 715        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 716            if let Some(invite_code) = &user.invite_code {
 717                let pool = self.connection_pool.lock();
 718                for connection_id in pool.user_connection_ids(user_id) {
 719                    self.peer.send(
 720                        connection_id,
 721                        proto::UpdateInviteInfo {
 722                            url: format!(
 723                                "{}{}",
 724                                self.app_state.config.invite_link_prefix, invite_code
 725                            ),
 726                            count: user.invite_count as u32,
 727                        },
 728                    )?;
 729                }
 730            }
 731        }
 732        Ok(())
 733    }
 734
 735    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 736        ServerSnapshot {
 737            connection_pool: ConnectionPoolGuard {
 738                guard: self.connection_pool.lock(),
 739                _not_send: PhantomData,
 740            },
 741            peer: &self.peer,
 742        }
 743    }
 744}
 745
 746impl<'a> Deref for ConnectionPoolGuard<'a> {
 747    type Target = ConnectionPool;
 748
 749    fn deref(&self) -> &Self::Target {
 750        &*self.guard
 751    }
 752}
 753
 754impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 755    fn deref_mut(&mut self) -> &mut Self::Target {
 756        &mut *self.guard
 757    }
 758}
 759
 760impl<'a> Drop for ConnectionPoolGuard<'a> {
 761    fn drop(&mut self) {
 762        #[cfg(test)]
 763        self.check_invariants();
 764    }
 765}
 766
 767fn broadcast<F>(
 768    sender_id: Option<ConnectionId>,
 769    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 770    mut f: F,
 771) where
 772    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 773{
 774    for receiver_id in receiver_ids {
 775        if Some(receiver_id) != sender_id {
 776            if let Err(error) = f(receiver_id) {
 777                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 778            }
 779        }
 780    }
 781}
 782
 783lazy_static! {
 784    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 785}
 786
 787pub struct ProtocolVersion(u32);
 788
 789impl Header for ProtocolVersion {
 790    fn name() -> &'static HeaderName {
 791        &ZED_PROTOCOL_VERSION
 792    }
 793
 794    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 795    where
 796        Self: Sized,
 797        I: Iterator<Item = &'i axum::http::HeaderValue>,
 798    {
 799        let version = values
 800            .next()
 801            .ok_or_else(axum::headers::Error::invalid)?
 802            .to_str()
 803            .map_err(|_| axum::headers::Error::invalid())?
 804            .parse()
 805            .map_err(|_| axum::headers::Error::invalid())?;
 806        Ok(Self(version))
 807    }
 808
 809    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 810        values.extend([self.0.to_string().parse().unwrap()]);
 811    }
 812}
 813
 814pub fn routes(server: Arc<Server>) -> Router<Body> {
 815    Router::new()
 816        .route("/rpc", get(handle_websocket_request))
 817        .layer(
 818            ServiceBuilder::new()
 819                .layer(Extension(server.app_state.clone()))
 820                .layer(middleware::from_fn(auth::validate_header)),
 821        )
 822        .route("/metrics", get(handle_metrics))
 823        .layer(Extension(server))
 824}
 825
 826pub async fn handle_websocket_request(
 827    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 828    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 829    Extension(server): Extension<Arc<Server>>,
 830    Extension(user): Extension<User>,
 831    ws: WebSocketUpgrade,
 832) -> axum::response::Response {
 833    if protocol_version != rpc::PROTOCOL_VERSION {
 834        return (
 835            StatusCode::UPGRADE_REQUIRED,
 836            "client must be upgraded".to_string(),
 837        )
 838            .into_response();
 839    }
 840    let socket_address = socket_address.to_string();
 841    ws.on_upgrade(move |socket| {
 842        use util::ResultExt;
 843        let socket = socket
 844            .map_ok(to_tungstenite_message)
 845            .err_into()
 846            .with(|message| async move { Ok(to_axum_message(message)) });
 847        let connection = Connection::new(Box::pin(socket));
 848        async move {
 849            server
 850                .handle_connection(connection, socket_address, user, None, Executor::Production)
 851                .await
 852                .log_err();
 853        }
 854    })
 855}
 856
 857pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 858    let connections = server
 859        .connection_pool
 860        .lock()
 861        .connections()
 862        .filter(|connection| !connection.admin)
 863        .count();
 864
 865    METRIC_CONNECTIONS.set(connections as _);
 866
 867    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 868    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 869
 870    let encoder = prometheus::TextEncoder::new();
 871    let metric_families = prometheus::gather();
 872    let encoded_metrics = encoder
 873        .encode_to_string(&metric_families)
 874        .map_err(|err| anyhow!("{}", err))?;
 875    Ok(encoded_metrics)
 876}
 877
 878#[instrument(err, skip(executor))]
 879async fn connection_lost(
 880    session: Session,
 881    mut teardown: watch::Receiver<()>,
 882    executor: Executor,
 883) -> Result<()> {
 884    session.peer.disconnect(session.connection_id);
 885    session
 886        .connection_pool()
 887        .await
 888        .remove_connection(session.connection_id)?;
 889
 890    session
 891        .db()
 892        .await
 893        .connection_lost(session.connection_id)
 894        .await
 895        .trace_err();
 896
 897    futures::select_biased! {
 898        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 899            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 900            leave_room_for_session(&session).await.trace_err();
 901            leave_channel_buffers_for_session(&session)
 902                .await
 903                .trace_err();
 904
 905            if !session
 906                .connection_pool()
 907                .await
 908                .is_user_online(session.user_id)
 909            {
 910                let db = session.db().await;
 911                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 912                    room_updated(&room, &session.peer);
 913                }
 914            }
 915
 916            update_user_contacts(session.user_id, &session).await?;
 917        }
 918        _ = teardown.changed().fuse() => {}
 919    }
 920
 921    Ok(())
 922}
 923
 924async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 925    response.send(proto::Ack {})?;
 926    Ok(())
 927}
 928
 929async fn create_room(
 930    _request: proto::CreateRoom,
 931    response: Response<proto::CreateRoom>,
 932    session: Session,
 933) -> Result<()> {
 934    let live_kit_room = nanoid::nanoid!(30);
 935
 936    let live_kit_connection_info = {
 937        let live_kit_room = live_kit_room.clone();
 938        let live_kit = session.live_kit_client.as_ref();
 939
 940        util::async_iife!({
 941            let live_kit = live_kit?;
 942
 943            let token = live_kit
 944                .room_token(&live_kit_room, &session.user_id.to_string())
 945                .trace_err()?;
 946
 947            Some(proto::LiveKitConnectionInfo {
 948                server_url: live_kit.url().into(),
 949                token,
 950            })
 951        })
 952    }
 953    .await;
 954
 955    let room = session
 956        .db()
 957        .await
 958        .create_room(
 959            session.user_id,
 960            session.connection_id,
 961            &live_kit_room,
 962            RELEASE_CHANNEL_NAME.as_str(),
 963        )
 964        .await?;
 965
 966    response.send(proto::CreateRoomResponse {
 967        room: Some(room.clone()),
 968        live_kit_connection_info,
 969    })?;
 970
 971    update_user_contacts(session.user_id, &session).await?;
 972    Ok(())
 973}
 974
 975async fn join_room(
 976    request: proto::JoinRoom,
 977    response: Response<proto::JoinRoom>,
 978    session: Session,
 979) -> Result<()> {
 980    let room_id = RoomId::from_proto(request.id);
 981
 982    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 983
 984    if let Some(channel_id) = channel_id {
 985        return join_channel_internal(channel_id, Box::new(response), session).await;
 986    }
 987
 988    let joined_room = {
 989        let room = session
 990            .db()
 991            .await
 992            .join_room(
 993                room_id,
 994                session.user_id,
 995                session.connection_id,
 996                RELEASE_CHANNEL_NAME.as_str(),
 997            )
 998            .await?;
 999        room_updated(&room.room, &session.peer);
1000        room.into_inner()
1001    };
1002
1003    for connection_id in session
1004        .connection_pool()
1005        .await
1006        .user_connection_ids(session.user_id)
1007    {
1008        session
1009            .peer
1010            .send(
1011                connection_id,
1012                proto::CallCanceled {
1013                    room_id: room_id.to_proto(),
1014                },
1015            )
1016            .trace_err();
1017    }
1018
1019    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1020        if let Some(token) = live_kit
1021            .room_token(
1022                &joined_room.room.live_kit_room,
1023                &session.user_id.to_string(),
1024            )
1025            .trace_err()
1026        {
1027            Some(proto::LiveKitConnectionInfo {
1028                server_url: live_kit.url().into(),
1029                token,
1030            })
1031        } else {
1032            None
1033        }
1034    } else {
1035        None
1036    };
1037
1038    response.send(proto::JoinRoomResponse {
1039        room: Some(joined_room.room),
1040        channel_id: None,
1041        live_kit_connection_info,
1042    })?;
1043
1044    update_user_contacts(session.user_id, &session).await?;
1045    Ok(())
1046}
1047
1048async fn rejoin_room(
1049    request: proto::RejoinRoom,
1050    response: Response<proto::RejoinRoom>,
1051    session: Session,
1052) -> Result<()> {
1053    let room;
1054    let channel_id;
1055    let channel_members;
1056    {
1057        let mut rejoined_room = session
1058            .db()
1059            .await
1060            .rejoin_room(request, session.user_id, session.connection_id)
1061            .await?;
1062
1063        response.send(proto::RejoinRoomResponse {
1064            room: Some(rejoined_room.room.clone()),
1065            reshared_projects: rejoined_room
1066                .reshared_projects
1067                .iter()
1068                .map(|project| proto::ResharedProject {
1069                    id: project.id.to_proto(),
1070                    collaborators: project
1071                        .collaborators
1072                        .iter()
1073                        .map(|collaborator| collaborator.to_proto())
1074                        .collect(),
1075                })
1076                .collect(),
1077            rejoined_projects: rejoined_room
1078                .rejoined_projects
1079                .iter()
1080                .map(|rejoined_project| proto::RejoinedProject {
1081                    id: rejoined_project.id.to_proto(),
1082                    worktrees: rejoined_project
1083                        .worktrees
1084                        .iter()
1085                        .map(|worktree| proto::WorktreeMetadata {
1086                            id: worktree.id,
1087                            root_name: worktree.root_name.clone(),
1088                            visible: worktree.visible,
1089                            abs_path: worktree.abs_path.clone(),
1090                        })
1091                        .collect(),
1092                    collaborators: rejoined_project
1093                        .collaborators
1094                        .iter()
1095                        .map(|collaborator| collaborator.to_proto())
1096                        .collect(),
1097                    language_servers: rejoined_project.language_servers.clone(),
1098                })
1099                .collect(),
1100        })?;
1101        room_updated(&rejoined_room.room, &session.peer);
1102
1103        for project in &rejoined_room.reshared_projects {
1104            for collaborator in &project.collaborators {
1105                session
1106                    .peer
1107                    .send(
1108                        collaborator.connection_id,
1109                        proto::UpdateProjectCollaborator {
1110                            project_id: project.id.to_proto(),
1111                            old_peer_id: Some(project.old_connection_id.into()),
1112                            new_peer_id: Some(session.connection_id.into()),
1113                        },
1114                    )
1115                    .trace_err();
1116            }
1117
1118            broadcast(
1119                Some(session.connection_id),
1120                project
1121                    .collaborators
1122                    .iter()
1123                    .map(|collaborator| collaborator.connection_id),
1124                |connection_id| {
1125                    session.peer.forward_send(
1126                        session.connection_id,
1127                        connection_id,
1128                        proto::UpdateProject {
1129                            project_id: project.id.to_proto(),
1130                            worktrees: project.worktrees.clone(),
1131                        },
1132                    )
1133                },
1134            );
1135        }
1136
1137        for project in &rejoined_room.rejoined_projects {
1138            for collaborator in &project.collaborators {
1139                session
1140                    .peer
1141                    .send(
1142                        collaborator.connection_id,
1143                        proto::UpdateProjectCollaborator {
1144                            project_id: project.id.to_proto(),
1145                            old_peer_id: Some(project.old_connection_id.into()),
1146                            new_peer_id: Some(session.connection_id.into()),
1147                        },
1148                    )
1149                    .trace_err();
1150            }
1151        }
1152
1153        for project in &mut rejoined_room.rejoined_projects {
1154            for worktree in mem::take(&mut project.worktrees) {
1155                #[cfg(any(test, feature = "test-support"))]
1156                const MAX_CHUNK_SIZE: usize = 2;
1157                #[cfg(not(any(test, feature = "test-support")))]
1158                const MAX_CHUNK_SIZE: usize = 256;
1159
1160                // Stream this worktree's entries.
1161                let message = proto::UpdateWorktree {
1162                    project_id: project.id.to_proto(),
1163                    worktree_id: worktree.id,
1164                    abs_path: worktree.abs_path.clone(),
1165                    root_name: worktree.root_name,
1166                    updated_entries: worktree.updated_entries,
1167                    removed_entries: worktree.removed_entries,
1168                    scan_id: worktree.scan_id,
1169                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1170                    updated_repositories: worktree.updated_repositories,
1171                    removed_repositories: worktree.removed_repositories,
1172                };
1173                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1174                    session.peer.send(session.connection_id, update.clone())?;
1175                }
1176
1177                // Stream this worktree's diagnostics.
1178                for summary in worktree.diagnostic_summaries {
1179                    session.peer.send(
1180                        session.connection_id,
1181                        proto::UpdateDiagnosticSummary {
1182                            project_id: project.id.to_proto(),
1183                            worktree_id: worktree.id,
1184                            summary: Some(summary),
1185                        },
1186                    )?;
1187                }
1188
1189                for settings_file in worktree.settings_files {
1190                    session.peer.send(
1191                        session.connection_id,
1192                        proto::UpdateWorktreeSettings {
1193                            project_id: project.id.to_proto(),
1194                            worktree_id: worktree.id,
1195                            path: settings_file.path,
1196                            content: Some(settings_file.content),
1197                        },
1198                    )?;
1199                }
1200            }
1201
1202            for language_server in &project.language_servers {
1203                session.peer.send(
1204                    session.connection_id,
1205                    proto::UpdateLanguageServer {
1206                        project_id: project.id.to_proto(),
1207                        language_server_id: language_server.id,
1208                        variant: Some(
1209                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1210                                proto::LspDiskBasedDiagnosticsUpdated {},
1211                            ),
1212                        ),
1213                    },
1214                )?;
1215            }
1216        }
1217
1218        let rejoined_room = rejoined_room.into_inner();
1219
1220        room = rejoined_room.room;
1221        channel_id = rejoined_room.channel_id;
1222        channel_members = rejoined_room.channel_members;
1223    }
1224
1225    if let Some(channel_id) = channel_id {
1226        channel_updated(
1227            channel_id,
1228            &room,
1229            &channel_members,
1230            &session.peer,
1231            &*session.connection_pool().await,
1232        );
1233    }
1234
1235    update_user_contacts(session.user_id, &session).await?;
1236    Ok(())
1237}
1238
1239async fn leave_room(
1240    _: proto::LeaveRoom,
1241    response: Response<proto::LeaveRoom>,
1242    session: Session,
1243) -> Result<()> {
1244    leave_room_for_session(&session).await?;
1245    response.send(proto::Ack {})?;
1246    Ok(())
1247}
1248
1249async fn call(
1250    request: proto::Call,
1251    response: Response<proto::Call>,
1252    session: Session,
1253) -> Result<()> {
1254    let room_id = RoomId::from_proto(request.room_id);
1255    let calling_user_id = session.user_id;
1256    let calling_connection_id = session.connection_id;
1257    let called_user_id = UserId::from_proto(request.called_user_id);
1258    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1259    if !session
1260        .db()
1261        .await
1262        .has_contact(calling_user_id, called_user_id)
1263        .await?
1264    {
1265        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1266    }
1267
1268    let incoming_call = {
1269        let (room, incoming_call) = &mut *session
1270            .db()
1271            .await
1272            .call(
1273                room_id,
1274                calling_user_id,
1275                calling_connection_id,
1276                called_user_id,
1277                initial_project_id,
1278            )
1279            .await?;
1280        room_updated(&room, &session.peer);
1281        mem::take(incoming_call)
1282    };
1283    update_user_contacts(called_user_id, &session).await?;
1284
1285    let mut calls = session
1286        .connection_pool()
1287        .await
1288        .user_connection_ids(called_user_id)
1289        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1290        .collect::<FuturesUnordered<_>>();
1291
1292    while let Some(call_response) = calls.next().await {
1293        match call_response.as_ref() {
1294            Ok(_) => {
1295                response.send(proto::Ack {})?;
1296                return Ok(());
1297            }
1298            Err(_) => {
1299                call_response.trace_err();
1300            }
1301        }
1302    }
1303
1304    {
1305        let room = session
1306            .db()
1307            .await
1308            .call_failed(room_id, called_user_id)
1309            .await?;
1310        room_updated(&room, &session.peer);
1311    }
1312    update_user_contacts(called_user_id, &session).await?;
1313
1314    Err(anyhow!("failed to ring user"))?
1315}
1316
1317async fn cancel_call(
1318    request: proto::CancelCall,
1319    response: Response<proto::CancelCall>,
1320    session: Session,
1321) -> Result<()> {
1322    let called_user_id = UserId::from_proto(request.called_user_id);
1323    let room_id = RoomId::from_proto(request.room_id);
1324    {
1325        let room = session
1326            .db()
1327            .await
1328            .cancel_call(room_id, session.connection_id, called_user_id)
1329            .await?;
1330        room_updated(&room, &session.peer);
1331    }
1332
1333    for connection_id in session
1334        .connection_pool()
1335        .await
1336        .user_connection_ids(called_user_id)
1337    {
1338        session
1339            .peer
1340            .send(
1341                connection_id,
1342                proto::CallCanceled {
1343                    room_id: room_id.to_proto(),
1344                },
1345            )
1346            .trace_err();
1347    }
1348    response.send(proto::Ack {})?;
1349
1350    update_user_contacts(called_user_id, &session).await?;
1351    Ok(())
1352}
1353
1354async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1355    let room_id = RoomId::from_proto(message.room_id);
1356    {
1357        let room = session
1358            .db()
1359            .await
1360            .decline_call(Some(room_id), session.user_id)
1361            .await?
1362            .ok_or_else(|| anyhow!("failed to decline call"))?;
1363        room_updated(&room, &session.peer);
1364    }
1365
1366    for connection_id in session
1367        .connection_pool()
1368        .await
1369        .user_connection_ids(session.user_id)
1370    {
1371        session
1372            .peer
1373            .send(
1374                connection_id,
1375                proto::CallCanceled {
1376                    room_id: room_id.to_proto(),
1377                },
1378            )
1379            .trace_err();
1380    }
1381    update_user_contacts(session.user_id, &session).await?;
1382    Ok(())
1383}
1384
1385async fn update_participant_location(
1386    request: proto::UpdateParticipantLocation,
1387    response: Response<proto::UpdateParticipantLocation>,
1388    session: Session,
1389) -> Result<()> {
1390    let room_id = RoomId::from_proto(request.room_id);
1391    let location = request
1392        .location
1393        .ok_or_else(|| anyhow!("invalid location"))?;
1394
1395    let db = session.db().await;
1396    let room = db
1397        .update_room_participant_location(room_id, session.connection_id, location)
1398        .await?;
1399
1400    room_updated(&room, &session.peer);
1401    response.send(proto::Ack {})?;
1402    Ok(())
1403}
1404
1405async fn share_project(
1406    request: proto::ShareProject,
1407    response: Response<proto::ShareProject>,
1408    session: Session,
1409) -> Result<()> {
1410    let (project_id, room) = &*session
1411        .db()
1412        .await
1413        .share_project(
1414            RoomId::from_proto(request.room_id),
1415            session.connection_id,
1416            &request.worktrees,
1417        )
1418        .await?;
1419    response.send(proto::ShareProjectResponse {
1420        project_id: project_id.to_proto(),
1421    })?;
1422    room_updated(&room, &session.peer);
1423
1424    Ok(())
1425}
1426
1427async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1428    let project_id = ProjectId::from_proto(message.project_id);
1429
1430    let (room, guest_connection_ids) = &*session
1431        .db()
1432        .await
1433        .unshare_project(project_id, session.connection_id)
1434        .await?;
1435
1436    broadcast(
1437        Some(session.connection_id),
1438        guest_connection_ids.iter().copied(),
1439        |conn_id| session.peer.send(conn_id, message.clone()),
1440    );
1441    room_updated(&room, &session.peer);
1442
1443    Ok(())
1444}
1445
1446async fn join_project(
1447    request: proto::JoinProject,
1448    response: Response<proto::JoinProject>,
1449    session: Session,
1450) -> Result<()> {
1451    let project_id = ProjectId::from_proto(request.project_id);
1452    let guest_user_id = session.user_id;
1453
1454    tracing::info!(%project_id, "join project");
1455
1456    let (project, replica_id) = &mut *session
1457        .db()
1458        .await
1459        .join_project(project_id, session.connection_id)
1460        .await?;
1461
1462    let collaborators = project
1463        .collaborators
1464        .iter()
1465        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1466        .map(|collaborator| collaborator.to_proto())
1467        .collect::<Vec<_>>();
1468
1469    let worktrees = project
1470        .worktrees
1471        .iter()
1472        .map(|(id, worktree)| proto::WorktreeMetadata {
1473            id: *id,
1474            root_name: worktree.root_name.clone(),
1475            visible: worktree.visible,
1476            abs_path: worktree.abs_path.clone(),
1477        })
1478        .collect::<Vec<_>>();
1479
1480    for collaborator in &collaborators {
1481        session
1482            .peer
1483            .send(
1484                collaborator.peer_id.unwrap().into(),
1485                proto::AddProjectCollaborator {
1486                    project_id: project_id.to_proto(),
1487                    collaborator: Some(proto::Collaborator {
1488                        peer_id: Some(session.connection_id.into()),
1489                        replica_id: replica_id.0 as u32,
1490                        user_id: guest_user_id.to_proto(),
1491                    }),
1492                },
1493            )
1494            .trace_err();
1495    }
1496
1497    // First, we send the metadata associated with each worktree.
1498    response.send(proto::JoinProjectResponse {
1499        worktrees: worktrees.clone(),
1500        replica_id: replica_id.0 as u32,
1501        collaborators: collaborators.clone(),
1502        language_servers: project.language_servers.clone(),
1503    })?;
1504
1505    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1506        #[cfg(any(test, feature = "test-support"))]
1507        const MAX_CHUNK_SIZE: usize = 2;
1508        #[cfg(not(any(test, feature = "test-support")))]
1509        const MAX_CHUNK_SIZE: usize = 256;
1510
1511        // Stream this worktree's entries.
1512        let message = proto::UpdateWorktree {
1513            project_id: project_id.to_proto(),
1514            worktree_id,
1515            abs_path: worktree.abs_path.clone(),
1516            root_name: worktree.root_name,
1517            updated_entries: worktree.entries,
1518            removed_entries: Default::default(),
1519            scan_id: worktree.scan_id,
1520            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1521            updated_repositories: worktree.repository_entries.into_values().collect(),
1522            removed_repositories: Default::default(),
1523        };
1524        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1525            session.peer.send(session.connection_id, update.clone())?;
1526        }
1527
1528        // Stream this worktree's diagnostics.
1529        for summary in worktree.diagnostic_summaries {
1530            session.peer.send(
1531                session.connection_id,
1532                proto::UpdateDiagnosticSummary {
1533                    project_id: project_id.to_proto(),
1534                    worktree_id: worktree.id,
1535                    summary: Some(summary),
1536                },
1537            )?;
1538        }
1539
1540        for settings_file in worktree.settings_files {
1541            session.peer.send(
1542                session.connection_id,
1543                proto::UpdateWorktreeSettings {
1544                    project_id: project_id.to_proto(),
1545                    worktree_id: worktree.id,
1546                    path: settings_file.path,
1547                    content: Some(settings_file.content),
1548                },
1549            )?;
1550        }
1551    }
1552
1553    for language_server in &project.language_servers {
1554        session.peer.send(
1555            session.connection_id,
1556            proto::UpdateLanguageServer {
1557                project_id: project_id.to_proto(),
1558                language_server_id: language_server.id,
1559                variant: Some(
1560                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1561                        proto::LspDiskBasedDiagnosticsUpdated {},
1562                    ),
1563                ),
1564            },
1565        )?;
1566    }
1567
1568    Ok(())
1569}
1570
1571async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1572    let sender_id = session.connection_id;
1573    let project_id = ProjectId::from_proto(request.project_id);
1574
1575    let (room, project) = &*session
1576        .db()
1577        .await
1578        .leave_project(project_id, sender_id)
1579        .await?;
1580    tracing::info!(
1581        %project_id,
1582        host_user_id = %project.host_user_id,
1583        host_connection_id = %project.host_connection_id,
1584        "leave project"
1585    );
1586
1587    project_left(&project, &session);
1588    room_updated(&room, &session.peer);
1589
1590    Ok(())
1591}
1592
1593async fn update_project(
1594    request: proto::UpdateProject,
1595    response: Response<proto::UpdateProject>,
1596    session: Session,
1597) -> Result<()> {
1598    let project_id = ProjectId::from_proto(request.project_id);
1599    let (room, guest_connection_ids) = &*session
1600        .db()
1601        .await
1602        .update_project(project_id, session.connection_id, &request.worktrees)
1603        .await?;
1604    broadcast(
1605        Some(session.connection_id),
1606        guest_connection_ids.iter().copied(),
1607        |connection_id| {
1608            session
1609                .peer
1610                .forward_send(session.connection_id, connection_id, request.clone())
1611        },
1612    );
1613    room_updated(&room, &session.peer);
1614    response.send(proto::Ack {})?;
1615
1616    Ok(())
1617}
1618
1619async fn update_worktree(
1620    request: proto::UpdateWorktree,
1621    response: Response<proto::UpdateWorktree>,
1622    session: Session,
1623) -> Result<()> {
1624    let guest_connection_ids = session
1625        .db()
1626        .await
1627        .update_worktree(&request, session.connection_id)
1628        .await?;
1629
1630    broadcast(
1631        Some(session.connection_id),
1632        guest_connection_ids.iter().copied(),
1633        |connection_id| {
1634            session
1635                .peer
1636                .forward_send(session.connection_id, connection_id, request.clone())
1637        },
1638    );
1639    response.send(proto::Ack {})?;
1640    Ok(())
1641}
1642
1643async fn update_diagnostic_summary(
1644    message: proto::UpdateDiagnosticSummary,
1645    session: Session,
1646) -> Result<()> {
1647    let guest_connection_ids = session
1648        .db()
1649        .await
1650        .update_diagnostic_summary(&message, session.connection_id)
1651        .await?;
1652
1653    broadcast(
1654        Some(session.connection_id),
1655        guest_connection_ids.iter().copied(),
1656        |connection_id| {
1657            session
1658                .peer
1659                .forward_send(session.connection_id, connection_id, message.clone())
1660        },
1661    );
1662
1663    Ok(())
1664}
1665
1666async fn update_worktree_settings(
1667    message: proto::UpdateWorktreeSettings,
1668    session: Session,
1669) -> Result<()> {
1670    let guest_connection_ids = session
1671        .db()
1672        .await
1673        .update_worktree_settings(&message, session.connection_id)
1674        .await?;
1675
1676    broadcast(
1677        Some(session.connection_id),
1678        guest_connection_ids.iter().copied(),
1679        |connection_id| {
1680            session
1681                .peer
1682                .forward_send(session.connection_id, connection_id, message.clone())
1683        },
1684    );
1685
1686    Ok(())
1687}
1688
1689async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1690    broadcast_project_message(request.project_id, request, session).await
1691}
1692
1693async fn start_language_server(
1694    request: proto::StartLanguageServer,
1695    session: Session,
1696) -> Result<()> {
1697    let guest_connection_ids = session
1698        .db()
1699        .await
1700        .start_language_server(&request, session.connection_id)
1701        .await?;
1702
1703    broadcast(
1704        Some(session.connection_id),
1705        guest_connection_ids.iter().copied(),
1706        |connection_id| {
1707            session
1708                .peer
1709                .forward_send(session.connection_id, connection_id, request.clone())
1710        },
1711    );
1712    Ok(())
1713}
1714
1715async fn update_language_server(
1716    request: proto::UpdateLanguageServer,
1717    session: Session,
1718) -> Result<()> {
1719    session.executor.record_backtrace();
1720    let project_id = ProjectId::from_proto(request.project_id);
1721    let project_connection_ids = session
1722        .db()
1723        .await
1724        .project_connection_ids(project_id, session.connection_id)
1725        .await?;
1726    broadcast(
1727        Some(session.connection_id),
1728        project_connection_ids.iter().copied(),
1729        |connection_id| {
1730            session
1731                .peer
1732                .forward_send(session.connection_id, connection_id, request.clone())
1733        },
1734    );
1735    Ok(())
1736}
1737
1738async fn forward_project_request<T>(
1739    request: T,
1740    response: Response<T>,
1741    session: Session,
1742) -> Result<()>
1743where
1744    T: EntityMessage + RequestMessage,
1745{
1746    session.executor.record_backtrace();
1747    let project_id = ProjectId::from_proto(request.remote_entity_id());
1748    let host_connection_id = {
1749        let collaborators = session
1750            .db()
1751            .await
1752            .project_collaborators(project_id, session.connection_id)
1753            .await?;
1754        collaborators
1755            .iter()
1756            .find(|collaborator| collaborator.is_host)
1757            .ok_or_else(|| anyhow!("host not found"))?
1758            .connection_id
1759    };
1760
1761    let payload = session
1762        .peer
1763        .forward_request(session.connection_id, host_connection_id, request)
1764        .await?;
1765
1766    response.send(payload)?;
1767    Ok(())
1768}
1769
1770async fn create_buffer_for_peer(
1771    request: proto::CreateBufferForPeer,
1772    session: Session,
1773) -> Result<()> {
1774    session.executor.record_backtrace();
1775    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1776    session
1777        .peer
1778        .forward_send(session.connection_id, peer_id.into(), request)?;
1779    Ok(())
1780}
1781
1782async fn update_buffer(
1783    request: proto::UpdateBuffer,
1784    response: Response<proto::UpdateBuffer>,
1785    session: Session,
1786) -> Result<()> {
1787    session.executor.record_backtrace();
1788    let project_id = ProjectId::from_proto(request.project_id);
1789    let mut guest_connection_ids;
1790    let mut host_connection_id = None;
1791    {
1792        let collaborators = session
1793            .db()
1794            .await
1795            .project_collaborators(project_id, session.connection_id)
1796            .await?;
1797        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1798        for collaborator in collaborators.iter() {
1799            if collaborator.is_host {
1800                host_connection_id = Some(collaborator.connection_id);
1801            } else {
1802                guest_connection_ids.push(collaborator.connection_id);
1803            }
1804        }
1805    }
1806    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1807
1808    session.executor.record_backtrace();
1809    broadcast(
1810        Some(session.connection_id),
1811        guest_connection_ids,
1812        |connection_id| {
1813            session
1814                .peer
1815                .forward_send(session.connection_id, connection_id, request.clone())
1816        },
1817    );
1818    if host_connection_id != session.connection_id {
1819        session
1820            .peer
1821            .forward_request(session.connection_id, host_connection_id, request.clone())
1822            .await?;
1823    }
1824
1825    response.send(proto::Ack {})?;
1826    Ok(())
1827}
1828
1829async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1830    let project_id = ProjectId::from_proto(request.project_id);
1831    let project_connection_ids = session
1832        .db()
1833        .await
1834        .project_connection_ids(project_id, session.connection_id)
1835        .await?;
1836
1837    broadcast(
1838        Some(session.connection_id),
1839        project_connection_ids.iter().copied(),
1840        |connection_id| {
1841            session
1842                .peer
1843                .forward_send(session.connection_id, connection_id, request.clone())
1844        },
1845    );
1846    Ok(())
1847}
1848
1849async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1850    let project_id = ProjectId::from_proto(request.project_id);
1851    let project_connection_ids = session
1852        .db()
1853        .await
1854        .project_connection_ids(project_id, session.connection_id)
1855        .await?;
1856    broadcast(
1857        Some(session.connection_id),
1858        project_connection_ids.iter().copied(),
1859        |connection_id| {
1860            session
1861                .peer
1862                .forward_send(session.connection_id, connection_id, request.clone())
1863        },
1864    );
1865    Ok(())
1866}
1867
1868async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1869    broadcast_project_message(request.project_id, request, session).await
1870}
1871
1872async fn broadcast_project_message<T: EnvelopedMessage>(
1873    project_id: u64,
1874    request: T,
1875    session: Session,
1876) -> Result<()> {
1877    let project_id = ProjectId::from_proto(project_id);
1878    let project_connection_ids = session
1879        .db()
1880        .await
1881        .project_connection_ids(project_id, session.connection_id)
1882        .await?;
1883    broadcast(
1884        Some(session.connection_id),
1885        project_connection_ids.iter().copied(),
1886        |connection_id| {
1887            session
1888                .peer
1889                .forward_send(session.connection_id, connection_id, request.clone())
1890        },
1891    );
1892    Ok(())
1893}
1894
1895async fn follow(
1896    request: proto::Follow,
1897    response: Response<proto::Follow>,
1898    session: Session,
1899) -> Result<()> {
1900    let room_id = RoomId::from_proto(request.room_id);
1901    let project_id = request.project_id.map(ProjectId::from_proto);
1902    let leader_id = request
1903        .leader_id
1904        .ok_or_else(|| anyhow!("invalid leader id"))?
1905        .into();
1906    let follower_id = session.connection_id;
1907
1908    session
1909        .db()
1910        .await
1911        .check_room_participants(room_id, leader_id, session.connection_id)
1912        .await?;
1913
1914    let response_payload = session
1915        .peer
1916        .forward_request(session.connection_id, leader_id, request)
1917        .await?;
1918    response.send(response_payload)?;
1919
1920    if let Some(project_id) = project_id {
1921        let room = session
1922            .db()
1923            .await
1924            .follow(room_id, project_id, leader_id, follower_id)
1925            .await?;
1926        room_updated(&room, &session.peer);
1927    }
1928
1929    Ok(())
1930}
1931
1932async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1933    let room_id = RoomId::from_proto(request.room_id);
1934    let project_id = request.project_id.map(ProjectId::from_proto);
1935    let leader_id = request
1936        .leader_id
1937        .ok_or_else(|| anyhow!("invalid leader id"))?
1938        .into();
1939    let follower_id = session.connection_id;
1940
1941    session
1942        .db()
1943        .await
1944        .check_room_participants(room_id, leader_id, session.connection_id)
1945        .await?;
1946
1947    session
1948        .peer
1949        .forward_send(session.connection_id, leader_id, request)?;
1950
1951    if let Some(project_id) = project_id {
1952        let room = session
1953            .db()
1954            .await
1955            .unfollow(room_id, project_id, leader_id, follower_id)
1956            .await?;
1957        room_updated(&room, &session.peer);
1958    }
1959
1960    Ok(())
1961}
1962
1963async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1964    let room_id = RoomId::from_proto(request.room_id);
1965    let database = session.db.lock().await;
1966
1967    let connection_ids = if let Some(project_id) = request.project_id {
1968        let project_id = ProjectId::from_proto(project_id);
1969        database
1970            .project_connection_ids(project_id, session.connection_id)
1971            .await?
1972    } else {
1973        database
1974            .room_connection_ids(room_id, session.connection_id)
1975            .await?
1976    };
1977
1978    // For now, don't send view update messages back to that view's current leader.
1979    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1980        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1981        _ => None,
1982    });
1983
1984    for follower_peer_id in request.follower_ids.iter().copied() {
1985        let follower_connection_id = follower_peer_id.into();
1986        if Some(follower_peer_id) != connection_id_to_omit
1987            && connection_ids.contains(&follower_connection_id)
1988        {
1989            session.peer.forward_send(
1990                session.connection_id,
1991                follower_connection_id,
1992                request.clone(),
1993            )?;
1994        }
1995    }
1996    Ok(())
1997}
1998
1999async fn get_users(
2000    request: proto::GetUsers,
2001    response: Response<proto::GetUsers>,
2002    session: Session,
2003) -> Result<()> {
2004    let user_ids = request
2005        .user_ids
2006        .into_iter()
2007        .map(UserId::from_proto)
2008        .collect();
2009    let users = session
2010        .db()
2011        .await
2012        .get_users_by_ids(user_ids)
2013        .await?
2014        .into_iter()
2015        .map(|user| proto::User {
2016            id: user.id.to_proto(),
2017            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2018            github_login: user.github_login,
2019        })
2020        .collect();
2021    response.send(proto::UsersResponse { users })?;
2022    Ok(())
2023}
2024
2025async fn fuzzy_search_users(
2026    request: proto::FuzzySearchUsers,
2027    response: Response<proto::FuzzySearchUsers>,
2028    session: Session,
2029) -> Result<()> {
2030    let query = request.query;
2031    let users = match query.len() {
2032        0 => vec![],
2033        1 | 2 => session
2034            .db()
2035            .await
2036            .get_user_by_github_login(&query)
2037            .await?
2038            .into_iter()
2039            .collect(),
2040        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2041    };
2042    let users = users
2043        .into_iter()
2044        .filter(|user| user.id != session.user_id)
2045        .map(|user| proto::User {
2046            id: user.id.to_proto(),
2047            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2048            github_login: user.github_login,
2049        })
2050        .collect();
2051    response.send(proto::UsersResponse { users })?;
2052    Ok(())
2053}
2054
2055async fn request_contact(
2056    request: proto::RequestContact,
2057    response: Response<proto::RequestContact>,
2058    session: Session,
2059) -> Result<()> {
2060    let requester_id = session.user_id;
2061    let responder_id = UserId::from_proto(request.responder_id);
2062    if requester_id == responder_id {
2063        return Err(anyhow!("cannot add yourself as a contact"))?;
2064    }
2065
2066    session
2067        .db()
2068        .await
2069        .send_contact_request(requester_id, responder_id)
2070        .await?;
2071
2072    // Update outgoing contact requests of requester
2073    let mut update = proto::UpdateContacts::default();
2074    update.outgoing_requests.push(responder_id.to_proto());
2075    for connection_id in session
2076        .connection_pool()
2077        .await
2078        .user_connection_ids(requester_id)
2079    {
2080        session.peer.send(connection_id, update.clone())?;
2081    }
2082
2083    // Update incoming contact requests of responder
2084    let mut update = proto::UpdateContacts::default();
2085    update
2086        .incoming_requests
2087        .push(proto::IncomingContactRequest {
2088            requester_id: requester_id.to_proto(),
2089            should_notify: true,
2090        });
2091    for connection_id in session
2092        .connection_pool()
2093        .await
2094        .user_connection_ids(responder_id)
2095    {
2096        session.peer.send(connection_id, update.clone())?;
2097    }
2098
2099    response.send(proto::Ack {})?;
2100    Ok(())
2101}
2102
2103async fn respond_to_contact_request(
2104    request: proto::RespondToContactRequest,
2105    response: Response<proto::RespondToContactRequest>,
2106    session: Session,
2107) -> Result<()> {
2108    let responder_id = session.user_id;
2109    let requester_id = UserId::from_proto(request.requester_id);
2110    let db = session.db().await;
2111    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2112        db.dismiss_contact_notification(responder_id, requester_id)
2113            .await?;
2114    } else {
2115        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2116
2117        db.respond_to_contact_request(responder_id, requester_id, accept)
2118            .await?;
2119        let requester_busy = db.is_user_busy(requester_id).await?;
2120        let responder_busy = db.is_user_busy(responder_id).await?;
2121
2122        let pool = session.connection_pool().await;
2123        // Update responder with new contact
2124        let mut update = proto::UpdateContacts::default();
2125        if accept {
2126            update
2127                .contacts
2128                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2129        }
2130        update
2131            .remove_incoming_requests
2132            .push(requester_id.to_proto());
2133        for connection_id in pool.user_connection_ids(responder_id) {
2134            session.peer.send(connection_id, update.clone())?;
2135        }
2136
2137        // Update requester with new contact
2138        let mut update = proto::UpdateContacts::default();
2139        if accept {
2140            update
2141                .contacts
2142                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2143        }
2144        update
2145            .remove_outgoing_requests
2146            .push(responder_id.to_proto());
2147        for connection_id in pool.user_connection_ids(requester_id) {
2148            session.peer.send(connection_id, update.clone())?;
2149        }
2150    }
2151
2152    response.send(proto::Ack {})?;
2153    Ok(())
2154}
2155
2156async fn remove_contact(
2157    request: proto::RemoveContact,
2158    response: Response<proto::RemoveContact>,
2159    session: Session,
2160) -> Result<()> {
2161    let requester_id = session.user_id;
2162    let responder_id = UserId::from_proto(request.user_id);
2163    let db = session.db().await;
2164    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2165
2166    let pool = session.connection_pool().await;
2167    // Update outgoing contact requests of requester
2168    let mut update = proto::UpdateContacts::default();
2169    if contact_accepted {
2170        update.remove_contacts.push(responder_id.to_proto());
2171    } else {
2172        update
2173            .remove_outgoing_requests
2174            .push(responder_id.to_proto());
2175    }
2176    for connection_id in pool.user_connection_ids(requester_id) {
2177        session.peer.send(connection_id, update.clone())?;
2178    }
2179
2180    // Update incoming contact requests of responder
2181    let mut update = proto::UpdateContacts::default();
2182    if contact_accepted {
2183        update.remove_contacts.push(requester_id.to_proto());
2184    } else {
2185        update
2186            .remove_incoming_requests
2187            .push(requester_id.to_proto());
2188    }
2189    for connection_id in pool.user_connection_ids(responder_id) {
2190        session.peer.send(connection_id, update.clone())?;
2191    }
2192
2193    response.send(proto::Ack {})?;
2194    Ok(())
2195}
2196
2197async fn create_channel(
2198    request: proto::CreateChannel,
2199    response: Response<proto::CreateChannel>,
2200    session: Session,
2201) -> Result<()> {
2202    let db = session.db().await;
2203
2204    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2205    let id = db
2206        .create_channel(&request.name, parent_id, session.user_id)
2207        .await?;
2208
2209    response.send(proto::CreateChannelResponse {
2210        channel: Some(proto::Channel {
2211            id: id.to_proto(),
2212            name: request.name,
2213            visibility: proto::ChannelVisibility::Members as i32,
2214            role: proto::ChannelRole::Admin.into(),
2215        }),
2216        parent_id: request.parent_id,
2217    })?;
2218
2219    let Some(parent_id) = parent_id else {
2220        return Ok(());
2221    };
2222
2223    let updates = db
2224        .participants_to_notify_for_channel_change(parent_id, session.user_id)
2225        .await?;
2226
2227    let connection_pool = session.connection_pool().await;
2228    for (user_id, channels) in updates {
2229        let update = build_initial_channels_update(channels, vec![]);
2230        for connection_id in connection_pool.user_connection_ids(user_id) {
2231            if user_id == session.user_id {
2232                continue;
2233            }
2234            session.peer.send(connection_id, update.clone())?;
2235        }
2236    }
2237
2238    Ok(())
2239}
2240
2241async fn delete_channel(
2242    request: proto::DeleteChannel,
2243    response: Response<proto::DeleteChannel>,
2244    session: Session,
2245) -> Result<()> {
2246    let db = session.db().await;
2247
2248    let channel_id = request.channel_id;
2249    let (removed_channels, member_ids) = db
2250        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2251        .await?;
2252    response.send(proto::Ack {})?;
2253
2254    // Notify members of removed channels
2255    let mut update = proto::UpdateChannels::default();
2256    update
2257        .delete_channels
2258        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2259
2260    let connection_pool = session.connection_pool().await;
2261    for member_id in member_ids {
2262        for connection_id in connection_pool.user_connection_ids(member_id) {
2263            session.peer.send(connection_id, update.clone())?;
2264        }
2265    }
2266
2267    Ok(())
2268}
2269
2270async fn invite_channel_member(
2271    request: proto::InviteChannelMember,
2272    response: Response<proto::InviteChannelMember>,
2273    session: Session,
2274) -> Result<()> {
2275    let db = session.db().await;
2276    let channel_id = ChannelId::from_proto(request.channel_id);
2277    let invitee_id = UserId::from_proto(request.user_id);
2278    db.invite_channel_member(
2279        channel_id,
2280        invitee_id,
2281        session.user_id,
2282        request.role().into(),
2283    )
2284    .await?;
2285
2286    let channel = db.get_channel(channel_id, session.user_id).await?;
2287
2288    let mut update = proto::UpdateChannels::default();
2289    update.channel_invitations.push(proto::Channel {
2290        id: channel.id.to_proto(),
2291        visibility: channel.visibility.into(),
2292        name: channel.name,
2293        role: request.role().into(),
2294    });
2295    for connection_id in session
2296        .connection_pool()
2297        .await
2298        .user_connection_ids(invitee_id)
2299    {
2300        session.peer.send(connection_id, update.clone())?;
2301    }
2302
2303    response.send(proto::Ack {})?;
2304    Ok(())
2305}
2306
2307async fn remove_channel_member(
2308    request: proto::RemoveChannelMember,
2309    response: Response<proto::RemoveChannelMember>,
2310    session: Session,
2311) -> Result<()> {
2312    let db = session.db().await;
2313    let channel_id = ChannelId::from_proto(request.channel_id);
2314    let member_id = UserId::from_proto(request.user_id);
2315
2316    db.remove_channel_member(channel_id, member_id, session.user_id)
2317        .await?;
2318
2319    let mut update = proto::UpdateChannels::default();
2320    update.delete_channels.push(channel_id.to_proto());
2321
2322    for connection_id in session
2323        .connection_pool()
2324        .await
2325        .user_connection_ids(member_id)
2326    {
2327        session.peer.send(connection_id, update.clone())?;
2328    }
2329
2330    response.send(proto::Ack {})?;
2331    Ok(())
2332}
2333
2334async fn set_channel_visibility(
2335    request: proto::SetChannelVisibility,
2336    response: Response<proto::SetChannelVisibility>,
2337    session: Session,
2338) -> Result<()> {
2339    let db = session.db().await;
2340    let channel_id = ChannelId::from_proto(request.channel_id);
2341    let visibility = request.visibility().into();
2342
2343    let previous_members = db
2344        .get_channel_participant_details(channel_id, session.user_id)
2345        .await?;
2346
2347    db.set_channel_visibility(channel_id, visibility, session.user_id)
2348        .await?;
2349
2350    let mut updates: HashMap<UserId, ChannelsForUser> = db
2351        .participants_to_notify_for_channel_change(channel_id, session.user_id)
2352        .await?
2353        .into_iter()
2354        .collect();
2355
2356    let mut participants_who_lost_access: HashSet<UserId> = HashSet::default();
2357    match visibility {
2358        ChannelVisibility::Members => {
2359            for member in previous_members {
2360                if ChannelRole::from(member.role()).can_only_see_public_descendants() {
2361                    participants_who_lost_access.insert(UserId::from_proto(member.user_id));
2362                }
2363            }
2364        }
2365        ChannelVisibility::Public => {
2366            if let Some(public_parent_id) = db.public_parent_channel_id(channel_id).await? {
2367                let parent_updates = db
2368                    .participants_to_notify_for_channel_change(public_parent_id, session.user_id)
2369                    .await?;
2370
2371                for (user_id, channels) in parent_updates {
2372                    updates.insert(user_id, channels);
2373                }
2374            }
2375        }
2376    }
2377
2378    let connection_pool = session.connection_pool().await;
2379    for (user_id, channels) in updates {
2380        let update = build_initial_channels_update(channels, vec![]);
2381        for connection_id in connection_pool.user_connection_ids(user_id) {
2382            session.peer.send(connection_id, update.clone())?;
2383        }
2384    }
2385    for user_id in participants_who_lost_access {
2386        let update = proto::UpdateChannels {
2387            delete_channels: vec![channel_id.to_proto()],
2388            ..Default::default()
2389        };
2390        for connection_id in connection_pool.user_connection_ids(user_id) {
2391            session.peer.send(connection_id, update.clone())?;
2392        }
2393    }
2394
2395    response.send(proto::Ack {})?;
2396    Ok(())
2397}
2398
2399async fn set_channel_member_role(
2400    request: proto::SetChannelMemberRole,
2401    response: Response<proto::SetChannelMemberRole>,
2402    session: Session,
2403) -> Result<()> {
2404    let db = session.db().await;
2405    let channel_id = ChannelId::from_proto(request.channel_id);
2406    let member_id = UserId::from_proto(request.user_id);
2407    let channel_member = db
2408        .set_channel_member_role(
2409            channel_id,
2410            session.user_id,
2411            member_id,
2412            request.role().into(),
2413        )
2414        .await?;
2415
2416    let mut update = proto::UpdateChannels::default();
2417    if channel_member.accepted {
2418        let channels = db.get_channel_for_user(channel_id, member_id).await?;
2419        update = build_initial_channels_update(channels, vec![]);
2420    } else {
2421        let channel = db.get_channel(channel_id, session.user_id).await?;
2422        update.channel_invitations.push(proto::Channel {
2423            id: channel_id.to_proto(),
2424            visibility: channel.visibility.into(),
2425            name: channel.name,
2426            role: request.role().into(),
2427        });
2428    }
2429
2430    for connection_id in session
2431        .connection_pool()
2432        .await
2433        .user_connection_ids(member_id)
2434    {
2435        session.peer.send(connection_id, update.clone())?;
2436    }
2437
2438    response.send(proto::Ack {})?;
2439    Ok(())
2440}
2441
2442async fn rename_channel(
2443    request: proto::RenameChannel,
2444    response: Response<proto::RenameChannel>,
2445    session: Session,
2446) -> Result<()> {
2447    let db = session.db().await;
2448    let channel_id = ChannelId::from_proto(request.channel_id);
2449    let channel = db
2450        .rename_channel(channel_id, session.user_id, &request.name)
2451        .await?;
2452
2453    response.send(proto::RenameChannelResponse {
2454        channel: Some(proto::Channel {
2455            id: channel.id.to_proto(),
2456            name: channel.name.clone(),
2457            visibility: channel.visibility.into(),
2458            role: proto::ChannelRole::Admin.into(),
2459        }),
2460    })?;
2461
2462    let members = db
2463        .get_channel_participant_details(channel_id, session.user_id)
2464        .await?;
2465
2466    let connection_pool = session.connection_pool().await;
2467    for member in members {
2468        for connection_id in connection_pool.user_connection_ids(UserId::from_proto(member.user_id))
2469        {
2470            let mut update = proto::UpdateChannels::default();
2471            update.channels.push(proto::Channel {
2472                id: channel.id.to_proto(),
2473                name: channel.name.clone(),
2474                visibility: channel.visibility.into(),
2475                role: member.role.into(),
2476            });
2477
2478            session.peer.send(connection_id, update.clone())?;
2479        }
2480    }
2481
2482    Ok(())
2483}
2484
2485// TODO: Implement in terms of symlinks
2486// Current behavior of this is more like 'Move root channel'
2487async fn link_channel(
2488    request: proto::LinkChannel,
2489    response: Response<proto::LinkChannel>,
2490    session: Session,
2491) -> Result<()> {
2492    let db = session.db().await;
2493    let channel_id = ChannelId::from_proto(request.channel_id);
2494    let to = ChannelId::from_proto(request.to);
2495
2496    // TODO: Remove this restriction once we have symlinks
2497    db.assert_root_channel(channel_id).await?;
2498
2499    db.link_channel(session.user_id, channel_id, to).await?;
2500
2501    let member_updates = db
2502        .participants_to_notify_for_channel_change(to, session.user_id)
2503        .await?;
2504
2505    dbg!(&member_updates);
2506
2507    let connection_pool = session.connection_pool().await;
2508
2509    for (member_id, channels) in member_updates {
2510        let update = build_initial_channels_update(channels, vec![]);
2511        for connection_id in connection_pool.user_connection_ids(member_id) {
2512            session.peer.send(connection_id, update.clone())?;
2513        }
2514    }
2515
2516    response.send(Ack {})?;
2517
2518    Ok(())
2519}
2520
2521// TODO: Implement in terms of symlinks
2522async fn unlink_channel(
2523    _request: proto::UnlinkChannel,
2524    _response: Response<proto::UnlinkChannel>,
2525    _session: Session,
2526) -> Result<()> {
2527    Err(anyhow!("unimplemented").into())
2528}
2529
2530async fn move_channel(
2531    request: proto::MoveChannel,
2532    response: Response<proto::MoveChannel>,
2533    session: Session,
2534) -> Result<()> {
2535    let db = session.db().await;
2536    let channel_id = ChannelId::from_proto(request.channel_id);
2537    let from_parent = ChannelId::from_proto(request.from);
2538    let to = ChannelId::from_proto(request.to);
2539
2540    let previous_participants = db
2541        .get_channel_participant_details(channel_id, session.user_id)
2542        .await?;
2543
2544    debug_assert_eq!(db.parent_channel_id(channel_id).await?, Some(from_parent));
2545
2546    let channels_to_send = db
2547        .move_channel(session.user_id, channel_id, from_parent, to)
2548        .await?;
2549
2550    if channels_to_send.is_empty() {
2551        response.send(Ack {})?;
2552        return Ok(());
2553    }
2554
2555    let updates = db
2556        .participants_to_notify_for_channel_change(to, session.user_id)
2557        .await?;
2558
2559    let mut participants_who_lost_access: HashSet<UserId> = HashSet::default();
2560    let mut channels_to_delete = db.get_channel_descendant_ids(channel_id).await?;
2561    channels_to_delete.insert(channel_id);
2562
2563    for previous_participant in previous_participants.iter() {
2564        let user_id = UserId::from_proto(previous_participant.user_id);
2565        if previous_participant.kind() == proto::channel_member::Kind::AncestorMember {
2566            participants_who_lost_access.insert(user_id);
2567        }
2568    }
2569
2570    let connection_pool = session.connection_pool().await;
2571    for (user_id, channels) in updates {
2572        let mut update = build_initial_channels_update(channels, vec![]);
2573        update.delete_channels = channels_to_delete
2574            .iter()
2575            .map(|channel_id| channel_id.to_proto())
2576            .collect();
2577        participants_who_lost_access.remove(&user_id);
2578        for connection_id in connection_pool.user_connection_ids(user_id) {
2579            session.peer.send(connection_id, update.clone())?;
2580        }
2581    }
2582
2583    for user_id in participants_who_lost_access {
2584        let update = proto::UpdateChannels {
2585            delete_channels: channels_to_delete
2586                .iter()
2587                .map(|channel_id| channel_id.to_proto())
2588                .collect(),
2589            ..Default::default()
2590        };
2591        for connection_id in connection_pool.user_connection_ids(user_id) {
2592            session.peer.send(connection_id, update.clone())?;
2593        }
2594    }
2595
2596    response.send(Ack {})?;
2597
2598    Ok(())
2599}
2600
2601async fn get_channel_members(
2602    request: proto::GetChannelMembers,
2603    response: Response<proto::GetChannelMembers>,
2604    session: Session,
2605) -> Result<()> {
2606    let db = session.db().await;
2607    let channel_id = ChannelId::from_proto(request.channel_id);
2608    let members = db
2609        .get_channel_participant_details(channel_id, session.user_id)
2610        .await?;
2611    response.send(proto::GetChannelMembersResponse { members })?;
2612    Ok(())
2613}
2614
2615async fn respond_to_channel_invite(
2616    request: proto::RespondToChannelInvite,
2617    response: Response<proto::RespondToChannelInvite>,
2618    session: Session,
2619) -> Result<()> {
2620    let db = session.db().await;
2621    let channel_id = ChannelId::from_proto(request.channel_id);
2622    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2623        .await?;
2624
2625    if request.accept {
2626        channel_membership_updated(db, channel_id, &session).await?;
2627    } else {
2628        let mut update = proto::UpdateChannels::default();
2629        update
2630            .remove_channel_invitations
2631            .push(channel_id.to_proto());
2632        session.peer.send(session.connection_id, update)?;
2633    }
2634    response.send(proto::Ack {})?;
2635
2636    Ok(())
2637}
2638
2639async fn channel_membership_updated(
2640    db: tokio::sync::MutexGuard<'_, DbHandle>,
2641    channel_id: ChannelId,
2642    session: &Session,
2643) -> Result<(), crate::Error> {
2644    let mut update = proto::UpdateChannels::default();
2645    update
2646        .remove_channel_invitations
2647        .push(channel_id.to_proto());
2648
2649    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2650    update.channels.extend(
2651        result
2652            .channels
2653            .channels
2654            .into_iter()
2655            .map(|channel| proto::Channel {
2656                id: channel.id.to_proto(),
2657                visibility: channel.visibility.into(),
2658                role: channel.role.into(),
2659                name: channel.name,
2660            }),
2661    );
2662    update.unseen_channel_messages = result.channel_messages;
2663    update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2664    update.insert_edge = result.channels.edges;
2665    update
2666        .channel_participants
2667        .extend(
2668            result
2669                .channel_participants
2670                .into_iter()
2671                .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2672                    channel_id: channel_id.to_proto(),
2673                    participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2674                }),
2675        );
2676    session.peer.send(session.connection_id, update)?;
2677    Ok(())
2678}
2679
2680async fn join_channel(
2681    request: proto::JoinChannel,
2682    response: Response<proto::JoinChannel>,
2683    session: Session,
2684) -> Result<()> {
2685    let channel_id = ChannelId::from_proto(request.channel_id);
2686    join_channel_internal(channel_id, Box::new(response), session).await
2687}
2688
2689trait JoinChannelInternalResponse {
2690    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2691}
2692impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2693    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2694        Response::<proto::JoinChannel>::send(self, result)
2695    }
2696}
2697impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2698    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2699        Response::<proto::JoinRoom>::send(self, result)
2700    }
2701}
2702
2703async fn join_channel_internal(
2704    channel_id: ChannelId,
2705    response: Box<impl JoinChannelInternalResponse>,
2706    session: Session,
2707) -> Result<()> {
2708    let joined_room = {
2709        leave_room_for_session(&session).await?;
2710        let db = session.db().await;
2711
2712        let (joined_room, joined_channel) = db
2713            .join_channel(
2714                channel_id,
2715                session.user_id,
2716                session.connection_id,
2717                RELEASE_CHANNEL_NAME.as_str(),
2718            )
2719            .await?;
2720
2721        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2722            let token = live_kit
2723                .room_token(
2724                    &joined_room.room.live_kit_room,
2725                    &session.user_id.to_string(),
2726                )
2727                .trace_err()?;
2728
2729            Some(LiveKitConnectionInfo {
2730                server_url: live_kit.url().into(),
2731                token,
2732            })
2733        });
2734
2735        response.send(proto::JoinRoomResponse {
2736            room: Some(joined_room.room.clone()),
2737            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2738            live_kit_connection_info,
2739        })?;
2740
2741        if let Some(joined_channel) = joined_channel {
2742            channel_membership_updated(db, joined_channel, &session).await?
2743        }
2744
2745        room_updated(&joined_room.room, &session.peer);
2746
2747        joined_room
2748    };
2749
2750    channel_updated(
2751        channel_id,
2752        &joined_room.room,
2753        &joined_room.channel_members,
2754        &session.peer,
2755        &*session.connection_pool().await,
2756    );
2757
2758    update_user_contacts(session.user_id, &session).await?;
2759    Ok(())
2760}
2761
2762async fn join_channel_buffer(
2763    request: proto::JoinChannelBuffer,
2764    response: Response<proto::JoinChannelBuffer>,
2765    session: Session,
2766) -> Result<()> {
2767    let db = session.db().await;
2768    let channel_id = ChannelId::from_proto(request.channel_id);
2769
2770    let open_response = db
2771        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2772        .await?;
2773
2774    let collaborators = open_response.collaborators.clone();
2775    response.send(open_response)?;
2776
2777    let update = UpdateChannelBufferCollaborators {
2778        channel_id: channel_id.to_proto(),
2779        collaborators: collaborators.clone(),
2780    };
2781    channel_buffer_updated(
2782        session.connection_id,
2783        collaborators
2784            .iter()
2785            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2786        &update,
2787        &session.peer,
2788    );
2789
2790    Ok(())
2791}
2792
2793async fn update_channel_buffer(
2794    request: proto::UpdateChannelBuffer,
2795    session: Session,
2796) -> Result<()> {
2797    let db = session.db().await;
2798    let channel_id = ChannelId::from_proto(request.channel_id);
2799
2800    let (collaborators, non_collaborators, epoch, version) = db
2801        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2802        .await?;
2803
2804    channel_buffer_updated(
2805        session.connection_id,
2806        collaborators,
2807        &proto::UpdateChannelBuffer {
2808            channel_id: channel_id.to_proto(),
2809            operations: request.operations,
2810        },
2811        &session.peer,
2812    );
2813
2814    let pool = &*session.connection_pool().await;
2815
2816    broadcast(
2817        None,
2818        non_collaborators
2819            .iter()
2820            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2821        |peer_id| {
2822            session.peer.send(
2823                peer_id.into(),
2824                proto::UpdateChannels {
2825                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2826                        channel_id: channel_id.to_proto(),
2827                        epoch: epoch as u64,
2828                        version: version.clone(),
2829                    }],
2830                    ..Default::default()
2831                },
2832            )
2833        },
2834    );
2835
2836    Ok(())
2837}
2838
2839async fn rejoin_channel_buffers(
2840    request: proto::RejoinChannelBuffers,
2841    response: Response<proto::RejoinChannelBuffers>,
2842    session: Session,
2843) -> Result<()> {
2844    let db = session.db().await;
2845    let buffers = db
2846        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2847        .await?;
2848
2849    for rejoined_buffer in &buffers {
2850        let collaborators_to_notify = rejoined_buffer
2851            .buffer
2852            .collaborators
2853            .iter()
2854            .filter_map(|c| Some(c.peer_id?.into()));
2855        channel_buffer_updated(
2856            session.connection_id,
2857            collaborators_to_notify,
2858            &proto::UpdateChannelBufferCollaborators {
2859                channel_id: rejoined_buffer.buffer.channel_id,
2860                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2861            },
2862            &session.peer,
2863        );
2864    }
2865
2866    response.send(proto::RejoinChannelBuffersResponse {
2867        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2868    })?;
2869
2870    Ok(())
2871}
2872
2873async fn leave_channel_buffer(
2874    request: proto::LeaveChannelBuffer,
2875    response: Response<proto::LeaveChannelBuffer>,
2876    session: Session,
2877) -> Result<()> {
2878    let db = session.db().await;
2879    let channel_id = ChannelId::from_proto(request.channel_id);
2880
2881    let left_buffer = db
2882        .leave_channel_buffer(channel_id, session.connection_id)
2883        .await?;
2884
2885    response.send(Ack {})?;
2886
2887    channel_buffer_updated(
2888        session.connection_id,
2889        left_buffer.connections,
2890        &proto::UpdateChannelBufferCollaborators {
2891            channel_id: channel_id.to_proto(),
2892            collaborators: left_buffer.collaborators,
2893        },
2894        &session.peer,
2895    );
2896
2897    Ok(())
2898}
2899
2900fn channel_buffer_updated<T: EnvelopedMessage>(
2901    sender_id: ConnectionId,
2902    collaborators: impl IntoIterator<Item = ConnectionId>,
2903    message: &T,
2904    peer: &Peer,
2905) {
2906    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2907        peer.send(peer_id.into(), message.clone())
2908    });
2909}
2910
2911async fn send_channel_message(
2912    request: proto::SendChannelMessage,
2913    response: Response<proto::SendChannelMessage>,
2914    session: Session,
2915) -> Result<()> {
2916    // Validate the message body.
2917    let body = request.body.trim().to_string();
2918    if body.len() > MAX_MESSAGE_LEN {
2919        return Err(anyhow!("message is too long"))?;
2920    }
2921    if body.is_empty() {
2922        return Err(anyhow!("message can't be blank"))?;
2923    }
2924
2925    let timestamp = OffsetDateTime::now_utc();
2926    let nonce = request
2927        .nonce
2928        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2929
2930    let channel_id = ChannelId::from_proto(request.channel_id);
2931    let (message_id, connection_ids, non_participants) = session
2932        .db()
2933        .await
2934        .create_channel_message(
2935            channel_id,
2936            session.user_id,
2937            &body,
2938            timestamp,
2939            nonce.clone().into(),
2940        )
2941        .await?;
2942    let message = proto::ChannelMessage {
2943        sender_id: session.user_id.to_proto(),
2944        id: message_id.to_proto(),
2945        body,
2946        timestamp: timestamp.unix_timestamp() as u64,
2947        nonce: Some(nonce),
2948    };
2949    broadcast(Some(session.connection_id), connection_ids, |connection| {
2950        session.peer.send(
2951            connection,
2952            proto::ChannelMessageSent {
2953                channel_id: channel_id.to_proto(),
2954                message: Some(message.clone()),
2955            },
2956        )
2957    });
2958    response.send(proto::SendChannelMessageResponse {
2959        message: Some(message),
2960    })?;
2961
2962    let pool = &*session.connection_pool().await;
2963    broadcast(
2964        None,
2965        non_participants
2966            .iter()
2967            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2968        |peer_id| {
2969            session.peer.send(
2970                peer_id.into(),
2971                proto::UpdateChannels {
2972                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2973                        channel_id: channel_id.to_proto(),
2974                        message_id: message_id.to_proto(),
2975                    }],
2976                    ..Default::default()
2977                },
2978            )
2979        },
2980    );
2981
2982    Ok(())
2983}
2984
2985async fn remove_channel_message(
2986    request: proto::RemoveChannelMessage,
2987    response: Response<proto::RemoveChannelMessage>,
2988    session: Session,
2989) -> Result<()> {
2990    let channel_id = ChannelId::from_proto(request.channel_id);
2991    let message_id = MessageId::from_proto(request.message_id);
2992    let connection_ids = session
2993        .db()
2994        .await
2995        .remove_channel_message(channel_id, message_id, session.user_id)
2996        .await?;
2997    broadcast(Some(session.connection_id), connection_ids, |connection| {
2998        session.peer.send(connection, request.clone())
2999    });
3000    response.send(proto::Ack {})?;
3001    Ok(())
3002}
3003
3004async fn acknowledge_channel_message(
3005    request: proto::AckChannelMessage,
3006    session: Session,
3007) -> Result<()> {
3008    let channel_id = ChannelId::from_proto(request.channel_id);
3009    let message_id = MessageId::from_proto(request.message_id);
3010    session
3011        .db()
3012        .await
3013        .observe_channel_message(channel_id, session.user_id, message_id)
3014        .await?;
3015    Ok(())
3016}
3017
3018async fn acknowledge_buffer_version(
3019    request: proto::AckBufferOperation,
3020    session: Session,
3021) -> Result<()> {
3022    let buffer_id = BufferId::from_proto(request.buffer_id);
3023    session
3024        .db()
3025        .await
3026        .observe_buffer_version(
3027            buffer_id,
3028            session.user_id,
3029            request.epoch as i32,
3030            &request.version,
3031        )
3032        .await?;
3033    Ok(())
3034}
3035
3036async fn join_channel_chat(
3037    request: proto::JoinChannelChat,
3038    response: Response<proto::JoinChannelChat>,
3039    session: Session,
3040) -> Result<()> {
3041    let channel_id = ChannelId::from_proto(request.channel_id);
3042
3043    let db = session.db().await;
3044    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3045        .await?;
3046    let messages = db
3047        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3048        .await?;
3049    response.send(proto::JoinChannelChatResponse {
3050        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3051        messages,
3052    })?;
3053    Ok(())
3054}
3055
3056async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3057    let channel_id = ChannelId::from_proto(request.channel_id);
3058    session
3059        .db()
3060        .await
3061        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3062        .await?;
3063    Ok(())
3064}
3065
3066async fn get_channel_messages(
3067    request: proto::GetChannelMessages,
3068    response: Response<proto::GetChannelMessages>,
3069    session: Session,
3070) -> Result<()> {
3071    let channel_id = ChannelId::from_proto(request.channel_id);
3072    let messages = session
3073        .db()
3074        .await
3075        .get_channel_messages(
3076            channel_id,
3077            session.user_id,
3078            MESSAGE_COUNT_PER_PAGE,
3079            Some(MessageId::from_proto(request.before_message_id)),
3080        )
3081        .await?;
3082    response.send(proto::GetChannelMessagesResponse {
3083        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3084        messages,
3085    })?;
3086    Ok(())
3087}
3088
3089async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3090    let project_id = ProjectId::from_proto(request.project_id);
3091    let project_connection_ids = session
3092        .db()
3093        .await
3094        .project_connection_ids(project_id, session.connection_id)
3095        .await?;
3096    broadcast(
3097        Some(session.connection_id),
3098        project_connection_ids.iter().copied(),
3099        |connection_id| {
3100            session
3101                .peer
3102                .forward_send(session.connection_id, connection_id, request.clone())
3103        },
3104    );
3105    Ok(())
3106}
3107
3108async fn get_private_user_info(
3109    _request: proto::GetPrivateUserInfo,
3110    response: Response<proto::GetPrivateUserInfo>,
3111    session: Session,
3112) -> Result<()> {
3113    let db = session.db().await;
3114
3115    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3116    let user = db
3117        .get_user_by_id(session.user_id)
3118        .await?
3119        .ok_or_else(|| anyhow!("user not found"))?;
3120    let flags = db.get_user_flags(session.user_id).await?;
3121
3122    response.send(proto::GetPrivateUserInfoResponse {
3123        metrics_id,
3124        staff: user.admin,
3125        flags,
3126    })?;
3127    Ok(())
3128}
3129
3130fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3131    match message {
3132        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3133        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3134        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3135        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3136        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3137            code: frame.code.into(),
3138            reason: frame.reason,
3139        })),
3140    }
3141}
3142
3143fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3144    match message {
3145        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3146        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3147        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3148        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3149        AxumMessage::Close(frame) => {
3150            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3151                code: frame.code.into(),
3152                reason: frame.reason,
3153            }))
3154        }
3155    }
3156}
3157
3158fn build_initial_channels_update(
3159    channels: ChannelsForUser,
3160    channel_invites: Vec<db::Channel>,
3161) -> proto::UpdateChannels {
3162    let mut update = proto::UpdateChannels::default();
3163
3164    for channel in channels.channels.channels {
3165        update.channels.push(proto::Channel {
3166            id: channel.id.to_proto(),
3167            name: channel.name,
3168            visibility: channel.visibility.into(),
3169            role: channel.role.into(),
3170        });
3171    }
3172
3173    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3174    update.unseen_channel_messages = channels.channel_messages;
3175    update.insert_edge = channels.channels.edges;
3176
3177    for (channel_id, participants) in channels.channel_participants {
3178        update
3179            .channel_participants
3180            .push(proto::ChannelParticipants {
3181                channel_id: channel_id.to_proto(),
3182                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3183            });
3184    }
3185
3186    for channel in channel_invites {
3187        update.channel_invitations.push(proto::Channel {
3188            id: channel.id.to_proto(),
3189            name: channel.name,
3190            visibility: channel.visibility.into(),
3191            role: channel.role.into(),
3192        });
3193    }
3194
3195    update
3196}
3197
3198fn build_initial_contacts_update(
3199    contacts: Vec<db::Contact>,
3200    pool: &ConnectionPool,
3201) -> proto::UpdateContacts {
3202    let mut update = proto::UpdateContacts::default();
3203
3204    for contact in contacts {
3205        match contact {
3206            db::Contact::Accepted {
3207                user_id,
3208                should_notify,
3209                busy,
3210            } => {
3211                update
3212                    .contacts
3213                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3214            }
3215            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3216            db::Contact::Incoming {
3217                user_id,
3218                should_notify,
3219            } => update
3220                .incoming_requests
3221                .push(proto::IncomingContactRequest {
3222                    requester_id: user_id.to_proto(),
3223                    should_notify,
3224                }),
3225        }
3226    }
3227
3228    update
3229}
3230
3231fn contact_for_user(
3232    user_id: UserId,
3233    should_notify: bool,
3234    busy: bool,
3235    pool: &ConnectionPool,
3236) -> proto::Contact {
3237    proto::Contact {
3238        user_id: user_id.to_proto(),
3239        online: pool.is_user_online(user_id),
3240        busy,
3241        should_notify,
3242    }
3243}
3244
3245fn room_updated(room: &proto::Room, peer: &Peer) {
3246    broadcast(
3247        None,
3248        room.participants
3249            .iter()
3250            .filter_map(|participant| Some(participant.peer_id?.into())),
3251        |peer_id| {
3252            peer.send(
3253                peer_id.into(),
3254                proto::RoomUpdated {
3255                    room: Some(room.clone()),
3256                },
3257            )
3258        },
3259    );
3260}
3261
3262fn channel_updated(
3263    channel_id: ChannelId,
3264    room: &proto::Room,
3265    channel_members: &[UserId],
3266    peer: &Peer,
3267    pool: &ConnectionPool,
3268) {
3269    let participants = room
3270        .participants
3271        .iter()
3272        .map(|p| p.user_id)
3273        .collect::<Vec<_>>();
3274
3275    broadcast(
3276        None,
3277        channel_members
3278            .iter()
3279            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3280        |peer_id| {
3281            peer.send(
3282                peer_id.into(),
3283                proto::UpdateChannels {
3284                    channel_participants: vec![proto::ChannelParticipants {
3285                        channel_id: channel_id.to_proto(),
3286                        participant_user_ids: participants.clone(),
3287                    }],
3288                    ..Default::default()
3289                },
3290            )
3291        },
3292    );
3293}
3294
3295async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3296    let db = session.db().await;
3297
3298    let contacts = db.get_contacts(user_id).await?;
3299    let busy = db.is_user_busy(user_id).await?;
3300
3301    let pool = session.connection_pool().await;
3302    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3303    for contact in contacts {
3304        if let db::Contact::Accepted {
3305            user_id: contact_user_id,
3306            ..
3307        } = contact
3308        {
3309            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3310                session
3311                    .peer
3312                    .send(
3313                        contact_conn_id,
3314                        proto::UpdateContacts {
3315                            contacts: vec![updated_contact.clone()],
3316                            remove_contacts: Default::default(),
3317                            incoming_requests: Default::default(),
3318                            remove_incoming_requests: Default::default(),
3319                            outgoing_requests: Default::default(),
3320                            remove_outgoing_requests: Default::default(),
3321                        },
3322                    )
3323                    .trace_err();
3324            }
3325        }
3326    }
3327    Ok(())
3328}
3329
3330async fn leave_room_for_session(session: &Session) -> Result<()> {
3331    let mut contacts_to_update = HashSet::default();
3332
3333    let room_id;
3334    let canceled_calls_to_user_ids;
3335    let live_kit_room;
3336    let delete_live_kit_room;
3337    let room;
3338    let channel_members;
3339    let channel_id;
3340
3341    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3342        contacts_to_update.insert(session.user_id);
3343
3344        for project in left_room.left_projects.values() {
3345            project_left(project, session);
3346        }
3347
3348        room_id = RoomId::from_proto(left_room.room.id);
3349        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3350        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3351        delete_live_kit_room = left_room.deleted;
3352        room = mem::take(&mut left_room.room);
3353        channel_members = mem::take(&mut left_room.channel_members);
3354        channel_id = left_room.channel_id;
3355
3356        room_updated(&room, &session.peer);
3357    } else {
3358        return Ok(());
3359    }
3360
3361    if let Some(channel_id) = channel_id {
3362        channel_updated(
3363            channel_id,
3364            &room,
3365            &channel_members,
3366            &session.peer,
3367            &*session.connection_pool().await,
3368        );
3369    }
3370
3371    {
3372        let pool = session.connection_pool().await;
3373        for canceled_user_id in canceled_calls_to_user_ids {
3374            for connection_id in pool.user_connection_ids(canceled_user_id) {
3375                session
3376                    .peer
3377                    .send(
3378                        connection_id,
3379                        proto::CallCanceled {
3380                            room_id: room_id.to_proto(),
3381                        },
3382                    )
3383                    .trace_err();
3384            }
3385            contacts_to_update.insert(canceled_user_id);
3386        }
3387    }
3388
3389    for contact_user_id in contacts_to_update {
3390        update_user_contacts(contact_user_id, &session).await?;
3391    }
3392
3393    if let Some(live_kit) = session.live_kit_client.as_ref() {
3394        live_kit
3395            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3396            .await
3397            .trace_err();
3398
3399        if delete_live_kit_room {
3400            live_kit.delete_room(live_kit_room).await.trace_err();
3401        }
3402    }
3403
3404    Ok(())
3405}
3406
3407async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3408    let left_channel_buffers = session
3409        .db()
3410        .await
3411        .leave_channel_buffers(session.connection_id)
3412        .await?;
3413
3414    for left_buffer in left_channel_buffers {
3415        channel_buffer_updated(
3416            session.connection_id,
3417            left_buffer.connections,
3418            &proto::UpdateChannelBufferCollaborators {
3419                channel_id: left_buffer.channel_id.to_proto(),
3420                collaborators: left_buffer.collaborators,
3421            },
3422            &session.peer,
3423        );
3424    }
3425
3426    Ok(())
3427}
3428
3429fn project_left(project: &db::LeftProject, session: &Session) {
3430    for connection_id in &project.connection_ids {
3431        if project.host_user_id == session.user_id {
3432            session
3433                .peer
3434                .send(
3435                    *connection_id,
3436                    proto::UnshareProject {
3437                        project_id: project.id.to_proto(),
3438                    },
3439                )
3440                .trace_err();
3441        } else {
3442            session
3443                .peer
3444                .send(
3445                    *connection_id,
3446                    proto::RemoveProjectCollaborator {
3447                        project_id: project.id.to_proto(),
3448                        peer_id: Some(session.connection_id.into()),
3449                    },
3450                )
3451                .trace_err();
3452        }
3453    }
3454}
3455
3456pub trait ResultExt {
3457    type Ok;
3458
3459    fn trace_err(self) -> Option<Self::Ok>;
3460}
3461
3462impl<T, E> ResultExt for Result<T, E>
3463where
3464    E: std::fmt::Debug,
3465{
3466    type Ok = T;
3467
3468    fn trace_err(self) -> Option<T> {
3469        match self {
3470            Ok(value) => Some(value),
3471            Err(error) => {
3472                tracing::error!("{:?}", error);
3473                None
3474            }
3475        }
3476    }
3477}