rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  39        RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(respond_to_channel_invite)
 250            .add_request_handler(join_channel)
 251            .add_request_handler(follow)
 252            .add_message_handler(unfollow)
 253            .add_message_handler(update_followers)
 254            .add_message_handler(update_diff_base)
 255            .add_request_handler(get_private_user_info);
 256
 257        Arc::new(server)
 258    }
 259
 260    pub async fn start(&self) -> Result<()> {
 261        let server_id = *self.id.lock();
 262        let app_state = self.app_state.clone();
 263        let peer = self.peer.clone();
 264        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 265        let pool = self.connection_pool.clone();
 266        let live_kit_client = self.app_state.live_kit_client.clone();
 267
 268        let span = info_span!("start server");
 269        self.executor.spawn_detached(
 270            async move {
 271                tracing::info!("waiting for cleanup timeout");
 272                timeout.await;
 273                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 274                if let Some(room_ids) = app_state
 275                    .db
 276                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 277                    .await
 278                    .trace_err()
 279                {
 280                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 281                    for room_id in room_ids {
 282                        let mut contacts_to_update = HashSet::default();
 283                        let mut canceled_calls_to_user_ids = Vec::new();
 284                        let mut live_kit_room = String::new();
 285                        let mut delete_live_kit_room = false;
 286
 287                        if let Some(mut refreshed_room) = app_state
 288                            .db
 289                            .refresh_room(room_id, server_id)
 290                            .await
 291                            .trace_err()
 292                        {
 293                            tracing::info!(
 294                                room_id = room_id.0,
 295                                new_participant_count = refreshed_room.room.participants.len(),
 296                                "refreshed room"
 297                            );
 298                            room_updated(&refreshed_room.room, &peer);
 299                            contacts_to_update
 300                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 301                            contacts_to_update
 302                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 303                            canceled_calls_to_user_ids =
 304                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 305                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 306                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 307                        }
 308
 309                        {
 310                            let pool = pool.lock();
 311                            for canceled_user_id in canceled_calls_to_user_ids {
 312                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 313                                    peer.send(
 314                                        connection_id,
 315                                        proto::CallCanceled {
 316                                            room_id: room_id.to_proto(),
 317                                        },
 318                                    )
 319                                    .trace_err();
 320                                }
 321                            }
 322                        }
 323
 324                        for user_id in contacts_to_update {
 325                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 326                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 327                            if let Some((busy, contacts)) = busy.zip(contacts) {
 328                                let pool = pool.lock();
 329                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 330                                for contact in contacts {
 331                                    if let db::Contact::Accepted {
 332                                        user_id: contact_user_id,
 333                                        ..
 334                                    } = contact
 335                                    {
 336                                        for contact_conn_id in
 337                                            pool.user_connection_ids(contact_user_id)
 338                                        {
 339                                            peer.send(
 340                                                contact_conn_id,
 341                                                proto::UpdateContacts {
 342                                                    contacts: vec![updated_contact.clone()],
 343                                                    remove_contacts: Default::default(),
 344                                                    incoming_requests: Default::default(),
 345                                                    remove_incoming_requests: Default::default(),
 346                                                    outgoing_requests: Default::default(),
 347                                                    remove_outgoing_requests: Default::default(),
 348                                                },
 349                                            )
 350                                            .trace_err();
 351                                        }
 352                                    }
 353                                }
 354                            }
 355                        }
 356
 357                        if let Some(live_kit) = live_kit_client.as_ref() {
 358                            if delete_live_kit_room {
 359                                live_kit.delete_room(live_kit_room).await.trace_err();
 360                            }
 361                        }
 362                    }
 363                }
 364
 365                app_state
 366                    .db
 367                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 368                    .await
 369                    .trace_err();
 370            }
 371            .instrument(span),
 372        );
 373        Ok(())
 374    }
 375
 376    pub fn teardown(&self) {
 377        self.peer.teardown();
 378        self.connection_pool.lock().reset();
 379        let _ = self.teardown.send(());
 380    }
 381
 382    #[cfg(test)]
 383    pub fn reset(&self, id: ServerId) {
 384        self.teardown();
 385        *self.id.lock() = id;
 386        self.peer.reset(id.0 as u32);
 387    }
 388
 389    #[cfg(test)]
 390    pub fn id(&self) -> ServerId {
 391        *self.id.lock()
 392    }
 393
 394    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 395    where
 396        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 397        Fut: 'static + Send + Future<Output = Result<()>>,
 398        M: EnvelopedMessage,
 399    {
 400        let prev_handler = self.handlers.insert(
 401            TypeId::of::<M>(),
 402            Box::new(move |envelope, session| {
 403                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 404                let span = info_span!(
 405                    "handle message",
 406                    payload_type = envelope.payload_type_name()
 407                );
 408                span.in_scope(|| {
 409                    tracing::info!(
 410                        payload_type = envelope.payload_type_name(),
 411                        "message received"
 412                    );
 413                });
 414                let start_time = Instant::now();
 415                let future = (handler)(*envelope, session);
 416                async move {
 417                    let result = future.await;
 418                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 419                    match result {
 420                        Err(error) => {
 421                            tracing::error!(%error, ?duration_ms, "error handling message")
 422                        }
 423                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 424                    }
 425                }
 426                .instrument(span)
 427                .boxed()
 428            }),
 429        );
 430        if prev_handler.is_some() {
 431            panic!("registered a handler for the same message twice");
 432        }
 433        self
 434    }
 435
 436    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 437    where
 438        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 439        Fut: 'static + Send + Future<Output = Result<()>>,
 440        M: EnvelopedMessage,
 441    {
 442        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 443        self
 444    }
 445
 446    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 447    where
 448        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 449        Fut: Send + Future<Output = Result<()>>,
 450        M: RequestMessage,
 451    {
 452        let handler = Arc::new(handler);
 453        self.add_handler(move |envelope, session| {
 454            let receipt = envelope.receipt();
 455            let handler = handler.clone();
 456            async move {
 457                let peer = session.peer.clone();
 458                let responded = Arc::new(AtomicBool::default());
 459                let response = Response {
 460                    peer: peer.clone(),
 461                    responded: responded.clone(),
 462                    receipt,
 463                };
 464                match (handler)(envelope.payload, response, session).await {
 465                    Ok(()) => {
 466                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 467                            Ok(())
 468                        } else {
 469                            Err(anyhow!("handler did not send a response"))?
 470                        }
 471                    }
 472                    Err(error) => {
 473                        peer.respond_with_error(
 474                            receipt,
 475                            proto::Error {
 476                                message: error.to_string(),
 477                            },
 478                        )?;
 479                        Err(error)
 480                    }
 481                }
 482            }
 483        })
 484    }
 485
 486    pub fn handle_connection(
 487        self: &Arc<Self>,
 488        connection: Connection,
 489        address: String,
 490        user: User,
 491        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 492        executor: Executor,
 493    ) -> impl Future<Output = Result<()>> {
 494        let this = self.clone();
 495        let user_id = user.id;
 496        let login = user.github_login;
 497        let span = info_span!("handle connection", %user_id, %login, %address);
 498        let mut teardown = self.teardown.subscribe();
 499        async move {
 500            let (connection_id, handle_io, mut incoming_rx) = this
 501                .peer
 502                .add_connection(connection, {
 503                    let executor = executor.clone();
 504                    move |duration| executor.sleep(duration)
 505                });
 506
 507            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 508            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 509            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 510
 511            if let Some(send_connection_id) = send_connection_id.take() {
 512                let _ = send_connection_id.send(connection_id);
 513            }
 514
 515            if !user.connected_once {
 516                this.peer.send(connection_id, proto::ShowContacts {})?;
 517                this.app_state.db.set_user_connected_once(user_id, true).await?;
 518            }
 519
 520            let (contacts, invite_code, channels, channel_invites) = future::try_join4(
 521                this.app_state.db.get_contacts(user_id),
 522                this.app_state.db.get_invite_code_for_user(user_id),
 523                this.app_state.db.get_channels(user_id),
 524                this.app_state.db.get_channel_invites(user_id)
 525            ).await?;
 526
 527            {
 528                let mut pool = this.connection_pool.lock();
 529                pool.add_connection(connection_id, user_id, user.admin);
 530                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 531                this.peer.send(connection_id, build_initial_channels_update(channels, channel_invites))?;
 532
 533                if let Some((code, count)) = invite_code {
 534                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 535                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 536                        count: count as u32,
 537                    })?;
 538                }
 539            }
 540
 541            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 542                this.peer.send(connection_id, incoming_call)?;
 543            }
 544
 545            let session = Session {
 546                user_id,
 547                connection_id,
 548                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 549                peer: this.peer.clone(),
 550                connection_pool: this.connection_pool.clone(),
 551                live_kit_client: this.app_state.live_kit_client.clone(),
 552                executor: executor.clone(),
 553            };
 554            update_user_contacts(user_id, &session).await?;
 555
 556            let handle_io = handle_io.fuse();
 557            futures::pin_mut!(handle_io);
 558
 559            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 560            // This prevents deadlocks when e.g., client A performs a request to client B and
 561            // client B performs a request to client A. If both clients stop processing further
 562            // messages until their respective request completes, they won't have a chance to
 563            // respond to the other client's request and cause a deadlock.
 564            //
 565            // This arrangement ensures we will attempt to process earlier messages first, but fall
 566            // back to processing messages arrived later in the spirit of making progress.
 567            let mut foreground_message_handlers = FuturesUnordered::new();
 568            let concurrent_handlers = Arc::new(Semaphore::new(256));
 569            loop {
 570                let next_message = async {
 571                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 572                    let message = incoming_rx.next().await;
 573                    (permit, message)
 574                }.fuse();
 575                futures::pin_mut!(next_message);
 576                futures::select_biased! {
 577                    _ = teardown.changed().fuse() => return Ok(()),
 578                    result = handle_io => {
 579                        if let Err(error) = result {
 580                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 581                        }
 582                        break;
 583                    }
 584                    _ = foreground_message_handlers.next() => {}
 585                    next_message = next_message => {
 586                        let (permit, message) = next_message;
 587                        if let Some(message) = message {
 588                            let type_name = message.payload_type_name();
 589                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 590                            let span_enter = span.enter();
 591                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 592                                let is_background = message.is_background();
 593                                let handle_message = (handler)(message, session.clone());
 594                                drop(span_enter);
 595
 596                                let handle_message = async move {
 597                                    handle_message.await;
 598                                    drop(permit);
 599                                }.instrument(span);
 600                                if is_background {
 601                                    executor.spawn_detached(handle_message);
 602                                } else {
 603                                    foreground_message_handlers.push(handle_message);
 604                                }
 605                            } else {
 606                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 607                            }
 608                        } else {
 609                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 610                            break;
 611                        }
 612                    }
 613                }
 614            }
 615
 616            drop(foreground_message_handlers);
 617            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 618            if let Err(error) = connection_lost(session, teardown, executor).await {
 619                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 620            }
 621
 622            Ok(())
 623        }.instrument(span)
 624    }
 625
 626    pub async fn invite_code_redeemed(
 627        self: &Arc<Self>,
 628        inviter_id: UserId,
 629        invitee_id: UserId,
 630    ) -> Result<()> {
 631        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 632            if let Some(code) = &user.invite_code {
 633                let pool = self.connection_pool.lock();
 634                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 635                for connection_id in pool.user_connection_ids(inviter_id) {
 636                    self.peer.send(
 637                        connection_id,
 638                        proto::UpdateContacts {
 639                            contacts: vec![invitee_contact.clone()],
 640                            ..Default::default()
 641                        },
 642                    )?;
 643                    self.peer.send(
 644                        connection_id,
 645                        proto::UpdateInviteInfo {
 646                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 647                            count: user.invite_count as u32,
 648                        },
 649                    )?;
 650                }
 651            }
 652        }
 653        Ok(())
 654    }
 655
 656    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 657        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 658            if let Some(invite_code) = &user.invite_code {
 659                let pool = self.connection_pool.lock();
 660                for connection_id in pool.user_connection_ids(user_id) {
 661                    self.peer.send(
 662                        connection_id,
 663                        proto::UpdateInviteInfo {
 664                            url: format!(
 665                                "{}{}",
 666                                self.app_state.config.invite_link_prefix, invite_code
 667                            ),
 668                            count: user.invite_count as u32,
 669                        },
 670                    )?;
 671                }
 672            }
 673        }
 674        Ok(())
 675    }
 676
 677    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 678        ServerSnapshot {
 679            connection_pool: ConnectionPoolGuard {
 680                guard: self.connection_pool.lock(),
 681                _not_send: PhantomData,
 682            },
 683            peer: &self.peer,
 684        }
 685    }
 686}
 687
 688impl<'a> Deref for ConnectionPoolGuard<'a> {
 689    type Target = ConnectionPool;
 690
 691    fn deref(&self) -> &Self::Target {
 692        &*self.guard
 693    }
 694}
 695
 696impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 697    fn deref_mut(&mut self) -> &mut Self::Target {
 698        &mut *self.guard
 699    }
 700}
 701
 702impl<'a> Drop for ConnectionPoolGuard<'a> {
 703    fn drop(&mut self) {
 704        #[cfg(test)]
 705        self.check_invariants();
 706    }
 707}
 708
 709fn broadcast<F>(
 710    sender_id: Option<ConnectionId>,
 711    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 712    mut f: F,
 713) where
 714    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 715{
 716    for receiver_id in receiver_ids {
 717        if Some(receiver_id) != sender_id {
 718            if let Err(error) = f(receiver_id) {
 719                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 720            }
 721        }
 722    }
 723}
 724
 725lazy_static! {
 726    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 727}
 728
 729pub struct ProtocolVersion(u32);
 730
 731impl Header for ProtocolVersion {
 732    fn name() -> &'static HeaderName {
 733        &ZED_PROTOCOL_VERSION
 734    }
 735
 736    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 737    where
 738        Self: Sized,
 739        I: Iterator<Item = &'i axum::http::HeaderValue>,
 740    {
 741        let version = values
 742            .next()
 743            .ok_or_else(axum::headers::Error::invalid)?
 744            .to_str()
 745            .map_err(|_| axum::headers::Error::invalid())?
 746            .parse()
 747            .map_err(|_| axum::headers::Error::invalid())?;
 748        Ok(Self(version))
 749    }
 750
 751    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 752        values.extend([self.0.to_string().parse().unwrap()]);
 753    }
 754}
 755
 756pub fn routes(server: Arc<Server>) -> Router<Body> {
 757    Router::new()
 758        .route("/rpc", get(handle_websocket_request))
 759        .layer(
 760            ServiceBuilder::new()
 761                .layer(Extension(server.app_state.clone()))
 762                .layer(middleware::from_fn(auth::validate_header)),
 763        )
 764        .route("/metrics", get(handle_metrics))
 765        .layer(Extension(server))
 766}
 767
 768pub async fn handle_websocket_request(
 769    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 770    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 771    Extension(server): Extension<Arc<Server>>,
 772    Extension(user): Extension<User>,
 773    ws: WebSocketUpgrade,
 774) -> axum::response::Response {
 775    if protocol_version != rpc::PROTOCOL_VERSION {
 776        return (
 777            StatusCode::UPGRADE_REQUIRED,
 778            "client must be upgraded".to_string(),
 779        )
 780            .into_response();
 781    }
 782    let socket_address = socket_address.to_string();
 783    ws.on_upgrade(move |socket| {
 784        use util::ResultExt;
 785        let socket = socket
 786            .map_ok(to_tungstenite_message)
 787            .err_into()
 788            .with(|message| async move { Ok(to_axum_message(message)) });
 789        let connection = Connection::new(Box::pin(socket));
 790        async move {
 791            server
 792                .handle_connection(connection, socket_address, user, None, Executor::Production)
 793                .await
 794                .log_err();
 795        }
 796    })
 797}
 798
 799pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 800    let connections = server
 801        .connection_pool
 802        .lock()
 803        .connections()
 804        .filter(|connection| !connection.admin)
 805        .count();
 806
 807    METRIC_CONNECTIONS.set(connections as _);
 808
 809    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 810    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 811
 812    let encoder = prometheus::TextEncoder::new();
 813    let metric_families = prometheus::gather();
 814    let encoded_metrics = encoder
 815        .encode_to_string(&metric_families)
 816        .map_err(|err| anyhow!("{}", err))?;
 817    Ok(encoded_metrics)
 818}
 819
 820#[instrument(err, skip(executor))]
 821async fn connection_lost(
 822    session: Session,
 823    mut teardown: watch::Receiver<()>,
 824    executor: Executor,
 825) -> Result<()> {
 826    session.peer.disconnect(session.connection_id);
 827    session
 828        .connection_pool()
 829        .await
 830        .remove_connection(session.connection_id)?;
 831
 832    session
 833        .db()
 834        .await
 835        .connection_lost(session.connection_id)
 836        .await
 837        .trace_err();
 838
 839    futures::select_biased! {
 840        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 841            leave_room_for_session(&session).await.trace_err();
 842
 843            if !session
 844                .connection_pool()
 845                .await
 846                .is_user_online(session.user_id)
 847            {
 848                let db = session.db().await;
 849                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 850                    room_updated(&room, &session.peer);
 851                }
 852            }
 853            update_user_contacts(session.user_id, &session).await?;
 854        }
 855        _ = teardown.changed().fuse() => {}
 856    }
 857
 858    Ok(())
 859}
 860
 861async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 862    response.send(proto::Ack {})?;
 863    Ok(())
 864}
 865
 866async fn create_room(
 867    _request: proto::CreateRoom,
 868    response: Response<proto::CreateRoom>,
 869    session: Session,
 870) -> Result<()> {
 871    let live_kit_room = nanoid::nanoid!(30);
 872
 873    let live_kit_connection_info = {
 874        let live_kit_room = live_kit_room.clone();
 875        let live_kit = session.live_kit_client.as_ref();
 876
 877        util::async_iife!({
 878            let live_kit = live_kit?;
 879
 880            live_kit
 881                .create_room(live_kit_room.clone())
 882                .await
 883                .trace_err()?;
 884
 885            let token = live_kit
 886                .room_token(&live_kit_room, &session.user_id.to_string())
 887                .trace_err()?;
 888
 889            Some(proto::LiveKitConnectionInfo {
 890                server_url: live_kit.url().into(),
 891                token,
 892            })
 893        })
 894    }
 895    .await;
 896
 897    let room = session
 898        .db()
 899        .await
 900        .create_room(session.user_id, session.connection_id, &live_kit_room)
 901        .await?;
 902
 903    response.send(proto::CreateRoomResponse {
 904        room: Some(room.clone()),
 905        live_kit_connection_info,
 906    })?;
 907
 908    update_user_contacts(session.user_id, &session).await?;
 909    Ok(())
 910}
 911
 912async fn join_room(
 913    request: proto::JoinRoom,
 914    response: Response<proto::JoinRoom>,
 915    session: Session,
 916) -> Result<()> {
 917    let room_id = RoomId::from_proto(request.id);
 918    let room = {
 919        let room = session
 920            .db()
 921            .await
 922            .join_room(room_id, session.user_id, None, session.connection_id)
 923            .await?;
 924        room_updated(&room, &session.peer);
 925        room.clone()
 926    };
 927
 928    for connection_id in session
 929        .connection_pool()
 930        .await
 931        .user_connection_ids(session.user_id)
 932    {
 933        session
 934            .peer
 935            .send(
 936                connection_id,
 937                proto::CallCanceled {
 938                    room_id: room_id.to_proto(),
 939                },
 940            )
 941            .trace_err();
 942    }
 943
 944    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 945        if let Some(token) = live_kit
 946            .room_token(&room.live_kit_room, &session.user_id.to_string())
 947            .trace_err()
 948        {
 949            Some(proto::LiveKitConnectionInfo {
 950                server_url: live_kit.url().into(),
 951                token,
 952            })
 953        } else {
 954            None
 955        }
 956    } else {
 957        None
 958    };
 959
 960    response.send(proto::JoinRoomResponse {
 961        room: Some(room),
 962        live_kit_connection_info,
 963    })?;
 964
 965    update_user_contacts(session.user_id, &session).await?;
 966    Ok(())
 967}
 968
 969async fn rejoin_room(
 970    request: proto::RejoinRoom,
 971    response: Response<proto::RejoinRoom>,
 972    session: Session,
 973) -> Result<()> {
 974    {
 975        let mut rejoined_room = session
 976            .db()
 977            .await
 978            .rejoin_room(request, session.user_id, session.connection_id)
 979            .await?;
 980
 981        response.send(proto::RejoinRoomResponse {
 982            room: Some(rejoined_room.room.clone()),
 983            reshared_projects: rejoined_room
 984                .reshared_projects
 985                .iter()
 986                .map(|project| proto::ResharedProject {
 987                    id: project.id.to_proto(),
 988                    collaborators: project
 989                        .collaborators
 990                        .iter()
 991                        .map(|collaborator| collaborator.to_proto())
 992                        .collect(),
 993                })
 994                .collect(),
 995            rejoined_projects: rejoined_room
 996                .rejoined_projects
 997                .iter()
 998                .map(|rejoined_project| proto::RejoinedProject {
 999                    id: rejoined_project.id.to_proto(),
1000                    worktrees: rejoined_project
1001                        .worktrees
1002                        .iter()
1003                        .map(|worktree| proto::WorktreeMetadata {
1004                            id: worktree.id,
1005                            root_name: worktree.root_name.clone(),
1006                            visible: worktree.visible,
1007                            abs_path: worktree.abs_path.clone(),
1008                        })
1009                        .collect(),
1010                    collaborators: rejoined_project
1011                        .collaborators
1012                        .iter()
1013                        .map(|collaborator| collaborator.to_proto())
1014                        .collect(),
1015                    language_servers: rejoined_project.language_servers.clone(),
1016                })
1017                .collect(),
1018        })?;
1019        room_updated(&rejoined_room.room, &session.peer);
1020
1021        for project in &rejoined_room.reshared_projects {
1022            for collaborator in &project.collaborators {
1023                session
1024                    .peer
1025                    .send(
1026                        collaborator.connection_id,
1027                        proto::UpdateProjectCollaborator {
1028                            project_id: project.id.to_proto(),
1029                            old_peer_id: Some(project.old_connection_id.into()),
1030                            new_peer_id: Some(session.connection_id.into()),
1031                        },
1032                    )
1033                    .trace_err();
1034            }
1035
1036            broadcast(
1037                Some(session.connection_id),
1038                project
1039                    .collaborators
1040                    .iter()
1041                    .map(|collaborator| collaborator.connection_id),
1042                |connection_id| {
1043                    session.peer.forward_send(
1044                        session.connection_id,
1045                        connection_id,
1046                        proto::UpdateProject {
1047                            project_id: project.id.to_proto(),
1048                            worktrees: project.worktrees.clone(),
1049                        },
1050                    )
1051                },
1052            );
1053        }
1054
1055        for project in &rejoined_room.rejoined_projects {
1056            for collaborator in &project.collaborators {
1057                session
1058                    .peer
1059                    .send(
1060                        collaborator.connection_id,
1061                        proto::UpdateProjectCollaborator {
1062                            project_id: project.id.to_proto(),
1063                            old_peer_id: Some(project.old_connection_id.into()),
1064                            new_peer_id: Some(session.connection_id.into()),
1065                        },
1066                    )
1067                    .trace_err();
1068            }
1069        }
1070
1071        for project in &mut rejoined_room.rejoined_projects {
1072            for worktree in mem::take(&mut project.worktrees) {
1073                #[cfg(any(test, feature = "test-support"))]
1074                const MAX_CHUNK_SIZE: usize = 2;
1075                #[cfg(not(any(test, feature = "test-support")))]
1076                const MAX_CHUNK_SIZE: usize = 256;
1077
1078                // Stream this worktree's entries.
1079                let message = proto::UpdateWorktree {
1080                    project_id: project.id.to_proto(),
1081                    worktree_id: worktree.id,
1082                    abs_path: worktree.abs_path.clone(),
1083                    root_name: worktree.root_name,
1084                    updated_entries: worktree.updated_entries,
1085                    removed_entries: worktree.removed_entries,
1086                    scan_id: worktree.scan_id,
1087                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1088                    updated_repositories: worktree.updated_repositories,
1089                    removed_repositories: worktree.removed_repositories,
1090                };
1091                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1092                    session.peer.send(session.connection_id, update.clone())?;
1093                }
1094
1095                // Stream this worktree's diagnostics.
1096                for summary in worktree.diagnostic_summaries {
1097                    session.peer.send(
1098                        session.connection_id,
1099                        proto::UpdateDiagnosticSummary {
1100                            project_id: project.id.to_proto(),
1101                            worktree_id: worktree.id,
1102                            summary: Some(summary),
1103                        },
1104                    )?;
1105                }
1106
1107                for settings_file in worktree.settings_files {
1108                    session.peer.send(
1109                        session.connection_id,
1110                        proto::UpdateWorktreeSettings {
1111                            project_id: project.id.to_proto(),
1112                            worktree_id: worktree.id,
1113                            path: settings_file.path,
1114                            content: Some(settings_file.content),
1115                        },
1116                    )?;
1117                }
1118            }
1119
1120            for language_server in &project.language_servers {
1121                session.peer.send(
1122                    session.connection_id,
1123                    proto::UpdateLanguageServer {
1124                        project_id: project.id.to_proto(),
1125                        language_server_id: language_server.id,
1126                        variant: Some(
1127                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1128                                proto::LspDiskBasedDiagnosticsUpdated {},
1129                            ),
1130                        ),
1131                    },
1132                )?;
1133            }
1134        }
1135    }
1136
1137    update_user_contacts(session.user_id, &session).await?;
1138    Ok(())
1139}
1140
1141async fn leave_room(
1142    _: proto::LeaveRoom,
1143    response: Response<proto::LeaveRoom>,
1144    session: Session,
1145) -> Result<()> {
1146    leave_room_for_session(&session).await?;
1147    response.send(proto::Ack {})?;
1148    Ok(())
1149}
1150
1151async fn call(
1152    request: proto::Call,
1153    response: Response<proto::Call>,
1154    session: Session,
1155) -> Result<()> {
1156    let room_id = RoomId::from_proto(request.room_id);
1157    let calling_user_id = session.user_id;
1158    let calling_connection_id = session.connection_id;
1159    let called_user_id = UserId::from_proto(request.called_user_id);
1160    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1161    if !session
1162        .db()
1163        .await
1164        .has_contact(calling_user_id, called_user_id)
1165        .await?
1166    {
1167        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1168    }
1169
1170    let incoming_call = {
1171        let (room, incoming_call) = &mut *session
1172            .db()
1173            .await
1174            .call(
1175                room_id,
1176                calling_user_id,
1177                calling_connection_id,
1178                called_user_id,
1179                initial_project_id,
1180            )
1181            .await?;
1182        room_updated(&room, &session.peer);
1183        mem::take(incoming_call)
1184    };
1185    update_user_contacts(called_user_id, &session).await?;
1186
1187    let mut calls = session
1188        .connection_pool()
1189        .await
1190        .user_connection_ids(called_user_id)
1191        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1192        .collect::<FuturesUnordered<_>>();
1193
1194    while let Some(call_response) = calls.next().await {
1195        match call_response.as_ref() {
1196            Ok(_) => {
1197                response.send(proto::Ack {})?;
1198                return Ok(());
1199            }
1200            Err(_) => {
1201                call_response.trace_err();
1202            }
1203        }
1204    }
1205
1206    {
1207        let room = session
1208            .db()
1209            .await
1210            .call_failed(room_id, called_user_id)
1211            .await?;
1212        room_updated(&room, &session.peer);
1213    }
1214    update_user_contacts(called_user_id, &session).await?;
1215
1216    Err(anyhow!("failed to ring user"))?
1217}
1218
1219async fn cancel_call(
1220    request: proto::CancelCall,
1221    response: Response<proto::CancelCall>,
1222    session: Session,
1223) -> Result<()> {
1224    let called_user_id = UserId::from_proto(request.called_user_id);
1225    let room_id = RoomId::from_proto(request.room_id);
1226    {
1227        let room = session
1228            .db()
1229            .await
1230            .cancel_call(room_id, session.connection_id, called_user_id)
1231            .await?;
1232        room_updated(&room, &session.peer);
1233    }
1234
1235    for connection_id in session
1236        .connection_pool()
1237        .await
1238        .user_connection_ids(called_user_id)
1239    {
1240        session
1241            .peer
1242            .send(
1243                connection_id,
1244                proto::CallCanceled {
1245                    room_id: room_id.to_proto(),
1246                },
1247            )
1248            .trace_err();
1249    }
1250    response.send(proto::Ack {})?;
1251
1252    update_user_contacts(called_user_id, &session).await?;
1253    Ok(())
1254}
1255
1256async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1257    let room_id = RoomId::from_proto(message.room_id);
1258    {
1259        let room = session
1260            .db()
1261            .await
1262            .decline_call(Some(room_id), session.user_id)
1263            .await?
1264            .ok_or_else(|| anyhow!("failed to decline call"))?;
1265        room_updated(&room, &session.peer);
1266    }
1267
1268    for connection_id in session
1269        .connection_pool()
1270        .await
1271        .user_connection_ids(session.user_id)
1272    {
1273        session
1274            .peer
1275            .send(
1276                connection_id,
1277                proto::CallCanceled {
1278                    room_id: room_id.to_proto(),
1279                },
1280            )
1281            .trace_err();
1282    }
1283    update_user_contacts(session.user_id, &session).await?;
1284    Ok(())
1285}
1286
1287async fn update_participant_location(
1288    request: proto::UpdateParticipantLocation,
1289    response: Response<proto::UpdateParticipantLocation>,
1290    session: Session,
1291) -> Result<()> {
1292    let room_id = RoomId::from_proto(request.room_id);
1293    let location = request
1294        .location
1295        .ok_or_else(|| anyhow!("invalid location"))?;
1296
1297    let db = session.db().await;
1298    let room = db
1299        .update_room_participant_location(room_id, session.connection_id, location)
1300        .await?;
1301
1302    room_updated(&room, &session.peer);
1303    response.send(proto::Ack {})?;
1304    Ok(())
1305}
1306
1307async fn share_project(
1308    request: proto::ShareProject,
1309    response: Response<proto::ShareProject>,
1310    session: Session,
1311) -> Result<()> {
1312    let (project_id, room) = &*session
1313        .db()
1314        .await
1315        .share_project(
1316            RoomId::from_proto(request.room_id),
1317            session.connection_id,
1318            &request.worktrees,
1319        )
1320        .await?;
1321    response.send(proto::ShareProjectResponse {
1322        project_id: project_id.to_proto(),
1323    })?;
1324    room_updated(&room, &session.peer);
1325
1326    Ok(())
1327}
1328
1329async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1330    let project_id = ProjectId::from_proto(message.project_id);
1331
1332    let (room, guest_connection_ids) = &*session
1333        .db()
1334        .await
1335        .unshare_project(project_id, session.connection_id)
1336        .await?;
1337
1338    broadcast(
1339        Some(session.connection_id),
1340        guest_connection_ids.iter().copied(),
1341        |conn_id| session.peer.send(conn_id, message.clone()),
1342    );
1343    room_updated(&room, &session.peer);
1344
1345    Ok(())
1346}
1347
1348async fn join_project(
1349    request: proto::JoinProject,
1350    response: Response<proto::JoinProject>,
1351    session: Session,
1352) -> Result<()> {
1353    let project_id = ProjectId::from_proto(request.project_id);
1354    let guest_user_id = session.user_id;
1355
1356    tracing::info!(%project_id, "join project");
1357
1358    let (project, replica_id) = &mut *session
1359        .db()
1360        .await
1361        .join_project(project_id, session.connection_id)
1362        .await?;
1363
1364    let collaborators = project
1365        .collaborators
1366        .iter()
1367        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1368        .map(|collaborator| collaborator.to_proto())
1369        .collect::<Vec<_>>();
1370
1371    let worktrees = project
1372        .worktrees
1373        .iter()
1374        .map(|(id, worktree)| proto::WorktreeMetadata {
1375            id: *id,
1376            root_name: worktree.root_name.clone(),
1377            visible: worktree.visible,
1378            abs_path: worktree.abs_path.clone(),
1379        })
1380        .collect::<Vec<_>>();
1381
1382    for collaborator in &collaborators {
1383        session
1384            .peer
1385            .send(
1386                collaborator.peer_id.unwrap().into(),
1387                proto::AddProjectCollaborator {
1388                    project_id: project_id.to_proto(),
1389                    collaborator: Some(proto::Collaborator {
1390                        peer_id: Some(session.connection_id.into()),
1391                        replica_id: replica_id.0 as u32,
1392                        user_id: guest_user_id.to_proto(),
1393                    }),
1394                },
1395            )
1396            .trace_err();
1397    }
1398
1399    // First, we send the metadata associated with each worktree.
1400    response.send(proto::JoinProjectResponse {
1401        worktrees: worktrees.clone(),
1402        replica_id: replica_id.0 as u32,
1403        collaborators: collaborators.clone(),
1404        language_servers: project.language_servers.clone(),
1405    })?;
1406
1407    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1408        #[cfg(any(test, feature = "test-support"))]
1409        const MAX_CHUNK_SIZE: usize = 2;
1410        #[cfg(not(any(test, feature = "test-support")))]
1411        const MAX_CHUNK_SIZE: usize = 256;
1412
1413        // Stream this worktree's entries.
1414        let message = proto::UpdateWorktree {
1415            project_id: project_id.to_proto(),
1416            worktree_id,
1417            abs_path: worktree.abs_path.clone(),
1418            root_name: worktree.root_name,
1419            updated_entries: worktree.entries,
1420            removed_entries: Default::default(),
1421            scan_id: worktree.scan_id,
1422            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1423            updated_repositories: worktree.repository_entries.into_values().collect(),
1424            removed_repositories: Default::default(),
1425        };
1426        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1427            session.peer.send(session.connection_id, update.clone())?;
1428        }
1429
1430        // Stream this worktree's diagnostics.
1431        for summary in worktree.diagnostic_summaries {
1432            session.peer.send(
1433                session.connection_id,
1434                proto::UpdateDiagnosticSummary {
1435                    project_id: project_id.to_proto(),
1436                    worktree_id: worktree.id,
1437                    summary: Some(summary),
1438                },
1439            )?;
1440        }
1441
1442        for settings_file in worktree.settings_files {
1443            session.peer.send(
1444                session.connection_id,
1445                proto::UpdateWorktreeSettings {
1446                    project_id: project_id.to_proto(),
1447                    worktree_id: worktree.id,
1448                    path: settings_file.path,
1449                    content: Some(settings_file.content),
1450                },
1451            )?;
1452        }
1453    }
1454
1455    for language_server in &project.language_servers {
1456        session.peer.send(
1457            session.connection_id,
1458            proto::UpdateLanguageServer {
1459                project_id: project_id.to_proto(),
1460                language_server_id: language_server.id,
1461                variant: Some(
1462                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1463                        proto::LspDiskBasedDiagnosticsUpdated {},
1464                    ),
1465                ),
1466            },
1467        )?;
1468    }
1469
1470    Ok(())
1471}
1472
1473async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1474    let sender_id = session.connection_id;
1475    let project_id = ProjectId::from_proto(request.project_id);
1476
1477    let (room, project) = &*session
1478        .db()
1479        .await
1480        .leave_project(project_id, sender_id)
1481        .await?;
1482    tracing::info!(
1483        %project_id,
1484        host_user_id = %project.host_user_id,
1485        host_connection_id = %project.host_connection_id,
1486        "leave project"
1487    );
1488
1489    project_left(&project, &session);
1490    room_updated(&room, &session.peer);
1491
1492    Ok(())
1493}
1494
1495async fn update_project(
1496    request: proto::UpdateProject,
1497    response: Response<proto::UpdateProject>,
1498    session: Session,
1499) -> Result<()> {
1500    let project_id = ProjectId::from_proto(request.project_id);
1501    let (room, guest_connection_ids) = &*session
1502        .db()
1503        .await
1504        .update_project(project_id, session.connection_id, &request.worktrees)
1505        .await?;
1506    broadcast(
1507        Some(session.connection_id),
1508        guest_connection_ids.iter().copied(),
1509        |connection_id| {
1510            session
1511                .peer
1512                .forward_send(session.connection_id, connection_id, request.clone())
1513        },
1514    );
1515    room_updated(&room, &session.peer);
1516    response.send(proto::Ack {})?;
1517
1518    Ok(())
1519}
1520
1521async fn update_worktree(
1522    request: proto::UpdateWorktree,
1523    response: Response<proto::UpdateWorktree>,
1524    session: Session,
1525) -> Result<()> {
1526    let guest_connection_ids = session
1527        .db()
1528        .await
1529        .update_worktree(&request, session.connection_id)
1530        .await?;
1531
1532    broadcast(
1533        Some(session.connection_id),
1534        guest_connection_ids.iter().copied(),
1535        |connection_id| {
1536            session
1537                .peer
1538                .forward_send(session.connection_id, connection_id, request.clone())
1539        },
1540    );
1541    response.send(proto::Ack {})?;
1542    Ok(())
1543}
1544
1545async fn update_diagnostic_summary(
1546    message: proto::UpdateDiagnosticSummary,
1547    session: Session,
1548) -> Result<()> {
1549    let guest_connection_ids = session
1550        .db()
1551        .await
1552        .update_diagnostic_summary(&message, session.connection_id)
1553        .await?;
1554
1555    broadcast(
1556        Some(session.connection_id),
1557        guest_connection_ids.iter().copied(),
1558        |connection_id| {
1559            session
1560                .peer
1561                .forward_send(session.connection_id, connection_id, message.clone())
1562        },
1563    );
1564
1565    Ok(())
1566}
1567
1568async fn update_worktree_settings(
1569    message: proto::UpdateWorktreeSettings,
1570    session: Session,
1571) -> Result<()> {
1572    let guest_connection_ids = session
1573        .db()
1574        .await
1575        .update_worktree_settings(&message, session.connection_id)
1576        .await?;
1577
1578    broadcast(
1579        Some(session.connection_id),
1580        guest_connection_ids.iter().copied(),
1581        |connection_id| {
1582            session
1583                .peer
1584                .forward_send(session.connection_id, connection_id, message.clone())
1585        },
1586    );
1587
1588    Ok(())
1589}
1590
1591async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1592    broadcast_project_message(request.project_id, request, session).await
1593}
1594
1595async fn start_language_server(
1596    request: proto::StartLanguageServer,
1597    session: Session,
1598) -> Result<()> {
1599    let guest_connection_ids = session
1600        .db()
1601        .await
1602        .start_language_server(&request, session.connection_id)
1603        .await?;
1604
1605    broadcast(
1606        Some(session.connection_id),
1607        guest_connection_ids.iter().copied(),
1608        |connection_id| {
1609            session
1610                .peer
1611                .forward_send(session.connection_id, connection_id, request.clone())
1612        },
1613    );
1614    Ok(())
1615}
1616
1617async fn update_language_server(
1618    request: proto::UpdateLanguageServer,
1619    session: Session,
1620) -> Result<()> {
1621    session.executor.record_backtrace();
1622    let project_id = ProjectId::from_proto(request.project_id);
1623    let project_connection_ids = session
1624        .db()
1625        .await
1626        .project_connection_ids(project_id, session.connection_id)
1627        .await?;
1628    broadcast(
1629        Some(session.connection_id),
1630        project_connection_ids.iter().copied(),
1631        |connection_id| {
1632            session
1633                .peer
1634                .forward_send(session.connection_id, connection_id, request.clone())
1635        },
1636    );
1637    Ok(())
1638}
1639
1640async fn forward_project_request<T>(
1641    request: T,
1642    response: Response<T>,
1643    session: Session,
1644) -> Result<()>
1645where
1646    T: EntityMessage + RequestMessage,
1647{
1648    session.executor.record_backtrace();
1649    let project_id = ProjectId::from_proto(request.remote_entity_id());
1650    let host_connection_id = {
1651        let collaborators = session
1652            .db()
1653            .await
1654            .project_collaborators(project_id, session.connection_id)
1655            .await?;
1656        collaborators
1657            .iter()
1658            .find(|collaborator| collaborator.is_host)
1659            .ok_or_else(|| anyhow!("host not found"))?
1660            .connection_id
1661    };
1662
1663    let payload = session
1664        .peer
1665        .forward_request(session.connection_id, host_connection_id, request)
1666        .await?;
1667
1668    response.send(payload)?;
1669    Ok(())
1670}
1671
1672async fn create_buffer_for_peer(
1673    request: proto::CreateBufferForPeer,
1674    session: Session,
1675) -> Result<()> {
1676    session.executor.record_backtrace();
1677    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1678    session
1679        .peer
1680        .forward_send(session.connection_id, peer_id.into(), request)?;
1681    Ok(())
1682}
1683
1684async fn update_buffer(
1685    request: proto::UpdateBuffer,
1686    response: Response<proto::UpdateBuffer>,
1687    session: Session,
1688) -> Result<()> {
1689    session.executor.record_backtrace();
1690    let project_id = ProjectId::from_proto(request.project_id);
1691    let mut guest_connection_ids;
1692    let mut host_connection_id = None;
1693    {
1694        let collaborators = session
1695            .db()
1696            .await
1697            .project_collaborators(project_id, session.connection_id)
1698            .await?;
1699        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1700        for collaborator in collaborators.iter() {
1701            if collaborator.is_host {
1702                host_connection_id = Some(collaborator.connection_id);
1703            } else {
1704                guest_connection_ids.push(collaborator.connection_id);
1705            }
1706        }
1707    }
1708    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1709
1710    session.executor.record_backtrace();
1711    broadcast(
1712        Some(session.connection_id),
1713        guest_connection_ids,
1714        |connection_id| {
1715            session
1716                .peer
1717                .forward_send(session.connection_id, connection_id, request.clone())
1718        },
1719    );
1720    if host_connection_id != session.connection_id {
1721        session
1722            .peer
1723            .forward_request(session.connection_id, host_connection_id, request.clone())
1724            .await?;
1725    }
1726
1727    response.send(proto::Ack {})?;
1728    Ok(())
1729}
1730
1731async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1732    let project_id = ProjectId::from_proto(request.project_id);
1733    let project_connection_ids = session
1734        .db()
1735        .await
1736        .project_connection_ids(project_id, session.connection_id)
1737        .await?;
1738
1739    broadcast(
1740        Some(session.connection_id),
1741        project_connection_ids.iter().copied(),
1742        |connection_id| {
1743            session
1744                .peer
1745                .forward_send(session.connection_id, connection_id, request.clone())
1746        },
1747    );
1748    Ok(())
1749}
1750
1751async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1752    let project_id = ProjectId::from_proto(request.project_id);
1753    let project_connection_ids = session
1754        .db()
1755        .await
1756        .project_connection_ids(project_id, session.connection_id)
1757        .await?;
1758    broadcast(
1759        Some(session.connection_id),
1760        project_connection_ids.iter().copied(),
1761        |connection_id| {
1762            session
1763                .peer
1764                .forward_send(session.connection_id, connection_id, request.clone())
1765        },
1766    );
1767    Ok(())
1768}
1769
1770async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1771    broadcast_project_message(request.project_id, request, session).await
1772}
1773
1774async fn broadcast_project_message<T: EnvelopedMessage>(
1775    project_id: u64,
1776    request: T,
1777    session: Session,
1778) -> Result<()> {
1779    let project_id = ProjectId::from_proto(project_id);
1780    let project_connection_ids = session
1781        .db()
1782        .await
1783        .project_connection_ids(project_id, session.connection_id)
1784        .await?;
1785    broadcast(
1786        Some(session.connection_id),
1787        project_connection_ids.iter().copied(),
1788        |connection_id| {
1789            session
1790                .peer
1791                .forward_send(session.connection_id, connection_id, request.clone())
1792        },
1793    );
1794    Ok(())
1795}
1796
1797async fn follow(
1798    request: proto::Follow,
1799    response: Response<proto::Follow>,
1800    session: Session,
1801) -> Result<()> {
1802    let project_id = ProjectId::from_proto(request.project_id);
1803    let leader_id = request
1804        .leader_id
1805        .ok_or_else(|| anyhow!("invalid leader id"))?
1806        .into();
1807    let follower_id = session.connection_id;
1808
1809    {
1810        let project_connection_ids = session
1811            .db()
1812            .await
1813            .project_connection_ids(project_id, session.connection_id)
1814            .await?;
1815
1816        if !project_connection_ids.contains(&leader_id) {
1817            Err(anyhow!("no such peer"))?;
1818        }
1819    }
1820
1821    let mut response_payload = session
1822        .peer
1823        .forward_request(session.connection_id, leader_id, request)
1824        .await?;
1825    response_payload
1826        .views
1827        .retain(|view| view.leader_id != Some(follower_id.into()));
1828    response.send(response_payload)?;
1829
1830    let room = session
1831        .db()
1832        .await
1833        .follow(project_id, leader_id, follower_id)
1834        .await?;
1835    room_updated(&room, &session.peer);
1836
1837    Ok(())
1838}
1839
1840async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1841    let project_id = ProjectId::from_proto(request.project_id);
1842    let leader_id = request
1843        .leader_id
1844        .ok_or_else(|| anyhow!("invalid leader id"))?
1845        .into();
1846    let follower_id = session.connection_id;
1847
1848    if !session
1849        .db()
1850        .await
1851        .project_connection_ids(project_id, session.connection_id)
1852        .await?
1853        .contains(&leader_id)
1854    {
1855        Err(anyhow!("no such peer"))?;
1856    }
1857
1858    session
1859        .peer
1860        .forward_send(session.connection_id, leader_id, request)?;
1861
1862    let room = session
1863        .db()
1864        .await
1865        .unfollow(project_id, leader_id, follower_id)
1866        .await?;
1867    room_updated(&room, &session.peer);
1868
1869    Ok(())
1870}
1871
1872async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1873    let project_id = ProjectId::from_proto(request.project_id);
1874    let project_connection_ids = session
1875        .db
1876        .lock()
1877        .await
1878        .project_connection_ids(project_id, session.connection_id)
1879        .await?;
1880
1881    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1882        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1883        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1884        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1885    });
1886    for follower_peer_id in request.follower_ids.iter().copied() {
1887        let follower_connection_id = follower_peer_id.into();
1888        if project_connection_ids.contains(&follower_connection_id)
1889            && Some(follower_peer_id) != leader_id
1890        {
1891            session.peer.forward_send(
1892                session.connection_id,
1893                follower_connection_id,
1894                request.clone(),
1895            )?;
1896        }
1897    }
1898    Ok(())
1899}
1900
1901async fn get_users(
1902    request: proto::GetUsers,
1903    response: Response<proto::GetUsers>,
1904    session: Session,
1905) -> Result<()> {
1906    let user_ids = request
1907        .user_ids
1908        .into_iter()
1909        .map(UserId::from_proto)
1910        .collect();
1911    let users = session
1912        .db()
1913        .await
1914        .get_users_by_ids(user_ids)
1915        .await?
1916        .into_iter()
1917        .map(|user| proto::User {
1918            id: user.id.to_proto(),
1919            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1920            github_login: user.github_login,
1921        })
1922        .collect();
1923    response.send(proto::UsersResponse { users })?;
1924    Ok(())
1925}
1926
1927async fn fuzzy_search_users(
1928    request: proto::FuzzySearchUsers,
1929    response: Response<proto::FuzzySearchUsers>,
1930    session: Session,
1931) -> Result<()> {
1932    let query = request.query;
1933    let users = match query.len() {
1934        0 => vec![],
1935        1 | 2 => session
1936            .db()
1937            .await
1938            .get_user_by_github_login(&query)
1939            .await?
1940            .into_iter()
1941            .collect(),
1942        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1943    };
1944    let users = users
1945        .into_iter()
1946        .filter(|user| user.id != session.user_id)
1947        .map(|user| proto::User {
1948            id: user.id.to_proto(),
1949            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1950            github_login: user.github_login,
1951        })
1952        .collect();
1953    response.send(proto::UsersResponse { users })?;
1954    Ok(())
1955}
1956
1957async fn request_contact(
1958    request: proto::RequestContact,
1959    response: Response<proto::RequestContact>,
1960    session: Session,
1961) -> Result<()> {
1962    let requester_id = session.user_id;
1963    let responder_id = UserId::from_proto(request.responder_id);
1964    if requester_id == responder_id {
1965        return Err(anyhow!("cannot add yourself as a contact"))?;
1966    }
1967
1968    session
1969        .db()
1970        .await
1971        .send_contact_request(requester_id, responder_id)
1972        .await?;
1973
1974    // Update outgoing contact requests of requester
1975    let mut update = proto::UpdateContacts::default();
1976    update.outgoing_requests.push(responder_id.to_proto());
1977    for connection_id in session
1978        .connection_pool()
1979        .await
1980        .user_connection_ids(requester_id)
1981    {
1982        session.peer.send(connection_id, update.clone())?;
1983    }
1984
1985    // Update incoming contact requests of responder
1986    let mut update = proto::UpdateContacts::default();
1987    update
1988        .incoming_requests
1989        .push(proto::IncomingContactRequest {
1990            requester_id: requester_id.to_proto(),
1991            should_notify: true,
1992        });
1993    for connection_id in session
1994        .connection_pool()
1995        .await
1996        .user_connection_ids(responder_id)
1997    {
1998        session.peer.send(connection_id, update.clone())?;
1999    }
2000
2001    response.send(proto::Ack {})?;
2002    Ok(())
2003}
2004
2005async fn respond_to_contact_request(
2006    request: proto::RespondToContactRequest,
2007    response: Response<proto::RespondToContactRequest>,
2008    session: Session,
2009) -> Result<()> {
2010    let responder_id = session.user_id;
2011    let requester_id = UserId::from_proto(request.requester_id);
2012    let db = session.db().await;
2013    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2014        db.dismiss_contact_notification(responder_id, requester_id)
2015            .await?;
2016    } else {
2017        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2018
2019        db.respond_to_contact_request(responder_id, requester_id, accept)
2020            .await?;
2021        let requester_busy = db.is_user_busy(requester_id).await?;
2022        let responder_busy = db.is_user_busy(responder_id).await?;
2023
2024        let pool = session.connection_pool().await;
2025        // Update responder with new contact
2026        let mut update = proto::UpdateContacts::default();
2027        if accept {
2028            update
2029                .contacts
2030                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2031        }
2032        update
2033            .remove_incoming_requests
2034            .push(requester_id.to_proto());
2035        for connection_id in pool.user_connection_ids(responder_id) {
2036            session.peer.send(connection_id, update.clone())?;
2037        }
2038
2039        // Update requester with new contact
2040        let mut update = proto::UpdateContacts::default();
2041        if accept {
2042            update
2043                .contacts
2044                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2045        }
2046        update
2047            .remove_outgoing_requests
2048            .push(responder_id.to_proto());
2049        for connection_id in pool.user_connection_ids(requester_id) {
2050            session.peer.send(connection_id, update.clone())?;
2051        }
2052    }
2053
2054    response.send(proto::Ack {})?;
2055    Ok(())
2056}
2057
2058async fn remove_contact(
2059    request: proto::RemoveContact,
2060    response: Response<proto::RemoveContact>,
2061    session: Session,
2062) -> Result<()> {
2063    let requester_id = session.user_id;
2064    let responder_id = UserId::from_proto(request.user_id);
2065    let db = session.db().await;
2066    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2067
2068    let pool = session.connection_pool().await;
2069    // Update outgoing contact requests of requester
2070    let mut update = proto::UpdateContacts::default();
2071    if contact_accepted {
2072        update.remove_contacts.push(responder_id.to_proto());
2073    } else {
2074        update
2075            .remove_outgoing_requests
2076            .push(responder_id.to_proto());
2077    }
2078    for connection_id in pool.user_connection_ids(requester_id) {
2079        session.peer.send(connection_id, update.clone())?;
2080    }
2081
2082    // Update incoming contact requests of responder
2083    let mut update = proto::UpdateContacts::default();
2084    if contact_accepted {
2085        update.remove_contacts.push(requester_id.to_proto());
2086    } else {
2087        update
2088            .remove_incoming_requests
2089            .push(requester_id.to_proto());
2090    }
2091    for connection_id in pool.user_connection_ids(responder_id) {
2092        session.peer.send(connection_id, update.clone())?;
2093    }
2094
2095    response.send(proto::Ack {})?;
2096    Ok(())
2097}
2098
2099async fn create_channel(
2100    request: proto::CreateChannel,
2101    response: Response<proto::CreateChannel>,
2102    session: Session,
2103) -> Result<()> {
2104    let db = session.db().await;
2105    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2106
2107    if let Some(live_kit) = session.live_kit_client.as_ref() {
2108        live_kit.create_room(live_kit_room.clone()).await?;
2109    }
2110
2111    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2112    let id = db
2113        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2114        .await?;
2115
2116    response.send(proto::CreateChannelResponse {
2117        channel_id: id.to_proto(),
2118    })?;
2119
2120    let mut update = proto::UpdateChannels::default();
2121    update.channels.push(proto::Channel {
2122        id: id.to_proto(),
2123        name: request.name,
2124        parent_id: request.parent_id,
2125    });
2126
2127    if let Some(parent_id) = parent_id {
2128        let member_ids = db.get_channel_members(parent_id).await?;
2129        let connection_pool = session.connection_pool().await;
2130        for member_id in member_ids {
2131            for connection_id in connection_pool.user_connection_ids(member_id) {
2132                session.peer.send(connection_id, update.clone())?;
2133            }
2134        }
2135    } else {
2136        session.peer.send(session.connection_id, update)?;
2137    }
2138
2139    Ok(())
2140}
2141
2142async fn remove_channel(
2143    request: proto::RemoveChannel,
2144    response: Response<proto::RemoveChannel>,
2145    session: Session,
2146) -> Result<()> {
2147    let db = session.db().await;
2148
2149    let channel_id = request.channel_id;
2150    let (removed_channels, member_ids) = db
2151        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2152        .await?;
2153    response.send(proto::Ack {})?;
2154
2155    // Notify members of removed channels
2156    let mut update = proto::UpdateChannels::default();
2157    update
2158        .remove_channels
2159        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2160
2161    let connection_pool = session.connection_pool().await;
2162    for member_id in member_ids {
2163        for connection_id in connection_pool.user_connection_ids(member_id) {
2164            session.peer.send(connection_id, update.clone())?;
2165        }
2166    }
2167
2168    Ok(())
2169}
2170
2171async fn invite_channel_member(
2172    request: proto::InviteChannelMember,
2173    response: Response<proto::InviteChannelMember>,
2174    session: Session,
2175) -> Result<()> {
2176    let db = session.db().await;
2177    let channel_id = ChannelId::from_proto(request.channel_id);
2178    let channel = db
2179        .get_channel(channel_id)
2180        .await?
2181        .ok_or_else(|| anyhow!("channel not found"))?;
2182    let invitee_id = UserId::from_proto(request.user_id);
2183    db.invite_channel_member(channel_id, invitee_id, session.user_id, false)
2184        .await?;
2185
2186    let mut update = proto::UpdateChannels::default();
2187    update.channel_invitations.push(proto::Channel {
2188        id: channel.id.to_proto(),
2189        name: channel.name,
2190        parent_id: None,
2191    });
2192    for connection_id in session
2193        .connection_pool()
2194        .await
2195        .user_connection_ids(invitee_id)
2196    {
2197        session.peer.send(connection_id, update.clone())?;
2198    }
2199
2200    response.send(proto::Ack {})?;
2201    Ok(())
2202}
2203
2204async fn remove_channel_member(
2205    request: proto::RemoveChannelMember,
2206    response: Response<proto::RemoveChannelMember>,
2207    session: Session,
2208) -> Result<()> {
2209    Ok(())
2210}
2211
2212async fn respond_to_channel_invite(
2213    request: proto::RespondToChannelInvite,
2214    response: Response<proto::RespondToChannelInvite>,
2215    session: Session,
2216) -> Result<()> {
2217    let db = session.db().await;
2218    let channel_id = ChannelId::from_proto(request.channel_id);
2219    let channel = db
2220        .get_channel(channel_id)
2221        .await?
2222        .ok_or_else(|| anyhow!("no such channel"))?;
2223    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2224        .await?;
2225
2226    let mut update = proto::UpdateChannels::default();
2227    update
2228        .remove_channel_invitations
2229        .push(channel_id.to_proto());
2230    if request.accept {
2231        update.channels.push(proto::Channel {
2232            id: channel.id.to_proto(),
2233            name: channel.name,
2234            parent_id: None,
2235        });
2236    }
2237    session.peer.send(session.connection_id, update)?;
2238    response.send(proto::Ack {})?;
2239
2240    Ok(())
2241}
2242
2243async fn join_channel(
2244    request: proto::JoinChannel,
2245    response: Response<proto::JoinChannel>,
2246    session: Session,
2247) -> Result<()> {
2248    let channel_id = ChannelId::from_proto(request.channel_id);
2249
2250    {
2251        let db = session.db().await;
2252        let room_id = db.get_channel_room(channel_id).await?;
2253
2254        let room = db
2255            .join_room(
2256                room_id,
2257                session.user_id,
2258                Some(channel_id),
2259                session.connection_id,
2260            )
2261            .await?;
2262
2263        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2264            let token = live_kit
2265                .room_token(&room.live_kit_room, &session.user_id.to_string())
2266                .trace_err()?;
2267
2268            Some(LiveKitConnectionInfo {
2269                server_url: live_kit.url().into(),
2270                token,
2271            })
2272        });
2273
2274        response.send(proto::JoinRoomResponse {
2275            room: Some(room.clone()),
2276            live_kit_connection_info,
2277        })?;
2278
2279        room_updated(&room, &session.peer);
2280    }
2281
2282    update_user_contacts(session.user_id, &session).await?;
2283
2284    Ok(())
2285}
2286
2287async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2288    let project_id = ProjectId::from_proto(request.project_id);
2289    let project_connection_ids = session
2290        .db()
2291        .await
2292        .project_connection_ids(project_id, session.connection_id)
2293        .await?;
2294    broadcast(
2295        Some(session.connection_id),
2296        project_connection_ids.iter().copied(),
2297        |connection_id| {
2298            session
2299                .peer
2300                .forward_send(session.connection_id, connection_id, request.clone())
2301        },
2302    );
2303    Ok(())
2304}
2305
2306async fn get_private_user_info(
2307    _request: proto::GetPrivateUserInfo,
2308    response: Response<proto::GetPrivateUserInfo>,
2309    session: Session,
2310) -> Result<()> {
2311    let metrics_id = session
2312        .db()
2313        .await
2314        .get_user_metrics_id(session.user_id)
2315        .await?;
2316    let user = session
2317        .db()
2318        .await
2319        .get_user_by_id(session.user_id)
2320        .await?
2321        .ok_or_else(|| anyhow!("user not found"))?;
2322    response.send(proto::GetPrivateUserInfoResponse {
2323        metrics_id,
2324        staff: user.admin,
2325    })?;
2326    Ok(())
2327}
2328
2329fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2330    match message {
2331        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2332        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2333        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2334        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2335        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2336            code: frame.code.into(),
2337            reason: frame.reason,
2338        })),
2339    }
2340}
2341
2342fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2343    match message {
2344        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2345        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2346        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2347        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2348        AxumMessage::Close(frame) => {
2349            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2350                code: frame.code.into(),
2351                reason: frame.reason,
2352            }))
2353        }
2354    }
2355}
2356
2357fn build_initial_channels_update(
2358    channels: Vec<db::Channel>,
2359    channel_invites: Vec<db::Channel>,
2360) -> proto::UpdateChannels {
2361    let mut update = proto::UpdateChannels::default();
2362
2363    for channel in channels {
2364        update.channels.push(proto::Channel {
2365            id: channel.id.to_proto(),
2366            name: channel.name,
2367            parent_id: channel.parent_id.map(|id| id.to_proto()),
2368        });
2369    }
2370
2371    for channel in channel_invites {
2372        update.channel_invitations.push(proto::Channel {
2373            id: channel.id.to_proto(),
2374            name: channel.name,
2375            parent_id: None,
2376        });
2377    }
2378
2379    update
2380}
2381
2382fn build_initial_contacts_update(
2383    contacts: Vec<db::Contact>,
2384    pool: &ConnectionPool,
2385) -> proto::UpdateContacts {
2386    let mut update = proto::UpdateContacts::default();
2387
2388    for contact in contacts {
2389        match contact {
2390            db::Contact::Accepted {
2391                user_id,
2392                should_notify,
2393                busy,
2394            } => {
2395                update
2396                    .contacts
2397                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2398            }
2399            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2400            db::Contact::Incoming {
2401                user_id,
2402                should_notify,
2403            } => update
2404                .incoming_requests
2405                .push(proto::IncomingContactRequest {
2406                    requester_id: user_id.to_proto(),
2407                    should_notify,
2408                }),
2409        }
2410    }
2411
2412    update
2413}
2414
2415fn contact_for_user(
2416    user_id: UserId,
2417    should_notify: bool,
2418    busy: bool,
2419    pool: &ConnectionPool,
2420) -> proto::Contact {
2421    proto::Contact {
2422        user_id: user_id.to_proto(),
2423        online: pool.is_user_online(user_id),
2424        busy,
2425        should_notify,
2426    }
2427}
2428
2429fn room_updated(room: &proto::Room, peer: &Peer) {
2430    broadcast(
2431        None,
2432        room.participants
2433            .iter()
2434            .filter_map(|participant| Some(participant.peer_id?.into())),
2435        |peer_id| {
2436            peer.send(
2437                peer_id.into(),
2438                proto::RoomUpdated {
2439                    room: Some(room.clone()),
2440                },
2441            )
2442        },
2443    );
2444}
2445
2446async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2447    let db = session.db().await;
2448    let contacts = db.get_contacts(user_id).await?;
2449    let busy = db.is_user_busy(user_id).await?;
2450
2451    let pool = session.connection_pool().await;
2452    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2453    for contact in contacts {
2454        if let db::Contact::Accepted {
2455            user_id: contact_user_id,
2456            ..
2457        } = contact
2458        {
2459            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2460                session
2461                    .peer
2462                    .send(
2463                        contact_conn_id,
2464                        proto::UpdateContacts {
2465                            contacts: vec![updated_contact.clone()],
2466                            remove_contacts: Default::default(),
2467                            incoming_requests: Default::default(),
2468                            remove_incoming_requests: Default::default(),
2469                            outgoing_requests: Default::default(),
2470                            remove_outgoing_requests: Default::default(),
2471                        },
2472                    )
2473                    .trace_err();
2474            }
2475        }
2476    }
2477    Ok(())
2478}
2479
2480async fn leave_room_for_session(session: &Session) -> Result<()> {
2481    let mut contacts_to_update = HashSet::default();
2482
2483    let room_id;
2484    let canceled_calls_to_user_ids;
2485    let live_kit_room;
2486    let delete_live_kit_room;
2487    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2488        contacts_to_update.insert(session.user_id);
2489
2490        for project in left_room.left_projects.values() {
2491            project_left(project, session);
2492        }
2493
2494        room_updated(&left_room.room, &session.peer);
2495        room_id = RoomId::from_proto(left_room.room.id);
2496        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2497        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2498        delete_live_kit_room = left_room.deleted;
2499    } else {
2500        return Ok(());
2501    }
2502
2503    {
2504        let pool = session.connection_pool().await;
2505        for canceled_user_id in canceled_calls_to_user_ids {
2506            for connection_id in pool.user_connection_ids(canceled_user_id) {
2507                session
2508                    .peer
2509                    .send(
2510                        connection_id,
2511                        proto::CallCanceled {
2512                            room_id: room_id.to_proto(),
2513                        },
2514                    )
2515                    .trace_err();
2516            }
2517            contacts_to_update.insert(canceled_user_id);
2518        }
2519    }
2520
2521    for contact_user_id in contacts_to_update {
2522        update_user_contacts(contact_user_id, &session).await?;
2523    }
2524
2525    if let Some(live_kit) = session.live_kit_client.as_ref() {
2526        live_kit
2527            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2528            .await
2529            .trace_err();
2530
2531        if delete_live_kit_room {
2532            live_kit.delete_room(live_kit_room).await.trace_err();
2533        }
2534    }
2535
2536    Ok(())
2537}
2538
2539fn project_left(project: &db::LeftProject, session: &Session) {
2540    for connection_id in &project.connection_ids {
2541        if project.host_user_id == session.user_id {
2542            session
2543                .peer
2544                .send(
2545                    *connection_id,
2546                    proto::UnshareProject {
2547                        project_id: project.id.to_proto(),
2548                    },
2549                )
2550                .trace_err();
2551        } else {
2552            session
2553                .peer
2554                .send(
2555                    *connection_id,
2556                    proto::RemoveProjectCollaborator {
2557                        project_id: project.id.to_proto(),
2558                        peer_id: Some(session.connection_id.into()),
2559                    },
2560                )
2561                .trace_err();
2562        }
2563    }
2564}
2565
2566pub trait ResultExt {
2567    type Ok;
2568
2569    fn trace_err(self) -> Option<Self::Ok>;
2570}
2571
2572impl<T, E> ResultExt for Result<T, E>
2573where
2574    E: std::fmt::Debug,
2575{
2576    type Ok = T;
2577
2578    fn trace_err(self) -> Option<T> {
2579        match self {
2580            Ok(value) => Some(value),
2581            Err(error) => {
2582                tracing::error!("{:?}", error);
2583                None
2584            }
2585        }
2586    }
2587}