rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  39        RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(respond_to_channel_invite)
 250            .add_request_handler(join_channel)
 251            .add_request_handler(follow)
 252            .add_message_handler(unfollow)
 253            .add_message_handler(update_followers)
 254            .add_message_handler(update_diff_base)
 255            .add_request_handler(get_private_user_info);
 256
 257        Arc::new(server)
 258    }
 259
 260    pub async fn start(&self) -> Result<()> {
 261        let server_id = *self.id.lock();
 262        let app_state = self.app_state.clone();
 263        let peer = self.peer.clone();
 264        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 265        let pool = self.connection_pool.clone();
 266        let live_kit_client = self.app_state.live_kit_client.clone();
 267
 268        let span = info_span!("start server");
 269        self.executor.spawn_detached(
 270            async move {
 271                tracing::info!("waiting for cleanup timeout");
 272                timeout.await;
 273                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 274                if let Some(room_ids) = app_state
 275                    .db
 276                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 277                    .await
 278                    .trace_err()
 279                {
 280                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 281                    for room_id in room_ids {
 282                        let mut contacts_to_update = HashSet::default();
 283                        let mut canceled_calls_to_user_ids = Vec::new();
 284                        let mut live_kit_room = String::new();
 285                        let mut delete_live_kit_room = false;
 286
 287                        if let Some(mut refreshed_room) = app_state
 288                            .db
 289                            .refresh_room(room_id, server_id)
 290                            .await
 291                            .trace_err()
 292                        {
 293                            tracing::info!(
 294                                room_id = room_id.0,
 295                                new_participant_count = refreshed_room.room.participants.len(),
 296                                "refreshed room"
 297                            );
 298                            room_updated(&refreshed_room.room, &peer);
 299                            if let Some(channel_id) = refreshed_room.channel_id {
 300                                channel_updated(
 301                                    channel_id,
 302                                    &refreshed_room.room,
 303                                    &refreshed_room.channel_members,
 304                                    &peer,
 305                                    &*pool.lock(),
 306                                );
 307                            }
 308                            contacts_to_update
 309                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 310                            contacts_to_update
 311                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 312                            canceled_calls_to_user_ids =
 313                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 314                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 315                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 316                        }
 317
 318                        {
 319                            let pool = pool.lock();
 320                            for canceled_user_id in canceled_calls_to_user_ids {
 321                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 322                                    peer.send(
 323                                        connection_id,
 324                                        proto::CallCanceled {
 325                                            room_id: room_id.to_proto(),
 326                                        },
 327                                    )
 328                                    .trace_err();
 329                                }
 330                            }
 331                        }
 332
 333                        for user_id in contacts_to_update {
 334                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 335                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 336                            if let Some((busy, contacts)) = busy.zip(contacts) {
 337                                let pool = pool.lock();
 338                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 339                                for contact in contacts {
 340                                    if let db::Contact::Accepted {
 341                                        user_id: contact_user_id,
 342                                        ..
 343                                    } = contact
 344                                    {
 345                                        for contact_conn_id in
 346                                            pool.user_connection_ids(contact_user_id)
 347                                        {
 348                                            peer.send(
 349                                                contact_conn_id,
 350                                                proto::UpdateContacts {
 351                                                    contacts: vec![updated_contact.clone()],
 352                                                    remove_contacts: Default::default(),
 353                                                    incoming_requests: Default::default(),
 354                                                    remove_incoming_requests: Default::default(),
 355                                                    outgoing_requests: Default::default(),
 356                                                    remove_outgoing_requests: Default::default(),
 357                                                },
 358                                            )
 359                                            .trace_err();
 360                                        }
 361                                    }
 362                                }
 363                            }
 364                        }
 365
 366                        if let Some(live_kit) = live_kit_client.as_ref() {
 367                            if delete_live_kit_room {
 368                                live_kit.delete_room(live_kit_room).await.trace_err();
 369                            }
 370                        }
 371                    }
 372                }
 373
 374                app_state
 375                    .db
 376                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 377                    .await
 378                    .trace_err();
 379            }
 380            .instrument(span),
 381        );
 382        Ok(())
 383    }
 384
 385    pub fn teardown(&self) {
 386        self.peer.teardown();
 387        self.connection_pool.lock().reset();
 388        let _ = self.teardown.send(());
 389    }
 390
 391    #[cfg(test)]
 392    pub fn reset(&self, id: ServerId) {
 393        self.teardown();
 394        *self.id.lock() = id;
 395        self.peer.reset(id.0 as u32);
 396    }
 397
 398    #[cfg(test)]
 399    pub fn id(&self) -> ServerId {
 400        *self.id.lock()
 401    }
 402
 403    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 404    where
 405        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 406        Fut: 'static + Send + Future<Output = Result<()>>,
 407        M: EnvelopedMessage,
 408    {
 409        let prev_handler = self.handlers.insert(
 410            TypeId::of::<M>(),
 411            Box::new(move |envelope, session| {
 412                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 413                let span = info_span!(
 414                    "handle message",
 415                    payload_type = envelope.payload_type_name()
 416                );
 417                span.in_scope(|| {
 418                    tracing::info!(
 419                        payload_type = envelope.payload_type_name(),
 420                        "message received"
 421                    );
 422                });
 423                let start_time = Instant::now();
 424                let future = (handler)(*envelope, session);
 425                async move {
 426                    let result = future.await;
 427                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 428                    match result {
 429                        Err(error) => {
 430                            tracing::error!(%error, ?duration_ms, "error handling message")
 431                        }
 432                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 433                    }
 434                }
 435                .instrument(span)
 436                .boxed()
 437            }),
 438        );
 439        if prev_handler.is_some() {
 440            panic!("registered a handler for the same message twice");
 441        }
 442        self
 443    }
 444
 445    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 446    where
 447        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 448        Fut: 'static + Send + Future<Output = Result<()>>,
 449        M: EnvelopedMessage,
 450    {
 451        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 452        self
 453    }
 454
 455    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 456    where
 457        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 458        Fut: Send + Future<Output = Result<()>>,
 459        M: RequestMessage,
 460    {
 461        let handler = Arc::new(handler);
 462        self.add_handler(move |envelope, session| {
 463            let receipt = envelope.receipt();
 464            let handler = handler.clone();
 465            async move {
 466                let peer = session.peer.clone();
 467                let responded = Arc::new(AtomicBool::default());
 468                let response = Response {
 469                    peer: peer.clone(),
 470                    responded: responded.clone(),
 471                    receipt,
 472                };
 473                match (handler)(envelope.payload, response, session).await {
 474                    Ok(()) => {
 475                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 476                            Ok(())
 477                        } else {
 478                            Err(anyhow!("handler did not send a response"))?
 479                        }
 480                    }
 481                    Err(error) => {
 482                        peer.respond_with_error(
 483                            receipt,
 484                            proto::Error {
 485                                message: error.to_string(),
 486                            },
 487                        )?;
 488                        Err(error)
 489                    }
 490                }
 491            }
 492        })
 493    }
 494
 495    pub fn handle_connection(
 496        self: &Arc<Self>,
 497        connection: Connection,
 498        address: String,
 499        user: User,
 500        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 501        executor: Executor,
 502    ) -> impl Future<Output = Result<()>> {
 503        let this = self.clone();
 504        let user_id = user.id;
 505        let login = user.github_login;
 506        let span = info_span!("handle connection", %user_id, %login, %address);
 507        let mut teardown = self.teardown.subscribe();
 508        async move {
 509            let (connection_id, handle_io, mut incoming_rx) = this
 510                .peer
 511                .add_connection(connection, {
 512                    let executor = executor.clone();
 513                    move |duration| executor.sleep(duration)
 514                });
 515
 516            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 517            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 518            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 519
 520            if let Some(send_connection_id) = send_connection_id.take() {
 521                let _ = send_connection_id.send(connection_id);
 522            }
 523
 524            if !user.connected_once {
 525                this.peer.send(connection_id, proto::ShowContacts {})?;
 526                this.app_state.db.set_user_connected_once(user_id, true).await?;
 527            }
 528
 529            let (contacts, invite_code, (channels, channel_participants), channel_invites) = future::try_join4(
 530                this.app_state.db.get_contacts(user_id),
 531                this.app_state.db.get_invite_code_for_user(user_id),
 532                this.app_state.db.get_channels(user_id),
 533                this.app_state.db.get_channel_invites(user_id)
 534            ).await?;
 535
 536            {
 537                let mut pool = this.connection_pool.lock();
 538                pool.add_connection(connection_id, user_id, user.admin);
 539                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 540                this.peer.send(connection_id, build_initial_channels_update(channels, channel_participants, channel_invites))?;
 541
 542                if let Some((code, count)) = invite_code {
 543                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 544                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 545                        count: count as u32,
 546                    })?;
 547                }
 548            }
 549
 550            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 551                this.peer.send(connection_id, incoming_call)?;
 552            }
 553
 554            let session = Session {
 555                user_id,
 556                connection_id,
 557                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 558                peer: this.peer.clone(),
 559                connection_pool: this.connection_pool.clone(),
 560                live_kit_client: this.app_state.live_kit_client.clone(),
 561                executor: executor.clone(),
 562            };
 563            update_user_contacts(user_id, &session).await?;
 564
 565            let handle_io = handle_io.fuse();
 566            futures::pin_mut!(handle_io);
 567
 568            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 569            // This prevents deadlocks when e.g., client A performs a request to client B and
 570            // client B performs a request to client A. If both clients stop processing further
 571            // messages until their respective request completes, they won't have a chance to
 572            // respond to the other client's request and cause a deadlock.
 573            //
 574            // This arrangement ensures we will attempt to process earlier messages first, but fall
 575            // back to processing messages arrived later in the spirit of making progress.
 576            let mut foreground_message_handlers = FuturesUnordered::new();
 577            let concurrent_handlers = Arc::new(Semaphore::new(256));
 578            loop {
 579                let next_message = async {
 580                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 581                    let message = incoming_rx.next().await;
 582                    (permit, message)
 583                }.fuse();
 584                futures::pin_mut!(next_message);
 585                futures::select_biased! {
 586                    _ = teardown.changed().fuse() => return Ok(()),
 587                    result = handle_io => {
 588                        if let Err(error) = result {
 589                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 590                        }
 591                        break;
 592                    }
 593                    _ = foreground_message_handlers.next() => {}
 594                    next_message = next_message => {
 595                        let (permit, message) = next_message;
 596                        if let Some(message) = message {
 597                            let type_name = message.payload_type_name();
 598                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 599                            let span_enter = span.enter();
 600                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 601                                let is_background = message.is_background();
 602                                let handle_message = (handler)(message, session.clone());
 603                                drop(span_enter);
 604
 605                                let handle_message = async move {
 606                                    handle_message.await;
 607                                    drop(permit);
 608                                }.instrument(span);
 609                                if is_background {
 610                                    executor.spawn_detached(handle_message);
 611                                } else {
 612                                    foreground_message_handlers.push(handle_message);
 613                                }
 614                            } else {
 615                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 616                            }
 617                        } else {
 618                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 619                            break;
 620                        }
 621                    }
 622                }
 623            }
 624
 625            drop(foreground_message_handlers);
 626            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 627            if let Err(error) = connection_lost(session, teardown, executor).await {
 628                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 629            }
 630
 631            Ok(())
 632        }.instrument(span)
 633    }
 634
 635    pub async fn invite_code_redeemed(
 636        self: &Arc<Self>,
 637        inviter_id: UserId,
 638        invitee_id: UserId,
 639    ) -> Result<()> {
 640        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 641            if let Some(code) = &user.invite_code {
 642                let pool = self.connection_pool.lock();
 643                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 644                for connection_id in pool.user_connection_ids(inviter_id) {
 645                    self.peer.send(
 646                        connection_id,
 647                        proto::UpdateContacts {
 648                            contacts: vec![invitee_contact.clone()],
 649                            ..Default::default()
 650                        },
 651                    )?;
 652                    self.peer.send(
 653                        connection_id,
 654                        proto::UpdateInviteInfo {
 655                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 656                            count: user.invite_count as u32,
 657                        },
 658                    )?;
 659                }
 660            }
 661        }
 662        Ok(())
 663    }
 664
 665    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 666        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 667            if let Some(invite_code) = &user.invite_code {
 668                let pool = self.connection_pool.lock();
 669                for connection_id in pool.user_connection_ids(user_id) {
 670                    self.peer.send(
 671                        connection_id,
 672                        proto::UpdateInviteInfo {
 673                            url: format!(
 674                                "{}{}",
 675                                self.app_state.config.invite_link_prefix, invite_code
 676                            ),
 677                            count: user.invite_count as u32,
 678                        },
 679                    )?;
 680                }
 681            }
 682        }
 683        Ok(())
 684    }
 685
 686    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 687        ServerSnapshot {
 688            connection_pool: ConnectionPoolGuard {
 689                guard: self.connection_pool.lock(),
 690                _not_send: PhantomData,
 691            },
 692            peer: &self.peer,
 693        }
 694    }
 695}
 696
 697impl<'a> Deref for ConnectionPoolGuard<'a> {
 698    type Target = ConnectionPool;
 699
 700    fn deref(&self) -> &Self::Target {
 701        &*self.guard
 702    }
 703}
 704
 705impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 706    fn deref_mut(&mut self) -> &mut Self::Target {
 707        &mut *self.guard
 708    }
 709}
 710
 711impl<'a> Drop for ConnectionPoolGuard<'a> {
 712    fn drop(&mut self) {
 713        #[cfg(test)]
 714        self.check_invariants();
 715    }
 716}
 717
 718fn broadcast<F>(
 719    sender_id: Option<ConnectionId>,
 720    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 721    mut f: F,
 722) where
 723    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 724{
 725    for receiver_id in receiver_ids {
 726        if Some(receiver_id) != sender_id {
 727            if let Err(error) = f(receiver_id) {
 728                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 729            }
 730        }
 731    }
 732}
 733
 734lazy_static! {
 735    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 736}
 737
 738pub struct ProtocolVersion(u32);
 739
 740impl Header for ProtocolVersion {
 741    fn name() -> &'static HeaderName {
 742        &ZED_PROTOCOL_VERSION
 743    }
 744
 745    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 746    where
 747        Self: Sized,
 748        I: Iterator<Item = &'i axum::http::HeaderValue>,
 749    {
 750        let version = values
 751            .next()
 752            .ok_or_else(axum::headers::Error::invalid)?
 753            .to_str()
 754            .map_err(|_| axum::headers::Error::invalid())?
 755            .parse()
 756            .map_err(|_| axum::headers::Error::invalid())?;
 757        Ok(Self(version))
 758    }
 759
 760    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 761        values.extend([self.0.to_string().parse().unwrap()]);
 762    }
 763}
 764
 765pub fn routes(server: Arc<Server>) -> Router<Body> {
 766    Router::new()
 767        .route("/rpc", get(handle_websocket_request))
 768        .layer(
 769            ServiceBuilder::new()
 770                .layer(Extension(server.app_state.clone()))
 771                .layer(middleware::from_fn(auth::validate_header)),
 772        )
 773        .route("/metrics", get(handle_metrics))
 774        .layer(Extension(server))
 775}
 776
 777pub async fn handle_websocket_request(
 778    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 779    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 780    Extension(server): Extension<Arc<Server>>,
 781    Extension(user): Extension<User>,
 782    ws: WebSocketUpgrade,
 783) -> axum::response::Response {
 784    if protocol_version != rpc::PROTOCOL_VERSION {
 785        return (
 786            StatusCode::UPGRADE_REQUIRED,
 787            "client must be upgraded".to_string(),
 788        )
 789            .into_response();
 790    }
 791    let socket_address = socket_address.to_string();
 792    ws.on_upgrade(move |socket| {
 793        use util::ResultExt;
 794        let socket = socket
 795            .map_ok(to_tungstenite_message)
 796            .err_into()
 797            .with(|message| async move { Ok(to_axum_message(message)) });
 798        let connection = Connection::new(Box::pin(socket));
 799        async move {
 800            server
 801                .handle_connection(connection, socket_address, user, None, Executor::Production)
 802                .await
 803                .log_err();
 804        }
 805    })
 806}
 807
 808pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 809    let connections = server
 810        .connection_pool
 811        .lock()
 812        .connections()
 813        .filter(|connection| !connection.admin)
 814        .count();
 815
 816    METRIC_CONNECTIONS.set(connections as _);
 817
 818    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 819    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 820
 821    let encoder = prometheus::TextEncoder::new();
 822    let metric_families = prometheus::gather();
 823    let encoded_metrics = encoder
 824        .encode_to_string(&metric_families)
 825        .map_err(|err| anyhow!("{}", err))?;
 826    Ok(encoded_metrics)
 827}
 828
 829#[instrument(err, skip(executor))]
 830async fn connection_lost(
 831    session: Session,
 832    mut teardown: watch::Receiver<()>,
 833    executor: Executor,
 834) -> Result<()> {
 835    session.peer.disconnect(session.connection_id);
 836    session
 837        .connection_pool()
 838        .await
 839        .remove_connection(session.connection_id)?;
 840
 841    session
 842        .db()
 843        .await
 844        .connection_lost(session.connection_id)
 845        .await
 846        .trace_err();
 847
 848    futures::select_biased! {
 849        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 850            leave_room_for_session(&session).await.trace_err();
 851
 852            if !session
 853                .connection_pool()
 854                .await
 855                .is_user_online(session.user_id)
 856            {
 857                let db = session.db().await;
 858                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 859                    room_updated(&room, &session.peer);
 860                }
 861            }
 862            update_user_contacts(session.user_id, &session).await?;
 863        }
 864        _ = teardown.changed().fuse() => {}
 865    }
 866
 867    Ok(())
 868}
 869
 870async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 871    response.send(proto::Ack {})?;
 872    Ok(())
 873}
 874
 875async fn create_room(
 876    _request: proto::CreateRoom,
 877    response: Response<proto::CreateRoom>,
 878    session: Session,
 879) -> Result<()> {
 880    let live_kit_room = nanoid::nanoid!(30);
 881
 882    let live_kit_connection_info = {
 883        let live_kit_room = live_kit_room.clone();
 884        let live_kit = session.live_kit_client.as_ref();
 885
 886        util::async_iife!({
 887            let live_kit = live_kit?;
 888
 889            live_kit
 890                .create_room(live_kit_room.clone())
 891                .await
 892                .trace_err()?;
 893
 894            let token = live_kit
 895                .room_token(&live_kit_room, &session.user_id.to_string())
 896                .trace_err()?;
 897
 898            Some(proto::LiveKitConnectionInfo {
 899                server_url: live_kit.url().into(),
 900                token,
 901            })
 902        })
 903    }
 904    .await;
 905
 906    let room = session
 907        .db()
 908        .await
 909        .create_room(session.user_id, session.connection_id, &live_kit_room)
 910        .await?;
 911
 912    response.send(proto::CreateRoomResponse {
 913        room: Some(room.clone()),
 914        live_kit_connection_info,
 915    })?;
 916
 917    update_user_contacts(session.user_id, &session).await?;
 918    Ok(())
 919}
 920
 921async fn join_room(
 922    request: proto::JoinRoom,
 923    response: Response<proto::JoinRoom>,
 924    session: Session,
 925) -> Result<()> {
 926    let room_id = RoomId::from_proto(request.id);
 927    let room = {
 928        let room = session
 929            .db()
 930            .await
 931            .join_room(room_id, session.user_id, None, session.connection_id)
 932            .await?;
 933        room_updated(&room.room, &session.peer);
 934        room.room.clone()
 935    };
 936
 937    for connection_id in session
 938        .connection_pool()
 939        .await
 940        .user_connection_ids(session.user_id)
 941    {
 942        session
 943            .peer
 944            .send(
 945                connection_id,
 946                proto::CallCanceled {
 947                    room_id: room_id.to_proto(),
 948                },
 949            )
 950            .trace_err();
 951    }
 952
 953    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 954        if let Some(token) = live_kit
 955            .room_token(&room.live_kit_room, &session.user_id.to_string())
 956            .trace_err()
 957        {
 958            Some(proto::LiveKitConnectionInfo {
 959                server_url: live_kit.url().into(),
 960                token,
 961            })
 962        } else {
 963            None
 964        }
 965    } else {
 966        None
 967    };
 968
 969    response.send(proto::JoinRoomResponse {
 970        room: Some(room),
 971        live_kit_connection_info,
 972    })?;
 973
 974    update_user_contacts(session.user_id, &session).await?;
 975    Ok(())
 976}
 977
 978async fn rejoin_room(
 979    request: proto::RejoinRoom,
 980    response: Response<proto::RejoinRoom>,
 981    session: Session,
 982) -> Result<()> {
 983    let room;
 984    let channel_id;
 985    let channel_members;
 986    {
 987        let mut rejoined_room = session
 988            .db()
 989            .await
 990            .rejoin_room(request, session.user_id, session.connection_id)
 991            .await?;
 992
 993        response.send(proto::RejoinRoomResponse {
 994            room: Some(rejoined_room.room.clone()),
 995            reshared_projects: rejoined_room
 996                .reshared_projects
 997                .iter()
 998                .map(|project| proto::ResharedProject {
 999                    id: project.id.to_proto(),
1000                    collaborators: project
1001                        .collaborators
1002                        .iter()
1003                        .map(|collaborator| collaborator.to_proto())
1004                        .collect(),
1005                })
1006                .collect(),
1007            rejoined_projects: rejoined_room
1008                .rejoined_projects
1009                .iter()
1010                .map(|rejoined_project| proto::RejoinedProject {
1011                    id: rejoined_project.id.to_proto(),
1012                    worktrees: rejoined_project
1013                        .worktrees
1014                        .iter()
1015                        .map(|worktree| proto::WorktreeMetadata {
1016                            id: worktree.id,
1017                            root_name: worktree.root_name.clone(),
1018                            visible: worktree.visible,
1019                            abs_path: worktree.abs_path.clone(),
1020                        })
1021                        .collect(),
1022                    collaborators: rejoined_project
1023                        .collaborators
1024                        .iter()
1025                        .map(|collaborator| collaborator.to_proto())
1026                        .collect(),
1027                    language_servers: rejoined_project.language_servers.clone(),
1028                })
1029                .collect(),
1030        })?;
1031        room_updated(&rejoined_room.room, &session.peer);
1032
1033        for project in &rejoined_room.reshared_projects {
1034            for collaborator in &project.collaborators {
1035                session
1036                    .peer
1037                    .send(
1038                        collaborator.connection_id,
1039                        proto::UpdateProjectCollaborator {
1040                            project_id: project.id.to_proto(),
1041                            old_peer_id: Some(project.old_connection_id.into()),
1042                            new_peer_id: Some(session.connection_id.into()),
1043                        },
1044                    )
1045                    .trace_err();
1046            }
1047
1048            broadcast(
1049                Some(session.connection_id),
1050                project
1051                    .collaborators
1052                    .iter()
1053                    .map(|collaborator| collaborator.connection_id),
1054                |connection_id| {
1055                    session.peer.forward_send(
1056                        session.connection_id,
1057                        connection_id,
1058                        proto::UpdateProject {
1059                            project_id: project.id.to_proto(),
1060                            worktrees: project.worktrees.clone(),
1061                        },
1062                    )
1063                },
1064            );
1065        }
1066
1067        for project in &rejoined_room.rejoined_projects {
1068            for collaborator in &project.collaborators {
1069                session
1070                    .peer
1071                    .send(
1072                        collaborator.connection_id,
1073                        proto::UpdateProjectCollaborator {
1074                            project_id: project.id.to_proto(),
1075                            old_peer_id: Some(project.old_connection_id.into()),
1076                            new_peer_id: Some(session.connection_id.into()),
1077                        },
1078                    )
1079                    .trace_err();
1080            }
1081        }
1082
1083        for project in &mut rejoined_room.rejoined_projects {
1084            for worktree in mem::take(&mut project.worktrees) {
1085                #[cfg(any(test, feature = "test-support"))]
1086                const MAX_CHUNK_SIZE: usize = 2;
1087                #[cfg(not(any(test, feature = "test-support")))]
1088                const MAX_CHUNK_SIZE: usize = 256;
1089
1090                // Stream this worktree's entries.
1091                let message = proto::UpdateWorktree {
1092                    project_id: project.id.to_proto(),
1093                    worktree_id: worktree.id,
1094                    abs_path: worktree.abs_path.clone(),
1095                    root_name: worktree.root_name,
1096                    updated_entries: worktree.updated_entries,
1097                    removed_entries: worktree.removed_entries,
1098                    scan_id: worktree.scan_id,
1099                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1100                    updated_repositories: worktree.updated_repositories,
1101                    removed_repositories: worktree.removed_repositories,
1102                };
1103                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1104                    session.peer.send(session.connection_id, update.clone())?;
1105                }
1106
1107                // Stream this worktree's diagnostics.
1108                for summary in worktree.diagnostic_summaries {
1109                    session.peer.send(
1110                        session.connection_id,
1111                        proto::UpdateDiagnosticSummary {
1112                            project_id: project.id.to_proto(),
1113                            worktree_id: worktree.id,
1114                            summary: Some(summary),
1115                        },
1116                    )?;
1117                }
1118
1119                for settings_file in worktree.settings_files {
1120                    session.peer.send(
1121                        session.connection_id,
1122                        proto::UpdateWorktreeSettings {
1123                            project_id: project.id.to_proto(),
1124                            worktree_id: worktree.id,
1125                            path: settings_file.path,
1126                            content: Some(settings_file.content),
1127                        },
1128                    )?;
1129                }
1130            }
1131
1132            for language_server in &project.language_servers {
1133                session.peer.send(
1134                    session.connection_id,
1135                    proto::UpdateLanguageServer {
1136                        project_id: project.id.to_proto(),
1137                        language_server_id: language_server.id,
1138                        variant: Some(
1139                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1140                                proto::LspDiskBasedDiagnosticsUpdated {},
1141                            ),
1142                        ),
1143                    },
1144                )?;
1145            }
1146        }
1147
1148        room = mem::take(&mut rejoined_room.room);
1149        channel_id = rejoined_room.channel_id;
1150        channel_members = mem::take(&mut rejoined_room.channel_members);
1151    }
1152
1153    //TODO: move this into the room guard
1154    if let Some(channel_id) = channel_id {
1155        channel_updated(
1156            channel_id,
1157            &room,
1158            &channel_members,
1159            &session.peer,
1160            &*session.connection_pool().await,
1161        );
1162    }
1163
1164    update_user_contacts(session.user_id, &session).await?;
1165    Ok(())
1166}
1167
1168async fn leave_room(
1169    _: proto::LeaveRoom,
1170    response: Response<proto::LeaveRoom>,
1171    session: Session,
1172) -> Result<()> {
1173    leave_room_for_session(&session).await?;
1174    response.send(proto::Ack {})?;
1175    Ok(())
1176}
1177
1178async fn call(
1179    request: proto::Call,
1180    response: Response<proto::Call>,
1181    session: Session,
1182) -> Result<()> {
1183    let room_id = RoomId::from_proto(request.room_id);
1184    let calling_user_id = session.user_id;
1185    let calling_connection_id = session.connection_id;
1186    let called_user_id = UserId::from_proto(request.called_user_id);
1187    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1188    if !session
1189        .db()
1190        .await
1191        .has_contact(calling_user_id, called_user_id)
1192        .await?
1193    {
1194        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1195    }
1196
1197    let incoming_call = {
1198        let (room, incoming_call) = &mut *session
1199            .db()
1200            .await
1201            .call(
1202                room_id,
1203                calling_user_id,
1204                calling_connection_id,
1205                called_user_id,
1206                initial_project_id,
1207            )
1208            .await?;
1209        room_updated(&room, &session.peer);
1210        mem::take(incoming_call)
1211    };
1212    update_user_contacts(called_user_id, &session).await?;
1213
1214    let mut calls = session
1215        .connection_pool()
1216        .await
1217        .user_connection_ids(called_user_id)
1218        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1219        .collect::<FuturesUnordered<_>>();
1220
1221    while let Some(call_response) = calls.next().await {
1222        match call_response.as_ref() {
1223            Ok(_) => {
1224                response.send(proto::Ack {})?;
1225                return Ok(());
1226            }
1227            Err(_) => {
1228                call_response.trace_err();
1229            }
1230        }
1231    }
1232
1233    {
1234        let room = session
1235            .db()
1236            .await
1237            .call_failed(room_id, called_user_id)
1238            .await?;
1239        room_updated(&room, &session.peer);
1240    }
1241    update_user_contacts(called_user_id, &session).await?;
1242
1243    Err(anyhow!("failed to ring user"))?
1244}
1245
1246async fn cancel_call(
1247    request: proto::CancelCall,
1248    response: Response<proto::CancelCall>,
1249    session: Session,
1250) -> Result<()> {
1251    let called_user_id = UserId::from_proto(request.called_user_id);
1252    let room_id = RoomId::from_proto(request.room_id);
1253    {
1254        let room = session
1255            .db()
1256            .await
1257            .cancel_call(room_id, session.connection_id, called_user_id)
1258            .await?;
1259        room_updated(&room, &session.peer);
1260    }
1261
1262    for connection_id in session
1263        .connection_pool()
1264        .await
1265        .user_connection_ids(called_user_id)
1266    {
1267        session
1268            .peer
1269            .send(
1270                connection_id,
1271                proto::CallCanceled {
1272                    room_id: room_id.to_proto(),
1273                },
1274            )
1275            .trace_err();
1276    }
1277    response.send(proto::Ack {})?;
1278
1279    update_user_contacts(called_user_id, &session).await?;
1280    Ok(())
1281}
1282
1283async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1284    let room_id = RoomId::from_proto(message.room_id);
1285    {
1286        let room = session
1287            .db()
1288            .await
1289            .decline_call(Some(room_id), session.user_id)
1290            .await?
1291            .ok_or_else(|| anyhow!("failed to decline call"))?;
1292        room_updated(&room, &session.peer);
1293    }
1294
1295    for connection_id in session
1296        .connection_pool()
1297        .await
1298        .user_connection_ids(session.user_id)
1299    {
1300        session
1301            .peer
1302            .send(
1303                connection_id,
1304                proto::CallCanceled {
1305                    room_id: room_id.to_proto(),
1306                },
1307            )
1308            .trace_err();
1309    }
1310    update_user_contacts(session.user_id, &session).await?;
1311    Ok(())
1312}
1313
1314async fn update_participant_location(
1315    request: proto::UpdateParticipantLocation,
1316    response: Response<proto::UpdateParticipantLocation>,
1317    session: Session,
1318) -> Result<()> {
1319    let room_id = RoomId::from_proto(request.room_id);
1320    let location = request
1321        .location
1322        .ok_or_else(|| anyhow!("invalid location"))?;
1323
1324    let db = session.db().await;
1325    let room = db
1326        .update_room_participant_location(room_id, session.connection_id, location)
1327        .await?;
1328
1329    room_updated(&room, &session.peer);
1330    response.send(proto::Ack {})?;
1331    Ok(())
1332}
1333
1334async fn share_project(
1335    request: proto::ShareProject,
1336    response: Response<proto::ShareProject>,
1337    session: Session,
1338) -> Result<()> {
1339    let (project_id, room) = &*session
1340        .db()
1341        .await
1342        .share_project(
1343            RoomId::from_proto(request.room_id),
1344            session.connection_id,
1345            &request.worktrees,
1346        )
1347        .await?;
1348    response.send(proto::ShareProjectResponse {
1349        project_id: project_id.to_proto(),
1350    })?;
1351    room_updated(&room, &session.peer);
1352
1353    Ok(())
1354}
1355
1356async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1357    let project_id = ProjectId::from_proto(message.project_id);
1358
1359    let (room, guest_connection_ids) = &*session
1360        .db()
1361        .await
1362        .unshare_project(project_id, session.connection_id)
1363        .await?;
1364
1365    broadcast(
1366        Some(session.connection_id),
1367        guest_connection_ids.iter().copied(),
1368        |conn_id| session.peer.send(conn_id, message.clone()),
1369    );
1370    room_updated(&room, &session.peer);
1371
1372    Ok(())
1373}
1374
1375async fn join_project(
1376    request: proto::JoinProject,
1377    response: Response<proto::JoinProject>,
1378    session: Session,
1379) -> Result<()> {
1380    let project_id = ProjectId::from_proto(request.project_id);
1381    let guest_user_id = session.user_id;
1382
1383    tracing::info!(%project_id, "join project");
1384
1385    let (project, replica_id) = &mut *session
1386        .db()
1387        .await
1388        .join_project(project_id, session.connection_id)
1389        .await?;
1390
1391    let collaborators = project
1392        .collaborators
1393        .iter()
1394        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1395        .map(|collaborator| collaborator.to_proto())
1396        .collect::<Vec<_>>();
1397
1398    let worktrees = project
1399        .worktrees
1400        .iter()
1401        .map(|(id, worktree)| proto::WorktreeMetadata {
1402            id: *id,
1403            root_name: worktree.root_name.clone(),
1404            visible: worktree.visible,
1405            abs_path: worktree.abs_path.clone(),
1406        })
1407        .collect::<Vec<_>>();
1408
1409    for collaborator in &collaborators {
1410        session
1411            .peer
1412            .send(
1413                collaborator.peer_id.unwrap().into(),
1414                proto::AddProjectCollaborator {
1415                    project_id: project_id.to_proto(),
1416                    collaborator: Some(proto::Collaborator {
1417                        peer_id: Some(session.connection_id.into()),
1418                        replica_id: replica_id.0 as u32,
1419                        user_id: guest_user_id.to_proto(),
1420                    }),
1421                },
1422            )
1423            .trace_err();
1424    }
1425
1426    // First, we send the metadata associated with each worktree.
1427    response.send(proto::JoinProjectResponse {
1428        worktrees: worktrees.clone(),
1429        replica_id: replica_id.0 as u32,
1430        collaborators: collaborators.clone(),
1431        language_servers: project.language_servers.clone(),
1432    })?;
1433
1434    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1435        #[cfg(any(test, feature = "test-support"))]
1436        const MAX_CHUNK_SIZE: usize = 2;
1437        #[cfg(not(any(test, feature = "test-support")))]
1438        const MAX_CHUNK_SIZE: usize = 256;
1439
1440        // Stream this worktree's entries.
1441        let message = proto::UpdateWorktree {
1442            project_id: project_id.to_proto(),
1443            worktree_id,
1444            abs_path: worktree.abs_path.clone(),
1445            root_name: worktree.root_name,
1446            updated_entries: worktree.entries,
1447            removed_entries: Default::default(),
1448            scan_id: worktree.scan_id,
1449            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1450            updated_repositories: worktree.repository_entries.into_values().collect(),
1451            removed_repositories: Default::default(),
1452        };
1453        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1454            session.peer.send(session.connection_id, update.clone())?;
1455        }
1456
1457        // Stream this worktree's diagnostics.
1458        for summary in worktree.diagnostic_summaries {
1459            session.peer.send(
1460                session.connection_id,
1461                proto::UpdateDiagnosticSummary {
1462                    project_id: project_id.to_proto(),
1463                    worktree_id: worktree.id,
1464                    summary: Some(summary),
1465                },
1466            )?;
1467        }
1468
1469        for settings_file in worktree.settings_files {
1470            session.peer.send(
1471                session.connection_id,
1472                proto::UpdateWorktreeSettings {
1473                    project_id: project_id.to_proto(),
1474                    worktree_id: worktree.id,
1475                    path: settings_file.path,
1476                    content: Some(settings_file.content),
1477                },
1478            )?;
1479        }
1480    }
1481
1482    for language_server in &project.language_servers {
1483        session.peer.send(
1484            session.connection_id,
1485            proto::UpdateLanguageServer {
1486                project_id: project_id.to_proto(),
1487                language_server_id: language_server.id,
1488                variant: Some(
1489                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1490                        proto::LspDiskBasedDiagnosticsUpdated {},
1491                    ),
1492                ),
1493            },
1494        )?;
1495    }
1496
1497    Ok(())
1498}
1499
1500async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1501    let sender_id = session.connection_id;
1502    let project_id = ProjectId::from_proto(request.project_id);
1503
1504    let (room, project) = &*session
1505        .db()
1506        .await
1507        .leave_project(project_id, sender_id)
1508        .await?;
1509    tracing::info!(
1510        %project_id,
1511        host_user_id = %project.host_user_id,
1512        host_connection_id = %project.host_connection_id,
1513        "leave project"
1514    );
1515
1516    project_left(&project, &session);
1517    room_updated(&room, &session.peer);
1518
1519    Ok(())
1520}
1521
1522async fn update_project(
1523    request: proto::UpdateProject,
1524    response: Response<proto::UpdateProject>,
1525    session: Session,
1526) -> Result<()> {
1527    let project_id = ProjectId::from_proto(request.project_id);
1528    let (room, guest_connection_ids) = &*session
1529        .db()
1530        .await
1531        .update_project(project_id, session.connection_id, &request.worktrees)
1532        .await?;
1533    broadcast(
1534        Some(session.connection_id),
1535        guest_connection_ids.iter().copied(),
1536        |connection_id| {
1537            session
1538                .peer
1539                .forward_send(session.connection_id, connection_id, request.clone())
1540        },
1541    );
1542    room_updated(&room, &session.peer);
1543    response.send(proto::Ack {})?;
1544
1545    Ok(())
1546}
1547
1548async fn update_worktree(
1549    request: proto::UpdateWorktree,
1550    response: Response<proto::UpdateWorktree>,
1551    session: Session,
1552) -> Result<()> {
1553    let guest_connection_ids = session
1554        .db()
1555        .await
1556        .update_worktree(&request, session.connection_id)
1557        .await?;
1558
1559    broadcast(
1560        Some(session.connection_id),
1561        guest_connection_ids.iter().copied(),
1562        |connection_id| {
1563            session
1564                .peer
1565                .forward_send(session.connection_id, connection_id, request.clone())
1566        },
1567    );
1568    response.send(proto::Ack {})?;
1569    Ok(())
1570}
1571
1572async fn update_diagnostic_summary(
1573    message: proto::UpdateDiagnosticSummary,
1574    session: Session,
1575) -> Result<()> {
1576    let guest_connection_ids = session
1577        .db()
1578        .await
1579        .update_diagnostic_summary(&message, session.connection_id)
1580        .await?;
1581
1582    broadcast(
1583        Some(session.connection_id),
1584        guest_connection_ids.iter().copied(),
1585        |connection_id| {
1586            session
1587                .peer
1588                .forward_send(session.connection_id, connection_id, message.clone())
1589        },
1590    );
1591
1592    Ok(())
1593}
1594
1595async fn update_worktree_settings(
1596    message: proto::UpdateWorktreeSettings,
1597    session: Session,
1598) -> Result<()> {
1599    let guest_connection_ids = session
1600        .db()
1601        .await
1602        .update_worktree_settings(&message, session.connection_id)
1603        .await?;
1604
1605    broadcast(
1606        Some(session.connection_id),
1607        guest_connection_ids.iter().copied(),
1608        |connection_id| {
1609            session
1610                .peer
1611                .forward_send(session.connection_id, connection_id, message.clone())
1612        },
1613    );
1614
1615    Ok(())
1616}
1617
1618async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1619    broadcast_project_message(request.project_id, request, session).await
1620}
1621
1622async fn start_language_server(
1623    request: proto::StartLanguageServer,
1624    session: Session,
1625) -> Result<()> {
1626    let guest_connection_ids = session
1627        .db()
1628        .await
1629        .start_language_server(&request, session.connection_id)
1630        .await?;
1631
1632    broadcast(
1633        Some(session.connection_id),
1634        guest_connection_ids.iter().copied(),
1635        |connection_id| {
1636            session
1637                .peer
1638                .forward_send(session.connection_id, connection_id, request.clone())
1639        },
1640    );
1641    Ok(())
1642}
1643
1644async fn update_language_server(
1645    request: proto::UpdateLanguageServer,
1646    session: Session,
1647) -> Result<()> {
1648    session.executor.record_backtrace();
1649    let project_id = ProjectId::from_proto(request.project_id);
1650    let project_connection_ids = session
1651        .db()
1652        .await
1653        .project_connection_ids(project_id, session.connection_id)
1654        .await?;
1655    broadcast(
1656        Some(session.connection_id),
1657        project_connection_ids.iter().copied(),
1658        |connection_id| {
1659            session
1660                .peer
1661                .forward_send(session.connection_id, connection_id, request.clone())
1662        },
1663    );
1664    Ok(())
1665}
1666
1667async fn forward_project_request<T>(
1668    request: T,
1669    response: Response<T>,
1670    session: Session,
1671) -> Result<()>
1672where
1673    T: EntityMessage + RequestMessage,
1674{
1675    session.executor.record_backtrace();
1676    let project_id = ProjectId::from_proto(request.remote_entity_id());
1677    let host_connection_id = {
1678        let collaborators = session
1679            .db()
1680            .await
1681            .project_collaborators(project_id, session.connection_id)
1682            .await?;
1683        collaborators
1684            .iter()
1685            .find(|collaborator| collaborator.is_host)
1686            .ok_or_else(|| anyhow!("host not found"))?
1687            .connection_id
1688    };
1689
1690    let payload = session
1691        .peer
1692        .forward_request(session.connection_id, host_connection_id, request)
1693        .await?;
1694
1695    response.send(payload)?;
1696    Ok(())
1697}
1698
1699async fn create_buffer_for_peer(
1700    request: proto::CreateBufferForPeer,
1701    session: Session,
1702) -> Result<()> {
1703    session.executor.record_backtrace();
1704    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1705    session
1706        .peer
1707        .forward_send(session.connection_id, peer_id.into(), request)?;
1708    Ok(())
1709}
1710
1711async fn update_buffer(
1712    request: proto::UpdateBuffer,
1713    response: Response<proto::UpdateBuffer>,
1714    session: Session,
1715) -> Result<()> {
1716    session.executor.record_backtrace();
1717    let project_id = ProjectId::from_proto(request.project_id);
1718    let mut guest_connection_ids;
1719    let mut host_connection_id = None;
1720    {
1721        let collaborators = session
1722            .db()
1723            .await
1724            .project_collaborators(project_id, session.connection_id)
1725            .await?;
1726        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1727        for collaborator in collaborators.iter() {
1728            if collaborator.is_host {
1729                host_connection_id = Some(collaborator.connection_id);
1730            } else {
1731                guest_connection_ids.push(collaborator.connection_id);
1732            }
1733        }
1734    }
1735    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1736
1737    session.executor.record_backtrace();
1738    broadcast(
1739        Some(session.connection_id),
1740        guest_connection_ids,
1741        |connection_id| {
1742            session
1743                .peer
1744                .forward_send(session.connection_id, connection_id, request.clone())
1745        },
1746    );
1747    if host_connection_id != session.connection_id {
1748        session
1749            .peer
1750            .forward_request(session.connection_id, host_connection_id, request.clone())
1751            .await?;
1752    }
1753
1754    response.send(proto::Ack {})?;
1755    Ok(())
1756}
1757
1758async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1759    let project_id = ProjectId::from_proto(request.project_id);
1760    let project_connection_ids = session
1761        .db()
1762        .await
1763        .project_connection_ids(project_id, session.connection_id)
1764        .await?;
1765
1766    broadcast(
1767        Some(session.connection_id),
1768        project_connection_ids.iter().copied(),
1769        |connection_id| {
1770            session
1771                .peer
1772                .forward_send(session.connection_id, connection_id, request.clone())
1773        },
1774    );
1775    Ok(())
1776}
1777
1778async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1779    let project_id = ProjectId::from_proto(request.project_id);
1780    let project_connection_ids = session
1781        .db()
1782        .await
1783        .project_connection_ids(project_id, session.connection_id)
1784        .await?;
1785    broadcast(
1786        Some(session.connection_id),
1787        project_connection_ids.iter().copied(),
1788        |connection_id| {
1789            session
1790                .peer
1791                .forward_send(session.connection_id, connection_id, request.clone())
1792        },
1793    );
1794    Ok(())
1795}
1796
1797async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1798    broadcast_project_message(request.project_id, request, session).await
1799}
1800
1801async fn broadcast_project_message<T: EnvelopedMessage>(
1802    project_id: u64,
1803    request: T,
1804    session: Session,
1805) -> Result<()> {
1806    let project_id = ProjectId::from_proto(project_id);
1807    let project_connection_ids = session
1808        .db()
1809        .await
1810        .project_connection_ids(project_id, session.connection_id)
1811        .await?;
1812    broadcast(
1813        Some(session.connection_id),
1814        project_connection_ids.iter().copied(),
1815        |connection_id| {
1816            session
1817                .peer
1818                .forward_send(session.connection_id, connection_id, request.clone())
1819        },
1820    );
1821    Ok(())
1822}
1823
1824async fn follow(
1825    request: proto::Follow,
1826    response: Response<proto::Follow>,
1827    session: Session,
1828) -> Result<()> {
1829    let project_id = ProjectId::from_proto(request.project_id);
1830    let leader_id = request
1831        .leader_id
1832        .ok_or_else(|| anyhow!("invalid leader id"))?
1833        .into();
1834    let follower_id = session.connection_id;
1835
1836    {
1837        let project_connection_ids = session
1838            .db()
1839            .await
1840            .project_connection_ids(project_id, session.connection_id)
1841            .await?;
1842
1843        if !project_connection_ids.contains(&leader_id) {
1844            Err(anyhow!("no such peer"))?;
1845        }
1846    }
1847
1848    let mut response_payload = session
1849        .peer
1850        .forward_request(session.connection_id, leader_id, request)
1851        .await?;
1852    response_payload
1853        .views
1854        .retain(|view| view.leader_id != Some(follower_id.into()));
1855    response.send(response_payload)?;
1856
1857    let room = session
1858        .db()
1859        .await
1860        .follow(project_id, leader_id, follower_id)
1861        .await?;
1862    room_updated(&room, &session.peer);
1863
1864    Ok(())
1865}
1866
1867async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1868    let project_id = ProjectId::from_proto(request.project_id);
1869    let leader_id = request
1870        .leader_id
1871        .ok_or_else(|| anyhow!("invalid leader id"))?
1872        .into();
1873    let follower_id = session.connection_id;
1874
1875    if !session
1876        .db()
1877        .await
1878        .project_connection_ids(project_id, session.connection_id)
1879        .await?
1880        .contains(&leader_id)
1881    {
1882        Err(anyhow!("no such peer"))?;
1883    }
1884
1885    session
1886        .peer
1887        .forward_send(session.connection_id, leader_id, request)?;
1888
1889    let room = session
1890        .db()
1891        .await
1892        .unfollow(project_id, leader_id, follower_id)
1893        .await?;
1894    room_updated(&room, &session.peer);
1895
1896    Ok(())
1897}
1898
1899async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1900    let project_id = ProjectId::from_proto(request.project_id);
1901    let project_connection_ids = session
1902        .db
1903        .lock()
1904        .await
1905        .project_connection_ids(project_id, session.connection_id)
1906        .await?;
1907
1908    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1909        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1910        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1911        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1912    });
1913    for follower_peer_id in request.follower_ids.iter().copied() {
1914        let follower_connection_id = follower_peer_id.into();
1915        if project_connection_ids.contains(&follower_connection_id)
1916            && Some(follower_peer_id) != leader_id
1917        {
1918            session.peer.forward_send(
1919                session.connection_id,
1920                follower_connection_id,
1921                request.clone(),
1922            )?;
1923        }
1924    }
1925    Ok(())
1926}
1927
1928async fn get_users(
1929    request: proto::GetUsers,
1930    response: Response<proto::GetUsers>,
1931    session: Session,
1932) -> Result<()> {
1933    let user_ids = request
1934        .user_ids
1935        .into_iter()
1936        .map(UserId::from_proto)
1937        .collect();
1938    let users = session
1939        .db()
1940        .await
1941        .get_users_by_ids(user_ids)
1942        .await?
1943        .into_iter()
1944        .map(|user| proto::User {
1945            id: user.id.to_proto(),
1946            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1947            github_login: user.github_login,
1948        })
1949        .collect();
1950    response.send(proto::UsersResponse { users })?;
1951    Ok(())
1952}
1953
1954async fn fuzzy_search_users(
1955    request: proto::FuzzySearchUsers,
1956    response: Response<proto::FuzzySearchUsers>,
1957    session: Session,
1958) -> Result<()> {
1959    let query = request.query;
1960    let users = match query.len() {
1961        0 => vec![],
1962        1 | 2 => session
1963            .db()
1964            .await
1965            .get_user_by_github_login(&query)
1966            .await?
1967            .into_iter()
1968            .collect(),
1969        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1970    };
1971    let users = users
1972        .into_iter()
1973        .filter(|user| user.id != session.user_id)
1974        .map(|user| proto::User {
1975            id: user.id.to_proto(),
1976            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1977            github_login: user.github_login,
1978        })
1979        .collect();
1980    response.send(proto::UsersResponse { users })?;
1981    Ok(())
1982}
1983
1984async fn request_contact(
1985    request: proto::RequestContact,
1986    response: Response<proto::RequestContact>,
1987    session: Session,
1988) -> Result<()> {
1989    let requester_id = session.user_id;
1990    let responder_id = UserId::from_proto(request.responder_id);
1991    if requester_id == responder_id {
1992        return Err(anyhow!("cannot add yourself as a contact"))?;
1993    }
1994
1995    session
1996        .db()
1997        .await
1998        .send_contact_request(requester_id, responder_id)
1999        .await?;
2000
2001    // Update outgoing contact requests of requester
2002    let mut update = proto::UpdateContacts::default();
2003    update.outgoing_requests.push(responder_id.to_proto());
2004    for connection_id in session
2005        .connection_pool()
2006        .await
2007        .user_connection_ids(requester_id)
2008    {
2009        session.peer.send(connection_id, update.clone())?;
2010    }
2011
2012    // Update incoming contact requests of responder
2013    let mut update = proto::UpdateContacts::default();
2014    update
2015        .incoming_requests
2016        .push(proto::IncomingContactRequest {
2017            requester_id: requester_id.to_proto(),
2018            should_notify: true,
2019        });
2020    for connection_id in session
2021        .connection_pool()
2022        .await
2023        .user_connection_ids(responder_id)
2024    {
2025        session.peer.send(connection_id, update.clone())?;
2026    }
2027
2028    response.send(proto::Ack {})?;
2029    Ok(())
2030}
2031
2032async fn respond_to_contact_request(
2033    request: proto::RespondToContactRequest,
2034    response: Response<proto::RespondToContactRequest>,
2035    session: Session,
2036) -> Result<()> {
2037    let responder_id = session.user_id;
2038    let requester_id = UserId::from_proto(request.requester_id);
2039    let db = session.db().await;
2040    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2041        db.dismiss_contact_notification(responder_id, requester_id)
2042            .await?;
2043    } else {
2044        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2045
2046        db.respond_to_contact_request(responder_id, requester_id, accept)
2047            .await?;
2048        let requester_busy = db.is_user_busy(requester_id).await?;
2049        let responder_busy = db.is_user_busy(responder_id).await?;
2050
2051        let pool = session.connection_pool().await;
2052        // Update responder with new contact
2053        let mut update = proto::UpdateContacts::default();
2054        if accept {
2055            update
2056                .contacts
2057                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2058        }
2059        update
2060            .remove_incoming_requests
2061            .push(requester_id.to_proto());
2062        for connection_id in pool.user_connection_ids(responder_id) {
2063            session.peer.send(connection_id, update.clone())?;
2064        }
2065
2066        // Update requester with new contact
2067        let mut update = proto::UpdateContacts::default();
2068        if accept {
2069            update
2070                .contacts
2071                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2072        }
2073        update
2074            .remove_outgoing_requests
2075            .push(responder_id.to_proto());
2076        for connection_id in pool.user_connection_ids(requester_id) {
2077            session.peer.send(connection_id, update.clone())?;
2078        }
2079    }
2080
2081    response.send(proto::Ack {})?;
2082    Ok(())
2083}
2084
2085async fn remove_contact(
2086    request: proto::RemoveContact,
2087    response: Response<proto::RemoveContact>,
2088    session: Session,
2089) -> Result<()> {
2090    let requester_id = session.user_id;
2091    let responder_id = UserId::from_proto(request.user_id);
2092    let db = session.db().await;
2093    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2094
2095    let pool = session.connection_pool().await;
2096    // Update outgoing contact requests of requester
2097    let mut update = proto::UpdateContacts::default();
2098    if contact_accepted {
2099        update.remove_contacts.push(responder_id.to_proto());
2100    } else {
2101        update
2102            .remove_outgoing_requests
2103            .push(responder_id.to_proto());
2104    }
2105    for connection_id in pool.user_connection_ids(requester_id) {
2106        session.peer.send(connection_id, update.clone())?;
2107    }
2108
2109    // Update incoming contact requests of responder
2110    let mut update = proto::UpdateContacts::default();
2111    if contact_accepted {
2112        update.remove_contacts.push(requester_id.to_proto());
2113    } else {
2114        update
2115            .remove_incoming_requests
2116            .push(requester_id.to_proto());
2117    }
2118    for connection_id in pool.user_connection_ids(responder_id) {
2119        session.peer.send(connection_id, update.clone())?;
2120    }
2121
2122    response.send(proto::Ack {})?;
2123    Ok(())
2124}
2125
2126async fn create_channel(
2127    request: proto::CreateChannel,
2128    response: Response<proto::CreateChannel>,
2129    session: Session,
2130) -> Result<()> {
2131    let db = session.db().await;
2132    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2133
2134    if let Some(live_kit) = session.live_kit_client.as_ref() {
2135        live_kit.create_room(live_kit_room.clone()).await?;
2136    }
2137
2138    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2139    let id = db
2140        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2141        .await?;
2142
2143    response.send(proto::CreateChannelResponse {
2144        channel_id: id.to_proto(),
2145    })?;
2146
2147    let mut update = proto::UpdateChannels::default();
2148    update.channels.push(proto::Channel {
2149        id: id.to_proto(),
2150        name: request.name,
2151        parent_id: request.parent_id,
2152    });
2153
2154    if let Some(parent_id) = parent_id {
2155        let member_ids = db.get_channel_members(parent_id).await?;
2156        let connection_pool = session.connection_pool().await;
2157        for member_id in member_ids {
2158            for connection_id in connection_pool.user_connection_ids(member_id) {
2159                session.peer.send(connection_id, update.clone())?;
2160            }
2161        }
2162    } else {
2163        session.peer.send(session.connection_id, update)?;
2164    }
2165
2166    Ok(())
2167}
2168
2169async fn remove_channel(
2170    request: proto::RemoveChannel,
2171    response: Response<proto::RemoveChannel>,
2172    session: Session,
2173) -> Result<()> {
2174    let db = session.db().await;
2175
2176    let channel_id = request.channel_id;
2177    let (removed_channels, member_ids) = db
2178        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2179        .await?;
2180    response.send(proto::Ack {})?;
2181
2182    // Notify members of removed channels
2183    let mut update = proto::UpdateChannels::default();
2184    update
2185        .remove_channels
2186        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2187
2188    let connection_pool = session.connection_pool().await;
2189    for member_id in member_ids {
2190        for connection_id in connection_pool.user_connection_ids(member_id) {
2191            session.peer.send(connection_id, update.clone())?;
2192        }
2193    }
2194
2195    Ok(())
2196}
2197
2198async fn invite_channel_member(
2199    request: proto::InviteChannelMember,
2200    response: Response<proto::InviteChannelMember>,
2201    session: Session,
2202) -> Result<()> {
2203    let db = session.db().await;
2204    let channel_id = ChannelId::from_proto(request.channel_id);
2205    let channel = db
2206        .get_channel(channel_id)
2207        .await?
2208        .ok_or_else(|| anyhow!("channel not found"))?;
2209    let invitee_id = UserId::from_proto(request.user_id);
2210    db.invite_channel_member(channel_id, invitee_id, session.user_id, false)
2211        .await?;
2212
2213    let mut update = proto::UpdateChannels::default();
2214    update.channel_invitations.push(proto::Channel {
2215        id: channel.id.to_proto(),
2216        name: channel.name,
2217        parent_id: None,
2218    });
2219    for connection_id in session
2220        .connection_pool()
2221        .await
2222        .user_connection_ids(invitee_id)
2223    {
2224        session.peer.send(connection_id, update.clone())?;
2225    }
2226
2227    response.send(proto::Ack {})?;
2228    Ok(())
2229}
2230
2231async fn remove_channel_member(
2232    _request: proto::RemoveChannelMember,
2233    _response: Response<proto::RemoveChannelMember>,
2234    _session: Session,
2235) -> Result<()> {
2236    Ok(())
2237}
2238
2239async fn respond_to_channel_invite(
2240    request: proto::RespondToChannelInvite,
2241    response: Response<proto::RespondToChannelInvite>,
2242    session: Session,
2243) -> Result<()> {
2244    let db = session.db().await;
2245    let channel_id = ChannelId::from_proto(request.channel_id);
2246    let channel = db
2247        .get_channel(channel_id)
2248        .await?
2249        .ok_or_else(|| anyhow!("no such channel"))?;
2250    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2251        .await?;
2252
2253    let mut update = proto::UpdateChannels::default();
2254    update
2255        .remove_channel_invitations
2256        .push(channel_id.to_proto());
2257    if request.accept {
2258        update.channels.push(proto::Channel {
2259            id: channel.id.to_proto(),
2260            name: channel.name,
2261            parent_id: None,
2262        });
2263    }
2264    session.peer.send(session.connection_id, update)?;
2265    response.send(proto::Ack {})?;
2266
2267    Ok(())
2268}
2269
2270async fn join_channel(
2271    request: proto::JoinChannel,
2272    response: Response<proto::JoinChannel>,
2273    session: Session,
2274) -> Result<()> {
2275    let channel_id = ChannelId::from_proto(request.channel_id);
2276
2277    let joined_room = {
2278        let db = session.db().await;
2279        let room_id = db.room_id_for_channel(channel_id).await?;
2280
2281        let joined_room = db
2282            .join_room(
2283                room_id,
2284                session.user_id,
2285                Some(channel_id),
2286                session.connection_id,
2287            )
2288            .await?;
2289
2290        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2291            let token = live_kit
2292                .room_token(
2293                    &joined_room.room.live_kit_room,
2294                    &session.user_id.to_string(),
2295                )
2296                .trace_err()?;
2297
2298            Some(LiveKitConnectionInfo {
2299                server_url: live_kit.url().into(),
2300                token,
2301            })
2302        });
2303
2304        response.send(proto::JoinRoomResponse {
2305            room: Some(joined_room.room.clone()),
2306            live_kit_connection_info,
2307        })?;
2308
2309        room_updated(&joined_room.room, &session.peer);
2310
2311        joined_room.clone()
2312    };
2313
2314    // TODO - do this while still holding the room guard,
2315    // currently there's a possible race condition if someone joins the channel
2316    // after we've dropped the lock but before we finish sending these updates
2317    channel_updated(
2318        channel_id,
2319        &joined_room.room,
2320        &joined_room.channel_members,
2321        &session.peer,
2322        &*session.connection_pool().await,
2323    );
2324
2325    update_user_contacts(session.user_id, &session).await?;
2326
2327    Ok(())
2328}
2329
2330async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2331    let project_id = ProjectId::from_proto(request.project_id);
2332    let project_connection_ids = session
2333        .db()
2334        .await
2335        .project_connection_ids(project_id, session.connection_id)
2336        .await?;
2337    broadcast(
2338        Some(session.connection_id),
2339        project_connection_ids.iter().copied(),
2340        |connection_id| {
2341            session
2342                .peer
2343                .forward_send(session.connection_id, connection_id, request.clone())
2344        },
2345    );
2346    Ok(())
2347}
2348
2349async fn get_private_user_info(
2350    _request: proto::GetPrivateUserInfo,
2351    response: Response<proto::GetPrivateUserInfo>,
2352    session: Session,
2353) -> Result<()> {
2354    let metrics_id = session
2355        .db()
2356        .await
2357        .get_user_metrics_id(session.user_id)
2358        .await?;
2359    let user = session
2360        .db()
2361        .await
2362        .get_user_by_id(session.user_id)
2363        .await?
2364        .ok_or_else(|| anyhow!("user not found"))?;
2365    response.send(proto::GetPrivateUserInfoResponse {
2366        metrics_id,
2367        staff: user.admin,
2368    })?;
2369    Ok(())
2370}
2371
2372fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2373    match message {
2374        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2375        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2376        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2377        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2378        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2379            code: frame.code.into(),
2380            reason: frame.reason,
2381        })),
2382    }
2383}
2384
2385fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2386    match message {
2387        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2388        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2389        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2390        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2391        AxumMessage::Close(frame) => {
2392            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2393                code: frame.code.into(),
2394                reason: frame.reason,
2395            }))
2396        }
2397    }
2398}
2399
2400fn build_initial_channels_update(
2401    channels: Vec<db::Channel>,
2402    channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2403    channel_invites: Vec<db::Channel>,
2404) -> proto::UpdateChannels {
2405    let mut update = proto::UpdateChannels::default();
2406
2407    for channel in channels {
2408        update.channels.push(proto::Channel {
2409            id: channel.id.to_proto(),
2410            name: channel.name,
2411            parent_id: channel.parent_id.map(|id| id.to_proto()),
2412        });
2413    }
2414
2415    for channel in channel_invites {
2416        update.channel_invitations.push(proto::Channel {
2417            id: channel.id.to_proto(),
2418            name: channel.name,
2419            parent_id: None,
2420        });
2421    }
2422
2423    update
2424}
2425
2426fn build_initial_contacts_update(
2427    contacts: Vec<db::Contact>,
2428    pool: &ConnectionPool,
2429) -> proto::UpdateContacts {
2430    let mut update = proto::UpdateContacts::default();
2431
2432    for contact in contacts {
2433        match contact {
2434            db::Contact::Accepted {
2435                user_id,
2436                should_notify,
2437                busy,
2438            } => {
2439                update
2440                    .contacts
2441                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2442            }
2443            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2444            db::Contact::Incoming {
2445                user_id,
2446                should_notify,
2447            } => update
2448                .incoming_requests
2449                .push(proto::IncomingContactRequest {
2450                    requester_id: user_id.to_proto(),
2451                    should_notify,
2452                }),
2453        }
2454    }
2455
2456    update
2457}
2458
2459fn contact_for_user(
2460    user_id: UserId,
2461    should_notify: bool,
2462    busy: bool,
2463    pool: &ConnectionPool,
2464) -> proto::Contact {
2465    proto::Contact {
2466        user_id: user_id.to_proto(),
2467        online: pool.is_user_online(user_id),
2468        busy,
2469        should_notify,
2470    }
2471}
2472
2473fn room_updated(room: &proto::Room, peer: &Peer) {
2474    broadcast(
2475        None,
2476        room.participants
2477            .iter()
2478            .filter_map(|participant| Some(participant.peer_id?.into())),
2479        |peer_id| {
2480            peer.send(
2481                peer_id.into(),
2482                proto::RoomUpdated {
2483                    room: Some(room.clone()),
2484                },
2485            )
2486        },
2487    );
2488}
2489
2490fn channel_updated(
2491    channel_id: ChannelId,
2492    room: &proto::Room,
2493    channel_members: &[UserId],
2494    peer: &Peer,
2495    pool: &ConnectionPool,
2496) {
2497    let participants = room
2498        .participants
2499        .iter()
2500        .map(|p| p.user_id)
2501        .collect::<Vec<_>>();
2502
2503    broadcast(
2504        None,
2505        channel_members
2506            .iter()
2507            .filter(|user_id| {
2508                !room
2509                    .participants
2510                    .iter()
2511                    .any(|p| p.user_id == user_id.to_proto())
2512            })
2513            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2514        |peer_id| {
2515            peer.send(
2516                peer_id.into(),
2517                proto::UpdateChannels {
2518                    channel_participants: vec![proto::ChannelParticipants {
2519                        channel_id: channel_id.to_proto(),
2520                        participant_user_ids: participants.clone(),
2521                    }],
2522                    ..Default::default()
2523                },
2524            )
2525        },
2526    );
2527}
2528
2529async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2530    let db = session.db().await;
2531    let contacts = db.get_contacts(user_id).await?;
2532    let busy = db.is_user_busy(user_id).await?;
2533
2534    let pool = session.connection_pool().await;
2535    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2536    for contact in contacts {
2537        if let db::Contact::Accepted {
2538            user_id: contact_user_id,
2539            ..
2540        } = contact
2541        {
2542            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2543                session
2544                    .peer
2545                    .send(
2546                        contact_conn_id,
2547                        proto::UpdateContacts {
2548                            contacts: vec![updated_contact.clone()],
2549                            remove_contacts: Default::default(),
2550                            incoming_requests: Default::default(),
2551                            remove_incoming_requests: Default::default(),
2552                            outgoing_requests: Default::default(),
2553                            remove_outgoing_requests: Default::default(),
2554                        },
2555                    )
2556                    .trace_err();
2557            }
2558        }
2559    }
2560    Ok(())
2561}
2562
2563async fn leave_room_for_session(session: &Session) -> Result<()> {
2564    let mut contacts_to_update = HashSet::default();
2565
2566    let room_id;
2567    let canceled_calls_to_user_ids;
2568    let live_kit_room;
2569    let delete_live_kit_room;
2570    let room;
2571    let channel_members;
2572    let channel_id;
2573
2574    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2575        contacts_to_update.insert(session.user_id);
2576
2577        for project in left_room.left_projects.values() {
2578            project_left(project, session);
2579        }
2580
2581        room_id = RoomId::from_proto(left_room.room.id);
2582        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2583        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2584        delete_live_kit_room = left_room.deleted;
2585        room = mem::take(&mut left_room.room);
2586        channel_members = mem::take(&mut left_room.channel_members);
2587        channel_id = left_room.channel_id;
2588
2589        room_updated(&room, &session.peer);
2590    } else {
2591        return Ok(());
2592    }
2593
2594    // TODO - do this while holding the room guard.
2595    if let Some(channel_id) = channel_id {
2596        channel_updated(
2597            channel_id,
2598            &room,
2599            &channel_members,
2600            &session.peer,
2601            &*session.connection_pool().await,
2602        );
2603    }
2604
2605    {
2606        let pool = session.connection_pool().await;
2607        for canceled_user_id in canceled_calls_to_user_ids {
2608            for connection_id in pool.user_connection_ids(canceled_user_id) {
2609                session
2610                    .peer
2611                    .send(
2612                        connection_id,
2613                        proto::CallCanceled {
2614                            room_id: room_id.to_proto(),
2615                        },
2616                    )
2617                    .trace_err();
2618            }
2619            contacts_to_update.insert(canceled_user_id);
2620        }
2621    }
2622
2623    for contact_user_id in contacts_to_update {
2624        update_user_contacts(contact_user_id, &session).await?;
2625    }
2626
2627    if let Some(live_kit) = session.live_kit_client.as_ref() {
2628        live_kit
2629            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2630            .await
2631            .trace_err();
2632
2633        if delete_live_kit_room {
2634            live_kit.delete_room(live_kit_room).await.trace_err();
2635        }
2636    }
2637
2638    Ok(())
2639}
2640
2641fn project_left(project: &db::LeftProject, session: &Session) {
2642    for connection_id in &project.connection_ids {
2643        if project.host_user_id == session.user_id {
2644            session
2645                .peer
2646                .send(
2647                    *connection_id,
2648                    proto::UnshareProject {
2649                        project_id: project.id.to_proto(),
2650                    },
2651                )
2652                .trace_err();
2653        } else {
2654            session
2655                .peer
2656                .send(
2657                    *connection_id,
2658                    proto::RemoveProjectCollaborator {
2659                        project_id: project.id.to_proto(),
2660                        peer_id: Some(session.connection_id.into()),
2661                    },
2662                )
2663                .trace_err();
2664        }
2665    }
2666}
2667
2668pub trait ResultExt {
2669    type Ok;
2670
2671    fn trace_err(self) -> Option<Self::Ok>;
2672}
2673
2674impl<T, E> ResultExt for Result<T, E>
2675where
2676    E: std::fmt::Debug,
2677{
2678    type Ok = T;
2679
2680    fn trace_err(self) -> Option<T> {
2681        match self {
2682            Ok(value) => Some(value),
2683            Err(error) => {
2684                tracing::error!("{:?}", error);
2685                None
2686            }
2687        }
2688    }
2689}