rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, Database, MessageId,
   7        ProjectId, RoomId, ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73
  74lazy_static! {
  75    static ref METRIC_CONNECTIONS: IntGauge =
  76        register_int_gauge!("connections", "number of connections").unwrap();
  77    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  78        "shared_projects",
  79        "number of open projects with one or more guests"
  80    )
  81    .unwrap();
  82}
  83
  84type MessageHandler =
  85    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  86
  87struct Response<R> {
  88    peer: Arc<Peer>,
  89    receipt: Receipt<R>,
  90    responded: Arc<AtomicBool>,
  91}
  92
  93impl<R: RequestMessage> Response<R> {
  94    fn send(self, payload: R::Response) -> Result<()> {
  95        self.responded.store(true, SeqCst);
  96        self.peer.respond(self.receipt, payload)?;
  97        Ok(())
  98    }
  99}
 100
 101#[derive(Clone)]
 102struct Session {
 103    user_id: UserId,
 104    connection_id: ConnectionId,
 105    db: Arc<tokio::sync::Mutex<DbHandle>>,
 106    peer: Arc<Peer>,
 107    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 108    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 109    executor: Executor,
 110}
 111
 112impl Session {
 113    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.db.lock().await;
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        guard
 120    }
 121
 122    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.connection_pool.lock();
 126        ConnectionPoolGuard {
 127            guard,
 128            _not_send: PhantomData,
 129        }
 130    }
 131}
 132
 133impl fmt::Debug for Session {
 134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 135        f.debug_struct("Session")
 136            .field("user_id", &self.user_id)
 137            .field("connection_id", &self.connection_id)
 138            .finish()
 139    }
 140}
 141
 142struct DbHandle(Arc<Database>);
 143
 144impl Deref for DbHandle {
 145    type Target = Database;
 146
 147    fn deref(&self) -> &Self::Target {
 148        self.0.as_ref()
 149    }
 150}
 151
 152pub struct Server {
 153    id: parking_lot::Mutex<ServerId>,
 154    peer: Arc<Peer>,
 155    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 156    app_state: Arc<AppState>,
 157    executor: Executor,
 158    handlers: HashMap<TypeId, MessageHandler>,
 159    teardown: watch::Sender<()>,
 160}
 161
 162pub(crate) struct ConnectionPoolGuard<'a> {
 163    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 164    _not_send: PhantomData<Rc<()>>,
 165}
 166
 167#[derive(Serialize)]
 168pub struct ServerSnapshot<'a> {
 169    peer: &'a Peer,
 170    #[serde(serialize_with = "serialize_deref")]
 171    connection_pool: ConnectionPoolGuard<'a>,
 172}
 173
 174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 175where
 176    S: Serializer,
 177    T: Deref<Target = U>,
 178    U: Serialize,
 179{
 180    Serialize::serialize(value.deref(), serializer)
 181}
 182
 183impl Server {
 184    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 185        let mut server = Self {
 186            id: parking_lot::Mutex::new(id),
 187            peer: Peer::new(id.0 as u32),
 188            app_state,
 189            executor,
 190            connection_pool: Default::default(),
 191            handlers: Default::default(),
 192            teardown: watch::channel(()).0,
 193        };
 194
 195        server
 196            .add_request_handler(ping)
 197            .add_request_handler(create_room)
 198            .add_request_handler(join_room)
 199            .add_request_handler(rejoin_room)
 200            .add_request_handler(leave_room)
 201            .add_request_handler(call)
 202            .add_request_handler(cancel_call)
 203            .add_message_handler(decline_call)
 204            .add_request_handler(update_participant_location)
 205            .add_request_handler(share_project)
 206            .add_message_handler(unshare_project)
 207            .add_request_handler(join_project)
 208            .add_message_handler(leave_project)
 209            .add_request_handler(update_project)
 210            .add_request_handler(update_worktree)
 211            .add_message_handler(start_language_server)
 212            .add_message_handler(update_language_server)
 213            .add_message_handler(update_diagnostic_summary)
 214            .add_message_handler(update_worktree_settings)
 215            .add_message_handler(refresh_inlay_hints)
 216            .add_request_handler(forward_project_request::<proto::GetHover>)
 217            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetReferences>)
 220            .add_request_handler(forward_project_request::<proto::SearchProject>)
 221            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 222            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 226            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 227            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 228            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 229            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 230            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 231            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 232            .add_request_handler(forward_project_request::<proto::PerformRename>)
 233            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 234            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 235            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 236            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 242            .add_request_handler(forward_project_request::<proto::InlayHints>)
 243            .add_message_handler(create_buffer_for_peer)
 244            .add_request_handler(update_buffer)
 245            .add_message_handler(update_buffer_file)
 246            .add_message_handler(buffer_reloaded)
 247            .add_message_handler(buffer_saved)
 248            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 249            .add_request_handler(get_users)
 250            .add_request_handler(fuzzy_search_users)
 251            .add_request_handler(request_contact)
 252            .add_request_handler(remove_contact)
 253            .add_request_handler(respond_to_contact_request)
 254            .add_request_handler(create_channel)
 255            .add_request_handler(delete_channel)
 256            .add_request_handler(invite_channel_member)
 257            .add_request_handler(remove_channel_member)
 258            .add_request_handler(set_channel_member_role)
 259            .add_request_handler(set_channel_visibility)
 260            .add_request_handler(rename_channel)
 261            .add_request_handler(join_channel_buffer)
 262            .add_request_handler(leave_channel_buffer)
 263            .add_message_handler(update_channel_buffer)
 264            .add_request_handler(rejoin_channel_buffers)
 265            .add_request_handler(get_channel_members)
 266            .add_request_handler(respond_to_channel_invite)
 267            .add_request_handler(join_channel)
 268            .add_request_handler(join_channel_chat)
 269            .add_message_handler(leave_channel_chat)
 270            .add_request_handler(send_channel_message)
 271            .add_request_handler(remove_channel_message)
 272            .add_request_handler(get_channel_messages)
 273            .add_request_handler(link_channel)
 274            .add_request_handler(unlink_channel)
 275            .add_request_handler(move_channel)
 276            .add_request_handler(follow)
 277            .add_message_handler(unfollow)
 278            .add_message_handler(update_followers)
 279            .add_message_handler(update_diff_base)
 280            .add_request_handler(get_private_user_info)
 281            .add_message_handler(acknowledge_channel_message)
 282            .add_message_handler(acknowledge_buffer_version);
 283
 284        Arc::new(server)
 285    }
 286
 287    pub async fn start(&self) -> Result<()> {
 288        let server_id = *self.id.lock();
 289        let app_state = self.app_state.clone();
 290        let peer = self.peer.clone();
 291        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 292        let pool = self.connection_pool.clone();
 293        let live_kit_client = self.app_state.live_kit_client.clone();
 294
 295        let span = info_span!("start server");
 296        self.executor.spawn_detached(
 297            async move {
 298                tracing::info!("waiting for cleanup timeout");
 299                timeout.await;
 300                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 301                if let Some((room_ids, channel_ids)) = app_state
 302                    .db
 303                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 304                    .await
 305                    .trace_err()
 306                {
 307                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 308                    tracing::info!(
 309                        stale_channel_buffer_count = channel_ids.len(),
 310                        "retrieved stale channel buffers"
 311                    );
 312
 313                    for channel_id in channel_ids {
 314                        if let Some(refreshed_channel_buffer) = app_state
 315                            .db
 316                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 317                            .await
 318                            .trace_err()
 319                        {
 320                            for connection_id in refreshed_channel_buffer.connection_ids {
 321                                peer.send(
 322                                    connection_id,
 323                                    proto::UpdateChannelBufferCollaborators {
 324                                        channel_id: channel_id.to_proto(),
 325                                        collaborators: refreshed_channel_buffer
 326                                            .collaborators
 327                                            .clone(),
 328                                    },
 329                                )
 330                                .trace_err();
 331                            }
 332                        }
 333                    }
 334
 335                    for room_id in room_ids {
 336                        let mut contacts_to_update = HashSet::default();
 337                        let mut canceled_calls_to_user_ids = Vec::new();
 338                        let mut live_kit_room = String::new();
 339                        let mut delete_live_kit_room = false;
 340
 341                        if let Some(mut refreshed_room) = app_state
 342                            .db
 343                            .clear_stale_room_participants(room_id, server_id)
 344                            .await
 345                            .trace_err()
 346                        {
 347                            tracing::info!(
 348                                room_id = room_id.0,
 349                                new_participant_count = refreshed_room.room.participants.len(),
 350                                "refreshed room"
 351                            );
 352                            room_updated(&refreshed_room.room, &peer);
 353                            if let Some(channel_id) = refreshed_room.channel_id {
 354                                channel_updated(
 355                                    channel_id,
 356                                    &refreshed_room.room,
 357                                    &refreshed_room.channel_members,
 358                                    &peer,
 359                                    &*pool.lock(),
 360                                );
 361                            }
 362                            contacts_to_update
 363                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 364                            contacts_to_update
 365                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 366                            canceled_calls_to_user_ids =
 367                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 368                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 369                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 370                        }
 371
 372                        {
 373                            let pool = pool.lock();
 374                            for canceled_user_id in canceled_calls_to_user_ids {
 375                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 376                                    peer.send(
 377                                        connection_id,
 378                                        proto::CallCanceled {
 379                                            room_id: room_id.to_proto(),
 380                                        },
 381                                    )
 382                                    .trace_err();
 383                                }
 384                            }
 385                        }
 386
 387                        for user_id in contacts_to_update {
 388                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 389                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 390                            if let Some((busy, contacts)) = busy.zip(contacts) {
 391                                let pool = pool.lock();
 392                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 393                                for contact in contacts {
 394                                    if let db::Contact::Accepted {
 395                                        user_id: contact_user_id,
 396                                        ..
 397                                    } = contact
 398                                    {
 399                                        for contact_conn_id in
 400                                            pool.user_connection_ids(contact_user_id)
 401                                        {
 402                                            peer.send(
 403                                                contact_conn_id,
 404                                                proto::UpdateContacts {
 405                                                    contacts: vec![updated_contact.clone()],
 406                                                    remove_contacts: Default::default(),
 407                                                    incoming_requests: Default::default(),
 408                                                    remove_incoming_requests: Default::default(),
 409                                                    outgoing_requests: Default::default(),
 410                                                    remove_outgoing_requests: Default::default(),
 411                                                },
 412                                            )
 413                                            .trace_err();
 414                                        }
 415                                    }
 416                                }
 417                            }
 418                        }
 419
 420                        if let Some(live_kit) = live_kit_client.as_ref() {
 421                            if delete_live_kit_room {
 422                                live_kit.delete_room(live_kit_room).await.trace_err();
 423                            }
 424                        }
 425                    }
 426                }
 427
 428                app_state
 429                    .db
 430                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 431                    .await
 432                    .trace_err();
 433            }
 434            .instrument(span),
 435        );
 436        Ok(())
 437    }
 438
 439    pub fn teardown(&self) {
 440        self.peer.teardown();
 441        self.connection_pool.lock().reset();
 442        let _ = self.teardown.send(());
 443    }
 444
 445    #[cfg(test)]
 446    pub fn reset(&self, id: ServerId) {
 447        self.teardown();
 448        *self.id.lock() = id;
 449        self.peer.reset(id.0 as u32);
 450    }
 451
 452    #[cfg(test)]
 453    pub fn id(&self) -> ServerId {
 454        *self.id.lock()
 455    }
 456
 457    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 458    where
 459        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 460        Fut: 'static + Send + Future<Output = Result<()>>,
 461        M: EnvelopedMessage,
 462    {
 463        let prev_handler = self.handlers.insert(
 464            TypeId::of::<M>(),
 465            Box::new(move |envelope, session| {
 466                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 467                let span = info_span!(
 468                    "handle message",
 469                    payload_type = envelope.payload_type_name()
 470                );
 471                span.in_scope(|| {
 472                    tracing::info!(
 473                        payload_type = envelope.payload_type_name(),
 474                        "message received"
 475                    );
 476                });
 477                let start_time = Instant::now();
 478                let future = (handler)(*envelope, session);
 479                async move {
 480                    let result = future.await;
 481                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 482                    match result {
 483                        Err(error) => {
 484                            tracing::error!(%error, ?duration_ms, "error handling message")
 485                        }
 486                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 487                    }
 488                }
 489                .instrument(span)
 490                .boxed()
 491            }),
 492        );
 493        if prev_handler.is_some() {
 494            panic!("registered a handler for the same message twice");
 495        }
 496        self
 497    }
 498
 499    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 500    where
 501        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 502        Fut: 'static + Send + Future<Output = Result<()>>,
 503        M: EnvelopedMessage,
 504    {
 505        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 506        self
 507    }
 508
 509    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 512        Fut: Send + Future<Output = Result<()>>,
 513        M: RequestMessage,
 514    {
 515        let handler = Arc::new(handler);
 516        self.add_handler(move |envelope, session| {
 517            let receipt = envelope.receipt();
 518            let handler = handler.clone();
 519            async move {
 520                let peer = session.peer.clone();
 521                let responded = Arc::new(AtomicBool::default());
 522                let response = Response {
 523                    peer: peer.clone(),
 524                    responded: responded.clone(),
 525                    receipt,
 526                };
 527                match (handler)(envelope.payload, response, session).await {
 528                    Ok(()) => {
 529                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 530                            Ok(())
 531                        } else {
 532                            Err(anyhow!("handler did not send a response"))?
 533                        }
 534                    }
 535                    Err(error) => {
 536                        peer.respond_with_error(
 537                            receipt,
 538                            proto::Error {
 539                                message: error.to_string(),
 540                            },
 541                        )?;
 542                        Err(error)
 543                    }
 544                }
 545            }
 546        })
 547    }
 548
 549    pub fn handle_connection(
 550        self: &Arc<Self>,
 551        connection: Connection,
 552        address: String,
 553        user: User,
 554        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 555        executor: Executor,
 556    ) -> impl Future<Output = Result<()>> {
 557        let this = self.clone();
 558        let user_id = user.id;
 559        let login = user.github_login;
 560        let span = info_span!("handle connection", %user_id, %login, %address);
 561        let mut teardown = self.teardown.subscribe();
 562        async move {
 563            let (connection_id, handle_io, mut incoming_rx) = this
 564                .peer
 565                .add_connection(connection, {
 566                    let executor = executor.clone();
 567                    move |duration| executor.sleep(duration)
 568                });
 569
 570            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 571            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 572            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 573
 574            if let Some(send_connection_id) = send_connection_id.take() {
 575                let _ = send_connection_id.send(connection_id);
 576            }
 577
 578            if !user.connected_once {
 579                this.peer.send(connection_id, proto::ShowContacts {})?;
 580                this.app_state.db.set_user_connected_once(user_id, true).await?;
 581            }
 582
 583            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 584                this.app_state.db.get_contacts(user_id),
 585                this.app_state.db.get_channels_for_user(user_id),
 586                this.app_state.db.get_channel_invites_for_user(user_id)
 587            ).await?;
 588
 589            {
 590                let mut pool = this.connection_pool.lock();
 591                pool.add_connection(connection_id, user_id, user.admin);
 592                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 593                this.peer.send(connection_id, build_initial_channels_update(
 594                    channels_for_user,
 595                    channel_invites
 596                ))?;
 597            }
 598
 599            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 600                this.peer.send(connection_id, incoming_call)?;
 601            }
 602
 603            let session = Session {
 604                user_id,
 605                connection_id,
 606                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 607                peer: this.peer.clone(),
 608                connection_pool: this.connection_pool.clone(),
 609                live_kit_client: this.app_state.live_kit_client.clone(),
 610                executor: executor.clone(),
 611            };
 612            update_user_contacts(user_id, &session).await?;
 613
 614            let handle_io = handle_io.fuse();
 615            futures::pin_mut!(handle_io);
 616
 617            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 618            // This prevents deadlocks when e.g., client A performs a request to client B and
 619            // client B performs a request to client A. If both clients stop processing further
 620            // messages until their respective request completes, they won't have a chance to
 621            // respond to the other client's request and cause a deadlock.
 622            //
 623            // This arrangement ensures we will attempt to process earlier messages first, but fall
 624            // back to processing messages arrived later in the spirit of making progress.
 625            let mut foreground_message_handlers = FuturesUnordered::new();
 626            let concurrent_handlers = Arc::new(Semaphore::new(256));
 627            loop {
 628                let next_message = async {
 629                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 630                    let message = incoming_rx.next().await;
 631                    (permit, message)
 632                }.fuse();
 633                futures::pin_mut!(next_message);
 634                futures::select_biased! {
 635                    _ = teardown.changed().fuse() => return Ok(()),
 636                    result = handle_io => {
 637                        if let Err(error) = result {
 638                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 639                        }
 640                        break;
 641                    }
 642                    _ = foreground_message_handlers.next() => {}
 643                    next_message = next_message => {
 644                        let (permit, message) = next_message;
 645                        if let Some(message) = message {
 646                            let type_name = message.payload_type_name();
 647                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 648                            let span_enter = span.enter();
 649                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 650                                let is_background = message.is_background();
 651                                let handle_message = (handler)(message, session.clone());
 652                                drop(span_enter);
 653
 654                                let handle_message = async move {
 655                                    handle_message.await;
 656                                    drop(permit);
 657                                }.instrument(span);
 658                                if is_background {
 659                                    executor.spawn_detached(handle_message);
 660                                } else {
 661                                    foreground_message_handlers.push(handle_message);
 662                                }
 663                            } else {
 664                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 665                            }
 666                        } else {
 667                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 668                            break;
 669                        }
 670                    }
 671                }
 672            }
 673
 674            drop(foreground_message_handlers);
 675            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 676            if let Err(error) = connection_lost(session, teardown, executor).await {
 677                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 678            }
 679
 680            Ok(())
 681        }.instrument(span)
 682    }
 683
 684    pub async fn invite_code_redeemed(
 685        self: &Arc<Self>,
 686        inviter_id: UserId,
 687        invitee_id: UserId,
 688    ) -> Result<()> {
 689        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 690            if let Some(code) = &user.invite_code {
 691                let pool = self.connection_pool.lock();
 692                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 693                for connection_id in pool.user_connection_ids(inviter_id) {
 694                    self.peer.send(
 695                        connection_id,
 696                        proto::UpdateContacts {
 697                            contacts: vec![invitee_contact.clone()],
 698                            ..Default::default()
 699                        },
 700                    )?;
 701                    self.peer.send(
 702                        connection_id,
 703                        proto::UpdateInviteInfo {
 704                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 705                            count: user.invite_count as u32,
 706                        },
 707                    )?;
 708                }
 709            }
 710        }
 711        Ok(())
 712    }
 713
 714    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 715        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 716            if let Some(invite_code) = &user.invite_code {
 717                let pool = self.connection_pool.lock();
 718                for connection_id in pool.user_connection_ids(user_id) {
 719                    self.peer.send(
 720                        connection_id,
 721                        proto::UpdateInviteInfo {
 722                            url: format!(
 723                                "{}{}",
 724                                self.app_state.config.invite_link_prefix, invite_code
 725                            ),
 726                            count: user.invite_count as u32,
 727                        },
 728                    )?;
 729                }
 730            }
 731        }
 732        Ok(())
 733    }
 734
 735    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 736        ServerSnapshot {
 737            connection_pool: ConnectionPoolGuard {
 738                guard: self.connection_pool.lock(),
 739                _not_send: PhantomData,
 740            },
 741            peer: &self.peer,
 742        }
 743    }
 744}
 745
 746impl<'a> Deref for ConnectionPoolGuard<'a> {
 747    type Target = ConnectionPool;
 748
 749    fn deref(&self) -> &Self::Target {
 750        &*self.guard
 751    }
 752}
 753
 754impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 755    fn deref_mut(&mut self) -> &mut Self::Target {
 756        &mut *self.guard
 757    }
 758}
 759
 760impl<'a> Drop for ConnectionPoolGuard<'a> {
 761    fn drop(&mut self) {
 762        #[cfg(test)]
 763        self.check_invariants();
 764    }
 765}
 766
 767fn broadcast<F>(
 768    sender_id: Option<ConnectionId>,
 769    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 770    mut f: F,
 771) where
 772    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 773{
 774    for receiver_id in receiver_ids {
 775        if Some(receiver_id) != sender_id {
 776            if let Err(error) = f(receiver_id) {
 777                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 778            }
 779        }
 780    }
 781}
 782
 783lazy_static! {
 784    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 785}
 786
 787pub struct ProtocolVersion(u32);
 788
 789impl Header for ProtocolVersion {
 790    fn name() -> &'static HeaderName {
 791        &ZED_PROTOCOL_VERSION
 792    }
 793
 794    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 795    where
 796        Self: Sized,
 797        I: Iterator<Item = &'i axum::http::HeaderValue>,
 798    {
 799        let version = values
 800            .next()
 801            .ok_or_else(axum::headers::Error::invalid)?
 802            .to_str()
 803            .map_err(|_| axum::headers::Error::invalid())?
 804            .parse()
 805            .map_err(|_| axum::headers::Error::invalid())?;
 806        Ok(Self(version))
 807    }
 808
 809    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 810        values.extend([self.0.to_string().parse().unwrap()]);
 811    }
 812}
 813
 814pub fn routes(server: Arc<Server>) -> Router<Body> {
 815    Router::new()
 816        .route("/rpc", get(handle_websocket_request))
 817        .layer(
 818            ServiceBuilder::new()
 819                .layer(Extension(server.app_state.clone()))
 820                .layer(middleware::from_fn(auth::validate_header)),
 821        )
 822        .route("/metrics", get(handle_metrics))
 823        .layer(Extension(server))
 824}
 825
 826pub async fn handle_websocket_request(
 827    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 828    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 829    Extension(server): Extension<Arc<Server>>,
 830    Extension(user): Extension<User>,
 831    ws: WebSocketUpgrade,
 832) -> axum::response::Response {
 833    if protocol_version != rpc::PROTOCOL_VERSION {
 834        return (
 835            StatusCode::UPGRADE_REQUIRED,
 836            "client must be upgraded".to_string(),
 837        )
 838            .into_response();
 839    }
 840    let socket_address = socket_address.to_string();
 841    ws.on_upgrade(move |socket| {
 842        use util::ResultExt;
 843        let socket = socket
 844            .map_ok(to_tungstenite_message)
 845            .err_into()
 846            .with(|message| async move { Ok(to_axum_message(message)) });
 847        let connection = Connection::new(Box::pin(socket));
 848        async move {
 849            server
 850                .handle_connection(connection, socket_address, user, None, Executor::Production)
 851                .await
 852                .log_err();
 853        }
 854    })
 855}
 856
 857pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 858    let connections = server
 859        .connection_pool
 860        .lock()
 861        .connections()
 862        .filter(|connection| !connection.admin)
 863        .count();
 864
 865    METRIC_CONNECTIONS.set(connections as _);
 866
 867    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 868    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 869
 870    let encoder = prometheus::TextEncoder::new();
 871    let metric_families = prometheus::gather();
 872    let encoded_metrics = encoder
 873        .encode_to_string(&metric_families)
 874        .map_err(|err| anyhow!("{}", err))?;
 875    Ok(encoded_metrics)
 876}
 877
 878#[instrument(err, skip(executor))]
 879async fn connection_lost(
 880    session: Session,
 881    mut teardown: watch::Receiver<()>,
 882    executor: Executor,
 883) -> Result<()> {
 884    session.peer.disconnect(session.connection_id);
 885    session
 886        .connection_pool()
 887        .await
 888        .remove_connection(session.connection_id)?;
 889
 890    session
 891        .db()
 892        .await
 893        .connection_lost(session.connection_id)
 894        .await
 895        .trace_err();
 896
 897    futures::select_biased! {
 898        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 899            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 900            leave_room_for_session(&session).await.trace_err();
 901            leave_channel_buffers_for_session(&session)
 902                .await
 903                .trace_err();
 904
 905            if !session
 906                .connection_pool()
 907                .await
 908                .is_user_online(session.user_id)
 909            {
 910                let db = session.db().await;
 911                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 912                    room_updated(&room, &session.peer);
 913                }
 914            }
 915
 916            update_user_contacts(session.user_id, &session).await?;
 917        }
 918        _ = teardown.changed().fuse() => {}
 919    }
 920
 921    Ok(())
 922}
 923
 924async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 925    response.send(proto::Ack {})?;
 926    Ok(())
 927}
 928
 929async fn create_room(
 930    _request: proto::CreateRoom,
 931    response: Response<proto::CreateRoom>,
 932    session: Session,
 933) -> Result<()> {
 934    let live_kit_room = nanoid::nanoid!(30);
 935
 936    let live_kit_connection_info = {
 937        let live_kit_room = live_kit_room.clone();
 938        let live_kit = session.live_kit_client.as_ref();
 939
 940        util::async_iife!({
 941            let live_kit = live_kit?;
 942
 943            let token = live_kit
 944                .room_token(&live_kit_room, &session.user_id.to_string())
 945                .trace_err()?;
 946
 947            Some(proto::LiveKitConnectionInfo {
 948                server_url: live_kit.url().into(),
 949                token,
 950            })
 951        })
 952    }
 953    .await;
 954
 955    let room = session
 956        .db()
 957        .await
 958        .create_room(
 959            session.user_id,
 960            session.connection_id,
 961            &live_kit_room,
 962            RELEASE_CHANNEL_NAME.as_str(),
 963        )
 964        .await?;
 965
 966    response.send(proto::CreateRoomResponse {
 967        room: Some(room.clone()),
 968        live_kit_connection_info,
 969    })?;
 970
 971    update_user_contacts(session.user_id, &session).await?;
 972    Ok(())
 973}
 974
 975async fn join_room(
 976    request: proto::JoinRoom,
 977    response: Response<proto::JoinRoom>,
 978    session: Session,
 979) -> Result<()> {
 980    let room_id = RoomId::from_proto(request.id);
 981
 982    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 983
 984    if let Some(channel_id) = channel_id {
 985        return join_channel_internal(channel_id, Box::new(response), session).await;
 986    }
 987
 988    let joined_room = {
 989        let room = session
 990            .db()
 991            .await
 992            .join_room(
 993                room_id,
 994                session.user_id,
 995                session.connection_id,
 996                RELEASE_CHANNEL_NAME.as_str(),
 997            )
 998            .await?;
 999        room_updated(&room.room, &session.peer);
1000        room.into_inner()
1001    };
1002
1003    for connection_id in session
1004        .connection_pool()
1005        .await
1006        .user_connection_ids(session.user_id)
1007    {
1008        session
1009            .peer
1010            .send(
1011                connection_id,
1012                proto::CallCanceled {
1013                    room_id: room_id.to_proto(),
1014                },
1015            )
1016            .trace_err();
1017    }
1018
1019    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1020        if let Some(token) = live_kit
1021            .room_token(
1022                &joined_room.room.live_kit_room,
1023                &session.user_id.to_string(),
1024            )
1025            .trace_err()
1026        {
1027            Some(proto::LiveKitConnectionInfo {
1028                server_url: live_kit.url().into(),
1029                token,
1030            })
1031        } else {
1032            None
1033        }
1034    } else {
1035        None
1036    };
1037
1038    response.send(proto::JoinRoomResponse {
1039        room: Some(joined_room.room),
1040        channel_id: None,
1041        live_kit_connection_info,
1042    })?;
1043
1044    update_user_contacts(session.user_id, &session).await?;
1045    Ok(())
1046}
1047
1048async fn rejoin_room(
1049    request: proto::RejoinRoom,
1050    response: Response<proto::RejoinRoom>,
1051    session: Session,
1052) -> Result<()> {
1053    let room;
1054    let channel_id;
1055    let channel_members;
1056    {
1057        let mut rejoined_room = session
1058            .db()
1059            .await
1060            .rejoin_room(request, session.user_id, session.connection_id)
1061            .await?;
1062
1063        response.send(proto::RejoinRoomResponse {
1064            room: Some(rejoined_room.room.clone()),
1065            reshared_projects: rejoined_room
1066                .reshared_projects
1067                .iter()
1068                .map(|project| proto::ResharedProject {
1069                    id: project.id.to_proto(),
1070                    collaborators: project
1071                        .collaborators
1072                        .iter()
1073                        .map(|collaborator| collaborator.to_proto())
1074                        .collect(),
1075                })
1076                .collect(),
1077            rejoined_projects: rejoined_room
1078                .rejoined_projects
1079                .iter()
1080                .map(|rejoined_project| proto::RejoinedProject {
1081                    id: rejoined_project.id.to_proto(),
1082                    worktrees: rejoined_project
1083                        .worktrees
1084                        .iter()
1085                        .map(|worktree| proto::WorktreeMetadata {
1086                            id: worktree.id,
1087                            root_name: worktree.root_name.clone(),
1088                            visible: worktree.visible,
1089                            abs_path: worktree.abs_path.clone(),
1090                        })
1091                        .collect(),
1092                    collaborators: rejoined_project
1093                        .collaborators
1094                        .iter()
1095                        .map(|collaborator| collaborator.to_proto())
1096                        .collect(),
1097                    language_servers: rejoined_project.language_servers.clone(),
1098                })
1099                .collect(),
1100        })?;
1101        room_updated(&rejoined_room.room, &session.peer);
1102
1103        for project in &rejoined_room.reshared_projects {
1104            for collaborator in &project.collaborators {
1105                session
1106                    .peer
1107                    .send(
1108                        collaborator.connection_id,
1109                        proto::UpdateProjectCollaborator {
1110                            project_id: project.id.to_proto(),
1111                            old_peer_id: Some(project.old_connection_id.into()),
1112                            new_peer_id: Some(session.connection_id.into()),
1113                        },
1114                    )
1115                    .trace_err();
1116            }
1117
1118            broadcast(
1119                Some(session.connection_id),
1120                project
1121                    .collaborators
1122                    .iter()
1123                    .map(|collaborator| collaborator.connection_id),
1124                |connection_id| {
1125                    session.peer.forward_send(
1126                        session.connection_id,
1127                        connection_id,
1128                        proto::UpdateProject {
1129                            project_id: project.id.to_proto(),
1130                            worktrees: project.worktrees.clone(),
1131                        },
1132                    )
1133                },
1134            );
1135        }
1136
1137        for project in &rejoined_room.rejoined_projects {
1138            for collaborator in &project.collaborators {
1139                session
1140                    .peer
1141                    .send(
1142                        collaborator.connection_id,
1143                        proto::UpdateProjectCollaborator {
1144                            project_id: project.id.to_proto(),
1145                            old_peer_id: Some(project.old_connection_id.into()),
1146                            new_peer_id: Some(session.connection_id.into()),
1147                        },
1148                    )
1149                    .trace_err();
1150            }
1151        }
1152
1153        for project in &mut rejoined_room.rejoined_projects {
1154            for worktree in mem::take(&mut project.worktrees) {
1155                #[cfg(any(test, feature = "test-support"))]
1156                const MAX_CHUNK_SIZE: usize = 2;
1157                #[cfg(not(any(test, feature = "test-support")))]
1158                const MAX_CHUNK_SIZE: usize = 256;
1159
1160                // Stream this worktree's entries.
1161                let message = proto::UpdateWorktree {
1162                    project_id: project.id.to_proto(),
1163                    worktree_id: worktree.id,
1164                    abs_path: worktree.abs_path.clone(),
1165                    root_name: worktree.root_name,
1166                    updated_entries: worktree.updated_entries,
1167                    removed_entries: worktree.removed_entries,
1168                    scan_id: worktree.scan_id,
1169                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1170                    updated_repositories: worktree.updated_repositories,
1171                    removed_repositories: worktree.removed_repositories,
1172                };
1173                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1174                    session.peer.send(session.connection_id, update.clone())?;
1175                }
1176
1177                // Stream this worktree's diagnostics.
1178                for summary in worktree.diagnostic_summaries {
1179                    session.peer.send(
1180                        session.connection_id,
1181                        proto::UpdateDiagnosticSummary {
1182                            project_id: project.id.to_proto(),
1183                            worktree_id: worktree.id,
1184                            summary: Some(summary),
1185                        },
1186                    )?;
1187                }
1188
1189                for settings_file in worktree.settings_files {
1190                    session.peer.send(
1191                        session.connection_id,
1192                        proto::UpdateWorktreeSettings {
1193                            project_id: project.id.to_proto(),
1194                            worktree_id: worktree.id,
1195                            path: settings_file.path,
1196                            content: Some(settings_file.content),
1197                        },
1198                    )?;
1199                }
1200            }
1201
1202            for language_server in &project.language_servers {
1203                session.peer.send(
1204                    session.connection_id,
1205                    proto::UpdateLanguageServer {
1206                        project_id: project.id.to_proto(),
1207                        language_server_id: language_server.id,
1208                        variant: Some(
1209                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1210                                proto::LspDiskBasedDiagnosticsUpdated {},
1211                            ),
1212                        ),
1213                    },
1214                )?;
1215            }
1216        }
1217
1218        let rejoined_room = rejoined_room.into_inner();
1219
1220        room = rejoined_room.room;
1221        channel_id = rejoined_room.channel_id;
1222        channel_members = rejoined_room.channel_members;
1223    }
1224
1225    if let Some(channel_id) = channel_id {
1226        channel_updated(
1227            channel_id,
1228            &room,
1229            &channel_members,
1230            &session.peer,
1231            &*session.connection_pool().await,
1232        );
1233    }
1234
1235    update_user_contacts(session.user_id, &session).await?;
1236    Ok(())
1237}
1238
1239async fn leave_room(
1240    _: proto::LeaveRoom,
1241    response: Response<proto::LeaveRoom>,
1242    session: Session,
1243) -> Result<()> {
1244    leave_room_for_session(&session).await?;
1245    response.send(proto::Ack {})?;
1246    Ok(())
1247}
1248
1249async fn call(
1250    request: proto::Call,
1251    response: Response<proto::Call>,
1252    session: Session,
1253) -> Result<()> {
1254    let room_id = RoomId::from_proto(request.room_id);
1255    let calling_user_id = session.user_id;
1256    let calling_connection_id = session.connection_id;
1257    let called_user_id = UserId::from_proto(request.called_user_id);
1258    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1259    if !session
1260        .db()
1261        .await
1262        .has_contact(calling_user_id, called_user_id)
1263        .await?
1264    {
1265        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1266    }
1267
1268    let incoming_call = {
1269        let (room, incoming_call) = &mut *session
1270            .db()
1271            .await
1272            .call(
1273                room_id,
1274                calling_user_id,
1275                calling_connection_id,
1276                called_user_id,
1277                initial_project_id,
1278            )
1279            .await?;
1280        room_updated(&room, &session.peer);
1281        mem::take(incoming_call)
1282    };
1283    update_user_contacts(called_user_id, &session).await?;
1284
1285    let mut calls = session
1286        .connection_pool()
1287        .await
1288        .user_connection_ids(called_user_id)
1289        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1290        .collect::<FuturesUnordered<_>>();
1291
1292    while let Some(call_response) = calls.next().await {
1293        match call_response.as_ref() {
1294            Ok(_) => {
1295                response.send(proto::Ack {})?;
1296                return Ok(());
1297            }
1298            Err(_) => {
1299                call_response.trace_err();
1300            }
1301        }
1302    }
1303
1304    {
1305        let room = session
1306            .db()
1307            .await
1308            .call_failed(room_id, called_user_id)
1309            .await?;
1310        room_updated(&room, &session.peer);
1311    }
1312    update_user_contacts(called_user_id, &session).await?;
1313
1314    Err(anyhow!("failed to ring user"))?
1315}
1316
1317async fn cancel_call(
1318    request: proto::CancelCall,
1319    response: Response<proto::CancelCall>,
1320    session: Session,
1321) -> Result<()> {
1322    let called_user_id = UserId::from_proto(request.called_user_id);
1323    let room_id = RoomId::from_proto(request.room_id);
1324    {
1325        let room = session
1326            .db()
1327            .await
1328            .cancel_call(room_id, session.connection_id, called_user_id)
1329            .await?;
1330        room_updated(&room, &session.peer);
1331    }
1332
1333    for connection_id in session
1334        .connection_pool()
1335        .await
1336        .user_connection_ids(called_user_id)
1337    {
1338        session
1339            .peer
1340            .send(
1341                connection_id,
1342                proto::CallCanceled {
1343                    room_id: room_id.to_proto(),
1344                },
1345            )
1346            .trace_err();
1347    }
1348    response.send(proto::Ack {})?;
1349
1350    update_user_contacts(called_user_id, &session).await?;
1351    Ok(())
1352}
1353
1354async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1355    let room_id = RoomId::from_proto(message.room_id);
1356    {
1357        let room = session
1358            .db()
1359            .await
1360            .decline_call(Some(room_id), session.user_id)
1361            .await?
1362            .ok_or_else(|| anyhow!("failed to decline call"))?;
1363        room_updated(&room, &session.peer);
1364    }
1365
1366    for connection_id in session
1367        .connection_pool()
1368        .await
1369        .user_connection_ids(session.user_id)
1370    {
1371        session
1372            .peer
1373            .send(
1374                connection_id,
1375                proto::CallCanceled {
1376                    room_id: room_id.to_proto(),
1377                },
1378            )
1379            .trace_err();
1380    }
1381    update_user_contacts(session.user_id, &session).await?;
1382    Ok(())
1383}
1384
1385async fn update_participant_location(
1386    request: proto::UpdateParticipantLocation,
1387    response: Response<proto::UpdateParticipantLocation>,
1388    session: Session,
1389) -> Result<()> {
1390    let room_id = RoomId::from_proto(request.room_id);
1391    let location = request
1392        .location
1393        .ok_or_else(|| anyhow!("invalid location"))?;
1394
1395    let db = session.db().await;
1396    let room = db
1397        .update_room_participant_location(room_id, session.connection_id, location)
1398        .await?;
1399
1400    room_updated(&room, &session.peer);
1401    response.send(proto::Ack {})?;
1402    Ok(())
1403}
1404
1405async fn share_project(
1406    request: proto::ShareProject,
1407    response: Response<proto::ShareProject>,
1408    session: Session,
1409) -> Result<()> {
1410    let (project_id, room) = &*session
1411        .db()
1412        .await
1413        .share_project(
1414            RoomId::from_proto(request.room_id),
1415            session.connection_id,
1416            &request.worktrees,
1417        )
1418        .await?;
1419    response.send(proto::ShareProjectResponse {
1420        project_id: project_id.to_proto(),
1421    })?;
1422    room_updated(&room, &session.peer);
1423
1424    Ok(())
1425}
1426
1427async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1428    let project_id = ProjectId::from_proto(message.project_id);
1429
1430    let (room, guest_connection_ids) = &*session
1431        .db()
1432        .await
1433        .unshare_project(project_id, session.connection_id)
1434        .await?;
1435
1436    broadcast(
1437        Some(session.connection_id),
1438        guest_connection_ids.iter().copied(),
1439        |conn_id| session.peer.send(conn_id, message.clone()),
1440    );
1441    room_updated(&room, &session.peer);
1442
1443    Ok(())
1444}
1445
1446async fn join_project(
1447    request: proto::JoinProject,
1448    response: Response<proto::JoinProject>,
1449    session: Session,
1450) -> Result<()> {
1451    let project_id = ProjectId::from_proto(request.project_id);
1452    let guest_user_id = session.user_id;
1453
1454    tracing::info!(%project_id, "join project");
1455
1456    let (project, replica_id) = &mut *session
1457        .db()
1458        .await
1459        .join_project(project_id, session.connection_id)
1460        .await?;
1461
1462    let collaborators = project
1463        .collaborators
1464        .iter()
1465        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1466        .map(|collaborator| collaborator.to_proto())
1467        .collect::<Vec<_>>();
1468
1469    let worktrees = project
1470        .worktrees
1471        .iter()
1472        .map(|(id, worktree)| proto::WorktreeMetadata {
1473            id: *id,
1474            root_name: worktree.root_name.clone(),
1475            visible: worktree.visible,
1476            abs_path: worktree.abs_path.clone(),
1477        })
1478        .collect::<Vec<_>>();
1479
1480    for collaborator in &collaborators {
1481        session
1482            .peer
1483            .send(
1484                collaborator.peer_id.unwrap().into(),
1485                proto::AddProjectCollaborator {
1486                    project_id: project_id.to_proto(),
1487                    collaborator: Some(proto::Collaborator {
1488                        peer_id: Some(session.connection_id.into()),
1489                        replica_id: replica_id.0 as u32,
1490                        user_id: guest_user_id.to_proto(),
1491                    }),
1492                },
1493            )
1494            .trace_err();
1495    }
1496
1497    // First, we send the metadata associated with each worktree.
1498    response.send(proto::JoinProjectResponse {
1499        worktrees: worktrees.clone(),
1500        replica_id: replica_id.0 as u32,
1501        collaborators: collaborators.clone(),
1502        language_servers: project.language_servers.clone(),
1503    })?;
1504
1505    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1506        #[cfg(any(test, feature = "test-support"))]
1507        const MAX_CHUNK_SIZE: usize = 2;
1508        #[cfg(not(any(test, feature = "test-support")))]
1509        const MAX_CHUNK_SIZE: usize = 256;
1510
1511        // Stream this worktree's entries.
1512        let message = proto::UpdateWorktree {
1513            project_id: project_id.to_proto(),
1514            worktree_id,
1515            abs_path: worktree.abs_path.clone(),
1516            root_name: worktree.root_name,
1517            updated_entries: worktree.entries,
1518            removed_entries: Default::default(),
1519            scan_id: worktree.scan_id,
1520            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1521            updated_repositories: worktree.repository_entries.into_values().collect(),
1522            removed_repositories: Default::default(),
1523        };
1524        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1525            session.peer.send(session.connection_id, update.clone())?;
1526        }
1527
1528        // Stream this worktree's diagnostics.
1529        for summary in worktree.diagnostic_summaries {
1530            session.peer.send(
1531                session.connection_id,
1532                proto::UpdateDiagnosticSummary {
1533                    project_id: project_id.to_proto(),
1534                    worktree_id: worktree.id,
1535                    summary: Some(summary),
1536                },
1537            )?;
1538        }
1539
1540        for settings_file in worktree.settings_files {
1541            session.peer.send(
1542                session.connection_id,
1543                proto::UpdateWorktreeSettings {
1544                    project_id: project_id.to_proto(),
1545                    worktree_id: worktree.id,
1546                    path: settings_file.path,
1547                    content: Some(settings_file.content),
1548                },
1549            )?;
1550        }
1551    }
1552
1553    for language_server in &project.language_servers {
1554        session.peer.send(
1555            session.connection_id,
1556            proto::UpdateLanguageServer {
1557                project_id: project_id.to_proto(),
1558                language_server_id: language_server.id,
1559                variant: Some(
1560                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1561                        proto::LspDiskBasedDiagnosticsUpdated {},
1562                    ),
1563                ),
1564            },
1565        )?;
1566    }
1567
1568    Ok(())
1569}
1570
1571async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1572    let sender_id = session.connection_id;
1573    let project_id = ProjectId::from_proto(request.project_id);
1574
1575    let (room, project) = &*session
1576        .db()
1577        .await
1578        .leave_project(project_id, sender_id)
1579        .await?;
1580    tracing::info!(
1581        %project_id,
1582        host_user_id = %project.host_user_id,
1583        host_connection_id = %project.host_connection_id,
1584        "leave project"
1585    );
1586
1587    project_left(&project, &session);
1588    room_updated(&room, &session.peer);
1589
1590    Ok(())
1591}
1592
1593async fn update_project(
1594    request: proto::UpdateProject,
1595    response: Response<proto::UpdateProject>,
1596    session: Session,
1597) -> Result<()> {
1598    let project_id = ProjectId::from_proto(request.project_id);
1599    let (room, guest_connection_ids) = &*session
1600        .db()
1601        .await
1602        .update_project(project_id, session.connection_id, &request.worktrees)
1603        .await?;
1604    broadcast(
1605        Some(session.connection_id),
1606        guest_connection_ids.iter().copied(),
1607        |connection_id| {
1608            session
1609                .peer
1610                .forward_send(session.connection_id, connection_id, request.clone())
1611        },
1612    );
1613    room_updated(&room, &session.peer);
1614    response.send(proto::Ack {})?;
1615
1616    Ok(())
1617}
1618
1619async fn update_worktree(
1620    request: proto::UpdateWorktree,
1621    response: Response<proto::UpdateWorktree>,
1622    session: Session,
1623) -> Result<()> {
1624    let guest_connection_ids = session
1625        .db()
1626        .await
1627        .update_worktree(&request, session.connection_id)
1628        .await?;
1629
1630    broadcast(
1631        Some(session.connection_id),
1632        guest_connection_ids.iter().copied(),
1633        |connection_id| {
1634            session
1635                .peer
1636                .forward_send(session.connection_id, connection_id, request.clone())
1637        },
1638    );
1639    response.send(proto::Ack {})?;
1640    Ok(())
1641}
1642
1643async fn update_diagnostic_summary(
1644    message: proto::UpdateDiagnosticSummary,
1645    session: Session,
1646) -> Result<()> {
1647    let guest_connection_ids = session
1648        .db()
1649        .await
1650        .update_diagnostic_summary(&message, session.connection_id)
1651        .await?;
1652
1653    broadcast(
1654        Some(session.connection_id),
1655        guest_connection_ids.iter().copied(),
1656        |connection_id| {
1657            session
1658                .peer
1659                .forward_send(session.connection_id, connection_id, message.clone())
1660        },
1661    );
1662
1663    Ok(())
1664}
1665
1666async fn update_worktree_settings(
1667    message: proto::UpdateWorktreeSettings,
1668    session: Session,
1669) -> Result<()> {
1670    let guest_connection_ids = session
1671        .db()
1672        .await
1673        .update_worktree_settings(&message, session.connection_id)
1674        .await?;
1675
1676    broadcast(
1677        Some(session.connection_id),
1678        guest_connection_ids.iter().copied(),
1679        |connection_id| {
1680            session
1681                .peer
1682                .forward_send(session.connection_id, connection_id, message.clone())
1683        },
1684    );
1685
1686    Ok(())
1687}
1688
1689async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1690    broadcast_project_message(request.project_id, request, session).await
1691}
1692
1693async fn start_language_server(
1694    request: proto::StartLanguageServer,
1695    session: Session,
1696) -> Result<()> {
1697    let guest_connection_ids = session
1698        .db()
1699        .await
1700        .start_language_server(&request, session.connection_id)
1701        .await?;
1702
1703    broadcast(
1704        Some(session.connection_id),
1705        guest_connection_ids.iter().copied(),
1706        |connection_id| {
1707            session
1708                .peer
1709                .forward_send(session.connection_id, connection_id, request.clone())
1710        },
1711    );
1712    Ok(())
1713}
1714
1715async fn update_language_server(
1716    request: proto::UpdateLanguageServer,
1717    session: Session,
1718) -> Result<()> {
1719    session.executor.record_backtrace();
1720    let project_id = ProjectId::from_proto(request.project_id);
1721    let project_connection_ids = session
1722        .db()
1723        .await
1724        .project_connection_ids(project_id, session.connection_id)
1725        .await?;
1726    broadcast(
1727        Some(session.connection_id),
1728        project_connection_ids.iter().copied(),
1729        |connection_id| {
1730            session
1731                .peer
1732                .forward_send(session.connection_id, connection_id, request.clone())
1733        },
1734    );
1735    Ok(())
1736}
1737
1738async fn forward_project_request<T>(
1739    request: T,
1740    response: Response<T>,
1741    session: Session,
1742) -> Result<()>
1743where
1744    T: EntityMessage + RequestMessage,
1745{
1746    session.executor.record_backtrace();
1747    let project_id = ProjectId::from_proto(request.remote_entity_id());
1748    let host_connection_id = {
1749        let collaborators = session
1750            .db()
1751            .await
1752            .project_collaborators(project_id, session.connection_id)
1753            .await?;
1754        collaborators
1755            .iter()
1756            .find(|collaborator| collaborator.is_host)
1757            .ok_or_else(|| anyhow!("host not found"))?
1758            .connection_id
1759    };
1760
1761    let payload = session
1762        .peer
1763        .forward_request(session.connection_id, host_connection_id, request)
1764        .await?;
1765
1766    response.send(payload)?;
1767    Ok(())
1768}
1769
1770async fn create_buffer_for_peer(
1771    request: proto::CreateBufferForPeer,
1772    session: Session,
1773) -> Result<()> {
1774    session.executor.record_backtrace();
1775    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1776    session
1777        .peer
1778        .forward_send(session.connection_id, peer_id.into(), request)?;
1779    Ok(())
1780}
1781
1782async fn update_buffer(
1783    request: proto::UpdateBuffer,
1784    response: Response<proto::UpdateBuffer>,
1785    session: Session,
1786) -> Result<()> {
1787    session.executor.record_backtrace();
1788    let project_id = ProjectId::from_proto(request.project_id);
1789    let mut guest_connection_ids;
1790    let mut host_connection_id = None;
1791    {
1792        let collaborators = session
1793            .db()
1794            .await
1795            .project_collaborators(project_id, session.connection_id)
1796            .await?;
1797        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1798        for collaborator in collaborators.iter() {
1799            if collaborator.is_host {
1800                host_connection_id = Some(collaborator.connection_id);
1801            } else {
1802                guest_connection_ids.push(collaborator.connection_id);
1803            }
1804        }
1805    }
1806    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1807
1808    session.executor.record_backtrace();
1809    broadcast(
1810        Some(session.connection_id),
1811        guest_connection_ids,
1812        |connection_id| {
1813            session
1814                .peer
1815                .forward_send(session.connection_id, connection_id, request.clone())
1816        },
1817    );
1818    if host_connection_id != session.connection_id {
1819        session
1820            .peer
1821            .forward_request(session.connection_id, host_connection_id, request.clone())
1822            .await?;
1823    }
1824
1825    response.send(proto::Ack {})?;
1826    Ok(())
1827}
1828
1829async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1830    let project_id = ProjectId::from_proto(request.project_id);
1831    let project_connection_ids = session
1832        .db()
1833        .await
1834        .project_connection_ids(project_id, session.connection_id)
1835        .await?;
1836
1837    broadcast(
1838        Some(session.connection_id),
1839        project_connection_ids.iter().copied(),
1840        |connection_id| {
1841            session
1842                .peer
1843                .forward_send(session.connection_id, connection_id, request.clone())
1844        },
1845    );
1846    Ok(())
1847}
1848
1849async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1850    let project_id = ProjectId::from_proto(request.project_id);
1851    let project_connection_ids = session
1852        .db()
1853        .await
1854        .project_connection_ids(project_id, session.connection_id)
1855        .await?;
1856    broadcast(
1857        Some(session.connection_id),
1858        project_connection_ids.iter().copied(),
1859        |connection_id| {
1860            session
1861                .peer
1862                .forward_send(session.connection_id, connection_id, request.clone())
1863        },
1864    );
1865    Ok(())
1866}
1867
1868async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1869    broadcast_project_message(request.project_id, request, session).await
1870}
1871
1872async fn broadcast_project_message<T: EnvelopedMessage>(
1873    project_id: u64,
1874    request: T,
1875    session: Session,
1876) -> Result<()> {
1877    let project_id = ProjectId::from_proto(project_id);
1878    let project_connection_ids = session
1879        .db()
1880        .await
1881        .project_connection_ids(project_id, session.connection_id)
1882        .await?;
1883    broadcast(
1884        Some(session.connection_id),
1885        project_connection_ids.iter().copied(),
1886        |connection_id| {
1887            session
1888                .peer
1889                .forward_send(session.connection_id, connection_id, request.clone())
1890        },
1891    );
1892    Ok(())
1893}
1894
1895async fn follow(
1896    request: proto::Follow,
1897    response: Response<proto::Follow>,
1898    session: Session,
1899) -> Result<()> {
1900    let room_id = RoomId::from_proto(request.room_id);
1901    let project_id = request.project_id.map(ProjectId::from_proto);
1902    let leader_id = request
1903        .leader_id
1904        .ok_or_else(|| anyhow!("invalid leader id"))?
1905        .into();
1906    let follower_id = session.connection_id;
1907
1908    session
1909        .db()
1910        .await
1911        .check_room_participants(room_id, leader_id, session.connection_id)
1912        .await?;
1913
1914    let response_payload = session
1915        .peer
1916        .forward_request(session.connection_id, leader_id, request)
1917        .await?;
1918    response.send(response_payload)?;
1919
1920    if let Some(project_id) = project_id {
1921        let room = session
1922            .db()
1923            .await
1924            .follow(room_id, project_id, leader_id, follower_id)
1925            .await?;
1926        room_updated(&room, &session.peer);
1927    }
1928
1929    Ok(())
1930}
1931
1932async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1933    let room_id = RoomId::from_proto(request.room_id);
1934    let project_id = request.project_id.map(ProjectId::from_proto);
1935    let leader_id = request
1936        .leader_id
1937        .ok_or_else(|| anyhow!("invalid leader id"))?
1938        .into();
1939    let follower_id = session.connection_id;
1940
1941    session
1942        .db()
1943        .await
1944        .check_room_participants(room_id, leader_id, session.connection_id)
1945        .await?;
1946
1947    session
1948        .peer
1949        .forward_send(session.connection_id, leader_id, request)?;
1950
1951    if let Some(project_id) = project_id {
1952        let room = session
1953            .db()
1954            .await
1955            .unfollow(room_id, project_id, leader_id, follower_id)
1956            .await?;
1957        room_updated(&room, &session.peer);
1958    }
1959
1960    Ok(())
1961}
1962
1963async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1964    let room_id = RoomId::from_proto(request.room_id);
1965    let database = session.db.lock().await;
1966
1967    let connection_ids = if let Some(project_id) = request.project_id {
1968        let project_id = ProjectId::from_proto(project_id);
1969        database
1970            .project_connection_ids(project_id, session.connection_id)
1971            .await?
1972    } else {
1973        database
1974            .room_connection_ids(room_id, session.connection_id)
1975            .await?
1976    };
1977
1978    // For now, don't send view update messages back to that view's current leader.
1979    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1980        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1981        _ => None,
1982    });
1983
1984    for follower_peer_id in request.follower_ids.iter().copied() {
1985        let follower_connection_id = follower_peer_id.into();
1986        if Some(follower_peer_id) != connection_id_to_omit
1987            && connection_ids.contains(&follower_connection_id)
1988        {
1989            session.peer.forward_send(
1990                session.connection_id,
1991                follower_connection_id,
1992                request.clone(),
1993            )?;
1994        }
1995    }
1996    Ok(())
1997}
1998
1999async fn get_users(
2000    request: proto::GetUsers,
2001    response: Response<proto::GetUsers>,
2002    session: Session,
2003) -> Result<()> {
2004    let user_ids = request
2005        .user_ids
2006        .into_iter()
2007        .map(UserId::from_proto)
2008        .collect();
2009    let users = session
2010        .db()
2011        .await
2012        .get_users_by_ids(user_ids)
2013        .await?
2014        .into_iter()
2015        .map(|user| proto::User {
2016            id: user.id.to_proto(),
2017            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2018            github_login: user.github_login,
2019        })
2020        .collect();
2021    response.send(proto::UsersResponse { users })?;
2022    Ok(())
2023}
2024
2025async fn fuzzy_search_users(
2026    request: proto::FuzzySearchUsers,
2027    response: Response<proto::FuzzySearchUsers>,
2028    session: Session,
2029) -> Result<()> {
2030    let query = request.query;
2031    let users = match query.len() {
2032        0 => vec![],
2033        1 | 2 => session
2034            .db()
2035            .await
2036            .get_user_by_github_login(&query)
2037            .await?
2038            .into_iter()
2039            .collect(),
2040        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2041    };
2042    let users = users
2043        .into_iter()
2044        .filter(|user| user.id != session.user_id)
2045        .map(|user| proto::User {
2046            id: user.id.to_proto(),
2047            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2048            github_login: user.github_login,
2049        })
2050        .collect();
2051    response.send(proto::UsersResponse { users })?;
2052    Ok(())
2053}
2054
2055async fn request_contact(
2056    request: proto::RequestContact,
2057    response: Response<proto::RequestContact>,
2058    session: Session,
2059) -> Result<()> {
2060    let requester_id = session.user_id;
2061    let responder_id = UserId::from_proto(request.responder_id);
2062    if requester_id == responder_id {
2063        return Err(anyhow!("cannot add yourself as a contact"))?;
2064    }
2065
2066    session
2067        .db()
2068        .await
2069        .send_contact_request(requester_id, responder_id)
2070        .await?;
2071
2072    // Update outgoing contact requests of requester
2073    let mut update = proto::UpdateContacts::default();
2074    update.outgoing_requests.push(responder_id.to_proto());
2075    for connection_id in session
2076        .connection_pool()
2077        .await
2078        .user_connection_ids(requester_id)
2079    {
2080        session.peer.send(connection_id, update.clone())?;
2081    }
2082
2083    // Update incoming contact requests of responder
2084    let mut update = proto::UpdateContacts::default();
2085    update
2086        .incoming_requests
2087        .push(proto::IncomingContactRequest {
2088            requester_id: requester_id.to_proto(),
2089            should_notify: true,
2090        });
2091    for connection_id in session
2092        .connection_pool()
2093        .await
2094        .user_connection_ids(responder_id)
2095    {
2096        session.peer.send(connection_id, update.clone())?;
2097    }
2098
2099    response.send(proto::Ack {})?;
2100    Ok(())
2101}
2102
2103async fn respond_to_contact_request(
2104    request: proto::RespondToContactRequest,
2105    response: Response<proto::RespondToContactRequest>,
2106    session: Session,
2107) -> Result<()> {
2108    let responder_id = session.user_id;
2109    let requester_id = UserId::from_proto(request.requester_id);
2110    let db = session.db().await;
2111    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2112        db.dismiss_contact_notification(responder_id, requester_id)
2113            .await?;
2114    } else {
2115        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2116
2117        db.respond_to_contact_request(responder_id, requester_id, accept)
2118            .await?;
2119        let requester_busy = db.is_user_busy(requester_id).await?;
2120        let responder_busy = db.is_user_busy(responder_id).await?;
2121
2122        let pool = session.connection_pool().await;
2123        // Update responder with new contact
2124        let mut update = proto::UpdateContacts::default();
2125        if accept {
2126            update
2127                .contacts
2128                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2129        }
2130        update
2131            .remove_incoming_requests
2132            .push(requester_id.to_proto());
2133        for connection_id in pool.user_connection_ids(responder_id) {
2134            session.peer.send(connection_id, update.clone())?;
2135        }
2136
2137        // Update requester with new contact
2138        let mut update = proto::UpdateContacts::default();
2139        if accept {
2140            update
2141                .contacts
2142                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2143        }
2144        update
2145            .remove_outgoing_requests
2146            .push(responder_id.to_proto());
2147        for connection_id in pool.user_connection_ids(requester_id) {
2148            session.peer.send(connection_id, update.clone())?;
2149        }
2150    }
2151
2152    response.send(proto::Ack {})?;
2153    Ok(())
2154}
2155
2156async fn remove_contact(
2157    request: proto::RemoveContact,
2158    response: Response<proto::RemoveContact>,
2159    session: Session,
2160) -> Result<()> {
2161    let requester_id = session.user_id;
2162    let responder_id = UserId::from_proto(request.user_id);
2163    let db = session.db().await;
2164    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2165
2166    let pool = session.connection_pool().await;
2167    // Update outgoing contact requests of requester
2168    let mut update = proto::UpdateContacts::default();
2169    if contact_accepted {
2170        update.remove_contacts.push(responder_id.to_proto());
2171    } else {
2172        update
2173            .remove_outgoing_requests
2174            .push(responder_id.to_proto());
2175    }
2176    for connection_id in pool.user_connection_ids(requester_id) {
2177        session.peer.send(connection_id, update.clone())?;
2178    }
2179
2180    // Update incoming contact requests of responder
2181    let mut update = proto::UpdateContacts::default();
2182    if contact_accepted {
2183        update.remove_contacts.push(requester_id.to_proto());
2184    } else {
2185        update
2186            .remove_incoming_requests
2187            .push(requester_id.to_proto());
2188    }
2189    for connection_id in pool.user_connection_ids(responder_id) {
2190        session.peer.send(connection_id, update.clone())?;
2191    }
2192
2193    response.send(proto::Ack {})?;
2194    Ok(())
2195}
2196
2197async fn create_channel(
2198    request: proto::CreateChannel,
2199    response: Response<proto::CreateChannel>,
2200    session: Session,
2201) -> Result<()> {
2202    let db = session.db().await;
2203
2204    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2205    let id = db
2206        .create_channel(&request.name, parent_id, session.user_id)
2207        .await?;
2208
2209    let channel = proto::Channel {
2210        id: id.to_proto(),
2211        name: request.name,
2212        visibility: proto::ChannelVisibility::Members as i32,
2213    };
2214
2215    response.send(proto::CreateChannelResponse {
2216        channel: Some(channel.clone()),
2217        parent_id: request.parent_id,
2218    })?;
2219
2220    let Some(parent_id) = parent_id else {
2221        return Ok(());
2222    };
2223
2224    let update = proto::UpdateChannels {
2225        channels: vec![channel],
2226        insert_edge: vec![ChannelEdge {
2227            parent_id: parent_id.to_proto(),
2228            channel_id: id.to_proto(),
2229        }],
2230        ..Default::default()
2231    };
2232
2233    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2234
2235    let connection_pool = session.connection_pool().await;
2236    for user_id in user_ids_to_notify {
2237        for connection_id in connection_pool.user_connection_ids(user_id) {
2238            if user_id == session.user_id {
2239                continue;
2240            }
2241            session.peer.send(connection_id, update.clone())?;
2242        }
2243    }
2244
2245    Ok(())
2246}
2247
2248async fn delete_channel(
2249    request: proto::DeleteChannel,
2250    response: Response<proto::DeleteChannel>,
2251    session: Session,
2252) -> Result<()> {
2253    let db = session.db().await;
2254
2255    let channel_id = request.channel_id;
2256    let (removed_channels, member_ids) = db
2257        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2258        .await?;
2259    response.send(proto::Ack {})?;
2260
2261    // Notify members of removed channels
2262    let mut update = proto::UpdateChannels::default();
2263    update
2264        .delete_channels
2265        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2266
2267    let connection_pool = session.connection_pool().await;
2268    for member_id in member_ids {
2269        for connection_id in connection_pool.user_connection_ids(member_id) {
2270            session.peer.send(connection_id, update.clone())?;
2271        }
2272    }
2273
2274    Ok(())
2275}
2276
2277async fn invite_channel_member(
2278    request: proto::InviteChannelMember,
2279    response: Response<proto::InviteChannelMember>,
2280    session: Session,
2281) -> Result<()> {
2282    let db = session.db().await;
2283    let channel_id = ChannelId::from_proto(request.channel_id);
2284    let invitee_id = UserId::from_proto(request.user_id);
2285    db.invite_channel_member(
2286        channel_id,
2287        invitee_id,
2288        session.user_id,
2289        request.role().into(),
2290    )
2291    .await?;
2292
2293    let channel = db.get_channel(channel_id, session.user_id).await?;
2294
2295    let mut update = proto::UpdateChannels::default();
2296    update.channel_invitations.push(proto::Channel {
2297        id: channel.id.to_proto(),
2298        visibility: channel.visibility.into(),
2299        name: channel.name,
2300    });
2301    for connection_id in session
2302        .connection_pool()
2303        .await
2304        .user_connection_ids(invitee_id)
2305    {
2306        session.peer.send(connection_id, update.clone())?;
2307    }
2308
2309    response.send(proto::Ack {})?;
2310    Ok(())
2311}
2312
2313async fn remove_channel_member(
2314    request: proto::RemoveChannelMember,
2315    response: Response<proto::RemoveChannelMember>,
2316    session: Session,
2317) -> Result<()> {
2318    let db = session.db().await;
2319    let channel_id = ChannelId::from_proto(request.channel_id);
2320    let member_id = UserId::from_proto(request.user_id);
2321
2322    db.remove_channel_member(channel_id, member_id, session.user_id)
2323        .await?;
2324
2325    let mut update = proto::UpdateChannels::default();
2326    update.delete_channels.push(channel_id.to_proto());
2327
2328    for connection_id in session
2329        .connection_pool()
2330        .await
2331        .user_connection_ids(member_id)
2332    {
2333        session.peer.send(connection_id, update.clone())?;
2334    }
2335
2336    response.send(proto::Ack {})?;
2337    Ok(())
2338}
2339
2340async fn set_channel_visibility(
2341    request: proto::SetChannelVisibility,
2342    response: Response<proto::SetChannelVisibility>,
2343    session: Session,
2344) -> Result<()> {
2345    let db = session.db().await;
2346    let channel_id = ChannelId::from_proto(request.channel_id);
2347    let visibility = request.visibility().into();
2348
2349    let channel = db
2350        .set_channel_visibility(channel_id, visibility, session.user_id)
2351        .await?;
2352
2353    let mut update = proto::UpdateChannels::default();
2354    update.channels.push(proto::Channel {
2355        id: channel.id.to_proto(),
2356        name: channel.name,
2357        visibility: channel.visibility.into(),
2358    });
2359
2360    let member_ids = db.get_channel_members(channel_id).await?;
2361
2362    let connection_pool = session.connection_pool().await;
2363    for member_id in member_ids {
2364        for connection_id in connection_pool.user_connection_ids(member_id) {
2365            session.peer.send(connection_id, update.clone())?;
2366        }
2367    }
2368
2369    response.send(proto::Ack {})?;
2370    Ok(())
2371}
2372
2373async fn set_channel_member_role(
2374    request: proto::SetChannelMemberRole,
2375    response: Response<proto::SetChannelMemberRole>,
2376    session: Session,
2377) -> Result<()> {
2378    let db = session.db().await;
2379    let channel_id = ChannelId::from_proto(request.channel_id);
2380    let member_id = UserId::from_proto(request.user_id);
2381    let channel_member = db
2382        .set_channel_member_role(
2383            channel_id,
2384            session.user_id,
2385            member_id,
2386            request.role().into(),
2387        )
2388        .await?;
2389
2390    let channel = db.get_channel(channel_id, session.user_id).await?;
2391
2392    let mut update = proto::UpdateChannels::default();
2393    if channel_member.accepted {
2394        update.channel_permissions.push(proto::ChannelPermission {
2395            channel_id: channel.id.to_proto(),
2396            role: request.role,
2397        });
2398    }
2399
2400    for connection_id in session
2401        .connection_pool()
2402        .await
2403        .user_connection_ids(member_id)
2404    {
2405        session.peer.send(connection_id, update.clone())?;
2406    }
2407
2408    response.send(proto::Ack {})?;
2409    Ok(())
2410}
2411
2412async fn rename_channel(
2413    request: proto::RenameChannel,
2414    response: Response<proto::RenameChannel>,
2415    session: Session,
2416) -> Result<()> {
2417    let db = session.db().await;
2418    let channel_id = ChannelId::from_proto(request.channel_id);
2419    let channel = db
2420        .rename_channel(channel_id, session.user_id, &request.name)
2421        .await?;
2422
2423    let channel = proto::Channel {
2424        id: channel.id.to_proto(),
2425        name: channel.name,
2426        visibility: channel.visibility.into(),
2427    };
2428    response.send(proto::RenameChannelResponse {
2429        channel: Some(channel.clone()),
2430    })?;
2431    let mut update = proto::UpdateChannels::default();
2432    update.channels.push(channel);
2433
2434    let member_ids = db.get_channel_members(channel_id).await?;
2435
2436    let connection_pool = session.connection_pool().await;
2437    for member_id in member_ids {
2438        for connection_id in connection_pool.user_connection_ids(member_id) {
2439            session.peer.send(connection_id, update.clone())?;
2440        }
2441    }
2442
2443    Ok(())
2444}
2445
2446async fn link_channel(
2447    request: proto::LinkChannel,
2448    response: Response<proto::LinkChannel>,
2449    session: Session,
2450) -> Result<()> {
2451    let db = session.db().await;
2452    let channel_id = ChannelId::from_proto(request.channel_id);
2453    let to = ChannelId::from_proto(request.to);
2454    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2455
2456    let members = db.get_channel_members(to).await?;
2457    let connection_pool = session.connection_pool().await;
2458    let update = proto::UpdateChannels {
2459        channels: channels_to_send
2460            .channels
2461            .into_iter()
2462            .map(|channel| proto::Channel {
2463                id: channel.id.to_proto(),
2464                visibility: channel.visibility.into(),
2465                name: channel.name,
2466            })
2467            .collect(),
2468        insert_edge: channels_to_send.edges,
2469        ..Default::default()
2470    };
2471    for member_id in members {
2472        for connection_id in connection_pool.user_connection_ids(member_id) {
2473            session.peer.send(connection_id, update.clone())?;
2474        }
2475    }
2476
2477    response.send(Ack {})?;
2478
2479    Ok(())
2480}
2481
2482async fn unlink_channel(
2483    request: proto::UnlinkChannel,
2484    response: Response<proto::UnlinkChannel>,
2485    session: Session,
2486) -> Result<()> {
2487    let db = session.db().await;
2488    let channel_id = ChannelId::from_proto(request.channel_id);
2489    let from = ChannelId::from_proto(request.from);
2490
2491    db.unlink_channel(session.user_id, channel_id, from).await?;
2492
2493    let members = db.get_channel_members(from).await?;
2494
2495    let update = proto::UpdateChannels {
2496        delete_edge: vec![proto::ChannelEdge {
2497            channel_id: channel_id.to_proto(),
2498            parent_id: from.to_proto(),
2499        }],
2500        ..Default::default()
2501    };
2502    let connection_pool = session.connection_pool().await;
2503    for member_id in members {
2504        for connection_id in connection_pool.user_connection_ids(member_id) {
2505            session.peer.send(connection_id, update.clone())?;
2506        }
2507    }
2508
2509    response.send(Ack {})?;
2510
2511    Ok(())
2512}
2513
2514async fn move_channel(
2515    request: proto::MoveChannel,
2516    response: Response<proto::MoveChannel>,
2517    session: Session,
2518) -> Result<()> {
2519    let db = session.db().await;
2520    let channel_id = ChannelId::from_proto(request.channel_id);
2521    let from_parent = ChannelId::from_proto(request.from);
2522    let to = ChannelId::from_proto(request.to);
2523
2524    let channels_to_send = db
2525        .move_channel(session.user_id, channel_id, from_parent, to)
2526        .await?;
2527
2528    if channels_to_send.is_empty() {
2529        response.send(Ack {})?;
2530        return Ok(());
2531    }
2532
2533    let members_from = db.get_channel_members(from_parent).await?;
2534    let members_to = db.get_channel_members(to).await?;
2535
2536    let update = proto::UpdateChannels {
2537        delete_edge: vec![proto::ChannelEdge {
2538            channel_id: channel_id.to_proto(),
2539            parent_id: from_parent.to_proto(),
2540        }],
2541        ..Default::default()
2542    };
2543    let connection_pool = session.connection_pool().await;
2544    for member_id in members_from {
2545        for connection_id in connection_pool.user_connection_ids(member_id) {
2546            session.peer.send(connection_id, update.clone())?;
2547        }
2548    }
2549
2550    let update = proto::UpdateChannels {
2551        channels: channels_to_send
2552            .channels
2553            .into_iter()
2554            .map(|channel| proto::Channel {
2555                id: channel.id.to_proto(),
2556                visibility: channel.visibility.into(),
2557                name: channel.name,
2558            })
2559            .collect(),
2560        insert_edge: channels_to_send.edges,
2561        ..Default::default()
2562    };
2563    for member_id in members_to {
2564        for connection_id in connection_pool.user_connection_ids(member_id) {
2565            session.peer.send(connection_id, update.clone())?;
2566        }
2567    }
2568
2569    response.send(Ack {})?;
2570
2571    Ok(())
2572}
2573
2574async fn get_channel_members(
2575    request: proto::GetChannelMembers,
2576    response: Response<proto::GetChannelMembers>,
2577    session: Session,
2578) -> Result<()> {
2579    let db = session.db().await;
2580    let channel_id = ChannelId::from_proto(request.channel_id);
2581    let members = db
2582        .get_channel_participant_details(channel_id, session.user_id)
2583        .await?;
2584    response.send(proto::GetChannelMembersResponse { members })?;
2585    Ok(())
2586}
2587
2588async fn respond_to_channel_invite(
2589    request: proto::RespondToChannelInvite,
2590    response: Response<proto::RespondToChannelInvite>,
2591    session: Session,
2592) -> Result<()> {
2593    let db = session.db().await;
2594    let channel_id = ChannelId::from_proto(request.channel_id);
2595    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2596        .await?;
2597
2598    if request.accept {
2599        channel_membership_updated(db, channel_id, &session).await?;
2600    } else {
2601        let mut update = proto::UpdateChannels::default();
2602        update
2603            .remove_channel_invitations
2604            .push(channel_id.to_proto());
2605        session.peer.send(session.connection_id, update)?;
2606    }
2607    response.send(proto::Ack {})?;
2608
2609    Ok(())
2610}
2611
2612async fn channel_membership_updated(
2613    db: tokio::sync::MutexGuard<'_, DbHandle>,
2614    channel_id: ChannelId,
2615    session: &Session,
2616) -> Result<(), crate::Error> {
2617    let mut update = proto::UpdateChannels::default();
2618    update
2619        .remove_channel_invitations
2620        .push(channel_id.to_proto());
2621
2622    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2623    update.channels.extend(
2624        result
2625            .channels
2626            .channels
2627            .into_iter()
2628            .map(|channel| proto::Channel {
2629                id: channel.id.to_proto(),
2630                visibility: channel.visibility.into(),
2631                name: channel.name,
2632            }),
2633    );
2634    update.unseen_channel_messages = result.channel_messages;
2635    update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2636    update.insert_edge = result.channels.edges;
2637    update
2638        .channel_participants
2639        .extend(
2640            result
2641                .channel_participants
2642                .into_iter()
2643                .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2644                    channel_id: channel_id.to_proto(),
2645                    participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2646                }),
2647        );
2648    update
2649        .channel_permissions
2650        .extend(
2651            result
2652                .channels_with_admin_privileges
2653                .into_iter()
2654                .map(|channel_id| proto::ChannelPermission {
2655                    channel_id: channel_id.to_proto(),
2656                    role: proto::ChannelRole::Admin.into(),
2657                }),
2658        );
2659    session.peer.send(session.connection_id, update)?;
2660    Ok(())
2661}
2662
2663async fn join_channel(
2664    request: proto::JoinChannel,
2665    response: Response<proto::JoinChannel>,
2666    session: Session,
2667) -> Result<()> {
2668    let channel_id = ChannelId::from_proto(request.channel_id);
2669    join_channel_internal(channel_id, Box::new(response), session).await
2670}
2671
2672trait JoinChannelInternalResponse {
2673    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2674}
2675impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2676    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2677        Response::<proto::JoinChannel>::send(self, result)
2678    }
2679}
2680impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2681    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2682        Response::<proto::JoinRoom>::send(self, result)
2683    }
2684}
2685
2686async fn join_channel_internal(
2687    channel_id: ChannelId,
2688    response: Box<impl JoinChannelInternalResponse>,
2689    session: Session,
2690) -> Result<()> {
2691    let joined_room = {
2692        leave_room_for_session(&session).await?;
2693        let db = session.db().await;
2694
2695        let (joined_room, joined_channel) = db
2696            .join_channel(
2697                channel_id,
2698                session.user_id,
2699                session.connection_id,
2700                RELEASE_CHANNEL_NAME.as_str(),
2701            )
2702            .await?;
2703
2704        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2705            let token = live_kit
2706                .room_token(
2707                    &joined_room.room.live_kit_room,
2708                    &session.user_id.to_string(),
2709                )
2710                .trace_err()?;
2711
2712            Some(LiveKitConnectionInfo {
2713                server_url: live_kit.url().into(),
2714                token,
2715            })
2716        });
2717
2718        response.send(proto::JoinRoomResponse {
2719            room: Some(joined_room.room.clone()),
2720            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2721            live_kit_connection_info,
2722        })?;
2723
2724        if let Some(joined_channel) = joined_channel {
2725            channel_membership_updated(db, joined_channel, &session).await?
2726        }
2727
2728        room_updated(&joined_room.room, &session.peer);
2729
2730        joined_room
2731    };
2732
2733    channel_updated(
2734        channel_id,
2735        &joined_room.room,
2736        &joined_room.channel_members,
2737        &session.peer,
2738        &*session.connection_pool().await,
2739    );
2740
2741    update_user_contacts(session.user_id, &session).await?;
2742    Ok(())
2743}
2744
2745async fn join_channel_buffer(
2746    request: proto::JoinChannelBuffer,
2747    response: Response<proto::JoinChannelBuffer>,
2748    session: Session,
2749) -> Result<()> {
2750    let db = session.db().await;
2751    let channel_id = ChannelId::from_proto(request.channel_id);
2752
2753    let open_response = db
2754        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2755        .await?;
2756
2757    let collaborators = open_response.collaborators.clone();
2758    response.send(open_response)?;
2759
2760    let update = UpdateChannelBufferCollaborators {
2761        channel_id: channel_id.to_proto(),
2762        collaborators: collaborators.clone(),
2763    };
2764    channel_buffer_updated(
2765        session.connection_id,
2766        collaborators
2767            .iter()
2768            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2769        &update,
2770        &session.peer,
2771    );
2772
2773    Ok(())
2774}
2775
2776async fn update_channel_buffer(
2777    request: proto::UpdateChannelBuffer,
2778    session: Session,
2779) -> Result<()> {
2780    let db = session.db().await;
2781    let channel_id = ChannelId::from_proto(request.channel_id);
2782
2783    let (collaborators, non_collaborators, epoch, version) = db
2784        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2785        .await?;
2786
2787    channel_buffer_updated(
2788        session.connection_id,
2789        collaborators,
2790        &proto::UpdateChannelBuffer {
2791            channel_id: channel_id.to_proto(),
2792            operations: request.operations,
2793        },
2794        &session.peer,
2795    );
2796
2797    let pool = &*session.connection_pool().await;
2798
2799    broadcast(
2800        None,
2801        non_collaborators
2802            .iter()
2803            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2804        |peer_id| {
2805            session.peer.send(
2806                peer_id.into(),
2807                proto::UpdateChannels {
2808                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2809                        channel_id: channel_id.to_proto(),
2810                        epoch: epoch as u64,
2811                        version: version.clone(),
2812                    }],
2813                    ..Default::default()
2814                },
2815            )
2816        },
2817    );
2818
2819    Ok(())
2820}
2821
2822async fn rejoin_channel_buffers(
2823    request: proto::RejoinChannelBuffers,
2824    response: Response<proto::RejoinChannelBuffers>,
2825    session: Session,
2826) -> Result<()> {
2827    let db = session.db().await;
2828    let buffers = db
2829        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2830        .await?;
2831
2832    for rejoined_buffer in &buffers {
2833        let collaborators_to_notify = rejoined_buffer
2834            .buffer
2835            .collaborators
2836            .iter()
2837            .filter_map(|c| Some(c.peer_id?.into()));
2838        channel_buffer_updated(
2839            session.connection_id,
2840            collaborators_to_notify,
2841            &proto::UpdateChannelBufferCollaborators {
2842                channel_id: rejoined_buffer.buffer.channel_id,
2843                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2844            },
2845            &session.peer,
2846        );
2847    }
2848
2849    response.send(proto::RejoinChannelBuffersResponse {
2850        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2851    })?;
2852
2853    Ok(())
2854}
2855
2856async fn leave_channel_buffer(
2857    request: proto::LeaveChannelBuffer,
2858    response: Response<proto::LeaveChannelBuffer>,
2859    session: Session,
2860) -> Result<()> {
2861    let db = session.db().await;
2862    let channel_id = ChannelId::from_proto(request.channel_id);
2863
2864    let left_buffer = db
2865        .leave_channel_buffer(channel_id, session.connection_id)
2866        .await?;
2867
2868    response.send(Ack {})?;
2869
2870    channel_buffer_updated(
2871        session.connection_id,
2872        left_buffer.connections,
2873        &proto::UpdateChannelBufferCollaborators {
2874            channel_id: channel_id.to_proto(),
2875            collaborators: left_buffer.collaborators,
2876        },
2877        &session.peer,
2878    );
2879
2880    Ok(())
2881}
2882
2883fn channel_buffer_updated<T: EnvelopedMessage>(
2884    sender_id: ConnectionId,
2885    collaborators: impl IntoIterator<Item = ConnectionId>,
2886    message: &T,
2887    peer: &Peer,
2888) {
2889    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2890        peer.send(peer_id.into(), message.clone())
2891    });
2892}
2893
2894async fn send_channel_message(
2895    request: proto::SendChannelMessage,
2896    response: Response<proto::SendChannelMessage>,
2897    session: Session,
2898) -> Result<()> {
2899    // Validate the message body.
2900    let body = request.body.trim().to_string();
2901    if body.len() > MAX_MESSAGE_LEN {
2902        return Err(anyhow!("message is too long"))?;
2903    }
2904    if body.is_empty() {
2905        return Err(anyhow!("message can't be blank"))?;
2906    }
2907
2908    let timestamp = OffsetDateTime::now_utc();
2909    let nonce = request
2910        .nonce
2911        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2912
2913    let channel_id = ChannelId::from_proto(request.channel_id);
2914    let (message_id, connection_ids, non_participants) = session
2915        .db()
2916        .await
2917        .create_channel_message(
2918            channel_id,
2919            session.user_id,
2920            &body,
2921            timestamp,
2922            nonce.clone().into(),
2923        )
2924        .await?;
2925    let message = proto::ChannelMessage {
2926        sender_id: session.user_id.to_proto(),
2927        id: message_id.to_proto(),
2928        body,
2929        timestamp: timestamp.unix_timestamp() as u64,
2930        nonce: Some(nonce),
2931    };
2932    broadcast(Some(session.connection_id), connection_ids, |connection| {
2933        session.peer.send(
2934            connection,
2935            proto::ChannelMessageSent {
2936                channel_id: channel_id.to_proto(),
2937                message: Some(message.clone()),
2938            },
2939        )
2940    });
2941    response.send(proto::SendChannelMessageResponse {
2942        message: Some(message),
2943    })?;
2944
2945    let pool = &*session.connection_pool().await;
2946    broadcast(
2947        None,
2948        non_participants
2949            .iter()
2950            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2951        |peer_id| {
2952            session.peer.send(
2953                peer_id.into(),
2954                proto::UpdateChannels {
2955                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2956                        channel_id: channel_id.to_proto(),
2957                        message_id: message_id.to_proto(),
2958                    }],
2959                    ..Default::default()
2960                },
2961            )
2962        },
2963    );
2964
2965    Ok(())
2966}
2967
2968async fn remove_channel_message(
2969    request: proto::RemoveChannelMessage,
2970    response: Response<proto::RemoveChannelMessage>,
2971    session: Session,
2972) -> Result<()> {
2973    let channel_id = ChannelId::from_proto(request.channel_id);
2974    let message_id = MessageId::from_proto(request.message_id);
2975    let connection_ids = session
2976        .db()
2977        .await
2978        .remove_channel_message(channel_id, message_id, session.user_id)
2979        .await?;
2980    broadcast(Some(session.connection_id), connection_ids, |connection| {
2981        session.peer.send(connection, request.clone())
2982    });
2983    response.send(proto::Ack {})?;
2984    Ok(())
2985}
2986
2987async fn acknowledge_channel_message(
2988    request: proto::AckChannelMessage,
2989    session: Session,
2990) -> Result<()> {
2991    let channel_id = ChannelId::from_proto(request.channel_id);
2992    let message_id = MessageId::from_proto(request.message_id);
2993    session
2994        .db()
2995        .await
2996        .observe_channel_message(channel_id, session.user_id, message_id)
2997        .await?;
2998    Ok(())
2999}
3000
3001async fn acknowledge_buffer_version(
3002    request: proto::AckBufferOperation,
3003    session: Session,
3004) -> Result<()> {
3005    let buffer_id = BufferId::from_proto(request.buffer_id);
3006    session
3007        .db()
3008        .await
3009        .observe_buffer_version(
3010            buffer_id,
3011            session.user_id,
3012            request.epoch as i32,
3013            &request.version,
3014        )
3015        .await?;
3016    Ok(())
3017}
3018
3019async fn join_channel_chat(
3020    request: proto::JoinChannelChat,
3021    response: Response<proto::JoinChannelChat>,
3022    session: Session,
3023) -> Result<()> {
3024    let channel_id = ChannelId::from_proto(request.channel_id);
3025
3026    let db = session.db().await;
3027    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3028        .await?;
3029    let messages = db
3030        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3031        .await?;
3032    response.send(proto::JoinChannelChatResponse {
3033        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3034        messages,
3035    })?;
3036    Ok(())
3037}
3038
3039async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3040    let channel_id = ChannelId::from_proto(request.channel_id);
3041    session
3042        .db()
3043        .await
3044        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3045        .await?;
3046    Ok(())
3047}
3048
3049async fn get_channel_messages(
3050    request: proto::GetChannelMessages,
3051    response: Response<proto::GetChannelMessages>,
3052    session: Session,
3053) -> Result<()> {
3054    let channel_id = ChannelId::from_proto(request.channel_id);
3055    let messages = session
3056        .db()
3057        .await
3058        .get_channel_messages(
3059            channel_id,
3060            session.user_id,
3061            MESSAGE_COUNT_PER_PAGE,
3062            Some(MessageId::from_proto(request.before_message_id)),
3063        )
3064        .await?;
3065    response.send(proto::GetChannelMessagesResponse {
3066        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3067        messages,
3068    })?;
3069    Ok(())
3070}
3071
3072async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3073    let project_id = ProjectId::from_proto(request.project_id);
3074    let project_connection_ids = session
3075        .db()
3076        .await
3077        .project_connection_ids(project_id, session.connection_id)
3078        .await?;
3079    broadcast(
3080        Some(session.connection_id),
3081        project_connection_ids.iter().copied(),
3082        |connection_id| {
3083            session
3084                .peer
3085                .forward_send(session.connection_id, connection_id, request.clone())
3086        },
3087    );
3088    Ok(())
3089}
3090
3091async fn get_private_user_info(
3092    _request: proto::GetPrivateUserInfo,
3093    response: Response<proto::GetPrivateUserInfo>,
3094    session: Session,
3095) -> Result<()> {
3096    let db = session.db().await;
3097
3098    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3099    let user = db
3100        .get_user_by_id(session.user_id)
3101        .await?
3102        .ok_or_else(|| anyhow!("user not found"))?;
3103    let flags = db.get_user_flags(session.user_id).await?;
3104
3105    response.send(proto::GetPrivateUserInfoResponse {
3106        metrics_id,
3107        staff: user.admin,
3108        flags,
3109    })?;
3110    Ok(())
3111}
3112
3113fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3114    match message {
3115        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3116        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3117        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3118        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3119        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3120            code: frame.code.into(),
3121            reason: frame.reason,
3122        })),
3123    }
3124}
3125
3126fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3127    match message {
3128        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3129        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3130        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3131        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3132        AxumMessage::Close(frame) => {
3133            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3134                code: frame.code.into(),
3135                reason: frame.reason,
3136            }))
3137        }
3138    }
3139}
3140
3141fn build_initial_channels_update(
3142    channels: ChannelsForUser,
3143    channel_invites: Vec<db::Channel>,
3144) -> proto::UpdateChannels {
3145    let mut update = proto::UpdateChannels::default();
3146
3147    for channel in channels.channels.channels {
3148        update.channels.push(proto::Channel {
3149            id: channel.id.to_proto(),
3150            name: channel.name,
3151            visibility: channel.visibility.into(),
3152        });
3153    }
3154
3155    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3156    update.unseen_channel_messages = channels.channel_messages;
3157    update.insert_edge = channels.channels.edges;
3158
3159    for (channel_id, participants) in channels.channel_participants {
3160        update
3161            .channel_participants
3162            .push(proto::ChannelParticipants {
3163                channel_id: channel_id.to_proto(),
3164                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3165            });
3166    }
3167
3168    update
3169        .channel_permissions
3170        .extend(
3171            channels
3172                .channels_with_admin_privileges
3173                .into_iter()
3174                .map(|id| proto::ChannelPermission {
3175                    channel_id: id.to_proto(),
3176                    role: proto::ChannelRole::Admin.into(),
3177                }),
3178        );
3179
3180    for channel in channel_invites {
3181        update.channel_invitations.push(proto::Channel {
3182            id: channel.id.to_proto(),
3183            name: channel.name,
3184            // TODO: Visibility
3185            visibility: ChannelVisibility::Public.into(),
3186        });
3187    }
3188
3189    update
3190}
3191
3192fn build_initial_contacts_update(
3193    contacts: Vec<db::Contact>,
3194    pool: &ConnectionPool,
3195) -> proto::UpdateContacts {
3196    let mut update = proto::UpdateContacts::default();
3197
3198    for contact in contacts {
3199        match contact {
3200            db::Contact::Accepted {
3201                user_id,
3202                should_notify,
3203                busy,
3204            } => {
3205                update
3206                    .contacts
3207                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3208            }
3209            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3210            db::Contact::Incoming {
3211                user_id,
3212                should_notify,
3213            } => update
3214                .incoming_requests
3215                .push(proto::IncomingContactRequest {
3216                    requester_id: user_id.to_proto(),
3217                    should_notify,
3218                }),
3219        }
3220    }
3221
3222    update
3223}
3224
3225fn contact_for_user(
3226    user_id: UserId,
3227    should_notify: bool,
3228    busy: bool,
3229    pool: &ConnectionPool,
3230) -> proto::Contact {
3231    proto::Contact {
3232        user_id: user_id.to_proto(),
3233        online: pool.is_user_online(user_id),
3234        busy,
3235        should_notify,
3236    }
3237}
3238
3239fn room_updated(room: &proto::Room, peer: &Peer) {
3240    broadcast(
3241        None,
3242        room.participants
3243            .iter()
3244            .filter_map(|participant| Some(participant.peer_id?.into())),
3245        |peer_id| {
3246            peer.send(
3247                peer_id.into(),
3248                proto::RoomUpdated {
3249                    room: Some(room.clone()),
3250                },
3251            )
3252        },
3253    );
3254}
3255
3256fn channel_updated(
3257    channel_id: ChannelId,
3258    room: &proto::Room,
3259    channel_members: &[UserId],
3260    peer: &Peer,
3261    pool: &ConnectionPool,
3262) {
3263    let participants = room
3264        .participants
3265        .iter()
3266        .map(|p| p.user_id)
3267        .collect::<Vec<_>>();
3268
3269    broadcast(
3270        None,
3271        channel_members
3272            .iter()
3273            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3274        |peer_id| {
3275            peer.send(
3276                peer_id.into(),
3277                proto::UpdateChannels {
3278                    channel_participants: vec![proto::ChannelParticipants {
3279                        channel_id: channel_id.to_proto(),
3280                        participant_user_ids: participants.clone(),
3281                    }],
3282                    ..Default::default()
3283                },
3284            )
3285        },
3286    );
3287}
3288
3289async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3290    let db = session.db().await;
3291
3292    let contacts = db.get_contacts(user_id).await?;
3293    let busy = db.is_user_busy(user_id).await?;
3294
3295    let pool = session.connection_pool().await;
3296    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3297    for contact in contacts {
3298        if let db::Contact::Accepted {
3299            user_id: contact_user_id,
3300            ..
3301        } = contact
3302        {
3303            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3304                session
3305                    .peer
3306                    .send(
3307                        contact_conn_id,
3308                        proto::UpdateContacts {
3309                            contacts: vec![updated_contact.clone()],
3310                            remove_contacts: Default::default(),
3311                            incoming_requests: Default::default(),
3312                            remove_incoming_requests: Default::default(),
3313                            outgoing_requests: Default::default(),
3314                            remove_outgoing_requests: Default::default(),
3315                        },
3316                    )
3317                    .trace_err();
3318            }
3319        }
3320    }
3321    Ok(())
3322}
3323
3324async fn leave_room_for_session(session: &Session) -> Result<()> {
3325    let mut contacts_to_update = HashSet::default();
3326
3327    let room_id;
3328    let canceled_calls_to_user_ids;
3329    let live_kit_room;
3330    let delete_live_kit_room;
3331    let room;
3332    let channel_members;
3333    let channel_id;
3334
3335    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3336        contacts_to_update.insert(session.user_id);
3337
3338        for project in left_room.left_projects.values() {
3339            project_left(project, session);
3340        }
3341
3342        room_id = RoomId::from_proto(left_room.room.id);
3343        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3344        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3345        delete_live_kit_room = left_room.deleted;
3346        room = mem::take(&mut left_room.room);
3347        channel_members = mem::take(&mut left_room.channel_members);
3348        channel_id = left_room.channel_id;
3349
3350        room_updated(&room, &session.peer);
3351    } else {
3352        return Ok(());
3353    }
3354
3355    if let Some(channel_id) = channel_id {
3356        channel_updated(
3357            channel_id,
3358            &room,
3359            &channel_members,
3360            &session.peer,
3361            &*session.connection_pool().await,
3362        );
3363    }
3364
3365    {
3366        let pool = session.connection_pool().await;
3367        for canceled_user_id in canceled_calls_to_user_ids {
3368            for connection_id in pool.user_connection_ids(canceled_user_id) {
3369                session
3370                    .peer
3371                    .send(
3372                        connection_id,
3373                        proto::CallCanceled {
3374                            room_id: room_id.to_proto(),
3375                        },
3376                    )
3377                    .trace_err();
3378            }
3379            contacts_to_update.insert(canceled_user_id);
3380        }
3381    }
3382
3383    for contact_user_id in contacts_to_update {
3384        update_user_contacts(contact_user_id, &session).await?;
3385    }
3386
3387    if let Some(live_kit) = session.live_kit_client.as_ref() {
3388        live_kit
3389            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3390            .await
3391            .trace_err();
3392
3393        if delete_live_kit_room {
3394            live_kit.delete_room(live_kit_room).await.trace_err();
3395        }
3396    }
3397
3398    Ok(())
3399}
3400
3401async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3402    let left_channel_buffers = session
3403        .db()
3404        .await
3405        .leave_channel_buffers(session.connection_id)
3406        .await?;
3407
3408    for left_buffer in left_channel_buffers {
3409        channel_buffer_updated(
3410            session.connection_id,
3411            left_buffer.connections,
3412            &proto::UpdateChannelBufferCollaborators {
3413                channel_id: left_buffer.channel_id.to_proto(),
3414                collaborators: left_buffer.collaborators,
3415            },
3416            &session.peer,
3417        );
3418    }
3419
3420    Ok(())
3421}
3422
3423fn project_left(project: &db::LeftProject, session: &Session) {
3424    for connection_id in &project.connection_ids {
3425        if project.host_user_id == session.user_id {
3426            session
3427                .peer
3428                .send(
3429                    *connection_id,
3430                    proto::UnshareProject {
3431                        project_id: project.id.to_proto(),
3432                    },
3433                )
3434                .trace_err();
3435        } else {
3436            session
3437                .peer
3438                .send(
3439                    *connection_id,
3440                    proto::RemoveProjectCollaborator {
3441                        project_id: project.id.to_proto(),
3442                        peer_id: Some(session.connection_id.into()),
3443                    },
3444                )
3445                .trace_err();
3446        }
3447    }
3448}
3449
3450pub trait ResultExt {
3451    type Ok;
3452
3453    fn trace_err(self) -> Option<Self::Ok>;
3454}
3455
3456impl<T, E> ResultExt for Result<T, E>
3457where
3458    E: std::fmt::Debug,
3459{
3460    type Ok = T;
3461
3462    fn trace_err(self) -> Option<T> {
3463        match self {
3464            Ok(value) => Some(value),
3465            Err(error) => {
3466                tracing::error!("{:?}", error);
3467                None
3468            }
3469        }
3470    }
3471}