rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User,
   7        UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66
  67pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  68pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  69
  70const MESSAGE_COUNT_PER_PAGE: usize = 100;
  71const MAX_MESSAGE_LEN: usize = 1024;
  72
  73lazy_static! {
  74    static ref METRIC_CONNECTIONS: IntGauge =
  75        register_int_gauge!("connections", "number of connections").unwrap();
  76    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  77        "shared_projects",
  78        "number of open projects with one or more guests"
  79    )
  80    .unwrap();
  81}
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100#[derive(Clone)]
 101struct Session {
 102    user_id: UserId,
 103    connection_id: ConnectionId,
 104    db: Arc<tokio::sync::Mutex<DbHandle>>,
 105    peer: Arc<Peer>,
 106    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 107    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 108    executor: Executor,
 109}
 110
 111impl Session {
 112    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        let guard = self.db.lock().await;
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        guard
 119    }
 120
 121    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 122        #[cfg(test)]
 123        tokio::task::yield_now().await;
 124        let guard = self.connection_pool.lock();
 125        ConnectionPoolGuard {
 126            guard,
 127            _not_send: PhantomData,
 128        }
 129    }
 130}
 131
 132impl fmt::Debug for Session {
 133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 134        f.debug_struct("Session")
 135            .field("user_id", &self.user_id)
 136            .field("connection_id", &self.connection_id)
 137            .finish()
 138    }
 139}
 140
 141struct DbHandle(Arc<Database>);
 142
 143impl Deref for DbHandle {
 144    type Target = Database;
 145
 146    fn deref(&self) -> &Self::Target {
 147        self.0.as_ref()
 148    }
 149}
 150
 151pub struct Server {
 152    id: parking_lot::Mutex<ServerId>,
 153    peer: Arc<Peer>,
 154    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 155    app_state: Arc<AppState>,
 156    executor: Executor,
 157    handlers: HashMap<TypeId, MessageHandler>,
 158    teardown: watch::Sender<()>,
 159}
 160
 161pub(crate) struct ConnectionPoolGuard<'a> {
 162    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 163    _not_send: PhantomData<Rc<()>>,
 164}
 165
 166#[derive(Serialize)]
 167pub struct ServerSnapshot<'a> {
 168    peer: &'a Peer,
 169    #[serde(serialize_with = "serialize_deref")]
 170    connection_pool: ConnectionPoolGuard<'a>,
 171}
 172
 173pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 174where
 175    S: Serializer,
 176    T: Deref<Target = U>,
 177    U: Serialize,
 178{
 179    Serialize::serialize(value.deref(), serializer)
 180}
 181
 182impl Server {
 183    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 184        let mut server = Self {
 185            id: parking_lot::Mutex::new(id),
 186            peer: Peer::new(id.0 as u32),
 187            app_state,
 188            executor,
 189            connection_pool: Default::default(),
 190            handlers: Default::default(),
 191            teardown: watch::channel(()).0,
 192        };
 193
 194        server
 195            .add_request_handler(ping)
 196            .add_request_handler(create_room)
 197            .add_request_handler(join_room)
 198            .add_request_handler(rejoin_room)
 199            .add_request_handler(leave_room)
 200            .add_request_handler(call)
 201            .add_request_handler(cancel_call)
 202            .add_message_handler(decline_call)
 203            .add_request_handler(update_participant_location)
 204            .add_request_handler(share_project)
 205            .add_message_handler(unshare_project)
 206            .add_request_handler(join_project)
 207            .add_message_handler(leave_project)
 208            .add_request_handler(update_project)
 209            .add_request_handler(update_worktree)
 210            .add_message_handler(start_language_server)
 211            .add_message_handler(update_language_server)
 212            .add_message_handler(update_diagnostic_summary)
 213            .add_message_handler(update_worktree_settings)
 214            .add_message_handler(refresh_inlay_hints)
 215            .add_request_handler(forward_project_request::<proto::GetHover>)
 216            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 217            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetReferences>)
 219            .add_request_handler(forward_project_request::<proto::SearchProject>)
 220            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 221            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 222            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 225            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 226            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 227            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 229            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 230            .add_request_handler(forward_project_request::<proto::PerformRename>)
 231            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 232            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 233            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 234            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 235            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 240            .add_request_handler(forward_project_request::<proto::InlayHints>)
 241            .add_message_handler(create_buffer_for_peer)
 242            .add_request_handler(update_buffer)
 243            .add_message_handler(update_buffer_file)
 244            .add_message_handler(buffer_reloaded)
 245            .add_message_handler(buffer_saved)
 246            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 247            .add_request_handler(get_users)
 248            .add_request_handler(fuzzy_search_users)
 249            .add_request_handler(request_contact)
 250            .add_request_handler(remove_contact)
 251            .add_request_handler(respond_to_contact_request)
 252            .add_request_handler(create_channel)
 253            .add_request_handler(delete_channel)
 254            .add_request_handler(invite_channel_member)
 255            .add_request_handler(remove_channel_member)
 256            .add_request_handler(set_channel_member_admin)
 257            .add_request_handler(rename_channel)
 258            .add_request_handler(join_channel_buffer)
 259            .add_request_handler(leave_channel_buffer)
 260            .add_message_handler(update_channel_buffer)
 261            .add_request_handler(rejoin_channel_buffers)
 262            .add_request_handler(get_channel_members)
 263            .add_request_handler(respond_to_channel_invite)
 264            .add_request_handler(join_channel)
 265            .add_request_handler(join_channel_chat)
 266            .add_message_handler(leave_channel_chat)
 267            .add_request_handler(send_channel_message)
 268            .add_request_handler(remove_channel_message)
 269            .add_request_handler(get_channel_messages)
 270            .add_request_handler(link_channel)
 271            .add_request_handler(unlink_channel)
 272            .add_request_handler(move_channel)
 273            .add_request_handler(follow)
 274            .add_message_handler(unfollow)
 275            .add_message_handler(update_followers)
 276            .add_message_handler(update_diff_base)
 277            .add_request_handler(get_private_user_info)
 278            .add_message_handler(acknowledge_channel_message);
 279
 280        Arc::new(server)
 281    }
 282
 283    pub async fn start(&self) -> Result<()> {
 284        let server_id = *self.id.lock();
 285        let app_state = self.app_state.clone();
 286        let peer = self.peer.clone();
 287        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 288        let pool = self.connection_pool.clone();
 289        let live_kit_client = self.app_state.live_kit_client.clone();
 290
 291        let span = info_span!("start server");
 292        self.executor.spawn_detached(
 293            async move {
 294                tracing::info!("waiting for cleanup timeout");
 295                timeout.await;
 296                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 297                if let Some((room_ids, channel_ids)) = app_state
 298                    .db
 299                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 300                    .await
 301                    .trace_err()
 302                {
 303                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 304                    tracing::info!(
 305                        stale_channel_buffer_count = channel_ids.len(),
 306                        "retrieved stale channel buffers"
 307                    );
 308
 309                    for channel_id in channel_ids {
 310                        if let Some(refreshed_channel_buffer) = app_state
 311                            .db
 312                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 313                            .await
 314                            .trace_err()
 315                        {
 316                            for connection_id in refreshed_channel_buffer.connection_ids {
 317                                peer.send(
 318                                    connection_id,
 319                                    proto::UpdateChannelBufferCollaborators {
 320                                        channel_id: channel_id.to_proto(),
 321                                        collaborators: refreshed_channel_buffer
 322                                            .collaborators
 323                                            .clone(),
 324                                    },
 325                                )
 326                                .trace_err();
 327                            }
 328                        }
 329                    }
 330
 331                    for room_id in room_ids {
 332                        let mut contacts_to_update = HashSet::default();
 333                        let mut canceled_calls_to_user_ids = Vec::new();
 334                        let mut live_kit_room = String::new();
 335                        let mut delete_live_kit_room = false;
 336
 337                        if let Some(mut refreshed_room) = app_state
 338                            .db
 339                            .clear_stale_room_participants(room_id, server_id)
 340                            .await
 341                            .trace_err()
 342                        {
 343                            tracing::info!(
 344                                room_id = room_id.0,
 345                                new_participant_count = refreshed_room.room.participants.len(),
 346                                "refreshed room"
 347                            );
 348                            room_updated(&refreshed_room.room, &peer);
 349                            if let Some(channel_id) = refreshed_room.channel_id {
 350                                channel_updated(
 351                                    channel_id,
 352                                    &refreshed_room.room,
 353                                    &refreshed_room.channel_members,
 354                                    &peer,
 355                                    &*pool.lock(),
 356                                );
 357                            }
 358                            contacts_to_update
 359                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 360                            contacts_to_update
 361                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 362                            canceled_calls_to_user_ids =
 363                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 364                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 365                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 366                        }
 367
 368                        {
 369                            let pool = pool.lock();
 370                            for canceled_user_id in canceled_calls_to_user_ids {
 371                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 372                                    peer.send(
 373                                        connection_id,
 374                                        proto::CallCanceled {
 375                                            room_id: room_id.to_proto(),
 376                                        },
 377                                    )
 378                                    .trace_err();
 379                                }
 380                            }
 381                        }
 382
 383                        for user_id in contacts_to_update {
 384                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 385                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 386                            if let Some((busy, contacts)) = busy.zip(contacts) {
 387                                let pool = pool.lock();
 388                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 389                                for contact in contacts {
 390                                    if let db::Contact::Accepted {
 391                                        user_id: contact_user_id,
 392                                        ..
 393                                    } = contact
 394                                    {
 395                                        for contact_conn_id in
 396                                            pool.user_connection_ids(contact_user_id)
 397                                        {
 398                                            peer.send(
 399                                                contact_conn_id,
 400                                                proto::UpdateContacts {
 401                                                    contacts: vec![updated_contact.clone()],
 402                                                    remove_contacts: Default::default(),
 403                                                    incoming_requests: Default::default(),
 404                                                    remove_incoming_requests: Default::default(),
 405                                                    outgoing_requests: Default::default(),
 406                                                    remove_outgoing_requests: Default::default(),
 407                                                },
 408                                            )
 409                                            .trace_err();
 410                                        }
 411                                    }
 412                                }
 413                            }
 414                        }
 415
 416                        if let Some(live_kit) = live_kit_client.as_ref() {
 417                            if delete_live_kit_room {
 418                                live_kit.delete_room(live_kit_room).await.trace_err();
 419                            }
 420                        }
 421                    }
 422                }
 423
 424                app_state
 425                    .db
 426                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 427                    .await
 428                    .trace_err();
 429            }
 430            .instrument(span),
 431        );
 432        Ok(())
 433    }
 434
 435    pub fn teardown(&self) {
 436        self.peer.teardown();
 437        self.connection_pool.lock().reset();
 438        let _ = self.teardown.send(());
 439    }
 440
 441    #[cfg(test)]
 442    pub fn reset(&self, id: ServerId) {
 443        self.teardown();
 444        *self.id.lock() = id;
 445        self.peer.reset(id.0 as u32);
 446    }
 447
 448    #[cfg(test)]
 449    pub fn id(&self) -> ServerId {
 450        *self.id.lock()
 451    }
 452
 453    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 454    where
 455        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 456        Fut: 'static + Send + Future<Output = Result<()>>,
 457        M: EnvelopedMessage,
 458    {
 459        let prev_handler = self.handlers.insert(
 460            TypeId::of::<M>(),
 461            Box::new(move |envelope, session| {
 462                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 463                let span = info_span!(
 464                    "handle message",
 465                    payload_type = envelope.payload_type_name()
 466                );
 467                span.in_scope(|| {
 468                    tracing::info!(
 469                        payload_type = envelope.payload_type_name(),
 470                        "message received"
 471                    );
 472                });
 473                let start_time = Instant::now();
 474                let future = (handler)(*envelope, session);
 475                async move {
 476                    let result = future.await;
 477                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 478                    match result {
 479                        Err(error) => {
 480                            tracing::error!(%error, ?duration_ms, "error handling message")
 481                        }
 482                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 483                    }
 484                }
 485                .instrument(span)
 486                .boxed()
 487            }),
 488        );
 489        if prev_handler.is_some() {
 490            panic!("registered a handler for the same message twice");
 491        }
 492        self
 493    }
 494
 495    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 496    where
 497        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 498        Fut: 'static + Send + Future<Output = Result<()>>,
 499        M: EnvelopedMessage,
 500    {
 501        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 502        self
 503    }
 504
 505    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 506    where
 507        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 508        Fut: Send + Future<Output = Result<()>>,
 509        M: RequestMessage,
 510    {
 511        let handler = Arc::new(handler);
 512        self.add_handler(move |envelope, session| {
 513            let receipt = envelope.receipt();
 514            let handler = handler.clone();
 515            async move {
 516                let peer = session.peer.clone();
 517                let responded = Arc::new(AtomicBool::default());
 518                let response = Response {
 519                    peer: peer.clone(),
 520                    responded: responded.clone(),
 521                    receipt,
 522                };
 523                match (handler)(envelope.payload, response, session).await {
 524                    Ok(()) => {
 525                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 526                            Ok(())
 527                        } else {
 528                            Err(anyhow!("handler did not send a response"))?
 529                        }
 530                    }
 531                    Err(error) => {
 532                        peer.respond_with_error(
 533                            receipt,
 534                            proto::Error {
 535                                message: error.to_string(),
 536                            },
 537                        )?;
 538                        Err(error)
 539                    }
 540                }
 541            }
 542        })
 543    }
 544
 545    pub fn handle_connection(
 546        self: &Arc<Self>,
 547        connection: Connection,
 548        address: String,
 549        user: User,
 550        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 551        executor: Executor,
 552    ) -> impl Future<Output = Result<()>> {
 553        let this = self.clone();
 554        let user_id = user.id;
 555        let login = user.github_login;
 556        let span = info_span!("handle connection", %user_id, %login, %address);
 557        let mut teardown = self.teardown.subscribe();
 558        async move {
 559            let (connection_id, handle_io, mut incoming_rx) = this
 560                .peer
 561                .add_connection(connection, {
 562                    let executor = executor.clone();
 563                    move |duration| executor.sleep(duration)
 564                });
 565
 566            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 567            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 568            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 569
 570            if let Some(send_connection_id) = send_connection_id.take() {
 571                let _ = send_connection_id.send(connection_id);
 572            }
 573
 574            if !user.connected_once {
 575                this.peer.send(connection_id, proto::ShowContacts {})?;
 576                this.app_state.db.set_user_connected_once(user_id, true).await?;
 577            }
 578
 579            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 580                this.app_state.db.get_contacts(user_id),
 581                this.app_state.db.get_channels_for_user(user_id),
 582                this.app_state.db.get_channel_invites_for_user(user_id)
 583            ).await?;
 584
 585            {
 586                let mut pool = this.connection_pool.lock();
 587                pool.add_connection(connection_id, user_id, user.admin);
 588                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 589                this.peer.send(connection_id, build_initial_channels_update(
 590                    channels_for_user,
 591                    channel_invites
 592                ))?;
 593            }
 594
 595            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 596                this.peer.send(connection_id, incoming_call)?;
 597            }
 598
 599            let session = Session {
 600                user_id,
 601                connection_id,
 602                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 603                peer: this.peer.clone(),
 604                connection_pool: this.connection_pool.clone(),
 605                live_kit_client: this.app_state.live_kit_client.clone(),
 606                executor: executor.clone(),
 607            };
 608            update_user_contacts(user_id, &session).await?;
 609
 610            let handle_io = handle_io.fuse();
 611            futures::pin_mut!(handle_io);
 612
 613            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 614            // This prevents deadlocks when e.g., client A performs a request to client B and
 615            // client B performs a request to client A. If both clients stop processing further
 616            // messages until their respective request completes, they won't have a chance to
 617            // respond to the other client's request and cause a deadlock.
 618            //
 619            // This arrangement ensures we will attempt to process earlier messages first, but fall
 620            // back to processing messages arrived later in the spirit of making progress.
 621            let mut foreground_message_handlers = FuturesUnordered::new();
 622            let concurrent_handlers = Arc::new(Semaphore::new(256));
 623            loop {
 624                let next_message = async {
 625                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 626                    let message = incoming_rx.next().await;
 627                    (permit, message)
 628                }.fuse();
 629                futures::pin_mut!(next_message);
 630                futures::select_biased! {
 631                    _ = teardown.changed().fuse() => return Ok(()),
 632                    result = handle_io => {
 633                        if let Err(error) = result {
 634                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 635                        }
 636                        break;
 637                    }
 638                    _ = foreground_message_handlers.next() => {}
 639                    next_message = next_message => {
 640                        let (permit, message) = next_message;
 641                        if let Some(message) = message {
 642                            let type_name = message.payload_type_name();
 643                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 644                            let span_enter = span.enter();
 645                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 646                                let is_background = message.is_background();
 647                                let handle_message = (handler)(message, session.clone());
 648                                drop(span_enter);
 649
 650                                let handle_message = async move {
 651                                    handle_message.await;
 652                                    drop(permit);
 653                                }.instrument(span);
 654                                if is_background {
 655                                    executor.spawn_detached(handle_message);
 656                                } else {
 657                                    foreground_message_handlers.push(handle_message);
 658                                }
 659                            } else {
 660                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 661                            }
 662                        } else {
 663                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 664                            break;
 665                        }
 666                    }
 667                }
 668            }
 669
 670            drop(foreground_message_handlers);
 671            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 672            if let Err(error) = connection_lost(session, teardown, executor).await {
 673                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 674            }
 675
 676            Ok(())
 677        }.instrument(span)
 678    }
 679
 680    pub async fn invite_code_redeemed(
 681        self: &Arc<Self>,
 682        inviter_id: UserId,
 683        invitee_id: UserId,
 684    ) -> Result<()> {
 685        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 686            if let Some(code) = &user.invite_code {
 687                let pool = self.connection_pool.lock();
 688                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 689                for connection_id in pool.user_connection_ids(inviter_id) {
 690                    self.peer.send(
 691                        connection_id,
 692                        proto::UpdateContacts {
 693                            contacts: vec![invitee_contact.clone()],
 694                            ..Default::default()
 695                        },
 696                    )?;
 697                    self.peer.send(
 698                        connection_id,
 699                        proto::UpdateInviteInfo {
 700                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 701                            count: user.invite_count as u32,
 702                        },
 703                    )?;
 704                }
 705            }
 706        }
 707        Ok(())
 708    }
 709
 710    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 711        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 712            if let Some(invite_code) = &user.invite_code {
 713                let pool = self.connection_pool.lock();
 714                for connection_id in pool.user_connection_ids(user_id) {
 715                    self.peer.send(
 716                        connection_id,
 717                        proto::UpdateInviteInfo {
 718                            url: format!(
 719                                "{}{}",
 720                                self.app_state.config.invite_link_prefix, invite_code
 721                            ),
 722                            count: user.invite_count as u32,
 723                        },
 724                    )?;
 725                }
 726            }
 727        }
 728        Ok(())
 729    }
 730
 731    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 732        ServerSnapshot {
 733            connection_pool: ConnectionPoolGuard {
 734                guard: self.connection_pool.lock(),
 735                _not_send: PhantomData,
 736            },
 737            peer: &self.peer,
 738        }
 739    }
 740}
 741
 742impl<'a> Deref for ConnectionPoolGuard<'a> {
 743    type Target = ConnectionPool;
 744
 745    fn deref(&self) -> &Self::Target {
 746        &*self.guard
 747    }
 748}
 749
 750impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 751    fn deref_mut(&mut self) -> &mut Self::Target {
 752        &mut *self.guard
 753    }
 754}
 755
 756impl<'a> Drop for ConnectionPoolGuard<'a> {
 757    fn drop(&mut self) {
 758        #[cfg(test)]
 759        self.check_invariants();
 760    }
 761}
 762
 763fn broadcast<F>(
 764    sender_id: Option<ConnectionId>,
 765    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 766    mut f: F,
 767) where
 768    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 769{
 770    for receiver_id in receiver_ids {
 771        if Some(receiver_id) != sender_id {
 772            if let Err(error) = f(receiver_id) {
 773                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 774            }
 775        }
 776    }
 777}
 778
 779lazy_static! {
 780    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 781}
 782
 783pub struct ProtocolVersion(u32);
 784
 785impl Header for ProtocolVersion {
 786    fn name() -> &'static HeaderName {
 787        &ZED_PROTOCOL_VERSION
 788    }
 789
 790    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 791    where
 792        Self: Sized,
 793        I: Iterator<Item = &'i axum::http::HeaderValue>,
 794    {
 795        let version = values
 796            .next()
 797            .ok_or_else(axum::headers::Error::invalid)?
 798            .to_str()
 799            .map_err(|_| axum::headers::Error::invalid())?
 800            .parse()
 801            .map_err(|_| axum::headers::Error::invalid())?;
 802        Ok(Self(version))
 803    }
 804
 805    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 806        values.extend([self.0.to_string().parse().unwrap()]);
 807    }
 808}
 809
 810pub fn routes(server: Arc<Server>) -> Router<Body> {
 811    Router::new()
 812        .route("/rpc", get(handle_websocket_request))
 813        .layer(
 814            ServiceBuilder::new()
 815                .layer(Extension(server.app_state.clone()))
 816                .layer(middleware::from_fn(auth::validate_header)),
 817        )
 818        .route("/metrics", get(handle_metrics))
 819        .layer(Extension(server))
 820}
 821
 822pub async fn handle_websocket_request(
 823    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 824    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 825    Extension(server): Extension<Arc<Server>>,
 826    Extension(user): Extension<User>,
 827    ws: WebSocketUpgrade,
 828) -> axum::response::Response {
 829    if protocol_version != rpc::PROTOCOL_VERSION {
 830        return (
 831            StatusCode::UPGRADE_REQUIRED,
 832            "client must be upgraded".to_string(),
 833        )
 834            .into_response();
 835    }
 836    let socket_address = socket_address.to_string();
 837    ws.on_upgrade(move |socket| {
 838        use util::ResultExt;
 839        let socket = socket
 840            .map_ok(to_tungstenite_message)
 841            .err_into()
 842            .with(|message| async move { Ok(to_axum_message(message)) });
 843        let connection = Connection::new(Box::pin(socket));
 844        async move {
 845            server
 846                .handle_connection(connection, socket_address, user, None, Executor::Production)
 847                .await
 848                .log_err();
 849        }
 850    })
 851}
 852
 853pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 854    let connections = server
 855        .connection_pool
 856        .lock()
 857        .connections()
 858        .filter(|connection| !connection.admin)
 859        .count();
 860
 861    METRIC_CONNECTIONS.set(connections as _);
 862
 863    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 864    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 865
 866    let encoder = prometheus::TextEncoder::new();
 867    let metric_families = prometheus::gather();
 868    let encoded_metrics = encoder
 869        .encode_to_string(&metric_families)
 870        .map_err(|err| anyhow!("{}", err))?;
 871    Ok(encoded_metrics)
 872}
 873
 874#[instrument(err, skip(executor))]
 875async fn connection_lost(
 876    session: Session,
 877    mut teardown: watch::Receiver<()>,
 878    executor: Executor,
 879) -> Result<()> {
 880    session.peer.disconnect(session.connection_id);
 881    session
 882        .connection_pool()
 883        .await
 884        .remove_connection(session.connection_id)?;
 885
 886    session
 887        .db()
 888        .await
 889        .connection_lost(session.connection_id)
 890        .await
 891        .trace_err();
 892
 893    futures::select_biased! {
 894        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 895            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 896            leave_room_for_session(&session).await.trace_err();
 897            leave_channel_buffers_for_session(&session)
 898                .await
 899                .trace_err();
 900
 901            if !session
 902                .connection_pool()
 903                .await
 904                .is_user_online(session.user_id)
 905            {
 906                let db = session.db().await;
 907                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 908                    room_updated(&room, &session.peer);
 909                }
 910            }
 911
 912            update_user_contacts(session.user_id, &session).await?;
 913        }
 914        _ = teardown.changed().fuse() => {}
 915    }
 916
 917    Ok(())
 918}
 919
 920async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 921    response.send(proto::Ack {})?;
 922    Ok(())
 923}
 924
 925async fn create_room(
 926    _request: proto::CreateRoom,
 927    response: Response<proto::CreateRoom>,
 928    session: Session,
 929) -> Result<()> {
 930    let live_kit_room = nanoid::nanoid!(30);
 931
 932    let live_kit_connection_info = {
 933        let live_kit_room = live_kit_room.clone();
 934        let live_kit = session.live_kit_client.as_ref();
 935
 936        util::async_iife!({
 937            let live_kit = live_kit?;
 938
 939            live_kit
 940                .create_room(live_kit_room.clone())
 941                .await
 942                .trace_err()?;
 943
 944            let token = live_kit
 945                .room_token(&live_kit_room, &session.user_id.to_string())
 946                .trace_err()?;
 947
 948            Some(proto::LiveKitConnectionInfo {
 949                server_url: live_kit.url().into(),
 950                token,
 951            })
 952        })
 953    }
 954    .await;
 955
 956    let room = session
 957        .db()
 958        .await
 959        .create_room(session.user_id, session.connection_id, &live_kit_room)
 960        .await?;
 961
 962    response.send(proto::CreateRoomResponse {
 963        room: Some(room.clone()),
 964        live_kit_connection_info,
 965    })?;
 966
 967    update_user_contacts(session.user_id, &session).await?;
 968    Ok(())
 969}
 970
 971async fn join_room(
 972    request: proto::JoinRoom,
 973    response: Response<proto::JoinRoom>,
 974    session: Session,
 975) -> Result<()> {
 976    let room_id = RoomId::from_proto(request.id);
 977    let joined_room = {
 978        let room = session
 979            .db()
 980            .await
 981            .join_room(room_id, session.user_id, session.connection_id)
 982            .await?;
 983        room_updated(&room.room, &session.peer);
 984        room.into_inner()
 985    };
 986
 987    if let Some(channel_id) = joined_room.channel_id {
 988        channel_updated(
 989            channel_id,
 990            &joined_room.room,
 991            &joined_room.channel_members,
 992            &session.peer,
 993            &*session.connection_pool().await,
 994        )
 995    }
 996
 997    for connection_id in session
 998        .connection_pool()
 999        .await
1000        .user_connection_ids(session.user_id)
1001    {
1002        session
1003            .peer
1004            .send(
1005                connection_id,
1006                proto::CallCanceled {
1007                    room_id: room_id.to_proto(),
1008                },
1009            )
1010            .trace_err();
1011    }
1012
1013    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1014        if let Some(token) = live_kit
1015            .room_token(
1016                &joined_room.room.live_kit_room,
1017                &session.user_id.to_string(),
1018            )
1019            .trace_err()
1020        {
1021            Some(proto::LiveKitConnectionInfo {
1022                server_url: live_kit.url().into(),
1023                token,
1024            })
1025        } else {
1026            None
1027        }
1028    } else {
1029        None
1030    };
1031
1032    response.send(proto::JoinRoomResponse {
1033        room: Some(joined_room.room),
1034        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1035        live_kit_connection_info,
1036    })?;
1037
1038    update_user_contacts(session.user_id, &session).await?;
1039    Ok(())
1040}
1041
1042async fn rejoin_room(
1043    request: proto::RejoinRoom,
1044    response: Response<proto::RejoinRoom>,
1045    session: Session,
1046) -> Result<()> {
1047    let room;
1048    let channel_id;
1049    let channel_members;
1050    {
1051        let mut rejoined_room = session
1052            .db()
1053            .await
1054            .rejoin_room(request, session.user_id, session.connection_id)
1055            .await?;
1056
1057        response.send(proto::RejoinRoomResponse {
1058            room: Some(rejoined_room.room.clone()),
1059            reshared_projects: rejoined_room
1060                .reshared_projects
1061                .iter()
1062                .map(|project| proto::ResharedProject {
1063                    id: project.id.to_proto(),
1064                    collaborators: project
1065                        .collaborators
1066                        .iter()
1067                        .map(|collaborator| collaborator.to_proto())
1068                        .collect(),
1069                })
1070                .collect(),
1071            rejoined_projects: rejoined_room
1072                .rejoined_projects
1073                .iter()
1074                .map(|rejoined_project| proto::RejoinedProject {
1075                    id: rejoined_project.id.to_proto(),
1076                    worktrees: rejoined_project
1077                        .worktrees
1078                        .iter()
1079                        .map(|worktree| proto::WorktreeMetadata {
1080                            id: worktree.id,
1081                            root_name: worktree.root_name.clone(),
1082                            visible: worktree.visible,
1083                            abs_path: worktree.abs_path.clone(),
1084                        })
1085                        .collect(),
1086                    collaborators: rejoined_project
1087                        .collaborators
1088                        .iter()
1089                        .map(|collaborator| collaborator.to_proto())
1090                        .collect(),
1091                    language_servers: rejoined_project.language_servers.clone(),
1092                })
1093                .collect(),
1094        })?;
1095        room_updated(&rejoined_room.room, &session.peer);
1096
1097        for project in &rejoined_room.reshared_projects {
1098            for collaborator in &project.collaborators {
1099                session
1100                    .peer
1101                    .send(
1102                        collaborator.connection_id,
1103                        proto::UpdateProjectCollaborator {
1104                            project_id: project.id.to_proto(),
1105                            old_peer_id: Some(project.old_connection_id.into()),
1106                            new_peer_id: Some(session.connection_id.into()),
1107                        },
1108                    )
1109                    .trace_err();
1110            }
1111
1112            broadcast(
1113                Some(session.connection_id),
1114                project
1115                    .collaborators
1116                    .iter()
1117                    .map(|collaborator| collaborator.connection_id),
1118                |connection_id| {
1119                    session.peer.forward_send(
1120                        session.connection_id,
1121                        connection_id,
1122                        proto::UpdateProject {
1123                            project_id: project.id.to_proto(),
1124                            worktrees: project.worktrees.clone(),
1125                        },
1126                    )
1127                },
1128            );
1129        }
1130
1131        for project in &rejoined_room.rejoined_projects {
1132            for collaborator in &project.collaborators {
1133                session
1134                    .peer
1135                    .send(
1136                        collaborator.connection_id,
1137                        proto::UpdateProjectCollaborator {
1138                            project_id: project.id.to_proto(),
1139                            old_peer_id: Some(project.old_connection_id.into()),
1140                            new_peer_id: Some(session.connection_id.into()),
1141                        },
1142                    )
1143                    .trace_err();
1144            }
1145        }
1146
1147        for project in &mut rejoined_room.rejoined_projects {
1148            for worktree in mem::take(&mut project.worktrees) {
1149                #[cfg(any(test, feature = "test-support"))]
1150                const MAX_CHUNK_SIZE: usize = 2;
1151                #[cfg(not(any(test, feature = "test-support")))]
1152                const MAX_CHUNK_SIZE: usize = 256;
1153
1154                // Stream this worktree's entries.
1155                let message = proto::UpdateWorktree {
1156                    project_id: project.id.to_proto(),
1157                    worktree_id: worktree.id,
1158                    abs_path: worktree.abs_path.clone(),
1159                    root_name: worktree.root_name,
1160                    updated_entries: worktree.updated_entries,
1161                    removed_entries: worktree.removed_entries,
1162                    scan_id: worktree.scan_id,
1163                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1164                    updated_repositories: worktree.updated_repositories,
1165                    removed_repositories: worktree.removed_repositories,
1166                };
1167                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1168                    session.peer.send(session.connection_id, update.clone())?;
1169                }
1170
1171                // Stream this worktree's diagnostics.
1172                for summary in worktree.diagnostic_summaries {
1173                    session.peer.send(
1174                        session.connection_id,
1175                        proto::UpdateDiagnosticSummary {
1176                            project_id: project.id.to_proto(),
1177                            worktree_id: worktree.id,
1178                            summary: Some(summary),
1179                        },
1180                    )?;
1181                }
1182
1183                for settings_file in worktree.settings_files {
1184                    session.peer.send(
1185                        session.connection_id,
1186                        proto::UpdateWorktreeSettings {
1187                            project_id: project.id.to_proto(),
1188                            worktree_id: worktree.id,
1189                            path: settings_file.path,
1190                            content: Some(settings_file.content),
1191                        },
1192                    )?;
1193                }
1194            }
1195
1196            for language_server in &project.language_servers {
1197                session.peer.send(
1198                    session.connection_id,
1199                    proto::UpdateLanguageServer {
1200                        project_id: project.id.to_proto(),
1201                        language_server_id: language_server.id,
1202                        variant: Some(
1203                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1204                                proto::LspDiskBasedDiagnosticsUpdated {},
1205                            ),
1206                        ),
1207                    },
1208                )?;
1209            }
1210        }
1211
1212        let rejoined_room = rejoined_room.into_inner();
1213
1214        room = rejoined_room.room;
1215        channel_id = rejoined_room.channel_id;
1216        channel_members = rejoined_room.channel_members;
1217    }
1218
1219    if let Some(channel_id) = channel_id {
1220        channel_updated(
1221            channel_id,
1222            &room,
1223            &channel_members,
1224            &session.peer,
1225            &*session.connection_pool().await,
1226        );
1227    }
1228
1229    update_user_contacts(session.user_id, &session).await?;
1230    Ok(())
1231}
1232
1233async fn leave_room(
1234    _: proto::LeaveRoom,
1235    response: Response<proto::LeaveRoom>,
1236    session: Session,
1237) -> Result<()> {
1238    leave_room_for_session(&session).await?;
1239    response.send(proto::Ack {})?;
1240    Ok(())
1241}
1242
1243async fn call(
1244    request: proto::Call,
1245    response: Response<proto::Call>,
1246    session: Session,
1247) -> Result<()> {
1248    let room_id = RoomId::from_proto(request.room_id);
1249    let calling_user_id = session.user_id;
1250    let calling_connection_id = session.connection_id;
1251    let called_user_id = UserId::from_proto(request.called_user_id);
1252    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1253    if !session
1254        .db()
1255        .await
1256        .has_contact(calling_user_id, called_user_id)
1257        .await?
1258    {
1259        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1260    }
1261
1262    let incoming_call = {
1263        let (room, incoming_call) = &mut *session
1264            .db()
1265            .await
1266            .call(
1267                room_id,
1268                calling_user_id,
1269                calling_connection_id,
1270                called_user_id,
1271                initial_project_id,
1272            )
1273            .await?;
1274        room_updated(&room, &session.peer);
1275        mem::take(incoming_call)
1276    };
1277    update_user_contacts(called_user_id, &session).await?;
1278
1279    let mut calls = session
1280        .connection_pool()
1281        .await
1282        .user_connection_ids(called_user_id)
1283        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1284        .collect::<FuturesUnordered<_>>();
1285
1286    while let Some(call_response) = calls.next().await {
1287        match call_response.as_ref() {
1288            Ok(_) => {
1289                response.send(proto::Ack {})?;
1290                return Ok(());
1291            }
1292            Err(_) => {
1293                call_response.trace_err();
1294            }
1295        }
1296    }
1297
1298    {
1299        let room = session
1300            .db()
1301            .await
1302            .call_failed(room_id, called_user_id)
1303            .await?;
1304        room_updated(&room, &session.peer);
1305    }
1306    update_user_contacts(called_user_id, &session).await?;
1307
1308    Err(anyhow!("failed to ring user"))?
1309}
1310
1311async fn cancel_call(
1312    request: proto::CancelCall,
1313    response: Response<proto::CancelCall>,
1314    session: Session,
1315) -> Result<()> {
1316    let called_user_id = UserId::from_proto(request.called_user_id);
1317    let room_id = RoomId::from_proto(request.room_id);
1318    {
1319        let room = session
1320            .db()
1321            .await
1322            .cancel_call(room_id, session.connection_id, called_user_id)
1323            .await?;
1324        room_updated(&room, &session.peer);
1325    }
1326
1327    for connection_id in session
1328        .connection_pool()
1329        .await
1330        .user_connection_ids(called_user_id)
1331    {
1332        session
1333            .peer
1334            .send(
1335                connection_id,
1336                proto::CallCanceled {
1337                    room_id: room_id.to_proto(),
1338                },
1339            )
1340            .trace_err();
1341    }
1342    response.send(proto::Ack {})?;
1343
1344    update_user_contacts(called_user_id, &session).await?;
1345    Ok(())
1346}
1347
1348async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1349    let room_id = RoomId::from_proto(message.room_id);
1350    {
1351        let room = session
1352            .db()
1353            .await
1354            .decline_call(Some(room_id), session.user_id)
1355            .await?
1356            .ok_or_else(|| anyhow!("failed to decline call"))?;
1357        room_updated(&room, &session.peer);
1358    }
1359
1360    for connection_id in session
1361        .connection_pool()
1362        .await
1363        .user_connection_ids(session.user_id)
1364    {
1365        session
1366            .peer
1367            .send(
1368                connection_id,
1369                proto::CallCanceled {
1370                    room_id: room_id.to_proto(),
1371                },
1372            )
1373            .trace_err();
1374    }
1375    update_user_contacts(session.user_id, &session).await?;
1376    Ok(())
1377}
1378
1379async fn update_participant_location(
1380    request: proto::UpdateParticipantLocation,
1381    response: Response<proto::UpdateParticipantLocation>,
1382    session: Session,
1383) -> Result<()> {
1384    let room_id = RoomId::from_proto(request.room_id);
1385    let location = request
1386        .location
1387        .ok_or_else(|| anyhow!("invalid location"))?;
1388
1389    let db = session.db().await;
1390    let room = db
1391        .update_room_participant_location(room_id, session.connection_id, location)
1392        .await?;
1393
1394    room_updated(&room, &session.peer);
1395    response.send(proto::Ack {})?;
1396    Ok(())
1397}
1398
1399async fn share_project(
1400    request: proto::ShareProject,
1401    response: Response<proto::ShareProject>,
1402    session: Session,
1403) -> Result<()> {
1404    let (project_id, room) = &*session
1405        .db()
1406        .await
1407        .share_project(
1408            RoomId::from_proto(request.room_id),
1409            session.connection_id,
1410            &request.worktrees,
1411        )
1412        .await?;
1413    response.send(proto::ShareProjectResponse {
1414        project_id: project_id.to_proto(),
1415    })?;
1416    room_updated(&room, &session.peer);
1417
1418    Ok(())
1419}
1420
1421async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1422    let project_id = ProjectId::from_proto(message.project_id);
1423
1424    let (room, guest_connection_ids) = &*session
1425        .db()
1426        .await
1427        .unshare_project(project_id, session.connection_id)
1428        .await?;
1429
1430    broadcast(
1431        Some(session.connection_id),
1432        guest_connection_ids.iter().copied(),
1433        |conn_id| session.peer.send(conn_id, message.clone()),
1434    );
1435    room_updated(&room, &session.peer);
1436
1437    Ok(())
1438}
1439
1440async fn join_project(
1441    request: proto::JoinProject,
1442    response: Response<proto::JoinProject>,
1443    session: Session,
1444) -> Result<()> {
1445    let project_id = ProjectId::from_proto(request.project_id);
1446    let guest_user_id = session.user_id;
1447
1448    tracing::info!(%project_id, "join project");
1449
1450    let (project, replica_id) = &mut *session
1451        .db()
1452        .await
1453        .join_project(project_id, session.connection_id)
1454        .await?;
1455
1456    let collaborators = project
1457        .collaborators
1458        .iter()
1459        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1460        .map(|collaborator| collaborator.to_proto())
1461        .collect::<Vec<_>>();
1462
1463    let worktrees = project
1464        .worktrees
1465        .iter()
1466        .map(|(id, worktree)| proto::WorktreeMetadata {
1467            id: *id,
1468            root_name: worktree.root_name.clone(),
1469            visible: worktree.visible,
1470            abs_path: worktree.abs_path.clone(),
1471        })
1472        .collect::<Vec<_>>();
1473
1474    for collaborator in &collaborators {
1475        session
1476            .peer
1477            .send(
1478                collaborator.peer_id.unwrap().into(),
1479                proto::AddProjectCollaborator {
1480                    project_id: project_id.to_proto(),
1481                    collaborator: Some(proto::Collaborator {
1482                        peer_id: Some(session.connection_id.into()),
1483                        replica_id: replica_id.0 as u32,
1484                        user_id: guest_user_id.to_proto(),
1485                    }),
1486                },
1487            )
1488            .trace_err();
1489    }
1490
1491    // First, we send the metadata associated with each worktree.
1492    response.send(proto::JoinProjectResponse {
1493        worktrees: worktrees.clone(),
1494        replica_id: replica_id.0 as u32,
1495        collaborators: collaborators.clone(),
1496        language_servers: project.language_servers.clone(),
1497    })?;
1498
1499    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1500        #[cfg(any(test, feature = "test-support"))]
1501        const MAX_CHUNK_SIZE: usize = 2;
1502        #[cfg(not(any(test, feature = "test-support")))]
1503        const MAX_CHUNK_SIZE: usize = 256;
1504
1505        // Stream this worktree's entries.
1506        let message = proto::UpdateWorktree {
1507            project_id: project_id.to_proto(),
1508            worktree_id,
1509            abs_path: worktree.abs_path.clone(),
1510            root_name: worktree.root_name,
1511            updated_entries: worktree.entries,
1512            removed_entries: Default::default(),
1513            scan_id: worktree.scan_id,
1514            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1515            updated_repositories: worktree.repository_entries.into_values().collect(),
1516            removed_repositories: Default::default(),
1517        };
1518        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1519            session.peer.send(session.connection_id, update.clone())?;
1520        }
1521
1522        // Stream this worktree's diagnostics.
1523        for summary in worktree.diagnostic_summaries {
1524            session.peer.send(
1525                session.connection_id,
1526                proto::UpdateDiagnosticSummary {
1527                    project_id: project_id.to_proto(),
1528                    worktree_id: worktree.id,
1529                    summary: Some(summary),
1530                },
1531            )?;
1532        }
1533
1534        for settings_file in worktree.settings_files {
1535            session.peer.send(
1536                session.connection_id,
1537                proto::UpdateWorktreeSettings {
1538                    project_id: project_id.to_proto(),
1539                    worktree_id: worktree.id,
1540                    path: settings_file.path,
1541                    content: Some(settings_file.content),
1542                },
1543            )?;
1544        }
1545    }
1546
1547    for language_server in &project.language_servers {
1548        session.peer.send(
1549            session.connection_id,
1550            proto::UpdateLanguageServer {
1551                project_id: project_id.to_proto(),
1552                language_server_id: language_server.id,
1553                variant: Some(
1554                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1555                        proto::LspDiskBasedDiagnosticsUpdated {},
1556                    ),
1557                ),
1558            },
1559        )?;
1560    }
1561
1562    Ok(())
1563}
1564
1565async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1566    let sender_id = session.connection_id;
1567    let project_id = ProjectId::from_proto(request.project_id);
1568
1569    let (room, project) = &*session
1570        .db()
1571        .await
1572        .leave_project(project_id, sender_id)
1573        .await?;
1574    tracing::info!(
1575        %project_id,
1576        host_user_id = %project.host_user_id,
1577        host_connection_id = %project.host_connection_id,
1578        "leave project"
1579    );
1580
1581    project_left(&project, &session);
1582    room_updated(&room, &session.peer);
1583
1584    Ok(())
1585}
1586
1587async fn update_project(
1588    request: proto::UpdateProject,
1589    response: Response<proto::UpdateProject>,
1590    session: Session,
1591) -> Result<()> {
1592    let project_id = ProjectId::from_proto(request.project_id);
1593    let (room, guest_connection_ids) = &*session
1594        .db()
1595        .await
1596        .update_project(project_id, session.connection_id, &request.worktrees)
1597        .await?;
1598    broadcast(
1599        Some(session.connection_id),
1600        guest_connection_ids.iter().copied(),
1601        |connection_id| {
1602            session
1603                .peer
1604                .forward_send(session.connection_id, connection_id, request.clone())
1605        },
1606    );
1607    room_updated(&room, &session.peer);
1608    response.send(proto::Ack {})?;
1609
1610    Ok(())
1611}
1612
1613async fn update_worktree(
1614    request: proto::UpdateWorktree,
1615    response: Response<proto::UpdateWorktree>,
1616    session: Session,
1617) -> Result<()> {
1618    let guest_connection_ids = session
1619        .db()
1620        .await
1621        .update_worktree(&request, session.connection_id)
1622        .await?;
1623
1624    broadcast(
1625        Some(session.connection_id),
1626        guest_connection_ids.iter().copied(),
1627        |connection_id| {
1628            session
1629                .peer
1630                .forward_send(session.connection_id, connection_id, request.clone())
1631        },
1632    );
1633    response.send(proto::Ack {})?;
1634    Ok(())
1635}
1636
1637async fn update_diagnostic_summary(
1638    message: proto::UpdateDiagnosticSummary,
1639    session: Session,
1640) -> Result<()> {
1641    let guest_connection_ids = session
1642        .db()
1643        .await
1644        .update_diagnostic_summary(&message, session.connection_id)
1645        .await?;
1646
1647    broadcast(
1648        Some(session.connection_id),
1649        guest_connection_ids.iter().copied(),
1650        |connection_id| {
1651            session
1652                .peer
1653                .forward_send(session.connection_id, connection_id, message.clone())
1654        },
1655    );
1656
1657    Ok(())
1658}
1659
1660async fn update_worktree_settings(
1661    message: proto::UpdateWorktreeSettings,
1662    session: Session,
1663) -> Result<()> {
1664    let guest_connection_ids = session
1665        .db()
1666        .await
1667        .update_worktree_settings(&message, session.connection_id)
1668        .await?;
1669
1670    broadcast(
1671        Some(session.connection_id),
1672        guest_connection_ids.iter().copied(),
1673        |connection_id| {
1674            session
1675                .peer
1676                .forward_send(session.connection_id, connection_id, message.clone())
1677        },
1678    );
1679
1680    Ok(())
1681}
1682
1683async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1684    broadcast_project_message(request.project_id, request, session).await
1685}
1686
1687async fn start_language_server(
1688    request: proto::StartLanguageServer,
1689    session: Session,
1690) -> Result<()> {
1691    let guest_connection_ids = session
1692        .db()
1693        .await
1694        .start_language_server(&request, session.connection_id)
1695        .await?;
1696
1697    broadcast(
1698        Some(session.connection_id),
1699        guest_connection_ids.iter().copied(),
1700        |connection_id| {
1701            session
1702                .peer
1703                .forward_send(session.connection_id, connection_id, request.clone())
1704        },
1705    );
1706    Ok(())
1707}
1708
1709async fn update_language_server(
1710    request: proto::UpdateLanguageServer,
1711    session: Session,
1712) -> Result<()> {
1713    session.executor.record_backtrace();
1714    let project_id = ProjectId::from_proto(request.project_id);
1715    let project_connection_ids = session
1716        .db()
1717        .await
1718        .project_connection_ids(project_id, session.connection_id)
1719        .await?;
1720    broadcast(
1721        Some(session.connection_id),
1722        project_connection_ids.iter().copied(),
1723        |connection_id| {
1724            session
1725                .peer
1726                .forward_send(session.connection_id, connection_id, request.clone())
1727        },
1728    );
1729    Ok(())
1730}
1731
1732async fn forward_project_request<T>(
1733    request: T,
1734    response: Response<T>,
1735    session: Session,
1736) -> Result<()>
1737where
1738    T: EntityMessage + RequestMessage,
1739{
1740    session.executor.record_backtrace();
1741    let project_id = ProjectId::from_proto(request.remote_entity_id());
1742    let host_connection_id = {
1743        let collaborators = session
1744            .db()
1745            .await
1746            .project_collaborators(project_id, session.connection_id)
1747            .await?;
1748        collaborators
1749            .iter()
1750            .find(|collaborator| collaborator.is_host)
1751            .ok_or_else(|| anyhow!("host not found"))?
1752            .connection_id
1753    };
1754
1755    let payload = session
1756        .peer
1757        .forward_request(session.connection_id, host_connection_id, request)
1758        .await?;
1759
1760    response.send(payload)?;
1761    Ok(())
1762}
1763
1764async fn create_buffer_for_peer(
1765    request: proto::CreateBufferForPeer,
1766    session: Session,
1767) -> Result<()> {
1768    session.executor.record_backtrace();
1769    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1770    session
1771        .peer
1772        .forward_send(session.connection_id, peer_id.into(), request)?;
1773    Ok(())
1774}
1775
1776async fn update_buffer(
1777    request: proto::UpdateBuffer,
1778    response: Response<proto::UpdateBuffer>,
1779    session: Session,
1780) -> Result<()> {
1781    session.executor.record_backtrace();
1782    let project_id = ProjectId::from_proto(request.project_id);
1783    let mut guest_connection_ids;
1784    let mut host_connection_id = None;
1785    {
1786        let collaborators = session
1787            .db()
1788            .await
1789            .project_collaborators(project_id, session.connection_id)
1790            .await?;
1791        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1792        for collaborator in collaborators.iter() {
1793            if collaborator.is_host {
1794                host_connection_id = Some(collaborator.connection_id);
1795            } else {
1796                guest_connection_ids.push(collaborator.connection_id);
1797            }
1798        }
1799    }
1800    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1801
1802    session.executor.record_backtrace();
1803    broadcast(
1804        Some(session.connection_id),
1805        guest_connection_ids,
1806        |connection_id| {
1807            session
1808                .peer
1809                .forward_send(session.connection_id, connection_id, request.clone())
1810        },
1811    );
1812    if host_connection_id != session.connection_id {
1813        session
1814            .peer
1815            .forward_request(session.connection_id, host_connection_id, request.clone())
1816            .await?;
1817    }
1818
1819    response.send(proto::Ack {})?;
1820    Ok(())
1821}
1822
1823async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1824    let project_id = ProjectId::from_proto(request.project_id);
1825    let project_connection_ids = session
1826        .db()
1827        .await
1828        .project_connection_ids(project_id, session.connection_id)
1829        .await?;
1830
1831    broadcast(
1832        Some(session.connection_id),
1833        project_connection_ids.iter().copied(),
1834        |connection_id| {
1835            session
1836                .peer
1837                .forward_send(session.connection_id, connection_id, request.clone())
1838        },
1839    );
1840    Ok(())
1841}
1842
1843async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1844    let project_id = ProjectId::from_proto(request.project_id);
1845    let project_connection_ids = session
1846        .db()
1847        .await
1848        .project_connection_ids(project_id, session.connection_id)
1849        .await?;
1850    broadcast(
1851        Some(session.connection_id),
1852        project_connection_ids.iter().copied(),
1853        |connection_id| {
1854            session
1855                .peer
1856                .forward_send(session.connection_id, connection_id, request.clone())
1857        },
1858    );
1859    Ok(())
1860}
1861
1862async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1863    broadcast_project_message(request.project_id, request, session).await
1864}
1865
1866async fn broadcast_project_message<T: EnvelopedMessage>(
1867    project_id: u64,
1868    request: T,
1869    session: Session,
1870) -> Result<()> {
1871    let project_id = ProjectId::from_proto(project_id);
1872    let project_connection_ids = session
1873        .db()
1874        .await
1875        .project_connection_ids(project_id, session.connection_id)
1876        .await?;
1877    broadcast(
1878        Some(session.connection_id),
1879        project_connection_ids.iter().copied(),
1880        |connection_id| {
1881            session
1882                .peer
1883                .forward_send(session.connection_id, connection_id, request.clone())
1884        },
1885    );
1886    Ok(())
1887}
1888
1889async fn follow(
1890    request: proto::Follow,
1891    response: Response<proto::Follow>,
1892    session: Session,
1893) -> Result<()> {
1894    let room_id = RoomId::from_proto(request.room_id);
1895    let project_id = request.project_id.map(ProjectId::from_proto);
1896    let leader_id = request
1897        .leader_id
1898        .ok_or_else(|| anyhow!("invalid leader id"))?
1899        .into();
1900    let follower_id = session.connection_id;
1901
1902    session
1903        .db()
1904        .await
1905        .check_room_participants(room_id, leader_id, session.connection_id)
1906        .await?;
1907
1908    let mut response_payload = session
1909        .peer
1910        .forward_request(session.connection_id, leader_id, request)
1911        .await?;
1912    response_payload
1913        .views
1914        .retain(|view| view.leader_id != Some(follower_id.into()));
1915    response.send(response_payload)?;
1916
1917    if let Some(project_id) = project_id {
1918        let room = session
1919            .db()
1920            .await
1921            .follow(room_id, project_id, leader_id, follower_id)
1922            .await?;
1923        room_updated(&room, &session.peer);
1924    }
1925
1926    Ok(())
1927}
1928
1929async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1930    let room_id = RoomId::from_proto(request.room_id);
1931    let project_id = request.project_id.map(ProjectId::from_proto);
1932    let leader_id = request
1933        .leader_id
1934        .ok_or_else(|| anyhow!("invalid leader id"))?
1935        .into();
1936    let follower_id = session.connection_id;
1937
1938    session
1939        .db()
1940        .await
1941        .check_room_participants(room_id, leader_id, session.connection_id)
1942        .await?;
1943
1944    session
1945        .peer
1946        .forward_send(session.connection_id, leader_id, request)?;
1947
1948    if let Some(project_id) = project_id {
1949        let room = session
1950            .db()
1951            .await
1952            .unfollow(room_id, project_id, leader_id, follower_id)
1953            .await?;
1954        room_updated(&room, &session.peer);
1955    }
1956
1957    Ok(())
1958}
1959
1960async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1961    let room_id = RoomId::from_proto(request.room_id);
1962    let database = session.db.lock().await;
1963
1964    let connection_ids = if let Some(project_id) = request.project_id {
1965        let project_id = ProjectId::from_proto(project_id);
1966        database
1967            .project_connection_ids(project_id, session.connection_id)
1968            .await?
1969    } else {
1970        database
1971            .room_connection_ids(room_id, session.connection_id)
1972            .await?
1973    };
1974
1975    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1976        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1977        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1978        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1979    });
1980    for follower_peer_id in request.follower_ids.iter().copied() {
1981        let follower_connection_id = follower_peer_id.into();
1982        if Some(follower_peer_id) != leader_id && connection_ids.contains(&follower_connection_id) {
1983            session.peer.forward_send(
1984                session.connection_id,
1985                follower_connection_id,
1986                request.clone(),
1987            )?;
1988        }
1989    }
1990    Ok(())
1991}
1992
1993async fn get_users(
1994    request: proto::GetUsers,
1995    response: Response<proto::GetUsers>,
1996    session: Session,
1997) -> Result<()> {
1998    let user_ids = request
1999        .user_ids
2000        .into_iter()
2001        .map(UserId::from_proto)
2002        .collect();
2003    let users = session
2004        .db()
2005        .await
2006        .get_users_by_ids(user_ids)
2007        .await?
2008        .into_iter()
2009        .map(|user| proto::User {
2010            id: user.id.to_proto(),
2011            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2012            github_login: user.github_login,
2013        })
2014        .collect();
2015    response.send(proto::UsersResponse { users })?;
2016    Ok(())
2017}
2018
2019async fn fuzzy_search_users(
2020    request: proto::FuzzySearchUsers,
2021    response: Response<proto::FuzzySearchUsers>,
2022    session: Session,
2023) -> Result<()> {
2024    let query = request.query;
2025    let users = match query.len() {
2026        0 => vec![],
2027        1 | 2 => session
2028            .db()
2029            .await
2030            .get_user_by_github_login(&query)
2031            .await?
2032            .into_iter()
2033            .collect(),
2034        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2035    };
2036    let users = users
2037        .into_iter()
2038        .filter(|user| user.id != session.user_id)
2039        .map(|user| proto::User {
2040            id: user.id.to_proto(),
2041            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2042            github_login: user.github_login,
2043        })
2044        .collect();
2045    response.send(proto::UsersResponse { users })?;
2046    Ok(())
2047}
2048
2049async fn request_contact(
2050    request: proto::RequestContact,
2051    response: Response<proto::RequestContact>,
2052    session: Session,
2053) -> Result<()> {
2054    let requester_id = session.user_id;
2055    let responder_id = UserId::from_proto(request.responder_id);
2056    if requester_id == responder_id {
2057        return Err(anyhow!("cannot add yourself as a contact"))?;
2058    }
2059
2060    session
2061        .db()
2062        .await
2063        .send_contact_request(requester_id, responder_id)
2064        .await?;
2065
2066    // Update outgoing contact requests of requester
2067    let mut update = proto::UpdateContacts::default();
2068    update.outgoing_requests.push(responder_id.to_proto());
2069    for connection_id in session
2070        .connection_pool()
2071        .await
2072        .user_connection_ids(requester_id)
2073    {
2074        session.peer.send(connection_id, update.clone())?;
2075    }
2076
2077    // Update incoming contact requests of responder
2078    let mut update = proto::UpdateContacts::default();
2079    update
2080        .incoming_requests
2081        .push(proto::IncomingContactRequest {
2082            requester_id: requester_id.to_proto(),
2083            should_notify: true,
2084        });
2085    for connection_id in session
2086        .connection_pool()
2087        .await
2088        .user_connection_ids(responder_id)
2089    {
2090        session.peer.send(connection_id, update.clone())?;
2091    }
2092
2093    response.send(proto::Ack {})?;
2094    Ok(())
2095}
2096
2097async fn respond_to_contact_request(
2098    request: proto::RespondToContactRequest,
2099    response: Response<proto::RespondToContactRequest>,
2100    session: Session,
2101) -> Result<()> {
2102    let responder_id = session.user_id;
2103    let requester_id = UserId::from_proto(request.requester_id);
2104    let db = session.db().await;
2105    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2106        db.dismiss_contact_notification(responder_id, requester_id)
2107            .await?;
2108    } else {
2109        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2110
2111        db.respond_to_contact_request(responder_id, requester_id, accept)
2112            .await?;
2113        let requester_busy = db.is_user_busy(requester_id).await?;
2114        let responder_busy = db.is_user_busy(responder_id).await?;
2115
2116        let pool = session.connection_pool().await;
2117        // Update responder with new contact
2118        let mut update = proto::UpdateContacts::default();
2119        if accept {
2120            update
2121                .contacts
2122                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2123        }
2124        update
2125            .remove_incoming_requests
2126            .push(requester_id.to_proto());
2127        for connection_id in pool.user_connection_ids(responder_id) {
2128            session.peer.send(connection_id, update.clone())?;
2129        }
2130
2131        // Update requester with new contact
2132        let mut update = proto::UpdateContacts::default();
2133        if accept {
2134            update
2135                .contacts
2136                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2137        }
2138        update
2139            .remove_outgoing_requests
2140            .push(responder_id.to_proto());
2141        for connection_id in pool.user_connection_ids(requester_id) {
2142            session.peer.send(connection_id, update.clone())?;
2143        }
2144    }
2145
2146    response.send(proto::Ack {})?;
2147    Ok(())
2148}
2149
2150async fn remove_contact(
2151    request: proto::RemoveContact,
2152    response: Response<proto::RemoveContact>,
2153    session: Session,
2154) -> Result<()> {
2155    let requester_id = session.user_id;
2156    let responder_id = UserId::from_proto(request.user_id);
2157    let db = session.db().await;
2158    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2159
2160    let pool = session.connection_pool().await;
2161    // Update outgoing contact requests of requester
2162    let mut update = proto::UpdateContacts::default();
2163    if contact_accepted {
2164        update.remove_contacts.push(responder_id.to_proto());
2165    } else {
2166        update
2167            .remove_outgoing_requests
2168            .push(responder_id.to_proto());
2169    }
2170    for connection_id in pool.user_connection_ids(requester_id) {
2171        session.peer.send(connection_id, update.clone())?;
2172    }
2173
2174    // Update incoming contact requests of responder
2175    let mut update = proto::UpdateContacts::default();
2176    if contact_accepted {
2177        update.remove_contacts.push(requester_id.to_proto());
2178    } else {
2179        update
2180            .remove_incoming_requests
2181            .push(requester_id.to_proto());
2182    }
2183    for connection_id in pool.user_connection_ids(responder_id) {
2184        session.peer.send(connection_id, update.clone())?;
2185    }
2186
2187    response.send(proto::Ack {})?;
2188    Ok(())
2189}
2190
2191async fn create_channel(
2192    request: proto::CreateChannel,
2193    response: Response<proto::CreateChannel>,
2194    session: Session,
2195) -> Result<()> {
2196    let db = session.db().await;
2197    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2198
2199    if let Some(live_kit) = session.live_kit_client.as_ref() {
2200        live_kit.create_room(live_kit_room.clone()).await?;
2201    }
2202
2203    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2204    let id = db
2205        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2206        .await?;
2207
2208    let channel = proto::Channel {
2209        id: id.to_proto(),
2210        name: request.name,
2211    };
2212
2213    response.send(proto::CreateChannelResponse {
2214        channel: Some(channel.clone()),
2215        parent_id: request.parent_id,
2216    })?;
2217
2218    let Some(parent_id) = parent_id else {
2219        return Ok(());
2220    };
2221
2222    let update = proto::UpdateChannels {
2223        channels: vec![channel],
2224        insert_edge: vec![ChannelEdge {
2225            parent_id: parent_id.to_proto(),
2226            channel_id: id.to_proto(),
2227        }],
2228        ..Default::default()
2229    };
2230
2231    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2232
2233    let connection_pool = session.connection_pool().await;
2234    for user_id in user_ids_to_notify {
2235        for connection_id in connection_pool.user_connection_ids(user_id) {
2236            if user_id == session.user_id {
2237                continue;
2238            }
2239            session.peer.send(connection_id, update.clone())?;
2240        }
2241    }
2242
2243    Ok(())
2244}
2245
2246async fn delete_channel(
2247    request: proto::DeleteChannel,
2248    response: Response<proto::DeleteChannel>,
2249    session: Session,
2250) -> Result<()> {
2251    let db = session.db().await;
2252
2253    let channel_id = request.channel_id;
2254    let (removed_channels, member_ids) = db
2255        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2256        .await?;
2257    response.send(proto::Ack {})?;
2258
2259    // Notify members of removed channels
2260    let mut update = proto::UpdateChannels::default();
2261    update
2262        .delete_channels
2263        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2264
2265    let connection_pool = session.connection_pool().await;
2266    for member_id in member_ids {
2267        for connection_id in connection_pool.user_connection_ids(member_id) {
2268            session.peer.send(connection_id, update.clone())?;
2269        }
2270    }
2271
2272    Ok(())
2273}
2274
2275async fn invite_channel_member(
2276    request: proto::InviteChannelMember,
2277    response: Response<proto::InviteChannelMember>,
2278    session: Session,
2279) -> Result<()> {
2280    let db = session.db().await;
2281    let channel_id = ChannelId::from_proto(request.channel_id);
2282    let invitee_id = UserId::from_proto(request.user_id);
2283    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2284        .await?;
2285
2286    let (channel, _) = db
2287        .get_channel(channel_id, session.user_id)
2288        .await?
2289        .ok_or_else(|| anyhow!("channel not found"))?;
2290
2291    let mut update = proto::UpdateChannels::default();
2292    update.channel_invitations.push(proto::Channel {
2293        id: channel.id.to_proto(),
2294        name: channel.name,
2295    });
2296    for connection_id in session
2297        .connection_pool()
2298        .await
2299        .user_connection_ids(invitee_id)
2300    {
2301        session.peer.send(connection_id, update.clone())?;
2302    }
2303
2304    response.send(proto::Ack {})?;
2305    Ok(())
2306}
2307
2308async fn remove_channel_member(
2309    request: proto::RemoveChannelMember,
2310    response: Response<proto::RemoveChannelMember>,
2311    session: Session,
2312) -> Result<()> {
2313    let db = session.db().await;
2314    let channel_id = ChannelId::from_proto(request.channel_id);
2315    let member_id = UserId::from_proto(request.user_id);
2316
2317    db.remove_channel_member(channel_id, member_id, session.user_id)
2318        .await?;
2319
2320    let mut update = proto::UpdateChannels::default();
2321    update.delete_channels.push(channel_id.to_proto());
2322
2323    for connection_id in session
2324        .connection_pool()
2325        .await
2326        .user_connection_ids(member_id)
2327    {
2328        session.peer.send(connection_id, update.clone())?;
2329    }
2330
2331    response.send(proto::Ack {})?;
2332    Ok(())
2333}
2334
2335async fn set_channel_member_admin(
2336    request: proto::SetChannelMemberAdmin,
2337    response: Response<proto::SetChannelMemberAdmin>,
2338    session: Session,
2339) -> Result<()> {
2340    let db = session.db().await;
2341    let channel_id = ChannelId::from_proto(request.channel_id);
2342    let member_id = UserId::from_proto(request.user_id);
2343    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2344        .await?;
2345
2346    let (channel, has_accepted) = db
2347        .get_channel(channel_id, member_id)
2348        .await?
2349        .ok_or_else(|| anyhow!("channel not found"))?;
2350
2351    let mut update = proto::UpdateChannels::default();
2352    if has_accepted {
2353        update.channel_permissions.push(proto::ChannelPermission {
2354            channel_id: channel.id.to_proto(),
2355            is_admin: request.admin,
2356        });
2357    }
2358
2359    for connection_id in session
2360        .connection_pool()
2361        .await
2362        .user_connection_ids(member_id)
2363    {
2364        session.peer.send(connection_id, update.clone())?;
2365    }
2366
2367    response.send(proto::Ack {})?;
2368    Ok(())
2369}
2370
2371async fn rename_channel(
2372    request: proto::RenameChannel,
2373    response: Response<proto::RenameChannel>,
2374    session: Session,
2375) -> Result<()> {
2376    let db = session.db().await;
2377    let channel_id = ChannelId::from_proto(request.channel_id);
2378    let new_name = db
2379        .rename_channel(channel_id, session.user_id, &request.name)
2380        .await?;
2381
2382    let channel = proto::Channel {
2383        id: request.channel_id,
2384        name: new_name,
2385    };
2386    response.send(proto::RenameChannelResponse {
2387        channel: Some(channel.clone()),
2388    })?;
2389    let mut update = proto::UpdateChannels::default();
2390    update.channels.push(channel);
2391
2392    let member_ids = db.get_channel_members(channel_id).await?;
2393
2394    let connection_pool = session.connection_pool().await;
2395    for member_id in member_ids {
2396        for connection_id in connection_pool.user_connection_ids(member_id) {
2397            session.peer.send(connection_id, update.clone())?;
2398        }
2399    }
2400
2401    Ok(())
2402}
2403
2404async fn link_channel(
2405    request: proto::LinkChannel,
2406    response: Response<proto::LinkChannel>,
2407    session: Session,
2408) -> Result<()> {
2409    let db = session.db().await;
2410    let channel_id = ChannelId::from_proto(request.channel_id);
2411    let to = ChannelId::from_proto(request.to);
2412    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2413
2414    let members = db.get_channel_members(to).await?;
2415    let connection_pool = session.connection_pool().await;
2416    let update = proto::UpdateChannels {
2417        channels: channels_to_send
2418            .channels
2419            .into_iter()
2420            .map(|channel| proto::Channel {
2421                id: channel.id.to_proto(),
2422                name: channel.name,
2423            })
2424            .collect(),
2425        insert_edge: channels_to_send.edges,
2426        ..Default::default()
2427    };
2428    for member_id in members {
2429        for connection_id in connection_pool.user_connection_ids(member_id) {
2430            session.peer.send(connection_id, update.clone())?;
2431        }
2432    }
2433
2434    response.send(Ack {})?;
2435
2436    Ok(())
2437}
2438
2439async fn unlink_channel(
2440    request: proto::UnlinkChannel,
2441    response: Response<proto::UnlinkChannel>,
2442    session: Session,
2443) -> Result<()> {
2444    let db = session.db().await;
2445    let channel_id = ChannelId::from_proto(request.channel_id);
2446    let from = ChannelId::from_proto(request.from);
2447
2448    db.unlink_channel(session.user_id, channel_id, from).await?;
2449
2450    let members = db.get_channel_members(from).await?;
2451
2452    let update = proto::UpdateChannels {
2453        delete_edge: vec![proto::ChannelEdge {
2454            channel_id: channel_id.to_proto(),
2455            parent_id: from.to_proto(),
2456        }],
2457        ..Default::default()
2458    };
2459    let connection_pool = session.connection_pool().await;
2460    for member_id in members {
2461        for connection_id in connection_pool.user_connection_ids(member_id) {
2462            session.peer.send(connection_id, update.clone())?;
2463        }
2464    }
2465
2466    response.send(Ack {})?;
2467
2468    Ok(())
2469}
2470
2471async fn move_channel(
2472    request: proto::MoveChannel,
2473    response: Response<proto::MoveChannel>,
2474    session: Session,
2475) -> Result<()> {
2476    let db = session.db().await;
2477    let channel_id = ChannelId::from_proto(request.channel_id);
2478    let from_parent = ChannelId::from_proto(request.from);
2479    let to = ChannelId::from_proto(request.to);
2480
2481    let channels_to_send = db
2482        .move_channel(session.user_id, channel_id, from_parent, to)
2483        .await?;
2484
2485    if channels_to_send.is_empty() {
2486        response.send(Ack {})?;
2487        return Ok(());
2488    }
2489
2490    let members_from = db.get_channel_members(from_parent).await?;
2491    let members_to = db.get_channel_members(to).await?;
2492
2493    let update = proto::UpdateChannels {
2494        delete_edge: vec![proto::ChannelEdge {
2495            channel_id: channel_id.to_proto(),
2496            parent_id: from_parent.to_proto(),
2497        }],
2498        ..Default::default()
2499    };
2500    let connection_pool = session.connection_pool().await;
2501    for member_id in members_from {
2502        for connection_id in connection_pool.user_connection_ids(member_id) {
2503            session.peer.send(connection_id, update.clone())?;
2504        }
2505    }
2506
2507    let update = proto::UpdateChannels {
2508        channels: channels_to_send
2509            .channels
2510            .into_iter()
2511            .map(|channel| proto::Channel {
2512                id: channel.id.to_proto(),
2513                name: channel.name,
2514            })
2515            .collect(),
2516        insert_edge: channels_to_send.edges,
2517        ..Default::default()
2518    };
2519    for member_id in members_to {
2520        for connection_id in connection_pool.user_connection_ids(member_id) {
2521            session.peer.send(connection_id, update.clone())?;
2522        }
2523    }
2524
2525    response.send(Ack {})?;
2526
2527    Ok(())
2528}
2529
2530async fn get_channel_members(
2531    request: proto::GetChannelMembers,
2532    response: Response<proto::GetChannelMembers>,
2533    session: Session,
2534) -> Result<()> {
2535    let db = session.db().await;
2536    let channel_id = ChannelId::from_proto(request.channel_id);
2537    let members = db
2538        .get_channel_member_details(channel_id, session.user_id)
2539        .await?;
2540    response.send(proto::GetChannelMembersResponse { members })?;
2541    Ok(())
2542}
2543
2544async fn respond_to_channel_invite(
2545    request: proto::RespondToChannelInvite,
2546    response: Response<proto::RespondToChannelInvite>,
2547    session: Session,
2548) -> Result<()> {
2549    let db = session.db().await;
2550    let channel_id = ChannelId::from_proto(request.channel_id);
2551    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2552        .await?;
2553
2554    let mut update = proto::UpdateChannels::default();
2555    update
2556        .remove_channel_invitations
2557        .push(channel_id.to_proto());
2558    if request.accept {
2559        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2560        update
2561            .channels
2562            .extend(
2563                result
2564                    .channels
2565                    .channels
2566                    .into_iter()
2567                    .map(|channel| proto::Channel {
2568                        id: channel.id.to_proto(),
2569                        name: channel.name,
2570                    }),
2571            );
2572        update.unseen_channel_messages = result.channel_messages;
2573        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2574        update.insert_edge = result.channels.edges;
2575        update
2576            .channel_participants
2577            .extend(
2578                result
2579                    .channel_participants
2580                    .into_iter()
2581                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2582                        channel_id: channel_id.to_proto(),
2583                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2584                    }),
2585            );
2586        update
2587            .channel_permissions
2588            .extend(
2589                result
2590                    .channels_with_admin_privileges
2591                    .into_iter()
2592                    .map(|channel_id| proto::ChannelPermission {
2593                        channel_id: channel_id.to_proto(),
2594                        is_admin: true,
2595                    }),
2596            );
2597    }
2598    session.peer.send(session.connection_id, update)?;
2599    response.send(proto::Ack {})?;
2600
2601    Ok(())
2602}
2603
2604async fn join_channel(
2605    request: proto::JoinChannel,
2606    response: Response<proto::JoinChannel>,
2607    session: Session,
2608) -> Result<()> {
2609    let channel_id = ChannelId::from_proto(request.channel_id);
2610
2611    let joined_room = {
2612        leave_room_for_session(&session).await?;
2613        let db = session.db().await;
2614
2615        let room_id = db.room_id_for_channel(channel_id).await?;
2616
2617        let joined_room = db
2618            .join_room(room_id, session.user_id, session.connection_id)
2619            .await?;
2620
2621        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2622            let token = live_kit
2623                .room_token(
2624                    &joined_room.room.live_kit_room,
2625                    &session.user_id.to_string(),
2626                )
2627                .trace_err()?;
2628
2629            Some(LiveKitConnectionInfo {
2630                server_url: live_kit.url().into(),
2631                token,
2632            })
2633        });
2634
2635        response.send(proto::JoinRoomResponse {
2636            room: Some(joined_room.room.clone()),
2637            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2638            live_kit_connection_info,
2639        })?;
2640
2641        room_updated(&joined_room.room, &session.peer);
2642
2643        joined_room.into_inner()
2644    };
2645
2646    channel_updated(
2647        channel_id,
2648        &joined_room.room,
2649        &joined_room.channel_members,
2650        &session.peer,
2651        &*session.connection_pool().await,
2652    );
2653
2654    update_user_contacts(session.user_id, &session).await?;
2655
2656    Ok(())
2657}
2658
2659async fn join_channel_buffer(
2660    request: proto::JoinChannelBuffer,
2661    response: Response<proto::JoinChannelBuffer>,
2662    session: Session,
2663) -> Result<()> {
2664    let db = session.db().await;
2665    let channel_id = ChannelId::from_proto(request.channel_id);
2666
2667    let open_response = db
2668        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2669        .await?;
2670
2671    let collaborators = open_response.collaborators.clone();
2672    response.send(open_response)?;
2673
2674    let update = UpdateChannelBufferCollaborators {
2675        channel_id: channel_id.to_proto(),
2676        collaborators: collaborators.clone(),
2677    };
2678    channel_buffer_updated(
2679        session.connection_id,
2680        collaborators
2681            .iter()
2682            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2683        &update,
2684        &session.peer,
2685    );
2686
2687    Ok(())
2688}
2689
2690async fn update_channel_buffer(
2691    request: proto::UpdateChannelBuffer,
2692    session: Session,
2693) -> Result<()> {
2694    let db = session.db().await;
2695    let channel_id = ChannelId::from_proto(request.channel_id);
2696
2697    let (collaborators, non_collaborators, epoch, version) = db
2698        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2699        .await?;
2700
2701    channel_buffer_updated(
2702        session.connection_id,
2703        collaborators,
2704        &proto::UpdateChannelBuffer {
2705            channel_id: channel_id.to_proto(),
2706            operations: request.operations,
2707        },
2708        &session.peer,
2709    );
2710
2711    let pool = &*session.connection_pool().await;
2712
2713    broadcast(
2714        None,
2715        non_collaborators
2716            .iter()
2717            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2718        |peer_id| {
2719            session.peer.send(
2720                peer_id.into(),
2721                proto::UpdateChannels {
2722                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2723                        channel_id: channel_id.to_proto(),
2724                        epoch: epoch as u64,
2725                        version: version.clone(),
2726                    }],
2727                    ..Default::default()
2728                },
2729            )
2730        },
2731    );
2732
2733    Ok(())
2734}
2735
2736async fn rejoin_channel_buffers(
2737    request: proto::RejoinChannelBuffers,
2738    response: Response<proto::RejoinChannelBuffers>,
2739    session: Session,
2740) -> Result<()> {
2741    let db = session.db().await;
2742    let buffers = db
2743        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2744        .await?;
2745
2746    for rejoined_buffer in &buffers {
2747        let collaborators_to_notify = rejoined_buffer
2748            .buffer
2749            .collaborators
2750            .iter()
2751            .filter_map(|c| Some(c.peer_id?.into()));
2752        channel_buffer_updated(
2753            session.connection_id,
2754            collaborators_to_notify,
2755            &proto::UpdateChannelBufferCollaborators {
2756                channel_id: rejoined_buffer.buffer.channel_id,
2757                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2758            },
2759            &session.peer,
2760        );
2761    }
2762
2763    response.send(proto::RejoinChannelBuffersResponse {
2764        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2765    })?;
2766
2767    Ok(())
2768}
2769
2770async fn leave_channel_buffer(
2771    request: proto::LeaveChannelBuffer,
2772    response: Response<proto::LeaveChannelBuffer>,
2773    session: Session,
2774) -> Result<()> {
2775    let db = session.db().await;
2776    let channel_id = ChannelId::from_proto(request.channel_id);
2777
2778    let left_buffer = db
2779        .leave_channel_buffer(channel_id, session.connection_id)
2780        .await?;
2781
2782    response.send(Ack {})?;
2783
2784    channel_buffer_updated(
2785        session.connection_id,
2786        left_buffer.connections,
2787        &proto::UpdateChannelBufferCollaborators {
2788            channel_id: channel_id.to_proto(),
2789            collaborators: left_buffer.collaborators,
2790        },
2791        &session.peer,
2792    );
2793
2794    Ok(())
2795}
2796
2797fn channel_buffer_updated<T: EnvelopedMessage>(
2798    sender_id: ConnectionId,
2799    collaborators: impl IntoIterator<Item = ConnectionId>,
2800    message: &T,
2801    peer: &Peer,
2802) {
2803    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2804        peer.send(peer_id.into(), message.clone())
2805    });
2806}
2807
2808async fn send_channel_message(
2809    request: proto::SendChannelMessage,
2810    response: Response<proto::SendChannelMessage>,
2811    session: Session,
2812) -> Result<()> {
2813    // Validate the message body.
2814    let body = request.body.trim().to_string();
2815    if body.len() > MAX_MESSAGE_LEN {
2816        return Err(anyhow!("message is too long"))?;
2817    }
2818    if body.is_empty() {
2819        return Err(anyhow!("message can't be blank"))?;
2820    }
2821
2822    let timestamp = OffsetDateTime::now_utc();
2823    let nonce = request
2824        .nonce
2825        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2826
2827    let channel_id = ChannelId::from_proto(request.channel_id);
2828    let (message_id, connection_ids, non_participants) = session
2829        .db()
2830        .await
2831        .create_channel_message(
2832            channel_id,
2833            session.user_id,
2834            &body,
2835            timestamp,
2836            nonce.clone().into(),
2837        )
2838        .await?;
2839    let message = proto::ChannelMessage {
2840        sender_id: session.user_id.to_proto(),
2841        id: message_id.to_proto(),
2842        body,
2843        timestamp: timestamp.unix_timestamp() as u64,
2844        nonce: Some(nonce),
2845    };
2846    broadcast(Some(session.connection_id), connection_ids, |connection| {
2847        session.peer.send(
2848            connection,
2849            proto::ChannelMessageSent {
2850                channel_id: channel_id.to_proto(),
2851                message: Some(message.clone()),
2852            },
2853        )
2854    });
2855    response.send(proto::SendChannelMessageResponse {
2856        message: Some(message),
2857    })?;
2858
2859    let pool = &*session.connection_pool().await;
2860    broadcast(
2861        None,
2862        non_participants
2863            .iter()
2864            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2865        |peer_id| {
2866            session.peer.send(
2867                peer_id.into(),
2868                proto::UpdateChannels {
2869                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2870                        channel_id: channel_id.to_proto(),
2871                        message_id: message_id.to_proto(),
2872                    }],
2873                    ..Default::default()
2874                },
2875            )
2876        },
2877    );
2878
2879    Ok(())
2880}
2881
2882async fn remove_channel_message(
2883    request: proto::RemoveChannelMessage,
2884    response: Response<proto::RemoveChannelMessage>,
2885    session: Session,
2886) -> Result<()> {
2887    let channel_id = ChannelId::from_proto(request.channel_id);
2888    let message_id = MessageId::from_proto(request.message_id);
2889    let connection_ids = session
2890        .db()
2891        .await
2892        .remove_channel_message(channel_id, message_id, session.user_id)
2893        .await?;
2894    broadcast(Some(session.connection_id), connection_ids, |connection| {
2895        session.peer.send(connection, request.clone())
2896    });
2897    response.send(proto::Ack {})?;
2898    Ok(())
2899}
2900
2901async fn acknowledge_channel_message(
2902    request: proto::AckChannelMessage,
2903    session: Session,
2904) -> Result<()> {
2905    let channel_id = ChannelId::from_proto(request.channel_id);
2906    let message_id = MessageId::from_proto(request.message_id);
2907    session
2908        .db()
2909        .await
2910        .observe_channel_message(channel_id, session.user_id, message_id)
2911        .await?;
2912    Ok(())
2913}
2914
2915async fn join_channel_chat(
2916    request: proto::JoinChannelChat,
2917    response: Response<proto::JoinChannelChat>,
2918    session: Session,
2919) -> Result<()> {
2920    let channel_id = ChannelId::from_proto(request.channel_id);
2921
2922    let db = session.db().await;
2923    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2924        .await?;
2925    let messages = db
2926        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2927        .await?;
2928    response.send(proto::JoinChannelChatResponse {
2929        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2930        messages,
2931    })?;
2932    Ok(())
2933}
2934
2935async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2936    let channel_id = ChannelId::from_proto(request.channel_id);
2937    session
2938        .db()
2939        .await
2940        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2941        .await?;
2942    Ok(())
2943}
2944
2945async fn get_channel_messages(
2946    request: proto::GetChannelMessages,
2947    response: Response<proto::GetChannelMessages>,
2948    session: Session,
2949) -> Result<()> {
2950    let channel_id = ChannelId::from_proto(request.channel_id);
2951    let messages = session
2952        .db()
2953        .await
2954        .get_channel_messages(
2955            channel_id,
2956            session.user_id,
2957            MESSAGE_COUNT_PER_PAGE,
2958            Some(MessageId::from_proto(request.before_message_id)),
2959        )
2960        .await?;
2961    response.send(proto::GetChannelMessagesResponse {
2962        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2963        messages,
2964    })?;
2965    Ok(())
2966}
2967
2968async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2969    let project_id = ProjectId::from_proto(request.project_id);
2970    let project_connection_ids = session
2971        .db()
2972        .await
2973        .project_connection_ids(project_id, session.connection_id)
2974        .await?;
2975    broadcast(
2976        Some(session.connection_id),
2977        project_connection_ids.iter().copied(),
2978        |connection_id| {
2979            session
2980                .peer
2981                .forward_send(session.connection_id, connection_id, request.clone())
2982        },
2983    );
2984    Ok(())
2985}
2986
2987async fn get_private_user_info(
2988    _request: proto::GetPrivateUserInfo,
2989    response: Response<proto::GetPrivateUserInfo>,
2990    session: Session,
2991) -> Result<()> {
2992    let db = session.db().await;
2993
2994    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2995    let user = db
2996        .get_user_by_id(session.user_id)
2997        .await?
2998        .ok_or_else(|| anyhow!("user not found"))?;
2999    let flags = db.get_user_flags(session.user_id).await?;
3000
3001    response.send(proto::GetPrivateUserInfoResponse {
3002        metrics_id,
3003        staff: user.admin,
3004        flags,
3005    })?;
3006    Ok(())
3007}
3008
3009fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3010    match message {
3011        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3012        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3013        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3014        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3015        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3016            code: frame.code.into(),
3017            reason: frame.reason,
3018        })),
3019    }
3020}
3021
3022fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3023    match message {
3024        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3025        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3026        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3027        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3028        AxumMessage::Close(frame) => {
3029            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3030                code: frame.code.into(),
3031                reason: frame.reason,
3032            }))
3033        }
3034    }
3035}
3036
3037fn build_initial_channels_update(
3038    channels: ChannelsForUser,
3039    channel_invites: Vec<db::Channel>,
3040) -> proto::UpdateChannels {
3041    let mut update = proto::UpdateChannels::default();
3042
3043    for channel in channels.channels.channels {
3044        update.channels.push(proto::Channel {
3045            id: channel.id.to_proto(),
3046            name: channel.name,
3047        });
3048    }
3049
3050    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3051    update.unseen_channel_messages = channels.channel_messages;
3052    update.insert_edge = channels.channels.edges;
3053
3054    for (channel_id, participants) in channels.channel_participants {
3055        update
3056            .channel_participants
3057            .push(proto::ChannelParticipants {
3058                channel_id: channel_id.to_proto(),
3059                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3060            });
3061    }
3062
3063    update
3064        .channel_permissions
3065        .extend(
3066            channels
3067                .channels_with_admin_privileges
3068                .into_iter()
3069                .map(|id| proto::ChannelPermission {
3070                    channel_id: id.to_proto(),
3071                    is_admin: true,
3072                }),
3073        );
3074
3075    for channel in channel_invites {
3076        update.channel_invitations.push(proto::Channel {
3077            id: channel.id.to_proto(),
3078            name: channel.name,
3079        });
3080    }
3081
3082    update
3083}
3084
3085fn build_initial_contacts_update(
3086    contacts: Vec<db::Contact>,
3087    pool: &ConnectionPool,
3088) -> proto::UpdateContacts {
3089    let mut update = proto::UpdateContacts::default();
3090
3091    for contact in contacts {
3092        match contact {
3093            db::Contact::Accepted {
3094                user_id,
3095                should_notify,
3096                busy,
3097            } => {
3098                update
3099                    .contacts
3100                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3101            }
3102            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3103            db::Contact::Incoming {
3104                user_id,
3105                should_notify,
3106            } => update
3107                .incoming_requests
3108                .push(proto::IncomingContactRequest {
3109                    requester_id: user_id.to_proto(),
3110                    should_notify,
3111                }),
3112        }
3113    }
3114
3115    update
3116}
3117
3118fn contact_for_user(
3119    user_id: UserId,
3120    should_notify: bool,
3121    busy: bool,
3122    pool: &ConnectionPool,
3123) -> proto::Contact {
3124    proto::Contact {
3125        user_id: user_id.to_proto(),
3126        online: pool.is_user_online(user_id),
3127        busy,
3128        should_notify,
3129    }
3130}
3131
3132fn room_updated(room: &proto::Room, peer: &Peer) {
3133    broadcast(
3134        None,
3135        room.participants
3136            .iter()
3137            .filter_map(|participant| Some(participant.peer_id?.into())),
3138        |peer_id| {
3139            peer.send(
3140                peer_id.into(),
3141                proto::RoomUpdated {
3142                    room: Some(room.clone()),
3143                },
3144            )
3145        },
3146    );
3147}
3148
3149fn channel_updated(
3150    channel_id: ChannelId,
3151    room: &proto::Room,
3152    channel_members: &[UserId],
3153    peer: &Peer,
3154    pool: &ConnectionPool,
3155) {
3156    let participants = room
3157        .participants
3158        .iter()
3159        .map(|p| p.user_id)
3160        .collect::<Vec<_>>();
3161
3162    broadcast(
3163        None,
3164        channel_members
3165            .iter()
3166            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3167        |peer_id| {
3168            peer.send(
3169                peer_id.into(),
3170                proto::UpdateChannels {
3171                    channel_participants: vec![proto::ChannelParticipants {
3172                        channel_id: channel_id.to_proto(),
3173                        participant_user_ids: participants.clone(),
3174                    }],
3175                    ..Default::default()
3176                },
3177            )
3178        },
3179    );
3180}
3181
3182async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3183    let db = session.db().await;
3184
3185    let contacts = db.get_contacts(user_id).await?;
3186    let busy = db.is_user_busy(user_id).await?;
3187
3188    let pool = session.connection_pool().await;
3189    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3190    for contact in contacts {
3191        if let db::Contact::Accepted {
3192            user_id: contact_user_id,
3193            ..
3194        } = contact
3195        {
3196            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3197                session
3198                    .peer
3199                    .send(
3200                        contact_conn_id,
3201                        proto::UpdateContacts {
3202                            contacts: vec![updated_contact.clone()],
3203                            remove_contacts: Default::default(),
3204                            incoming_requests: Default::default(),
3205                            remove_incoming_requests: Default::default(),
3206                            outgoing_requests: Default::default(),
3207                            remove_outgoing_requests: Default::default(),
3208                        },
3209                    )
3210                    .trace_err();
3211            }
3212        }
3213    }
3214    Ok(())
3215}
3216
3217async fn leave_room_for_session(session: &Session) -> Result<()> {
3218    let mut contacts_to_update = HashSet::default();
3219
3220    let room_id;
3221    let canceled_calls_to_user_ids;
3222    let live_kit_room;
3223    let delete_live_kit_room;
3224    let room;
3225    let channel_members;
3226    let channel_id;
3227
3228    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3229        contacts_to_update.insert(session.user_id);
3230
3231        for project in left_room.left_projects.values() {
3232            project_left(project, session);
3233        }
3234
3235        room_id = RoomId::from_proto(left_room.room.id);
3236        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3237        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3238        delete_live_kit_room = left_room.deleted;
3239        room = mem::take(&mut left_room.room);
3240        channel_members = mem::take(&mut left_room.channel_members);
3241        channel_id = left_room.channel_id;
3242
3243        room_updated(&room, &session.peer);
3244    } else {
3245        return Ok(());
3246    }
3247
3248    if let Some(channel_id) = channel_id {
3249        channel_updated(
3250            channel_id,
3251            &room,
3252            &channel_members,
3253            &session.peer,
3254            &*session.connection_pool().await,
3255        );
3256    }
3257
3258    {
3259        let pool = session.connection_pool().await;
3260        for canceled_user_id in canceled_calls_to_user_ids {
3261            for connection_id in pool.user_connection_ids(canceled_user_id) {
3262                session
3263                    .peer
3264                    .send(
3265                        connection_id,
3266                        proto::CallCanceled {
3267                            room_id: room_id.to_proto(),
3268                        },
3269                    )
3270                    .trace_err();
3271            }
3272            contacts_to_update.insert(canceled_user_id);
3273        }
3274    }
3275
3276    for contact_user_id in contacts_to_update {
3277        update_user_contacts(contact_user_id, &session).await?;
3278    }
3279
3280    if let Some(live_kit) = session.live_kit_client.as_ref() {
3281        live_kit
3282            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3283            .await
3284            .trace_err();
3285
3286        if delete_live_kit_room {
3287            live_kit.delete_room(live_kit_room).await.trace_err();
3288        }
3289    }
3290
3291    Ok(())
3292}
3293
3294async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3295    let left_channel_buffers = session
3296        .db()
3297        .await
3298        .leave_channel_buffers(session.connection_id)
3299        .await?;
3300
3301    for left_buffer in left_channel_buffers {
3302        channel_buffer_updated(
3303            session.connection_id,
3304            left_buffer.connections,
3305            &proto::UpdateChannelBufferCollaborators {
3306                channel_id: left_buffer.channel_id.to_proto(),
3307                collaborators: left_buffer.collaborators,
3308            },
3309            &session.peer,
3310        );
3311    }
3312
3313    Ok(())
3314}
3315
3316fn project_left(project: &db::LeftProject, session: &Session) {
3317    for connection_id in &project.connection_ids {
3318        if project.host_user_id == session.user_id {
3319            session
3320                .peer
3321                .send(
3322                    *connection_id,
3323                    proto::UnshareProject {
3324                        project_id: project.id.to_proto(),
3325                    },
3326                )
3327                .trace_err();
3328        } else {
3329            session
3330                .peer
3331                .send(
3332                    *connection_id,
3333                    proto::RemoveProjectCollaborator {
3334                        project_id: project.id.to_proto(),
3335                        peer_id: Some(session.connection_id.into()),
3336                    },
3337                )
3338                .trace_err();
3339        }
3340    }
3341}
3342
3343pub trait ResultExt {
3344    type Ok;
3345
3346    fn trace_err(self) -> Option<Self::Ok>;
3347}
3348
3349impl<T, E> ResultExt for Result<T, E>
3350where
3351    E: std::fmt::Debug,
3352{
3353    type Ok = T;
3354
3355    fn trace_err(self) -> Option<T> {
3356        match self {
3357            Ok(value) => Some(value),
3358            Err(error) => {
3359                tracing::error!("{:?}", error);
3360                None
3361            }
3362        }
3363    }
3364}