rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::{Mutex, MutexGuard};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECTION_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  66        "shared_projects",
  67        "number of open projects with one or more guests"
  68    )
  69    .unwrap();
  70}
  71
  72type MessageHandler =
  73    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  74
  75struct Response<R> {
  76    peer: Arc<Peer>,
  77    receipt: Receipt<R>,
  78    responded: Arc<AtomicBool>,
  79}
  80
  81impl<R: RequestMessage> Response<R> {
  82    fn send(self, payload: R::Response) -> Result<()> {
  83        self.responded.store(true, SeqCst);
  84        self.peer.respond(self.receipt, payload)?;
  85        Ok(())
  86    }
  87}
  88
  89#[derive(Clone)]
  90struct Session {
  91    user_id: UserId,
  92    connection_id: ConnectionId,
  93    db: Arc<Mutex<DbHandle>>,
  94    peer: Arc<Peer>,
  95    connection_pool: Arc<Mutex<ConnectionPool>>,
  96    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  97}
  98
  99impl Session {
 100    async fn db(&self) -> MutexGuard<DbHandle> {
 101        #[cfg(test)]
 102        tokio::task::yield_now().await;
 103        let guard = self.db.lock().await;
 104        #[cfg(test)]
 105        tokio::task::yield_now().await;
 106        guard
 107    }
 108
 109    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 110        #[cfg(test)]
 111        tokio::task::yield_now().await;
 112        let guard = self.connection_pool.lock().await;
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        ConnectionPoolGuard {
 116            guard,
 117            _not_send: PhantomData,
 118        }
 119    }
 120}
 121
 122impl fmt::Debug for Session {
 123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 124        f.debug_struct("Session")
 125            .field("user_id", &self.user_id)
 126            .field("connection_id", &self.connection_id)
 127            .finish()
 128    }
 129}
 130
 131struct DbHandle(Arc<Database>);
 132
 133impl Deref for DbHandle {
 134    type Target = Database;
 135
 136    fn deref(&self) -> &Self::Target {
 137        self.0.as_ref()
 138    }
 139}
 140
 141pub struct Server {
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    handlers: HashMap<TypeId, MessageHandler>,
 146}
 147
 148pub(crate) struct ConnectionPoolGuard<'a> {
 149    guard: MutexGuard<'a, ConnectionPool>,
 150    _not_send: PhantomData<Rc<()>>,
 151}
 152
 153#[derive(Serialize)]
 154pub struct ServerSnapshot<'a> {
 155    peer: &'a Peer,
 156    #[serde(serialize_with = "serialize_deref")]
 157    connection_pool: ConnectionPoolGuard<'a>,
 158}
 159
 160pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 161where
 162    S: Serializer,
 163    T: Deref<Target = U>,
 164    U: Serialize,
 165{
 166    Serialize::serialize(value.deref(), serializer)
 167}
 168
 169impl Server {
 170    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 171        let mut server = Self {
 172            peer: Peer::new(),
 173            app_state,
 174            connection_pool: Default::default(),
 175            handlers: Default::default(),
 176        };
 177
 178        server
 179            .add_request_handler(ping)
 180            .add_request_handler(create_room)
 181            .add_request_handler(join_room)
 182            .add_message_handler(leave_room)
 183            .add_request_handler(call)
 184            .add_request_handler(cancel_call)
 185            .add_message_handler(decline_call)
 186            .add_request_handler(update_participant_location)
 187            .add_request_handler(share_project)
 188            .add_message_handler(unshare_project)
 189            .add_request_handler(join_project)
 190            .add_message_handler(leave_project)
 191            .add_request_handler(update_project)
 192            .add_request_handler(update_worktree)
 193            .add_message_handler(start_language_server)
 194            .add_message_handler(update_language_server)
 195            .add_message_handler(update_diagnostic_summary)
 196            .add_request_handler(forward_project_request::<proto::GetHover>)
 197            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 198            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 199            .add_request_handler(forward_project_request::<proto::GetReferences>)
 200            .add_request_handler(forward_project_request::<proto::SearchProject>)
 201            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 202            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 203            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 204            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 205            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 206            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 207            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 208            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 209            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 210            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 211            .add_request_handler(forward_project_request::<proto::PerformRename>)
 212            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 213            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 214            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 215            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 216            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 217            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 218            .add_message_handler(create_buffer_for_peer)
 219            .add_request_handler(update_buffer)
 220            .add_message_handler(update_buffer_file)
 221            .add_message_handler(buffer_reloaded)
 222            .add_message_handler(buffer_saved)
 223            .add_request_handler(save_buffer)
 224            .add_request_handler(get_users)
 225            .add_request_handler(fuzzy_search_users)
 226            .add_request_handler(request_contact)
 227            .add_request_handler(remove_contact)
 228            .add_request_handler(respond_to_contact_request)
 229            .add_request_handler(follow)
 230            .add_message_handler(unfollow)
 231            .add_message_handler(update_followers)
 232            .add_message_handler(update_diff_base)
 233            .add_request_handler(get_private_user_info);
 234
 235        Arc::new(server)
 236    }
 237
 238    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 239    where
 240        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 241        Fut: 'static + Send + Future<Output = Result<()>>,
 242        M: EnvelopedMessage,
 243    {
 244        let prev_handler = self.handlers.insert(
 245            TypeId::of::<M>(),
 246            Box::new(move |envelope, session| {
 247                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 248                let span = info_span!(
 249                    "handle message",
 250                    payload_type = envelope.payload_type_name()
 251                );
 252                span.in_scope(|| {
 253                    tracing::info!(
 254                        payload_type = envelope.payload_type_name(),
 255                        "message received"
 256                    );
 257                });
 258                let future = (handler)(*envelope, session);
 259                async move {
 260                    if let Err(error) = future.await {
 261                        tracing::error!(%error, "error handling message");
 262                    }
 263                }
 264                .instrument(span)
 265                .boxed()
 266            }),
 267        );
 268        if prev_handler.is_some() {
 269            panic!("registered a handler for the same message twice");
 270        }
 271        self
 272    }
 273
 274    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 275    where
 276        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 277        Fut: 'static + Send + Future<Output = Result<()>>,
 278        M: EnvelopedMessage,
 279    {
 280        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 281        self
 282    }
 283
 284    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 285    where
 286        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 287        Fut: Send + Future<Output = Result<()>>,
 288        M: RequestMessage,
 289    {
 290        let handler = Arc::new(handler);
 291        self.add_handler(move |envelope, session| {
 292            let receipt = envelope.receipt();
 293            let handler = handler.clone();
 294            async move {
 295                let peer = session.peer.clone();
 296                let responded = Arc::new(AtomicBool::default());
 297                let response = Response {
 298                    peer: peer.clone(),
 299                    responded: responded.clone(),
 300                    receipt,
 301                };
 302                match (handler)(envelope.payload, response, session).await {
 303                    Ok(()) => {
 304                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 305                            Ok(())
 306                        } else {
 307                            Err(anyhow!("handler did not send a response"))?
 308                        }
 309                    }
 310                    Err(error) => {
 311                        peer.respond_with_error(
 312                            receipt,
 313                            proto::Error {
 314                                message: error.to_string(),
 315                            },
 316                        )?;
 317                        Err(error)
 318                    }
 319                }
 320            }
 321        })
 322    }
 323
 324    pub fn handle_connection(
 325        self: &Arc<Self>,
 326        connection: Connection,
 327        address: String,
 328        user: User,
 329        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 330        executor: Executor,
 331    ) -> impl Future<Output = Result<()>> {
 332        let this = self.clone();
 333        let user_id = user.id;
 334        let login = user.github_login;
 335        let span = info_span!("handle connection", %user_id, %login, %address);
 336        async move {
 337            let (connection_id, handle_io, mut incoming_rx) = this
 338                .peer
 339                .add_connection(connection, {
 340                    let executor = executor.clone();
 341                    move |duration| executor.sleep(duration)
 342                });
 343
 344            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 345            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 346            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 347
 348            if let Some(send_connection_id) = send_connection_id.take() {
 349                let _ = send_connection_id.send(connection_id);
 350            }
 351
 352            if !user.connected_once {
 353                this.peer.send(connection_id, proto::ShowContacts {})?;
 354                this.app_state.db.set_user_connected_once(user_id, true).await?;
 355            }
 356
 357            let (contacts, invite_code) = future::try_join(
 358                this.app_state.db.get_contacts(user_id),
 359                this.app_state.db.get_invite_code_for_user(user_id)
 360            ).await?;
 361
 362            {
 363                let mut pool = this.connection_pool.lock().await;
 364                pool.add_connection(connection_id, user_id, user.admin);
 365                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 366
 367                if let Some((code, count)) = invite_code {
 368                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 369                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 370                        count: count as u32,
 371                    })?;
 372                }
 373            }
 374
 375            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 376                this.peer.send(connection_id, incoming_call)?;
 377            }
 378
 379            let session = Session {
 380                user_id,
 381                connection_id,
 382                db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
 383                peer: this.peer.clone(),
 384                connection_pool: this.connection_pool.clone(),
 385                live_kit_client: this.app_state.live_kit_client.clone()
 386            };
 387            update_user_contacts(user_id, &session).await?;
 388
 389            let handle_io = handle_io.fuse();
 390            futures::pin_mut!(handle_io);
 391
 392            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 393            // This prevents deadlocks when e.g., client A performs a request to client B and
 394            // client B performs a request to client A. If both clients stop processing further
 395            // messages until their respective request completes, they won't have a chance to
 396            // respond to the other client's request and cause a deadlock.
 397            //
 398            // This arrangement ensures we will attempt to process earlier messages first, but fall
 399            // back to processing messages arrived later in the spirit of making progress.
 400            let mut foreground_message_handlers = FuturesUnordered::new();
 401            loop {
 402                let next_message = incoming_rx.next().fuse();
 403                futures::pin_mut!(next_message);
 404                futures::select_biased! {
 405                    result = handle_io => {
 406                        if let Err(error) = result {
 407                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 408                        }
 409                        break;
 410                    }
 411                    _ = foreground_message_handlers.next() => {}
 412                    message = next_message => {
 413                        if let Some(message) = message {
 414                            let type_name = message.payload_type_name();
 415                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 416                            let span_enter = span.enter();
 417                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 418                                let is_background = message.is_background();
 419                                let handle_message = (handler)(message, session.clone());
 420                                drop(span_enter);
 421
 422                                let handle_message = handle_message.instrument(span);
 423                                if is_background {
 424                                    executor.spawn_detached(handle_message);
 425                                } else {
 426                                    foreground_message_handlers.push(handle_message);
 427                                }
 428                            } else {
 429                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 430                            }
 431                        } else {
 432                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 433                            break;
 434                        }
 435                    }
 436                }
 437            }
 438
 439            drop(foreground_message_handlers);
 440            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 441            if let Err(error) = sign_out(session, executor).await {
 442                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 443            }
 444
 445            Ok(())
 446        }.instrument(span)
 447    }
 448
 449    pub async fn invite_code_redeemed(
 450        self: &Arc<Self>,
 451        inviter_id: UserId,
 452        invitee_id: UserId,
 453    ) -> Result<()> {
 454        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 455            if let Some(code) = &user.invite_code {
 456                let pool = self.connection_pool.lock().await;
 457                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 458                for connection_id in pool.user_connection_ids(inviter_id) {
 459                    self.peer.send(
 460                        connection_id,
 461                        proto::UpdateContacts {
 462                            contacts: vec![invitee_contact.clone()],
 463                            ..Default::default()
 464                        },
 465                    )?;
 466                    self.peer.send(
 467                        connection_id,
 468                        proto::UpdateInviteInfo {
 469                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 470                            count: user.invite_count as u32,
 471                        },
 472                    )?;
 473                }
 474            }
 475        }
 476        Ok(())
 477    }
 478
 479    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 480        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 481            if let Some(invite_code) = &user.invite_code {
 482                let pool = self.connection_pool.lock().await;
 483                for connection_id in pool.user_connection_ids(user_id) {
 484                    self.peer.send(
 485                        connection_id,
 486                        proto::UpdateInviteInfo {
 487                            url: format!(
 488                                "{}{}",
 489                                self.app_state.config.invite_link_prefix, invite_code
 490                            ),
 491                            count: user.invite_count as u32,
 492                        },
 493                    )?;
 494                }
 495            }
 496        }
 497        Ok(())
 498    }
 499
 500    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 501        ServerSnapshot {
 502            connection_pool: ConnectionPoolGuard {
 503                guard: self.connection_pool.lock().await,
 504                _not_send: PhantomData,
 505            },
 506            peer: &self.peer,
 507        }
 508    }
 509}
 510
 511impl<'a> Deref for ConnectionPoolGuard<'a> {
 512    type Target = ConnectionPool;
 513
 514    fn deref(&self) -> &Self::Target {
 515        &*self.guard
 516    }
 517}
 518
 519impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 520    fn deref_mut(&mut self) -> &mut Self::Target {
 521        &mut *self.guard
 522    }
 523}
 524
 525impl<'a> Drop for ConnectionPoolGuard<'a> {
 526    fn drop(&mut self) {
 527        #[cfg(test)]
 528        self.check_invariants();
 529    }
 530}
 531
 532fn broadcast<F>(
 533    sender_id: ConnectionId,
 534    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 535    mut f: F,
 536) where
 537    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 538{
 539    for receiver_id in receiver_ids {
 540        if receiver_id != sender_id {
 541            f(receiver_id).trace_err();
 542        }
 543    }
 544}
 545
 546lazy_static! {
 547    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 548}
 549
 550pub struct ProtocolVersion(u32);
 551
 552impl Header for ProtocolVersion {
 553    fn name() -> &'static HeaderName {
 554        &ZED_PROTOCOL_VERSION
 555    }
 556
 557    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 558    where
 559        Self: Sized,
 560        I: Iterator<Item = &'i axum::http::HeaderValue>,
 561    {
 562        let version = values
 563            .next()
 564            .ok_or_else(axum::headers::Error::invalid)?
 565            .to_str()
 566            .map_err(|_| axum::headers::Error::invalid())?
 567            .parse()
 568            .map_err(|_| axum::headers::Error::invalid())?;
 569        Ok(Self(version))
 570    }
 571
 572    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 573        values.extend([self.0.to_string().parse().unwrap()]);
 574    }
 575}
 576
 577pub fn routes(server: Arc<Server>) -> Router<Body> {
 578    Router::new()
 579        .route("/rpc", get(handle_websocket_request))
 580        .layer(
 581            ServiceBuilder::new()
 582                .layer(Extension(server.app_state.clone()))
 583                .layer(middleware::from_fn(auth::validate_header)),
 584        )
 585        .route("/metrics", get(handle_metrics))
 586        .layer(Extension(server))
 587}
 588
 589pub async fn handle_websocket_request(
 590    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 591    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 592    Extension(server): Extension<Arc<Server>>,
 593    Extension(user): Extension<User>,
 594    ws: WebSocketUpgrade,
 595) -> axum::response::Response {
 596    if protocol_version != rpc::PROTOCOL_VERSION {
 597        return (
 598            StatusCode::UPGRADE_REQUIRED,
 599            "client must be upgraded".to_string(),
 600        )
 601            .into_response();
 602    }
 603    let socket_address = socket_address.to_string();
 604    ws.on_upgrade(move |socket| {
 605        use util::ResultExt;
 606        let socket = socket
 607            .map_ok(to_tungstenite_message)
 608            .err_into()
 609            .with(|message| async move { Ok(to_axum_message(message)) });
 610        let connection = Connection::new(Box::pin(socket));
 611        async move {
 612            server
 613                .handle_connection(connection, socket_address, user, None, Executor::Production)
 614                .await
 615                .log_err();
 616        }
 617    })
 618}
 619
 620pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 621    let connections = server
 622        .connection_pool
 623        .lock()
 624        .await
 625        .connections()
 626        .filter(|connection| !connection.admin)
 627        .count();
 628
 629    METRIC_CONNECTIONS.set(connections as _);
 630
 631    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 632    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 633
 634    let encoder = prometheus::TextEncoder::new();
 635    let metric_families = prometheus::gather();
 636    let encoded_metrics = encoder
 637        .encode_to_string(&metric_families)
 638        .map_err(|err| anyhow!("{}", err))?;
 639    Ok(encoded_metrics)
 640}
 641
 642#[instrument(err, skip(executor))]
 643async fn sign_out(session: Session, executor: Executor) -> Result<()> {
 644    session.peer.disconnect(session.connection_id);
 645    session
 646        .connection_pool()
 647        .await
 648        .remove_connection(session.connection_id)?;
 649
 650    if let Ok(mut left_projects) = session
 651        .db()
 652        .await
 653        .connection_lost(session.connection_id)
 654        .await
 655    {
 656        for left_project in mem::take(&mut *left_projects) {
 657            project_left(&left_project, &session);
 658        }
 659    }
 660
 661    executor.sleep(RECONNECTION_TIMEOUT).await;
 662    leave_room_for_session(&session).await.trace_err();
 663
 664    if !session
 665        .connection_pool()
 666        .await
 667        .is_user_online(session.user_id)
 668    {
 669        let db = session.db().await;
 670        if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 671            room_updated(&room, &session);
 672        }
 673    }
 674    update_user_contacts(session.user_id, &session).await?;
 675
 676    Ok(())
 677}
 678
 679async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 680    response.send(proto::Ack {})?;
 681    Ok(())
 682}
 683
 684async fn create_room(
 685    _request: proto::CreateRoom,
 686    response: Response<proto::CreateRoom>,
 687    session: Session,
 688) -> Result<()> {
 689    let live_kit_room = nanoid::nanoid!(30);
 690    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 691        if let Some(_) = live_kit
 692            .create_room(live_kit_room.clone())
 693            .await
 694            .trace_err()
 695        {
 696            if let Some(token) = live_kit
 697                .room_token(&live_kit_room, &session.connection_id.to_string())
 698                .trace_err()
 699            {
 700                Some(proto::LiveKitConnectionInfo {
 701                    server_url: live_kit.url().into(),
 702                    token,
 703                })
 704            } else {
 705                None
 706            }
 707        } else {
 708            None
 709        }
 710    } else {
 711        None
 712    };
 713
 714    {
 715        let room = session
 716            .db()
 717            .await
 718            .create_room(session.user_id, session.connection_id, &live_kit_room)
 719            .await?;
 720
 721        response.send(proto::CreateRoomResponse {
 722            room: Some(room.clone()),
 723            live_kit_connection_info,
 724        })?;
 725    }
 726
 727    update_user_contacts(session.user_id, &session).await?;
 728    Ok(())
 729}
 730
 731async fn join_room(
 732    request: proto::JoinRoom,
 733    response: Response<proto::JoinRoom>,
 734    session: Session,
 735) -> Result<()> {
 736    let room = {
 737        let room = session
 738            .db()
 739            .await
 740            .join_room(
 741                RoomId::from_proto(request.id),
 742                session.user_id,
 743                session.connection_id,
 744            )
 745            .await?;
 746        room_updated(&room, &session);
 747        room.clone()
 748    };
 749
 750    for connection_id in session
 751        .connection_pool()
 752        .await
 753        .user_connection_ids(session.user_id)
 754    {
 755        session
 756            .peer
 757            .send(connection_id, proto::CallCanceled {})
 758            .trace_err();
 759    }
 760
 761    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 762        if let Some(token) = live_kit
 763            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 764            .trace_err()
 765        {
 766            Some(proto::LiveKitConnectionInfo {
 767                server_url: live_kit.url().into(),
 768                token,
 769            })
 770        } else {
 771            None
 772        }
 773    } else {
 774        None
 775    };
 776
 777    response.send(proto::JoinRoomResponse {
 778        room: Some(room),
 779        live_kit_connection_info,
 780    })?;
 781
 782    update_user_contacts(session.user_id, &session).await?;
 783    Ok(())
 784}
 785
 786async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 787    leave_room_for_session(&session).await
 788}
 789
 790async fn call(
 791    request: proto::Call,
 792    response: Response<proto::Call>,
 793    session: Session,
 794) -> Result<()> {
 795    let room_id = RoomId::from_proto(request.room_id);
 796    let calling_user_id = session.user_id;
 797    let calling_connection_id = session.connection_id;
 798    let called_user_id = UserId::from_proto(request.called_user_id);
 799    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 800    if !session
 801        .db()
 802        .await
 803        .has_contact(calling_user_id, called_user_id)
 804        .await?
 805    {
 806        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 807    }
 808
 809    let incoming_call = {
 810        let (room, incoming_call) = &mut *session
 811            .db()
 812            .await
 813            .call(
 814                room_id,
 815                calling_user_id,
 816                calling_connection_id,
 817                called_user_id,
 818                initial_project_id,
 819            )
 820            .await?;
 821        room_updated(&room, &session);
 822        mem::take(incoming_call)
 823    };
 824    update_user_contacts(called_user_id, &session).await?;
 825
 826    let mut calls = session
 827        .connection_pool()
 828        .await
 829        .user_connection_ids(called_user_id)
 830        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 831        .collect::<FuturesUnordered<_>>();
 832
 833    while let Some(call_response) = calls.next().await {
 834        match call_response.as_ref() {
 835            Ok(_) => {
 836                response.send(proto::Ack {})?;
 837                return Ok(());
 838            }
 839            Err(_) => {
 840                call_response.trace_err();
 841            }
 842        }
 843    }
 844
 845    {
 846        let room = session
 847            .db()
 848            .await
 849            .call_failed(room_id, called_user_id)
 850            .await?;
 851        room_updated(&room, &session);
 852    }
 853    update_user_contacts(called_user_id, &session).await?;
 854
 855    Err(anyhow!("failed to ring user"))?
 856}
 857
 858async fn cancel_call(
 859    request: proto::CancelCall,
 860    response: Response<proto::CancelCall>,
 861    session: Session,
 862) -> Result<()> {
 863    let called_user_id = UserId::from_proto(request.called_user_id);
 864    let room_id = RoomId::from_proto(request.room_id);
 865    {
 866        let room = session
 867            .db()
 868            .await
 869            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 870            .await?;
 871        room_updated(&room, &session);
 872    }
 873
 874    for connection_id in session
 875        .connection_pool()
 876        .await
 877        .user_connection_ids(called_user_id)
 878    {
 879        session
 880            .peer
 881            .send(connection_id, proto::CallCanceled {})
 882            .trace_err();
 883    }
 884    response.send(proto::Ack {})?;
 885
 886    update_user_contacts(called_user_id, &session).await?;
 887    Ok(())
 888}
 889
 890async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
 891    let room_id = RoomId::from_proto(message.room_id);
 892    {
 893        let room = session
 894            .db()
 895            .await
 896            .decline_call(Some(room_id), session.user_id)
 897            .await?;
 898        room_updated(&room, &session);
 899    }
 900
 901    for connection_id in session
 902        .connection_pool()
 903        .await
 904        .user_connection_ids(session.user_id)
 905    {
 906        session
 907            .peer
 908            .send(connection_id, proto::CallCanceled {})
 909            .trace_err();
 910    }
 911    update_user_contacts(session.user_id, &session).await?;
 912    Ok(())
 913}
 914
 915async fn update_participant_location(
 916    request: proto::UpdateParticipantLocation,
 917    response: Response<proto::UpdateParticipantLocation>,
 918    session: Session,
 919) -> Result<()> {
 920    let room_id = RoomId::from_proto(request.room_id);
 921    let location = request
 922        .location
 923        .ok_or_else(|| anyhow!("invalid location"))?;
 924    let room = session
 925        .db()
 926        .await
 927        .update_room_participant_location(room_id, session.connection_id, location)
 928        .await?;
 929    room_updated(&room, &session);
 930    response.send(proto::Ack {})?;
 931    Ok(())
 932}
 933
 934async fn share_project(
 935    request: proto::ShareProject,
 936    response: Response<proto::ShareProject>,
 937    session: Session,
 938) -> Result<()> {
 939    let (project_id, room) = &*session
 940        .db()
 941        .await
 942        .share_project(
 943            RoomId::from_proto(request.room_id),
 944            session.connection_id,
 945            &request.worktrees,
 946        )
 947        .await?;
 948    response.send(proto::ShareProjectResponse {
 949        project_id: project_id.to_proto(),
 950    })?;
 951    room_updated(&room, &session);
 952
 953    Ok(())
 954}
 955
 956async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
 957    let project_id = ProjectId::from_proto(message.project_id);
 958
 959    let (room, guest_connection_ids) = &*session
 960        .db()
 961        .await
 962        .unshare_project(project_id, session.connection_id)
 963        .await?;
 964
 965    broadcast(
 966        session.connection_id,
 967        guest_connection_ids.iter().copied(),
 968        |conn_id| session.peer.send(conn_id, message.clone()),
 969    );
 970    room_updated(&room, &session);
 971
 972    Ok(())
 973}
 974
 975async fn join_project(
 976    request: proto::JoinProject,
 977    response: Response<proto::JoinProject>,
 978    session: Session,
 979) -> Result<()> {
 980    let project_id = ProjectId::from_proto(request.project_id);
 981    let guest_user_id = session.user_id;
 982
 983    tracing::info!(%project_id, "join project");
 984
 985    let (project, replica_id) = &mut *session
 986        .db()
 987        .await
 988        .join_project(project_id, session.connection_id)
 989        .await?;
 990
 991    let collaborators = project
 992        .collaborators
 993        .iter()
 994        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
 995        .map(|collaborator| proto::Collaborator {
 996            peer_id: collaborator.connection_id as u32,
 997            replica_id: collaborator.replica_id.0 as u32,
 998            user_id: collaborator.user_id.to_proto(),
 999        })
1000        .collect::<Vec<_>>();
1001    let worktrees = project
1002        .worktrees
1003        .iter()
1004        .map(|(id, worktree)| proto::WorktreeMetadata {
1005            id: *id,
1006            root_name: worktree.root_name.clone(),
1007            visible: worktree.visible,
1008            abs_path: worktree.abs_path.clone(),
1009        })
1010        .collect::<Vec<_>>();
1011
1012    for collaborator in &collaborators {
1013        session
1014            .peer
1015            .send(
1016                ConnectionId(collaborator.peer_id),
1017                proto::AddProjectCollaborator {
1018                    project_id: project_id.to_proto(),
1019                    collaborator: Some(proto::Collaborator {
1020                        peer_id: session.connection_id.0,
1021                        replica_id: replica_id.0 as u32,
1022                        user_id: guest_user_id.to_proto(),
1023                    }),
1024                },
1025            )
1026            .trace_err();
1027    }
1028
1029    // First, we send the metadata associated with each worktree.
1030    response.send(proto::JoinProjectResponse {
1031        worktrees: worktrees.clone(),
1032        replica_id: replica_id.0 as u32,
1033        collaborators: collaborators.clone(),
1034        language_servers: project.language_servers.clone(),
1035    })?;
1036
1037    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1038        #[cfg(any(test, feature = "test-support"))]
1039        const MAX_CHUNK_SIZE: usize = 2;
1040        #[cfg(not(any(test, feature = "test-support")))]
1041        const MAX_CHUNK_SIZE: usize = 256;
1042
1043        // Stream this worktree's entries.
1044        let message = proto::UpdateWorktree {
1045            project_id: project_id.to_proto(),
1046            worktree_id,
1047            abs_path: worktree.abs_path.clone(),
1048            root_name: worktree.root_name,
1049            updated_entries: worktree.entries,
1050            removed_entries: Default::default(),
1051            scan_id: worktree.scan_id,
1052            is_last_update: worktree.is_complete,
1053        };
1054        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1055            session.peer.send(session.connection_id, update.clone())?;
1056        }
1057
1058        // Stream this worktree's diagnostics.
1059        for summary in worktree.diagnostic_summaries {
1060            session.peer.send(
1061                session.connection_id,
1062                proto::UpdateDiagnosticSummary {
1063                    project_id: project_id.to_proto(),
1064                    worktree_id: worktree.id,
1065                    summary: Some(summary),
1066                },
1067            )?;
1068        }
1069    }
1070
1071    for language_server in &project.language_servers {
1072        session.peer.send(
1073            session.connection_id,
1074            proto::UpdateLanguageServer {
1075                project_id: project_id.to_proto(),
1076                language_server_id: language_server.id,
1077                variant: Some(
1078                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1079                        proto::LspDiskBasedDiagnosticsUpdated {},
1080                    ),
1081                ),
1082            },
1083        )?;
1084    }
1085
1086    Ok(())
1087}
1088
1089async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1090    let sender_id = session.connection_id;
1091    let project_id = ProjectId::from_proto(request.project_id);
1092
1093    let project = session
1094        .db()
1095        .await
1096        .leave_project(project_id, sender_id)
1097        .await?;
1098    tracing::info!(
1099        %project_id,
1100        host_user_id = %project.host_user_id,
1101        host_connection_id = %project.host_connection_id,
1102        "leave project"
1103    );
1104    project_left(&project, &session);
1105
1106    Ok(())
1107}
1108
1109async fn update_project(
1110    request: proto::UpdateProject,
1111    response: Response<proto::UpdateProject>,
1112    session: Session,
1113) -> Result<()> {
1114    let project_id = ProjectId::from_proto(request.project_id);
1115    let (room, guest_connection_ids) = &*session
1116        .db()
1117        .await
1118        .update_project(project_id, session.connection_id, &request.worktrees)
1119        .await?;
1120    broadcast(
1121        session.connection_id,
1122        guest_connection_ids.iter().copied(),
1123        |connection_id| {
1124            session
1125                .peer
1126                .forward_send(session.connection_id, connection_id, request.clone())
1127        },
1128    );
1129    room_updated(&room, &session);
1130    response.send(proto::Ack {})?;
1131
1132    Ok(())
1133}
1134
1135async fn update_worktree(
1136    request: proto::UpdateWorktree,
1137    response: Response<proto::UpdateWorktree>,
1138    session: Session,
1139) -> Result<()> {
1140    let guest_connection_ids = session
1141        .db()
1142        .await
1143        .update_worktree(&request, session.connection_id)
1144        .await?;
1145
1146    broadcast(
1147        session.connection_id,
1148        guest_connection_ids.iter().copied(),
1149        |connection_id| {
1150            session
1151                .peer
1152                .forward_send(session.connection_id, connection_id, request.clone())
1153        },
1154    );
1155    response.send(proto::Ack {})?;
1156    Ok(())
1157}
1158
1159async fn update_diagnostic_summary(
1160    message: proto::UpdateDiagnosticSummary,
1161    session: Session,
1162) -> Result<()> {
1163    let guest_connection_ids = session
1164        .db()
1165        .await
1166        .update_diagnostic_summary(&message, session.connection_id)
1167        .await?;
1168
1169    broadcast(
1170        session.connection_id,
1171        guest_connection_ids.iter().copied(),
1172        |connection_id| {
1173            session
1174                .peer
1175                .forward_send(session.connection_id, connection_id, message.clone())
1176        },
1177    );
1178
1179    Ok(())
1180}
1181
1182async fn start_language_server(
1183    request: proto::StartLanguageServer,
1184    session: Session,
1185) -> Result<()> {
1186    let guest_connection_ids = session
1187        .db()
1188        .await
1189        .start_language_server(&request, session.connection_id)
1190        .await?;
1191
1192    broadcast(
1193        session.connection_id,
1194        guest_connection_ids.iter().copied(),
1195        |connection_id| {
1196            session
1197                .peer
1198                .forward_send(session.connection_id, connection_id, request.clone())
1199        },
1200    );
1201    Ok(())
1202}
1203
1204async fn update_language_server(
1205    request: proto::UpdateLanguageServer,
1206    session: Session,
1207) -> Result<()> {
1208    let project_id = ProjectId::from_proto(request.project_id);
1209    let project_connection_ids = session
1210        .db()
1211        .await
1212        .project_connection_ids(project_id, session.connection_id)
1213        .await?;
1214    broadcast(
1215        session.connection_id,
1216        project_connection_ids.iter().copied(),
1217        |connection_id| {
1218            session
1219                .peer
1220                .forward_send(session.connection_id, connection_id, request.clone())
1221        },
1222    );
1223    Ok(())
1224}
1225
1226async fn forward_project_request<T>(
1227    request: T,
1228    response: Response<T>,
1229    session: Session,
1230) -> Result<()>
1231where
1232    T: EntityMessage + RequestMessage,
1233{
1234    let project_id = ProjectId::from_proto(request.remote_entity_id());
1235    let host_connection_id = {
1236        let collaborators = session
1237            .db()
1238            .await
1239            .project_collaborators(project_id, session.connection_id)
1240            .await?;
1241        ConnectionId(
1242            collaborators
1243                .iter()
1244                .find(|collaborator| collaborator.is_host)
1245                .ok_or_else(|| anyhow!("host not found"))?
1246                .connection_id as u32,
1247        )
1248    };
1249
1250    let payload = session
1251        .peer
1252        .forward_request(session.connection_id, host_connection_id, request)
1253        .await?;
1254
1255    response.send(payload)?;
1256    Ok(())
1257}
1258
1259async fn save_buffer(
1260    request: proto::SaveBuffer,
1261    response: Response<proto::SaveBuffer>,
1262    session: Session,
1263) -> Result<()> {
1264    let project_id = ProjectId::from_proto(request.project_id);
1265    let host_connection_id = {
1266        let collaborators = session
1267            .db()
1268            .await
1269            .project_collaborators(project_id, session.connection_id)
1270            .await?;
1271        let host = collaborators
1272            .iter()
1273            .find(|collaborator| collaborator.is_host)
1274            .ok_or_else(|| anyhow!("host not found"))?;
1275        ConnectionId(host.connection_id as u32)
1276    };
1277    let response_payload = session
1278        .peer
1279        .forward_request(session.connection_id, host_connection_id, request.clone())
1280        .await?;
1281
1282    let mut collaborators = session
1283        .db()
1284        .await
1285        .project_collaborators(project_id, session.connection_id)
1286        .await?;
1287    collaborators
1288        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1289    let project_connection_ids = collaborators
1290        .iter()
1291        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1292    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1293        session
1294            .peer
1295            .forward_send(host_connection_id, conn_id, response_payload.clone())
1296    });
1297    response.send(response_payload)?;
1298    Ok(())
1299}
1300
1301async fn create_buffer_for_peer(
1302    request: proto::CreateBufferForPeer,
1303    session: Session,
1304) -> Result<()> {
1305    session.peer.forward_send(
1306        session.connection_id,
1307        ConnectionId(request.peer_id),
1308        request,
1309    )?;
1310    Ok(())
1311}
1312
1313async fn update_buffer(
1314    request: proto::UpdateBuffer,
1315    response: Response<proto::UpdateBuffer>,
1316    session: Session,
1317) -> Result<()> {
1318    let project_id = ProjectId::from_proto(request.project_id);
1319    let project_connection_ids = session
1320        .db()
1321        .await
1322        .project_connection_ids(project_id, session.connection_id)
1323        .await?;
1324
1325    broadcast(
1326        session.connection_id,
1327        project_connection_ids.iter().copied(),
1328        |connection_id| {
1329            session
1330                .peer
1331                .forward_send(session.connection_id, connection_id, request.clone())
1332        },
1333    );
1334    response.send(proto::Ack {})?;
1335    Ok(())
1336}
1337
1338async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1339    let project_id = ProjectId::from_proto(request.project_id);
1340    let project_connection_ids = session
1341        .db()
1342        .await
1343        .project_connection_ids(project_id, session.connection_id)
1344        .await?;
1345
1346    broadcast(
1347        session.connection_id,
1348        project_connection_ids.iter().copied(),
1349        |connection_id| {
1350            session
1351                .peer
1352                .forward_send(session.connection_id, connection_id, request.clone())
1353        },
1354    );
1355    Ok(())
1356}
1357
1358async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1359    let project_id = ProjectId::from_proto(request.project_id);
1360    let project_connection_ids = session
1361        .db()
1362        .await
1363        .project_connection_ids(project_id, session.connection_id)
1364        .await?;
1365    broadcast(
1366        session.connection_id,
1367        project_connection_ids.iter().copied(),
1368        |connection_id| {
1369            session
1370                .peer
1371                .forward_send(session.connection_id, connection_id, request.clone())
1372        },
1373    );
1374    Ok(())
1375}
1376
1377async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1378    let project_id = ProjectId::from_proto(request.project_id);
1379    let project_connection_ids = session
1380        .db()
1381        .await
1382        .project_connection_ids(project_id, session.connection_id)
1383        .await?;
1384    broadcast(
1385        session.connection_id,
1386        project_connection_ids.iter().copied(),
1387        |connection_id| {
1388            session
1389                .peer
1390                .forward_send(session.connection_id, connection_id, request.clone())
1391        },
1392    );
1393    Ok(())
1394}
1395
1396async fn follow(
1397    request: proto::Follow,
1398    response: Response<proto::Follow>,
1399    session: Session,
1400) -> Result<()> {
1401    let project_id = ProjectId::from_proto(request.project_id);
1402    let leader_id = ConnectionId(request.leader_id);
1403    let follower_id = session.connection_id;
1404    {
1405        let project_connection_ids = session
1406            .db()
1407            .await
1408            .project_connection_ids(project_id, session.connection_id)
1409            .await?;
1410
1411        if !project_connection_ids.contains(&leader_id) {
1412            Err(anyhow!("no such peer"))?;
1413        }
1414    }
1415
1416    let mut response_payload = session
1417        .peer
1418        .forward_request(session.connection_id, leader_id, request)
1419        .await?;
1420    response_payload
1421        .views
1422        .retain(|view| view.leader_id != Some(follower_id.0));
1423    response.send(response_payload)?;
1424    Ok(())
1425}
1426
1427async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1428    let project_id = ProjectId::from_proto(request.project_id);
1429    let leader_id = ConnectionId(request.leader_id);
1430    let project_connection_ids = session
1431        .db()
1432        .await
1433        .project_connection_ids(project_id, session.connection_id)
1434        .await?;
1435    if !project_connection_ids.contains(&leader_id) {
1436        Err(anyhow!("no such peer"))?;
1437    }
1438    session
1439        .peer
1440        .forward_send(session.connection_id, leader_id, request)?;
1441    Ok(())
1442}
1443
1444async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1445    let project_id = ProjectId::from_proto(request.project_id);
1446    let project_connection_ids = session
1447        .db
1448        .lock()
1449        .await
1450        .project_connection_ids(project_id, session.connection_id)
1451        .await?;
1452
1453    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1454        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1455        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1456        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1457    });
1458    for follower_id in &request.follower_ids {
1459        let follower_id = ConnectionId(*follower_id);
1460        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1461            session
1462                .peer
1463                .forward_send(session.connection_id, follower_id, request.clone())?;
1464        }
1465    }
1466    Ok(())
1467}
1468
1469async fn get_users(
1470    request: proto::GetUsers,
1471    response: Response<proto::GetUsers>,
1472    session: Session,
1473) -> Result<()> {
1474    let user_ids = request
1475        .user_ids
1476        .into_iter()
1477        .map(UserId::from_proto)
1478        .collect();
1479    let users = session
1480        .db()
1481        .await
1482        .get_users_by_ids(user_ids)
1483        .await?
1484        .into_iter()
1485        .map(|user| proto::User {
1486            id: user.id.to_proto(),
1487            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1488            github_login: user.github_login,
1489        })
1490        .collect();
1491    response.send(proto::UsersResponse { users })?;
1492    Ok(())
1493}
1494
1495async fn fuzzy_search_users(
1496    request: proto::FuzzySearchUsers,
1497    response: Response<proto::FuzzySearchUsers>,
1498    session: Session,
1499) -> Result<()> {
1500    let query = request.query;
1501    let users = match query.len() {
1502        0 => vec![],
1503        1 | 2 => session
1504            .db()
1505            .await
1506            .get_user_by_github_account(&query, None)
1507            .await?
1508            .into_iter()
1509            .collect(),
1510        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1511    };
1512    let users = users
1513        .into_iter()
1514        .filter(|user| user.id != session.user_id)
1515        .map(|user| proto::User {
1516            id: user.id.to_proto(),
1517            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1518            github_login: user.github_login,
1519        })
1520        .collect();
1521    response.send(proto::UsersResponse { users })?;
1522    Ok(())
1523}
1524
1525async fn request_contact(
1526    request: proto::RequestContact,
1527    response: Response<proto::RequestContact>,
1528    session: Session,
1529) -> Result<()> {
1530    let requester_id = session.user_id;
1531    let responder_id = UserId::from_proto(request.responder_id);
1532    if requester_id == responder_id {
1533        return Err(anyhow!("cannot add yourself as a contact"))?;
1534    }
1535
1536    session
1537        .db()
1538        .await
1539        .send_contact_request(requester_id, responder_id)
1540        .await?;
1541
1542    // Update outgoing contact requests of requester
1543    let mut update = proto::UpdateContacts::default();
1544    update.outgoing_requests.push(responder_id.to_proto());
1545    for connection_id in session
1546        .connection_pool()
1547        .await
1548        .user_connection_ids(requester_id)
1549    {
1550        session.peer.send(connection_id, update.clone())?;
1551    }
1552
1553    // Update incoming contact requests of responder
1554    let mut update = proto::UpdateContacts::default();
1555    update
1556        .incoming_requests
1557        .push(proto::IncomingContactRequest {
1558            requester_id: requester_id.to_proto(),
1559            should_notify: true,
1560        });
1561    for connection_id in session
1562        .connection_pool()
1563        .await
1564        .user_connection_ids(responder_id)
1565    {
1566        session.peer.send(connection_id, update.clone())?;
1567    }
1568
1569    response.send(proto::Ack {})?;
1570    Ok(())
1571}
1572
1573async fn respond_to_contact_request(
1574    request: proto::RespondToContactRequest,
1575    response: Response<proto::RespondToContactRequest>,
1576    session: Session,
1577) -> Result<()> {
1578    let responder_id = session.user_id;
1579    let requester_id = UserId::from_proto(request.requester_id);
1580    let db = session.db().await;
1581    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1582        db.dismiss_contact_notification(responder_id, requester_id)
1583            .await?;
1584    } else {
1585        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1586
1587        db.respond_to_contact_request(responder_id, requester_id, accept)
1588            .await?;
1589        let requester_busy = db.is_user_busy(requester_id).await?;
1590        let responder_busy = db.is_user_busy(responder_id).await?;
1591
1592        let pool = session.connection_pool().await;
1593        // Update responder with new contact
1594        let mut update = proto::UpdateContacts::default();
1595        if accept {
1596            update
1597                .contacts
1598                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1599        }
1600        update
1601            .remove_incoming_requests
1602            .push(requester_id.to_proto());
1603        for connection_id in pool.user_connection_ids(responder_id) {
1604            session.peer.send(connection_id, update.clone())?;
1605        }
1606
1607        // Update requester with new contact
1608        let mut update = proto::UpdateContacts::default();
1609        if accept {
1610            update
1611                .contacts
1612                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1613        }
1614        update
1615            .remove_outgoing_requests
1616            .push(responder_id.to_proto());
1617        for connection_id in pool.user_connection_ids(requester_id) {
1618            session.peer.send(connection_id, update.clone())?;
1619        }
1620    }
1621
1622    response.send(proto::Ack {})?;
1623    Ok(())
1624}
1625
1626async fn remove_contact(
1627    request: proto::RemoveContact,
1628    response: Response<proto::RemoveContact>,
1629    session: Session,
1630) -> Result<()> {
1631    let requester_id = session.user_id;
1632    let responder_id = UserId::from_proto(request.user_id);
1633    let db = session.db().await;
1634    db.remove_contact(requester_id, responder_id).await?;
1635
1636    let pool = session.connection_pool().await;
1637    // Update outgoing contact requests of requester
1638    let mut update = proto::UpdateContacts::default();
1639    update
1640        .remove_outgoing_requests
1641        .push(responder_id.to_proto());
1642    for connection_id in pool.user_connection_ids(requester_id) {
1643        session.peer.send(connection_id, update.clone())?;
1644    }
1645
1646    // Update incoming contact requests of responder
1647    let mut update = proto::UpdateContacts::default();
1648    update
1649        .remove_incoming_requests
1650        .push(requester_id.to_proto());
1651    for connection_id in pool.user_connection_ids(responder_id) {
1652        session.peer.send(connection_id, update.clone())?;
1653    }
1654
1655    response.send(proto::Ack {})?;
1656    Ok(())
1657}
1658
1659async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1660    let project_id = ProjectId::from_proto(request.project_id);
1661    let project_connection_ids = session
1662        .db()
1663        .await
1664        .project_connection_ids(project_id, session.connection_id)
1665        .await?;
1666    broadcast(
1667        session.connection_id,
1668        project_connection_ids.iter().copied(),
1669        |connection_id| {
1670            session
1671                .peer
1672                .forward_send(session.connection_id, connection_id, request.clone())
1673        },
1674    );
1675    Ok(())
1676}
1677
1678async fn get_private_user_info(
1679    _request: proto::GetPrivateUserInfo,
1680    response: Response<proto::GetPrivateUserInfo>,
1681    session: Session,
1682) -> Result<()> {
1683    let metrics_id = session
1684        .db()
1685        .await
1686        .get_user_metrics_id(session.user_id)
1687        .await?;
1688    let user = session
1689        .db()
1690        .await
1691        .get_user_by_id(session.user_id)
1692        .await?
1693        .ok_or_else(|| anyhow!("user not found"))?;
1694    response.send(proto::GetPrivateUserInfoResponse {
1695        metrics_id,
1696        staff: user.admin,
1697    })?;
1698    Ok(())
1699}
1700
1701fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1702    match message {
1703        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1704        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1705        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1706        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1707        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1708            code: frame.code.into(),
1709            reason: frame.reason,
1710        })),
1711    }
1712}
1713
1714fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1715    match message {
1716        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1717        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1718        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1719        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1720        AxumMessage::Close(frame) => {
1721            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1722                code: frame.code.into(),
1723                reason: frame.reason,
1724            }))
1725        }
1726    }
1727}
1728
1729fn build_initial_contacts_update(
1730    contacts: Vec<db::Contact>,
1731    pool: &ConnectionPool,
1732) -> proto::UpdateContacts {
1733    let mut update = proto::UpdateContacts::default();
1734
1735    for contact in contacts {
1736        match contact {
1737            db::Contact::Accepted {
1738                user_id,
1739                should_notify,
1740                busy,
1741            } => {
1742                update
1743                    .contacts
1744                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1745            }
1746            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1747            db::Contact::Incoming {
1748                user_id,
1749                should_notify,
1750            } => update
1751                .incoming_requests
1752                .push(proto::IncomingContactRequest {
1753                    requester_id: user_id.to_proto(),
1754                    should_notify,
1755                }),
1756        }
1757    }
1758
1759    update
1760}
1761
1762fn contact_for_user(
1763    user_id: UserId,
1764    should_notify: bool,
1765    busy: bool,
1766    pool: &ConnectionPool,
1767) -> proto::Contact {
1768    proto::Contact {
1769        user_id: user_id.to_proto(),
1770        online: pool.is_user_online(user_id),
1771        busy,
1772        should_notify,
1773    }
1774}
1775
1776fn room_updated(room: &proto::Room, session: &Session) {
1777    for participant in &room.participants {
1778        session
1779            .peer
1780            .send(
1781                ConnectionId(participant.peer_id),
1782                proto::RoomUpdated {
1783                    room: Some(room.clone()),
1784                },
1785            )
1786            .trace_err();
1787    }
1788}
1789
1790async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1791    let db = session.db().await;
1792    let contacts = db.get_contacts(user_id).await?;
1793    let busy = db.is_user_busy(user_id).await?;
1794
1795    let pool = session.connection_pool().await;
1796    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1797    for contact in contacts {
1798        if let db::Contact::Accepted {
1799            user_id: contact_user_id,
1800            ..
1801        } = contact
1802        {
1803            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1804                session
1805                    .peer
1806                    .send(
1807                        contact_conn_id,
1808                        proto::UpdateContacts {
1809                            contacts: vec![updated_contact.clone()],
1810                            remove_contacts: Default::default(),
1811                            incoming_requests: Default::default(),
1812                            remove_incoming_requests: Default::default(),
1813                            outgoing_requests: Default::default(),
1814                            remove_outgoing_requests: Default::default(),
1815                        },
1816                    )
1817                    .trace_err();
1818            }
1819        }
1820    }
1821    Ok(())
1822}
1823
1824async fn leave_room_for_session(session: &Session) -> Result<()> {
1825    let mut contacts_to_update = HashSet::default();
1826
1827    let canceled_calls_to_user_ids;
1828    let live_kit_room;
1829    let delete_live_kit_room;
1830    {
1831        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1832        contacts_to_update.insert(session.user_id);
1833
1834        for project in left_room.left_projects.values() {
1835            project_left(project, session);
1836        }
1837
1838        room_updated(&left_room.room, &session);
1839        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1840        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1841        delete_live_kit_room = left_room.room.participants.is_empty();
1842    }
1843
1844    {
1845        let pool = session.connection_pool().await;
1846        for canceled_user_id in canceled_calls_to_user_ids {
1847            for connection_id in pool.user_connection_ids(canceled_user_id) {
1848                session
1849                    .peer
1850                    .send(connection_id, proto::CallCanceled {})
1851                    .trace_err();
1852            }
1853            contacts_to_update.insert(canceled_user_id);
1854        }
1855    }
1856
1857    for contact_user_id in contacts_to_update {
1858        update_user_contacts(contact_user_id, &session).await?;
1859    }
1860
1861    if let Some(live_kit) = session.live_kit_client.as_ref() {
1862        live_kit
1863            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1864            .await
1865            .trace_err();
1866
1867        if delete_live_kit_room {
1868            live_kit.delete_room(live_kit_room).await.trace_err();
1869        }
1870    }
1871
1872    Ok(())
1873}
1874
1875fn project_left(project: &db::LeftProject, session: &Session) {
1876    for connection_id in &project.connection_ids {
1877        if project.host_user_id == session.user_id {
1878            session
1879                .peer
1880                .send(
1881                    *connection_id,
1882                    proto::UnshareProject {
1883                        project_id: project.id.to_proto(),
1884                    },
1885                )
1886                .trace_err();
1887        } else {
1888            session
1889                .peer
1890                .send(
1891                    *connection_id,
1892                    proto::RemoveProjectCollaborator {
1893                        project_id: project.id.to_proto(),
1894                        peer_id: session.connection_id.0,
1895                    },
1896                )
1897                .trace_err();
1898        }
1899    }
1900
1901    session
1902        .peer
1903        .send(
1904            session.connection_id,
1905            proto::UnshareProject {
1906                project_id: project.id.to_proto(),
1907            },
1908        )
1909        .trace_err();
1910}
1911
1912pub trait ResultExt {
1913    type Ok;
1914
1915    fn trace_err(self) -> Option<Self::Ok>;
1916}
1917
1918impl<T, E> ResultExt for Result<T, E>
1919where
1920    E: std::fmt::Debug,
1921{
1922    type Ok = T;
1923
1924    fn trace_err(self) -> Option<T> {
1925        match self {
1926            Ok(value) => Some(value),
1927            Err(error) => {
1928                tracing::error!("{:?}", error);
1929                None
1930            }
1931        }
1932    }
1933}