rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  39        RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(set_channel_member_admin)
 250            .add_request_handler(get_channel_members)
 251            .add_request_handler(respond_to_channel_invite)
 252            .add_request_handler(join_channel)
 253            .add_request_handler(follow)
 254            .add_message_handler(unfollow)
 255            .add_message_handler(update_followers)
 256            .add_message_handler(update_diff_base)
 257            .add_request_handler(get_private_user_info);
 258
 259        Arc::new(server)
 260    }
 261
 262    pub async fn start(&self) -> Result<()> {
 263        let server_id = *self.id.lock();
 264        let app_state = self.app_state.clone();
 265        let peer = self.peer.clone();
 266        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 267        let pool = self.connection_pool.clone();
 268        let live_kit_client = self.app_state.live_kit_client.clone();
 269
 270        let span = info_span!("start server");
 271        self.executor.spawn_detached(
 272            async move {
 273                tracing::info!("waiting for cleanup timeout");
 274                timeout.await;
 275                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 276                if let Some(room_ids) = app_state
 277                    .db
 278                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 279                    .await
 280                    .trace_err()
 281                {
 282                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 283                    for room_id in room_ids {
 284                        let mut contacts_to_update = HashSet::default();
 285                        let mut canceled_calls_to_user_ids = Vec::new();
 286                        let mut live_kit_room = String::new();
 287                        let mut delete_live_kit_room = false;
 288
 289                        if let Some(mut refreshed_room) = app_state
 290                            .db
 291                            .refresh_room(room_id, server_id)
 292                            .await
 293                            .trace_err()
 294                        {
 295                            tracing::info!(
 296                                room_id = room_id.0,
 297                                new_participant_count = refreshed_room.room.participants.len(),
 298                                "refreshed room"
 299                            );
 300                            room_updated(&refreshed_room.room, &peer);
 301                            if let Some(channel_id) = refreshed_room.channel_id {
 302                                channel_updated(
 303                                    channel_id,
 304                                    &refreshed_room.room,
 305                                    &refreshed_room.channel_members,
 306                                    &peer,
 307                                    &*pool.lock(),
 308                                );
 309                            }
 310                            contacts_to_update
 311                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 312                            contacts_to_update
 313                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 314                            canceled_calls_to_user_ids =
 315                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 316                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 317                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 318                        }
 319
 320                        {
 321                            let pool = pool.lock();
 322                            for canceled_user_id in canceled_calls_to_user_ids {
 323                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 324                                    peer.send(
 325                                        connection_id,
 326                                        proto::CallCanceled {
 327                                            room_id: room_id.to_proto(),
 328                                        },
 329                                    )
 330                                    .trace_err();
 331                                }
 332                            }
 333                        }
 334
 335                        for user_id in contacts_to_update {
 336                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 337                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 338                            if let Some((busy, contacts)) = busy.zip(contacts) {
 339                                let pool = pool.lock();
 340                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 341                                for contact in contacts {
 342                                    if let db::Contact::Accepted {
 343                                        user_id: contact_user_id,
 344                                        ..
 345                                    } = contact
 346                                    {
 347                                        for contact_conn_id in
 348                                            pool.user_connection_ids(contact_user_id)
 349                                        {
 350                                            peer.send(
 351                                                contact_conn_id,
 352                                                proto::UpdateContacts {
 353                                                    contacts: vec![updated_contact.clone()],
 354                                                    remove_contacts: Default::default(),
 355                                                    incoming_requests: Default::default(),
 356                                                    remove_incoming_requests: Default::default(),
 357                                                    outgoing_requests: Default::default(),
 358                                                    remove_outgoing_requests: Default::default(),
 359                                                },
 360                                            )
 361                                            .trace_err();
 362                                        }
 363                                    }
 364                                }
 365                            }
 366                        }
 367
 368                        if let Some(live_kit) = live_kit_client.as_ref() {
 369                            if delete_live_kit_room {
 370                                live_kit.delete_room(live_kit_room).await.trace_err();
 371                            }
 372                        }
 373                    }
 374                }
 375
 376                app_state
 377                    .db
 378                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 379                    .await
 380                    .trace_err();
 381            }
 382            .instrument(span),
 383        );
 384        Ok(())
 385    }
 386
 387    pub fn teardown(&self) {
 388        self.peer.teardown();
 389        self.connection_pool.lock().reset();
 390        let _ = self.teardown.send(());
 391    }
 392
 393    #[cfg(test)]
 394    pub fn reset(&self, id: ServerId) {
 395        self.teardown();
 396        *self.id.lock() = id;
 397        self.peer.reset(id.0 as u32);
 398    }
 399
 400    #[cfg(test)]
 401    pub fn id(&self) -> ServerId {
 402        *self.id.lock()
 403    }
 404
 405    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 406    where
 407        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 408        Fut: 'static + Send + Future<Output = Result<()>>,
 409        M: EnvelopedMessage,
 410    {
 411        let prev_handler = self.handlers.insert(
 412            TypeId::of::<M>(),
 413            Box::new(move |envelope, session| {
 414                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 415                let span = info_span!(
 416                    "handle message",
 417                    payload_type = envelope.payload_type_name()
 418                );
 419                span.in_scope(|| {
 420                    tracing::info!(
 421                        payload_type = envelope.payload_type_name(),
 422                        "message received"
 423                    );
 424                });
 425                let start_time = Instant::now();
 426                let future = (handler)(*envelope, session);
 427                async move {
 428                    let result = future.await;
 429                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 430                    match result {
 431                        Err(error) => {
 432                            tracing::error!(%error, ?duration_ms, "error handling message")
 433                        }
 434                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 435                    }
 436                }
 437                .instrument(span)
 438                .boxed()
 439            }),
 440        );
 441        if prev_handler.is_some() {
 442            panic!("registered a handler for the same message twice");
 443        }
 444        self
 445    }
 446
 447    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 448    where
 449        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 450        Fut: 'static + Send + Future<Output = Result<()>>,
 451        M: EnvelopedMessage,
 452    {
 453        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 454        self
 455    }
 456
 457    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 458    where
 459        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 460        Fut: Send + Future<Output = Result<()>>,
 461        M: RequestMessage,
 462    {
 463        let handler = Arc::new(handler);
 464        self.add_handler(move |envelope, session| {
 465            let receipt = envelope.receipt();
 466            let handler = handler.clone();
 467            async move {
 468                let peer = session.peer.clone();
 469                let responded = Arc::new(AtomicBool::default());
 470                let response = Response {
 471                    peer: peer.clone(),
 472                    responded: responded.clone(),
 473                    receipt,
 474                };
 475                match (handler)(envelope.payload, response, session).await {
 476                    Ok(()) => {
 477                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 478                            Ok(())
 479                        } else {
 480                            Err(anyhow!("handler did not send a response"))?
 481                        }
 482                    }
 483                    Err(error) => {
 484                        peer.respond_with_error(
 485                            receipt,
 486                            proto::Error {
 487                                message: error.to_string(),
 488                            },
 489                        )?;
 490                        Err(error)
 491                    }
 492                }
 493            }
 494        })
 495    }
 496
 497    pub fn handle_connection(
 498        self: &Arc<Self>,
 499        connection: Connection,
 500        address: String,
 501        user: User,
 502        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 503        executor: Executor,
 504    ) -> impl Future<Output = Result<()>> {
 505        let this = self.clone();
 506        let user_id = user.id;
 507        let login = user.github_login;
 508        let span = info_span!("handle connection", %user_id, %login, %address);
 509        let mut teardown = self.teardown.subscribe();
 510        async move {
 511            let (connection_id, handle_io, mut incoming_rx) = this
 512                .peer
 513                .add_connection(connection, {
 514                    let executor = executor.clone();
 515                    move |duration| executor.sleep(duration)
 516                });
 517
 518            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 519            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 520            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 521
 522            if let Some(send_connection_id) = send_connection_id.take() {
 523                let _ = send_connection_id.send(connection_id);
 524            }
 525
 526            if !user.connected_once {
 527                this.peer.send(connection_id, proto::ShowContacts {})?;
 528                this.app_state.db.set_user_connected_once(user_id, true).await?;
 529            }
 530
 531            let (contacts, invite_code, (channels, channel_participants), channel_invites) = future::try_join4(
 532                this.app_state.db.get_contacts(user_id),
 533                this.app_state.db.get_invite_code_for_user(user_id),
 534                this.app_state.db.get_channels_for_user(user_id),
 535                this.app_state.db.get_channel_invites_for_user(user_id)
 536            ).await?;
 537
 538            {
 539                let mut pool = this.connection_pool.lock();
 540                pool.add_connection(connection_id, user_id, user.admin);
 541                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 542                this.peer.send(connection_id, build_initial_channels_update(channels, channel_participants, channel_invites))?;
 543
 544                if let Some((code, count)) = invite_code {
 545                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 546                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 547                        count: count as u32,
 548                    })?;
 549                }
 550            }
 551
 552            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 553                this.peer.send(connection_id, incoming_call)?;
 554            }
 555
 556            let session = Session {
 557                user_id,
 558                connection_id,
 559                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 560                peer: this.peer.clone(),
 561                connection_pool: this.connection_pool.clone(),
 562                live_kit_client: this.app_state.live_kit_client.clone(),
 563                executor: executor.clone(),
 564            };
 565            update_user_contacts(user_id, &session).await?;
 566
 567            let handle_io = handle_io.fuse();
 568            futures::pin_mut!(handle_io);
 569
 570            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 571            // This prevents deadlocks when e.g., client A performs a request to client B and
 572            // client B performs a request to client A. If both clients stop processing further
 573            // messages until their respective request completes, they won't have a chance to
 574            // respond to the other client's request and cause a deadlock.
 575            //
 576            // This arrangement ensures we will attempt to process earlier messages first, but fall
 577            // back to processing messages arrived later in the spirit of making progress.
 578            let mut foreground_message_handlers = FuturesUnordered::new();
 579            let concurrent_handlers = Arc::new(Semaphore::new(256));
 580            loop {
 581                let next_message = async {
 582                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 583                    let message = incoming_rx.next().await;
 584                    (permit, message)
 585                }.fuse();
 586                futures::pin_mut!(next_message);
 587                futures::select_biased! {
 588                    _ = teardown.changed().fuse() => return Ok(()),
 589                    result = handle_io => {
 590                        if let Err(error) = result {
 591                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 592                        }
 593                        break;
 594                    }
 595                    _ = foreground_message_handlers.next() => {}
 596                    next_message = next_message => {
 597                        let (permit, message) = next_message;
 598                        if let Some(message) = message {
 599                            let type_name = message.payload_type_name();
 600                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 601                            let span_enter = span.enter();
 602                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 603                                let is_background = message.is_background();
 604                                let handle_message = (handler)(message, session.clone());
 605                                drop(span_enter);
 606
 607                                let handle_message = async move {
 608                                    handle_message.await;
 609                                    drop(permit);
 610                                }.instrument(span);
 611                                if is_background {
 612                                    executor.spawn_detached(handle_message);
 613                                } else {
 614                                    foreground_message_handlers.push(handle_message);
 615                                }
 616                            } else {
 617                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 618                            }
 619                        } else {
 620                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 621                            break;
 622                        }
 623                    }
 624                }
 625            }
 626
 627            drop(foreground_message_handlers);
 628            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 629            if let Err(error) = connection_lost(session, teardown, executor).await {
 630                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 631            }
 632
 633            Ok(())
 634        }.instrument(span)
 635    }
 636
 637    pub async fn invite_code_redeemed(
 638        self: &Arc<Self>,
 639        inviter_id: UserId,
 640        invitee_id: UserId,
 641    ) -> Result<()> {
 642        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 643            if let Some(code) = &user.invite_code {
 644                let pool = self.connection_pool.lock();
 645                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 646                for connection_id in pool.user_connection_ids(inviter_id) {
 647                    self.peer.send(
 648                        connection_id,
 649                        proto::UpdateContacts {
 650                            contacts: vec![invitee_contact.clone()],
 651                            ..Default::default()
 652                        },
 653                    )?;
 654                    self.peer.send(
 655                        connection_id,
 656                        proto::UpdateInviteInfo {
 657                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 658                            count: user.invite_count as u32,
 659                        },
 660                    )?;
 661                }
 662            }
 663        }
 664        Ok(())
 665    }
 666
 667    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 668        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 669            if let Some(invite_code) = &user.invite_code {
 670                let pool = self.connection_pool.lock();
 671                for connection_id in pool.user_connection_ids(user_id) {
 672                    self.peer.send(
 673                        connection_id,
 674                        proto::UpdateInviteInfo {
 675                            url: format!(
 676                                "{}{}",
 677                                self.app_state.config.invite_link_prefix, invite_code
 678                            ),
 679                            count: user.invite_count as u32,
 680                        },
 681                    )?;
 682                }
 683            }
 684        }
 685        Ok(())
 686    }
 687
 688    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 689        ServerSnapshot {
 690            connection_pool: ConnectionPoolGuard {
 691                guard: self.connection_pool.lock(),
 692                _not_send: PhantomData,
 693            },
 694            peer: &self.peer,
 695        }
 696    }
 697}
 698
 699impl<'a> Deref for ConnectionPoolGuard<'a> {
 700    type Target = ConnectionPool;
 701
 702    fn deref(&self) -> &Self::Target {
 703        &*self.guard
 704    }
 705}
 706
 707impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 708    fn deref_mut(&mut self) -> &mut Self::Target {
 709        &mut *self.guard
 710    }
 711}
 712
 713impl<'a> Drop for ConnectionPoolGuard<'a> {
 714    fn drop(&mut self) {
 715        #[cfg(test)]
 716        self.check_invariants();
 717    }
 718}
 719
 720fn broadcast<F>(
 721    sender_id: Option<ConnectionId>,
 722    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 723    mut f: F,
 724) where
 725    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 726{
 727    for receiver_id in receiver_ids {
 728        if Some(receiver_id) != sender_id {
 729            if let Err(error) = f(receiver_id) {
 730                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 731            }
 732        }
 733    }
 734}
 735
 736lazy_static! {
 737    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 738}
 739
 740pub struct ProtocolVersion(u32);
 741
 742impl Header for ProtocolVersion {
 743    fn name() -> &'static HeaderName {
 744        &ZED_PROTOCOL_VERSION
 745    }
 746
 747    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 748    where
 749        Self: Sized,
 750        I: Iterator<Item = &'i axum::http::HeaderValue>,
 751    {
 752        let version = values
 753            .next()
 754            .ok_or_else(axum::headers::Error::invalid)?
 755            .to_str()
 756            .map_err(|_| axum::headers::Error::invalid())?
 757            .parse()
 758            .map_err(|_| axum::headers::Error::invalid())?;
 759        Ok(Self(version))
 760    }
 761
 762    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 763        values.extend([self.0.to_string().parse().unwrap()]);
 764    }
 765}
 766
 767pub fn routes(server: Arc<Server>) -> Router<Body> {
 768    Router::new()
 769        .route("/rpc", get(handle_websocket_request))
 770        .layer(
 771            ServiceBuilder::new()
 772                .layer(Extension(server.app_state.clone()))
 773                .layer(middleware::from_fn(auth::validate_header)),
 774        )
 775        .route("/metrics", get(handle_metrics))
 776        .layer(Extension(server))
 777}
 778
 779pub async fn handle_websocket_request(
 780    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 781    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 782    Extension(server): Extension<Arc<Server>>,
 783    Extension(user): Extension<User>,
 784    ws: WebSocketUpgrade,
 785) -> axum::response::Response {
 786    if protocol_version != rpc::PROTOCOL_VERSION {
 787        return (
 788            StatusCode::UPGRADE_REQUIRED,
 789            "client must be upgraded".to_string(),
 790        )
 791            .into_response();
 792    }
 793    let socket_address = socket_address.to_string();
 794    ws.on_upgrade(move |socket| {
 795        use util::ResultExt;
 796        let socket = socket
 797            .map_ok(to_tungstenite_message)
 798            .err_into()
 799            .with(|message| async move { Ok(to_axum_message(message)) });
 800        let connection = Connection::new(Box::pin(socket));
 801        async move {
 802            server
 803                .handle_connection(connection, socket_address, user, None, Executor::Production)
 804                .await
 805                .log_err();
 806        }
 807    })
 808}
 809
 810pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 811    let connections = server
 812        .connection_pool
 813        .lock()
 814        .connections()
 815        .filter(|connection| !connection.admin)
 816        .count();
 817
 818    METRIC_CONNECTIONS.set(connections as _);
 819
 820    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 821    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 822
 823    let encoder = prometheus::TextEncoder::new();
 824    let metric_families = prometheus::gather();
 825    let encoded_metrics = encoder
 826        .encode_to_string(&metric_families)
 827        .map_err(|err| anyhow!("{}", err))?;
 828    Ok(encoded_metrics)
 829}
 830
 831#[instrument(err, skip(executor))]
 832async fn connection_lost(
 833    session: Session,
 834    mut teardown: watch::Receiver<()>,
 835    executor: Executor,
 836) -> Result<()> {
 837    session.peer.disconnect(session.connection_id);
 838    session
 839        .connection_pool()
 840        .await
 841        .remove_connection(session.connection_id)?;
 842
 843    session
 844        .db()
 845        .await
 846        .connection_lost(session.connection_id)
 847        .await
 848        .trace_err();
 849
 850    futures::select_biased! {
 851        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 852            leave_room_for_session(&session).await.trace_err();
 853
 854            if !session
 855                .connection_pool()
 856                .await
 857                .is_user_online(session.user_id)
 858            {
 859                let db = session.db().await;
 860                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 861                    room_updated(&room, &session.peer);
 862                }
 863            }
 864            update_user_contacts(session.user_id, &session).await?;
 865        }
 866        _ = teardown.changed().fuse() => {}
 867    }
 868
 869    Ok(())
 870}
 871
 872async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 873    response.send(proto::Ack {})?;
 874    Ok(())
 875}
 876
 877async fn create_room(
 878    _request: proto::CreateRoom,
 879    response: Response<proto::CreateRoom>,
 880    session: Session,
 881) -> Result<()> {
 882    let live_kit_room = nanoid::nanoid!(30);
 883
 884    let live_kit_connection_info = {
 885        let live_kit_room = live_kit_room.clone();
 886        let live_kit = session.live_kit_client.as_ref();
 887
 888        util::async_iife!({
 889            let live_kit = live_kit?;
 890
 891            live_kit
 892                .create_room(live_kit_room.clone())
 893                .await
 894                .trace_err()?;
 895
 896            let token = live_kit
 897                .room_token(&live_kit_room, &session.user_id.to_string())
 898                .trace_err()?;
 899
 900            Some(proto::LiveKitConnectionInfo {
 901                server_url: live_kit.url().into(),
 902                token,
 903            })
 904        })
 905    }
 906    .await;
 907
 908    let room = session
 909        .db()
 910        .await
 911        .create_room(session.user_id, session.connection_id, &live_kit_room)
 912        .await?;
 913
 914    response.send(proto::CreateRoomResponse {
 915        room: Some(room.clone()),
 916        live_kit_connection_info,
 917    })?;
 918
 919    update_user_contacts(session.user_id, &session).await?;
 920    Ok(())
 921}
 922
 923async fn join_room(
 924    request: proto::JoinRoom,
 925    response: Response<proto::JoinRoom>,
 926    session: Session,
 927) -> Result<()> {
 928    let room_id = RoomId::from_proto(request.id);
 929    let room = {
 930        let room = session
 931            .db()
 932            .await
 933            .join_room(room_id, session.user_id, None, session.connection_id)
 934            .await?;
 935        room_updated(&room.room, &session.peer);
 936        room.room.clone()
 937    };
 938
 939    for connection_id in session
 940        .connection_pool()
 941        .await
 942        .user_connection_ids(session.user_id)
 943    {
 944        session
 945            .peer
 946            .send(
 947                connection_id,
 948                proto::CallCanceled {
 949                    room_id: room_id.to_proto(),
 950                },
 951            )
 952            .trace_err();
 953    }
 954
 955    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 956        if let Some(token) = live_kit
 957            .room_token(&room.live_kit_room, &session.user_id.to_string())
 958            .trace_err()
 959        {
 960            Some(proto::LiveKitConnectionInfo {
 961                server_url: live_kit.url().into(),
 962                token,
 963            })
 964        } else {
 965            None
 966        }
 967    } else {
 968        None
 969    };
 970
 971    response.send(proto::JoinRoomResponse {
 972        room: Some(room),
 973        live_kit_connection_info,
 974    })?;
 975
 976    update_user_contacts(session.user_id, &session).await?;
 977    Ok(())
 978}
 979
 980async fn rejoin_room(
 981    request: proto::RejoinRoom,
 982    response: Response<proto::RejoinRoom>,
 983    session: Session,
 984) -> Result<()> {
 985    let room;
 986    let channel_id;
 987    let channel_members;
 988    {
 989        let mut rejoined_room = session
 990            .db()
 991            .await
 992            .rejoin_room(request, session.user_id, session.connection_id)
 993            .await?;
 994
 995        response.send(proto::RejoinRoomResponse {
 996            room: Some(rejoined_room.room.clone()),
 997            reshared_projects: rejoined_room
 998                .reshared_projects
 999                .iter()
1000                .map(|project| proto::ResharedProject {
1001                    id: project.id.to_proto(),
1002                    collaborators: project
1003                        .collaborators
1004                        .iter()
1005                        .map(|collaborator| collaborator.to_proto())
1006                        .collect(),
1007                })
1008                .collect(),
1009            rejoined_projects: rejoined_room
1010                .rejoined_projects
1011                .iter()
1012                .map(|rejoined_project| proto::RejoinedProject {
1013                    id: rejoined_project.id.to_proto(),
1014                    worktrees: rejoined_project
1015                        .worktrees
1016                        .iter()
1017                        .map(|worktree| proto::WorktreeMetadata {
1018                            id: worktree.id,
1019                            root_name: worktree.root_name.clone(),
1020                            visible: worktree.visible,
1021                            abs_path: worktree.abs_path.clone(),
1022                        })
1023                        .collect(),
1024                    collaborators: rejoined_project
1025                        .collaborators
1026                        .iter()
1027                        .map(|collaborator| collaborator.to_proto())
1028                        .collect(),
1029                    language_servers: rejoined_project.language_servers.clone(),
1030                })
1031                .collect(),
1032        })?;
1033        room_updated(&rejoined_room.room, &session.peer);
1034
1035        for project in &rejoined_room.reshared_projects {
1036            for collaborator in &project.collaborators {
1037                session
1038                    .peer
1039                    .send(
1040                        collaborator.connection_id,
1041                        proto::UpdateProjectCollaborator {
1042                            project_id: project.id.to_proto(),
1043                            old_peer_id: Some(project.old_connection_id.into()),
1044                            new_peer_id: Some(session.connection_id.into()),
1045                        },
1046                    )
1047                    .trace_err();
1048            }
1049
1050            broadcast(
1051                Some(session.connection_id),
1052                project
1053                    .collaborators
1054                    .iter()
1055                    .map(|collaborator| collaborator.connection_id),
1056                |connection_id| {
1057                    session.peer.forward_send(
1058                        session.connection_id,
1059                        connection_id,
1060                        proto::UpdateProject {
1061                            project_id: project.id.to_proto(),
1062                            worktrees: project.worktrees.clone(),
1063                        },
1064                    )
1065                },
1066            );
1067        }
1068
1069        for project in &rejoined_room.rejoined_projects {
1070            for collaborator in &project.collaborators {
1071                session
1072                    .peer
1073                    .send(
1074                        collaborator.connection_id,
1075                        proto::UpdateProjectCollaborator {
1076                            project_id: project.id.to_proto(),
1077                            old_peer_id: Some(project.old_connection_id.into()),
1078                            new_peer_id: Some(session.connection_id.into()),
1079                        },
1080                    )
1081                    .trace_err();
1082            }
1083        }
1084
1085        for project in &mut rejoined_room.rejoined_projects {
1086            for worktree in mem::take(&mut project.worktrees) {
1087                #[cfg(any(test, feature = "test-support"))]
1088                const MAX_CHUNK_SIZE: usize = 2;
1089                #[cfg(not(any(test, feature = "test-support")))]
1090                const MAX_CHUNK_SIZE: usize = 256;
1091
1092                // Stream this worktree's entries.
1093                let message = proto::UpdateWorktree {
1094                    project_id: project.id.to_proto(),
1095                    worktree_id: worktree.id,
1096                    abs_path: worktree.abs_path.clone(),
1097                    root_name: worktree.root_name,
1098                    updated_entries: worktree.updated_entries,
1099                    removed_entries: worktree.removed_entries,
1100                    scan_id: worktree.scan_id,
1101                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1102                    updated_repositories: worktree.updated_repositories,
1103                    removed_repositories: worktree.removed_repositories,
1104                };
1105                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1106                    session.peer.send(session.connection_id, update.clone())?;
1107                }
1108
1109                // Stream this worktree's diagnostics.
1110                for summary in worktree.diagnostic_summaries {
1111                    session.peer.send(
1112                        session.connection_id,
1113                        proto::UpdateDiagnosticSummary {
1114                            project_id: project.id.to_proto(),
1115                            worktree_id: worktree.id,
1116                            summary: Some(summary),
1117                        },
1118                    )?;
1119                }
1120
1121                for settings_file in worktree.settings_files {
1122                    session.peer.send(
1123                        session.connection_id,
1124                        proto::UpdateWorktreeSettings {
1125                            project_id: project.id.to_proto(),
1126                            worktree_id: worktree.id,
1127                            path: settings_file.path,
1128                            content: Some(settings_file.content),
1129                        },
1130                    )?;
1131                }
1132            }
1133
1134            for language_server in &project.language_servers {
1135                session.peer.send(
1136                    session.connection_id,
1137                    proto::UpdateLanguageServer {
1138                        project_id: project.id.to_proto(),
1139                        language_server_id: language_server.id,
1140                        variant: Some(
1141                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1142                                proto::LspDiskBasedDiagnosticsUpdated {},
1143                            ),
1144                        ),
1145                    },
1146                )?;
1147            }
1148        }
1149
1150        room = mem::take(&mut rejoined_room.room);
1151        channel_id = rejoined_room.channel_id;
1152        channel_members = mem::take(&mut rejoined_room.channel_members);
1153    }
1154
1155    //TODO: move this into the room guard
1156    if let Some(channel_id) = channel_id {
1157        channel_updated(
1158            channel_id,
1159            &room,
1160            &channel_members,
1161            &session.peer,
1162            &*session.connection_pool().await,
1163        );
1164    }
1165
1166    update_user_contacts(session.user_id, &session).await?;
1167    Ok(())
1168}
1169
1170async fn leave_room(
1171    _: proto::LeaveRoom,
1172    response: Response<proto::LeaveRoom>,
1173    session: Session,
1174) -> Result<()> {
1175    leave_room_for_session(&session).await?;
1176    response.send(proto::Ack {})?;
1177    Ok(())
1178}
1179
1180async fn call(
1181    request: proto::Call,
1182    response: Response<proto::Call>,
1183    session: Session,
1184) -> Result<()> {
1185    let room_id = RoomId::from_proto(request.room_id);
1186    let calling_user_id = session.user_id;
1187    let calling_connection_id = session.connection_id;
1188    let called_user_id = UserId::from_proto(request.called_user_id);
1189    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1190    if !session
1191        .db()
1192        .await
1193        .has_contact(calling_user_id, called_user_id)
1194        .await?
1195    {
1196        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1197    }
1198
1199    let incoming_call = {
1200        let (room, incoming_call) = &mut *session
1201            .db()
1202            .await
1203            .call(
1204                room_id,
1205                calling_user_id,
1206                calling_connection_id,
1207                called_user_id,
1208                initial_project_id,
1209            )
1210            .await?;
1211        room_updated(&room, &session.peer);
1212        mem::take(incoming_call)
1213    };
1214    update_user_contacts(called_user_id, &session).await?;
1215
1216    let mut calls = session
1217        .connection_pool()
1218        .await
1219        .user_connection_ids(called_user_id)
1220        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1221        .collect::<FuturesUnordered<_>>();
1222
1223    while let Some(call_response) = calls.next().await {
1224        match call_response.as_ref() {
1225            Ok(_) => {
1226                response.send(proto::Ack {})?;
1227                return Ok(());
1228            }
1229            Err(_) => {
1230                call_response.trace_err();
1231            }
1232        }
1233    }
1234
1235    {
1236        let room = session
1237            .db()
1238            .await
1239            .call_failed(room_id, called_user_id)
1240            .await?;
1241        room_updated(&room, &session.peer);
1242    }
1243    update_user_contacts(called_user_id, &session).await?;
1244
1245    Err(anyhow!("failed to ring user"))?
1246}
1247
1248async fn cancel_call(
1249    request: proto::CancelCall,
1250    response: Response<proto::CancelCall>,
1251    session: Session,
1252) -> Result<()> {
1253    let called_user_id = UserId::from_proto(request.called_user_id);
1254    let room_id = RoomId::from_proto(request.room_id);
1255    {
1256        let room = session
1257            .db()
1258            .await
1259            .cancel_call(room_id, session.connection_id, called_user_id)
1260            .await?;
1261        room_updated(&room, &session.peer);
1262    }
1263
1264    for connection_id in session
1265        .connection_pool()
1266        .await
1267        .user_connection_ids(called_user_id)
1268    {
1269        session
1270            .peer
1271            .send(
1272                connection_id,
1273                proto::CallCanceled {
1274                    room_id: room_id.to_proto(),
1275                },
1276            )
1277            .trace_err();
1278    }
1279    response.send(proto::Ack {})?;
1280
1281    update_user_contacts(called_user_id, &session).await?;
1282    Ok(())
1283}
1284
1285async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1286    let room_id = RoomId::from_proto(message.room_id);
1287    {
1288        let room = session
1289            .db()
1290            .await
1291            .decline_call(Some(room_id), session.user_id)
1292            .await?
1293            .ok_or_else(|| anyhow!("failed to decline call"))?;
1294        room_updated(&room, &session.peer);
1295    }
1296
1297    for connection_id in session
1298        .connection_pool()
1299        .await
1300        .user_connection_ids(session.user_id)
1301    {
1302        session
1303            .peer
1304            .send(
1305                connection_id,
1306                proto::CallCanceled {
1307                    room_id: room_id.to_proto(),
1308                },
1309            )
1310            .trace_err();
1311    }
1312    update_user_contacts(session.user_id, &session).await?;
1313    Ok(())
1314}
1315
1316async fn update_participant_location(
1317    request: proto::UpdateParticipantLocation,
1318    response: Response<proto::UpdateParticipantLocation>,
1319    session: Session,
1320) -> Result<()> {
1321    let room_id = RoomId::from_proto(request.room_id);
1322    let location = request
1323        .location
1324        .ok_or_else(|| anyhow!("invalid location"))?;
1325
1326    let db = session.db().await;
1327    let room = db
1328        .update_room_participant_location(room_id, session.connection_id, location)
1329        .await?;
1330
1331    room_updated(&room, &session.peer);
1332    response.send(proto::Ack {})?;
1333    Ok(())
1334}
1335
1336async fn share_project(
1337    request: proto::ShareProject,
1338    response: Response<proto::ShareProject>,
1339    session: Session,
1340) -> Result<()> {
1341    let (project_id, room) = &*session
1342        .db()
1343        .await
1344        .share_project(
1345            RoomId::from_proto(request.room_id),
1346            session.connection_id,
1347            &request.worktrees,
1348        )
1349        .await?;
1350    response.send(proto::ShareProjectResponse {
1351        project_id: project_id.to_proto(),
1352    })?;
1353    room_updated(&room, &session.peer);
1354
1355    Ok(())
1356}
1357
1358async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1359    let project_id = ProjectId::from_proto(message.project_id);
1360
1361    let (room, guest_connection_ids) = &*session
1362        .db()
1363        .await
1364        .unshare_project(project_id, session.connection_id)
1365        .await?;
1366
1367    broadcast(
1368        Some(session.connection_id),
1369        guest_connection_ids.iter().copied(),
1370        |conn_id| session.peer.send(conn_id, message.clone()),
1371    );
1372    room_updated(&room, &session.peer);
1373
1374    Ok(())
1375}
1376
1377async fn join_project(
1378    request: proto::JoinProject,
1379    response: Response<proto::JoinProject>,
1380    session: Session,
1381) -> Result<()> {
1382    let project_id = ProjectId::from_proto(request.project_id);
1383    let guest_user_id = session.user_id;
1384
1385    tracing::info!(%project_id, "join project");
1386
1387    let (project, replica_id) = &mut *session
1388        .db()
1389        .await
1390        .join_project(project_id, session.connection_id)
1391        .await?;
1392
1393    let collaborators = project
1394        .collaborators
1395        .iter()
1396        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1397        .map(|collaborator| collaborator.to_proto())
1398        .collect::<Vec<_>>();
1399
1400    let worktrees = project
1401        .worktrees
1402        .iter()
1403        .map(|(id, worktree)| proto::WorktreeMetadata {
1404            id: *id,
1405            root_name: worktree.root_name.clone(),
1406            visible: worktree.visible,
1407            abs_path: worktree.abs_path.clone(),
1408        })
1409        .collect::<Vec<_>>();
1410
1411    for collaborator in &collaborators {
1412        session
1413            .peer
1414            .send(
1415                collaborator.peer_id.unwrap().into(),
1416                proto::AddProjectCollaborator {
1417                    project_id: project_id.to_proto(),
1418                    collaborator: Some(proto::Collaborator {
1419                        peer_id: Some(session.connection_id.into()),
1420                        replica_id: replica_id.0 as u32,
1421                        user_id: guest_user_id.to_proto(),
1422                    }),
1423                },
1424            )
1425            .trace_err();
1426    }
1427
1428    // First, we send the metadata associated with each worktree.
1429    response.send(proto::JoinProjectResponse {
1430        worktrees: worktrees.clone(),
1431        replica_id: replica_id.0 as u32,
1432        collaborators: collaborators.clone(),
1433        language_servers: project.language_servers.clone(),
1434    })?;
1435
1436    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1437        #[cfg(any(test, feature = "test-support"))]
1438        const MAX_CHUNK_SIZE: usize = 2;
1439        #[cfg(not(any(test, feature = "test-support")))]
1440        const MAX_CHUNK_SIZE: usize = 256;
1441
1442        // Stream this worktree's entries.
1443        let message = proto::UpdateWorktree {
1444            project_id: project_id.to_proto(),
1445            worktree_id,
1446            abs_path: worktree.abs_path.clone(),
1447            root_name: worktree.root_name,
1448            updated_entries: worktree.entries,
1449            removed_entries: Default::default(),
1450            scan_id: worktree.scan_id,
1451            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1452            updated_repositories: worktree.repository_entries.into_values().collect(),
1453            removed_repositories: Default::default(),
1454        };
1455        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1456            session.peer.send(session.connection_id, update.clone())?;
1457        }
1458
1459        // Stream this worktree's diagnostics.
1460        for summary in worktree.diagnostic_summaries {
1461            session.peer.send(
1462                session.connection_id,
1463                proto::UpdateDiagnosticSummary {
1464                    project_id: project_id.to_proto(),
1465                    worktree_id: worktree.id,
1466                    summary: Some(summary),
1467                },
1468            )?;
1469        }
1470
1471        for settings_file in worktree.settings_files {
1472            session.peer.send(
1473                session.connection_id,
1474                proto::UpdateWorktreeSettings {
1475                    project_id: project_id.to_proto(),
1476                    worktree_id: worktree.id,
1477                    path: settings_file.path,
1478                    content: Some(settings_file.content),
1479                },
1480            )?;
1481        }
1482    }
1483
1484    for language_server in &project.language_servers {
1485        session.peer.send(
1486            session.connection_id,
1487            proto::UpdateLanguageServer {
1488                project_id: project_id.to_proto(),
1489                language_server_id: language_server.id,
1490                variant: Some(
1491                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1492                        proto::LspDiskBasedDiagnosticsUpdated {},
1493                    ),
1494                ),
1495            },
1496        )?;
1497    }
1498
1499    Ok(())
1500}
1501
1502async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1503    let sender_id = session.connection_id;
1504    let project_id = ProjectId::from_proto(request.project_id);
1505
1506    let (room, project) = &*session
1507        .db()
1508        .await
1509        .leave_project(project_id, sender_id)
1510        .await?;
1511    tracing::info!(
1512        %project_id,
1513        host_user_id = %project.host_user_id,
1514        host_connection_id = %project.host_connection_id,
1515        "leave project"
1516    );
1517
1518    project_left(&project, &session);
1519    room_updated(&room, &session.peer);
1520
1521    Ok(())
1522}
1523
1524async fn update_project(
1525    request: proto::UpdateProject,
1526    response: Response<proto::UpdateProject>,
1527    session: Session,
1528) -> Result<()> {
1529    let project_id = ProjectId::from_proto(request.project_id);
1530    let (room, guest_connection_ids) = &*session
1531        .db()
1532        .await
1533        .update_project(project_id, session.connection_id, &request.worktrees)
1534        .await?;
1535    broadcast(
1536        Some(session.connection_id),
1537        guest_connection_ids.iter().copied(),
1538        |connection_id| {
1539            session
1540                .peer
1541                .forward_send(session.connection_id, connection_id, request.clone())
1542        },
1543    );
1544    room_updated(&room, &session.peer);
1545    response.send(proto::Ack {})?;
1546
1547    Ok(())
1548}
1549
1550async fn update_worktree(
1551    request: proto::UpdateWorktree,
1552    response: Response<proto::UpdateWorktree>,
1553    session: Session,
1554) -> Result<()> {
1555    let guest_connection_ids = session
1556        .db()
1557        .await
1558        .update_worktree(&request, session.connection_id)
1559        .await?;
1560
1561    broadcast(
1562        Some(session.connection_id),
1563        guest_connection_ids.iter().copied(),
1564        |connection_id| {
1565            session
1566                .peer
1567                .forward_send(session.connection_id, connection_id, request.clone())
1568        },
1569    );
1570    response.send(proto::Ack {})?;
1571    Ok(())
1572}
1573
1574async fn update_diagnostic_summary(
1575    message: proto::UpdateDiagnosticSummary,
1576    session: Session,
1577) -> Result<()> {
1578    let guest_connection_ids = session
1579        .db()
1580        .await
1581        .update_diagnostic_summary(&message, session.connection_id)
1582        .await?;
1583
1584    broadcast(
1585        Some(session.connection_id),
1586        guest_connection_ids.iter().copied(),
1587        |connection_id| {
1588            session
1589                .peer
1590                .forward_send(session.connection_id, connection_id, message.clone())
1591        },
1592    );
1593
1594    Ok(())
1595}
1596
1597async fn update_worktree_settings(
1598    message: proto::UpdateWorktreeSettings,
1599    session: Session,
1600) -> Result<()> {
1601    let guest_connection_ids = session
1602        .db()
1603        .await
1604        .update_worktree_settings(&message, session.connection_id)
1605        .await?;
1606
1607    broadcast(
1608        Some(session.connection_id),
1609        guest_connection_ids.iter().copied(),
1610        |connection_id| {
1611            session
1612                .peer
1613                .forward_send(session.connection_id, connection_id, message.clone())
1614        },
1615    );
1616
1617    Ok(())
1618}
1619
1620async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1621    broadcast_project_message(request.project_id, request, session).await
1622}
1623
1624async fn start_language_server(
1625    request: proto::StartLanguageServer,
1626    session: Session,
1627) -> Result<()> {
1628    let guest_connection_ids = session
1629        .db()
1630        .await
1631        .start_language_server(&request, session.connection_id)
1632        .await?;
1633
1634    broadcast(
1635        Some(session.connection_id),
1636        guest_connection_ids.iter().copied(),
1637        |connection_id| {
1638            session
1639                .peer
1640                .forward_send(session.connection_id, connection_id, request.clone())
1641        },
1642    );
1643    Ok(())
1644}
1645
1646async fn update_language_server(
1647    request: proto::UpdateLanguageServer,
1648    session: Session,
1649) -> Result<()> {
1650    session.executor.record_backtrace();
1651    let project_id = ProjectId::from_proto(request.project_id);
1652    let project_connection_ids = session
1653        .db()
1654        .await
1655        .project_connection_ids(project_id, session.connection_id)
1656        .await?;
1657    broadcast(
1658        Some(session.connection_id),
1659        project_connection_ids.iter().copied(),
1660        |connection_id| {
1661            session
1662                .peer
1663                .forward_send(session.connection_id, connection_id, request.clone())
1664        },
1665    );
1666    Ok(())
1667}
1668
1669async fn forward_project_request<T>(
1670    request: T,
1671    response: Response<T>,
1672    session: Session,
1673) -> Result<()>
1674where
1675    T: EntityMessage + RequestMessage,
1676{
1677    session.executor.record_backtrace();
1678    let project_id = ProjectId::from_proto(request.remote_entity_id());
1679    let host_connection_id = {
1680        let collaborators = session
1681            .db()
1682            .await
1683            .project_collaborators(project_id, session.connection_id)
1684            .await?;
1685        collaborators
1686            .iter()
1687            .find(|collaborator| collaborator.is_host)
1688            .ok_or_else(|| anyhow!("host not found"))?
1689            .connection_id
1690    };
1691
1692    let payload = session
1693        .peer
1694        .forward_request(session.connection_id, host_connection_id, request)
1695        .await?;
1696
1697    response.send(payload)?;
1698    Ok(())
1699}
1700
1701async fn create_buffer_for_peer(
1702    request: proto::CreateBufferForPeer,
1703    session: Session,
1704) -> Result<()> {
1705    session.executor.record_backtrace();
1706    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1707    session
1708        .peer
1709        .forward_send(session.connection_id, peer_id.into(), request)?;
1710    Ok(())
1711}
1712
1713async fn update_buffer(
1714    request: proto::UpdateBuffer,
1715    response: Response<proto::UpdateBuffer>,
1716    session: Session,
1717) -> Result<()> {
1718    session.executor.record_backtrace();
1719    let project_id = ProjectId::from_proto(request.project_id);
1720    let mut guest_connection_ids;
1721    let mut host_connection_id = None;
1722    {
1723        let collaborators = session
1724            .db()
1725            .await
1726            .project_collaborators(project_id, session.connection_id)
1727            .await?;
1728        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1729        for collaborator in collaborators.iter() {
1730            if collaborator.is_host {
1731                host_connection_id = Some(collaborator.connection_id);
1732            } else {
1733                guest_connection_ids.push(collaborator.connection_id);
1734            }
1735        }
1736    }
1737    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1738
1739    session.executor.record_backtrace();
1740    broadcast(
1741        Some(session.connection_id),
1742        guest_connection_ids,
1743        |connection_id| {
1744            session
1745                .peer
1746                .forward_send(session.connection_id, connection_id, request.clone())
1747        },
1748    );
1749    if host_connection_id != session.connection_id {
1750        session
1751            .peer
1752            .forward_request(session.connection_id, host_connection_id, request.clone())
1753            .await?;
1754    }
1755
1756    response.send(proto::Ack {})?;
1757    Ok(())
1758}
1759
1760async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1761    let project_id = ProjectId::from_proto(request.project_id);
1762    let project_connection_ids = session
1763        .db()
1764        .await
1765        .project_connection_ids(project_id, session.connection_id)
1766        .await?;
1767
1768    broadcast(
1769        Some(session.connection_id),
1770        project_connection_ids.iter().copied(),
1771        |connection_id| {
1772            session
1773                .peer
1774                .forward_send(session.connection_id, connection_id, request.clone())
1775        },
1776    );
1777    Ok(())
1778}
1779
1780async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1781    let project_id = ProjectId::from_proto(request.project_id);
1782    let project_connection_ids = session
1783        .db()
1784        .await
1785        .project_connection_ids(project_id, session.connection_id)
1786        .await?;
1787    broadcast(
1788        Some(session.connection_id),
1789        project_connection_ids.iter().copied(),
1790        |connection_id| {
1791            session
1792                .peer
1793                .forward_send(session.connection_id, connection_id, request.clone())
1794        },
1795    );
1796    Ok(())
1797}
1798
1799async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1800    broadcast_project_message(request.project_id, request, session).await
1801}
1802
1803async fn broadcast_project_message<T: EnvelopedMessage>(
1804    project_id: u64,
1805    request: T,
1806    session: Session,
1807) -> Result<()> {
1808    let project_id = ProjectId::from_proto(project_id);
1809    let project_connection_ids = session
1810        .db()
1811        .await
1812        .project_connection_ids(project_id, session.connection_id)
1813        .await?;
1814    broadcast(
1815        Some(session.connection_id),
1816        project_connection_ids.iter().copied(),
1817        |connection_id| {
1818            session
1819                .peer
1820                .forward_send(session.connection_id, connection_id, request.clone())
1821        },
1822    );
1823    Ok(())
1824}
1825
1826async fn follow(
1827    request: proto::Follow,
1828    response: Response<proto::Follow>,
1829    session: Session,
1830) -> Result<()> {
1831    let project_id = ProjectId::from_proto(request.project_id);
1832    let leader_id = request
1833        .leader_id
1834        .ok_or_else(|| anyhow!("invalid leader id"))?
1835        .into();
1836    let follower_id = session.connection_id;
1837
1838    {
1839        let project_connection_ids = session
1840            .db()
1841            .await
1842            .project_connection_ids(project_id, session.connection_id)
1843            .await?;
1844
1845        if !project_connection_ids.contains(&leader_id) {
1846            Err(anyhow!("no such peer"))?;
1847        }
1848    }
1849
1850    let mut response_payload = session
1851        .peer
1852        .forward_request(session.connection_id, leader_id, request)
1853        .await?;
1854    response_payload
1855        .views
1856        .retain(|view| view.leader_id != Some(follower_id.into()));
1857    response.send(response_payload)?;
1858
1859    let room = session
1860        .db()
1861        .await
1862        .follow(project_id, leader_id, follower_id)
1863        .await?;
1864    room_updated(&room, &session.peer);
1865
1866    Ok(())
1867}
1868
1869async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1870    let project_id = ProjectId::from_proto(request.project_id);
1871    let leader_id = request
1872        .leader_id
1873        .ok_or_else(|| anyhow!("invalid leader id"))?
1874        .into();
1875    let follower_id = session.connection_id;
1876
1877    if !session
1878        .db()
1879        .await
1880        .project_connection_ids(project_id, session.connection_id)
1881        .await?
1882        .contains(&leader_id)
1883    {
1884        Err(anyhow!("no such peer"))?;
1885    }
1886
1887    session
1888        .peer
1889        .forward_send(session.connection_id, leader_id, request)?;
1890
1891    let room = session
1892        .db()
1893        .await
1894        .unfollow(project_id, leader_id, follower_id)
1895        .await?;
1896    room_updated(&room, &session.peer);
1897
1898    Ok(())
1899}
1900
1901async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1902    let project_id = ProjectId::from_proto(request.project_id);
1903    let project_connection_ids = session
1904        .db
1905        .lock()
1906        .await
1907        .project_connection_ids(project_id, session.connection_id)
1908        .await?;
1909
1910    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1911        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1912        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1913        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1914    });
1915    for follower_peer_id in request.follower_ids.iter().copied() {
1916        let follower_connection_id = follower_peer_id.into();
1917        if project_connection_ids.contains(&follower_connection_id)
1918            && Some(follower_peer_id) != leader_id
1919        {
1920            session.peer.forward_send(
1921                session.connection_id,
1922                follower_connection_id,
1923                request.clone(),
1924            )?;
1925        }
1926    }
1927    Ok(())
1928}
1929
1930async fn get_users(
1931    request: proto::GetUsers,
1932    response: Response<proto::GetUsers>,
1933    session: Session,
1934) -> Result<()> {
1935    let user_ids = request
1936        .user_ids
1937        .into_iter()
1938        .map(UserId::from_proto)
1939        .collect();
1940    let users = session
1941        .db()
1942        .await
1943        .get_users_by_ids(user_ids)
1944        .await?
1945        .into_iter()
1946        .map(|user| proto::User {
1947            id: user.id.to_proto(),
1948            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1949            github_login: user.github_login,
1950        })
1951        .collect();
1952    response.send(proto::UsersResponse { users })?;
1953    Ok(())
1954}
1955
1956async fn fuzzy_search_users(
1957    request: proto::FuzzySearchUsers,
1958    response: Response<proto::FuzzySearchUsers>,
1959    session: Session,
1960) -> Result<()> {
1961    let query = request.query;
1962    let users = match query.len() {
1963        0 => vec![],
1964        1 | 2 => session
1965            .db()
1966            .await
1967            .get_user_by_github_login(&query)
1968            .await?
1969            .into_iter()
1970            .collect(),
1971        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1972    };
1973    let users = users
1974        .into_iter()
1975        .filter(|user| user.id != session.user_id)
1976        .map(|user| proto::User {
1977            id: user.id.to_proto(),
1978            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1979            github_login: user.github_login,
1980        })
1981        .collect();
1982    response.send(proto::UsersResponse { users })?;
1983    Ok(())
1984}
1985
1986async fn request_contact(
1987    request: proto::RequestContact,
1988    response: Response<proto::RequestContact>,
1989    session: Session,
1990) -> Result<()> {
1991    let requester_id = session.user_id;
1992    let responder_id = UserId::from_proto(request.responder_id);
1993    if requester_id == responder_id {
1994        return Err(anyhow!("cannot add yourself as a contact"))?;
1995    }
1996
1997    session
1998        .db()
1999        .await
2000        .send_contact_request(requester_id, responder_id)
2001        .await?;
2002
2003    // Update outgoing contact requests of requester
2004    let mut update = proto::UpdateContacts::default();
2005    update.outgoing_requests.push(responder_id.to_proto());
2006    for connection_id in session
2007        .connection_pool()
2008        .await
2009        .user_connection_ids(requester_id)
2010    {
2011        session.peer.send(connection_id, update.clone())?;
2012    }
2013
2014    // Update incoming contact requests of responder
2015    let mut update = proto::UpdateContacts::default();
2016    update
2017        .incoming_requests
2018        .push(proto::IncomingContactRequest {
2019            requester_id: requester_id.to_proto(),
2020            should_notify: true,
2021        });
2022    for connection_id in session
2023        .connection_pool()
2024        .await
2025        .user_connection_ids(responder_id)
2026    {
2027        session.peer.send(connection_id, update.clone())?;
2028    }
2029
2030    response.send(proto::Ack {})?;
2031    Ok(())
2032}
2033
2034async fn respond_to_contact_request(
2035    request: proto::RespondToContactRequest,
2036    response: Response<proto::RespondToContactRequest>,
2037    session: Session,
2038) -> Result<()> {
2039    let responder_id = session.user_id;
2040    let requester_id = UserId::from_proto(request.requester_id);
2041    let db = session.db().await;
2042    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2043        db.dismiss_contact_notification(responder_id, requester_id)
2044            .await?;
2045    } else {
2046        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2047
2048        db.respond_to_contact_request(responder_id, requester_id, accept)
2049            .await?;
2050        let requester_busy = db.is_user_busy(requester_id).await?;
2051        let responder_busy = db.is_user_busy(responder_id).await?;
2052
2053        let pool = session.connection_pool().await;
2054        // Update responder with new contact
2055        let mut update = proto::UpdateContacts::default();
2056        if accept {
2057            update
2058                .contacts
2059                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2060        }
2061        update
2062            .remove_incoming_requests
2063            .push(requester_id.to_proto());
2064        for connection_id in pool.user_connection_ids(responder_id) {
2065            session.peer.send(connection_id, update.clone())?;
2066        }
2067
2068        // Update requester with new contact
2069        let mut update = proto::UpdateContacts::default();
2070        if accept {
2071            update
2072                .contacts
2073                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2074        }
2075        update
2076            .remove_outgoing_requests
2077            .push(responder_id.to_proto());
2078        for connection_id in pool.user_connection_ids(requester_id) {
2079            session.peer.send(connection_id, update.clone())?;
2080        }
2081    }
2082
2083    response.send(proto::Ack {})?;
2084    Ok(())
2085}
2086
2087async fn remove_contact(
2088    request: proto::RemoveContact,
2089    response: Response<proto::RemoveContact>,
2090    session: Session,
2091) -> Result<()> {
2092    let requester_id = session.user_id;
2093    let responder_id = UserId::from_proto(request.user_id);
2094    let db = session.db().await;
2095    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2096
2097    let pool = session.connection_pool().await;
2098    // Update outgoing contact requests of requester
2099    let mut update = proto::UpdateContacts::default();
2100    if contact_accepted {
2101        update.remove_contacts.push(responder_id.to_proto());
2102    } else {
2103        update
2104            .remove_outgoing_requests
2105            .push(responder_id.to_proto());
2106    }
2107    for connection_id in pool.user_connection_ids(requester_id) {
2108        session.peer.send(connection_id, update.clone())?;
2109    }
2110
2111    // Update incoming contact requests of responder
2112    let mut update = proto::UpdateContacts::default();
2113    if contact_accepted {
2114        update.remove_contacts.push(requester_id.to_proto());
2115    } else {
2116        update
2117            .remove_incoming_requests
2118            .push(requester_id.to_proto());
2119    }
2120    for connection_id in pool.user_connection_ids(responder_id) {
2121        session.peer.send(connection_id, update.clone())?;
2122    }
2123
2124    response.send(proto::Ack {})?;
2125    Ok(())
2126}
2127
2128async fn create_channel(
2129    request: proto::CreateChannel,
2130    response: Response<proto::CreateChannel>,
2131    session: Session,
2132) -> Result<()> {
2133    let db = session.db().await;
2134    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2135
2136    if let Some(live_kit) = session.live_kit_client.as_ref() {
2137        live_kit.create_room(live_kit_room.clone()).await?;
2138    }
2139
2140    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2141    let id = db
2142        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2143        .await?;
2144
2145    response.send(proto::CreateChannelResponse {
2146        channel_id: id.to_proto(),
2147    })?;
2148
2149    let mut update = proto::UpdateChannels::default();
2150    update.channels.push(proto::Channel {
2151        id: id.to_proto(),
2152        name: request.name,
2153        parent_id: request.parent_id,
2154        user_is_admin: false,
2155    });
2156
2157    let user_ids_to_notify = if let Some(parent_id) = parent_id {
2158        db.get_channel_members(parent_id).await?
2159    } else {
2160        vec![session.user_id]
2161    };
2162
2163    let connection_pool = session.connection_pool().await;
2164    for user_id in user_ids_to_notify {
2165        for connection_id in connection_pool.user_connection_ids(user_id) {
2166            let mut update = update.clone();
2167            if user_id == session.user_id {
2168                update.channels[0].user_is_admin = true;
2169            }
2170            session.peer.send(connection_id, update)?;
2171        }
2172    }
2173
2174    Ok(())
2175}
2176
2177async fn remove_channel(
2178    request: proto::RemoveChannel,
2179    response: Response<proto::RemoveChannel>,
2180    session: Session,
2181) -> Result<()> {
2182    let db = session.db().await;
2183
2184    let channel_id = request.channel_id;
2185    let (removed_channels, member_ids) = db
2186        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2187        .await?;
2188    response.send(proto::Ack {})?;
2189
2190    // Notify members of removed channels
2191    let mut update = proto::UpdateChannels::default();
2192    update
2193        .remove_channels
2194        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2195
2196    let connection_pool = session.connection_pool().await;
2197    for member_id in member_ids {
2198        for connection_id in connection_pool.user_connection_ids(member_id) {
2199            session.peer.send(connection_id, update.clone())?;
2200        }
2201    }
2202
2203    Ok(())
2204}
2205
2206async fn invite_channel_member(
2207    request: proto::InviteChannelMember,
2208    response: Response<proto::InviteChannelMember>,
2209    session: Session,
2210) -> Result<()> {
2211    let db = session.db().await;
2212    let channel_id = ChannelId::from_proto(request.channel_id);
2213    let channel = db
2214        .get_channel(channel_id, session.user_id)
2215        .await?
2216        .ok_or_else(|| anyhow!("channel not found"))?;
2217    let invitee_id = UserId::from_proto(request.user_id);
2218    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2219        .await?;
2220
2221    let mut update = proto::UpdateChannels::default();
2222    update.channel_invitations.push(proto::Channel {
2223        id: channel.id.to_proto(),
2224        name: channel.name,
2225        parent_id: None,
2226        user_is_admin: false,
2227    });
2228    for connection_id in session
2229        .connection_pool()
2230        .await
2231        .user_connection_ids(invitee_id)
2232    {
2233        session.peer.send(connection_id, update.clone())?;
2234    }
2235
2236    response.send(proto::Ack {})?;
2237    Ok(())
2238}
2239
2240async fn remove_channel_member(
2241    request: proto::RemoveChannelMember,
2242    response: Response<proto::RemoveChannelMember>,
2243    session: Session,
2244) -> Result<()> {
2245    let db = session.db().await;
2246    let channel_id = ChannelId::from_proto(request.channel_id);
2247    let member_id = UserId::from_proto(request.user_id);
2248
2249    db.remove_channel_member(channel_id, member_id, session.user_id)
2250        .await?;
2251
2252    let mut update = proto::UpdateChannels::default();
2253    update.remove_channels.push(channel_id.to_proto());
2254
2255    for connection_id in session
2256        .connection_pool()
2257        .await
2258        .user_connection_ids(member_id)
2259    {
2260        session.peer.send(connection_id, update.clone())?;
2261    }
2262
2263    response.send(proto::Ack {})?;
2264    Ok(())
2265}
2266
2267async fn set_channel_member_admin(
2268    request: proto::SetChannelMemberAdmin,
2269    response: Response<proto::SetChannelMemberAdmin>,
2270    session: Session,
2271) -> Result<()> {
2272    let db = session.db().await;
2273    let channel_id = ChannelId::from_proto(request.channel_id);
2274    let member_id = UserId::from_proto(request.user_id);
2275    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2276        .await?;
2277
2278    let channel = db
2279        .get_channel(channel_id, member_id)
2280        .await?
2281        .ok_or_else(|| anyhow!("channel not found"))?;
2282
2283    let mut update = proto::UpdateChannels::default();
2284    update.channels.push(proto::Channel {
2285        id: channel.id.to_proto(),
2286        name: channel.name,
2287        parent_id: None,
2288        user_is_admin: request.admin,
2289    });
2290
2291    for connection_id in session
2292        .connection_pool()
2293        .await
2294        .user_connection_ids(member_id)
2295    {
2296        session.peer.send(connection_id, update.clone())?;
2297    }
2298
2299    response.send(proto::Ack {})?;
2300    Ok(())
2301}
2302
2303async fn get_channel_members(
2304    request: proto::GetChannelMembers,
2305    response: Response<proto::GetChannelMembers>,
2306    session: Session,
2307) -> Result<()> {
2308    let db = session.db().await;
2309    let channel_id = ChannelId::from_proto(request.channel_id);
2310    let members = db
2311        .get_channel_member_details(channel_id, session.user_id)
2312        .await?;
2313    response.send(proto::GetChannelMembersResponse { members })?;
2314    Ok(())
2315}
2316
2317async fn respond_to_channel_invite(
2318    request: proto::RespondToChannelInvite,
2319    response: Response<proto::RespondToChannelInvite>,
2320    session: Session,
2321) -> Result<()> {
2322    let db = session.db().await;
2323    let channel_id = ChannelId::from_proto(request.channel_id);
2324    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2325        .await?;
2326
2327    let mut update = proto::UpdateChannels::default();
2328    update
2329        .remove_channel_invitations
2330        .push(channel_id.to_proto());
2331    if request.accept {
2332        let (channels, participants) = db.get_channels_for_user(session.user_id).await?;
2333        update
2334            .channels
2335            .extend(channels.into_iter().map(|channel| proto::Channel {
2336                id: channel.id.to_proto(),
2337                name: channel.name,
2338                user_is_admin: channel.user_is_admin,
2339                parent_id: channel.parent_id.map(ChannelId::to_proto),
2340            }));
2341        update
2342            .channel_participants
2343            .extend(participants.into_iter().map(|(channel_id, user_ids)| {
2344                proto::ChannelParticipants {
2345                    channel_id: channel_id.to_proto(),
2346                    participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2347                }
2348            }));
2349    }
2350    session.peer.send(session.connection_id, update)?;
2351    response.send(proto::Ack {})?;
2352
2353    Ok(())
2354}
2355
2356async fn join_channel(
2357    request: proto::JoinChannel,
2358    response: Response<proto::JoinChannel>,
2359    session: Session,
2360) -> Result<()> {
2361    let channel_id = ChannelId::from_proto(request.channel_id);
2362
2363    let joined_room = {
2364        let db = session.db().await;
2365
2366        let room_id = db.room_id_for_channel(channel_id).await?;
2367
2368        let joined_room = db
2369            .join_room(
2370                room_id,
2371                session.user_id,
2372                Some(channel_id),
2373                session.connection_id,
2374            )
2375            .await?;
2376
2377        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2378            let token = live_kit
2379                .room_token(
2380                    &joined_room.room.live_kit_room,
2381                    &session.user_id.to_string(),
2382                )
2383                .trace_err()?;
2384
2385            Some(LiveKitConnectionInfo {
2386                server_url: live_kit.url().into(),
2387                token,
2388            })
2389        });
2390
2391        response.send(proto::JoinRoomResponse {
2392            room: Some(joined_room.room.clone()),
2393            live_kit_connection_info,
2394        })?;
2395
2396        room_updated(&joined_room.room, &session.peer);
2397
2398        joined_room.clone()
2399    };
2400
2401    // TODO - do this while still holding the room guard,
2402    // currently there's a possible race condition if someone joins the channel
2403    // after we've dropped the lock but before we finish sending these updates
2404    channel_updated(
2405        channel_id,
2406        &joined_room.room,
2407        &joined_room.channel_members,
2408        &session.peer,
2409        &*session.connection_pool().await,
2410    );
2411
2412    update_user_contacts(session.user_id, &session).await?;
2413
2414    Ok(())
2415}
2416
2417async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2418    let project_id = ProjectId::from_proto(request.project_id);
2419    let project_connection_ids = session
2420        .db()
2421        .await
2422        .project_connection_ids(project_id, session.connection_id)
2423        .await?;
2424    broadcast(
2425        Some(session.connection_id),
2426        project_connection_ids.iter().copied(),
2427        |connection_id| {
2428            session
2429                .peer
2430                .forward_send(session.connection_id, connection_id, request.clone())
2431        },
2432    );
2433    Ok(())
2434}
2435
2436async fn get_private_user_info(
2437    _request: proto::GetPrivateUserInfo,
2438    response: Response<proto::GetPrivateUserInfo>,
2439    session: Session,
2440) -> Result<()> {
2441    let metrics_id = session
2442        .db()
2443        .await
2444        .get_user_metrics_id(session.user_id)
2445        .await?;
2446    let user = session
2447        .db()
2448        .await
2449        .get_user_by_id(session.user_id)
2450        .await?
2451        .ok_or_else(|| anyhow!("user not found"))?;
2452    response.send(proto::GetPrivateUserInfoResponse {
2453        metrics_id,
2454        staff: user.admin,
2455    })?;
2456    Ok(())
2457}
2458
2459fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2460    match message {
2461        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2462        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2463        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2464        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2465        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2466            code: frame.code.into(),
2467            reason: frame.reason,
2468        })),
2469    }
2470}
2471
2472fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2473    match message {
2474        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2475        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2476        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2477        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2478        AxumMessage::Close(frame) => {
2479            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2480                code: frame.code.into(),
2481                reason: frame.reason,
2482            }))
2483        }
2484    }
2485}
2486
2487fn build_initial_channels_update(
2488    channels: Vec<db::Channel>,
2489    channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2490    channel_invites: Vec<db::Channel>,
2491) -> proto::UpdateChannels {
2492    let mut update = proto::UpdateChannels::default();
2493
2494    for channel in channels {
2495        update.channels.push(proto::Channel {
2496            id: channel.id.to_proto(),
2497            name: channel.name,
2498            user_is_admin: channel.user_is_admin,
2499            parent_id: channel.parent_id.map(|id| id.to_proto()),
2500        });
2501    }
2502
2503    for (channel_id, participants) in channel_participants {
2504        update
2505            .channel_participants
2506            .push(proto::ChannelParticipants {
2507                channel_id: channel_id.to_proto(),
2508                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2509            });
2510    }
2511
2512    for channel in channel_invites {
2513        update.channel_invitations.push(proto::Channel {
2514            id: channel.id.to_proto(),
2515            name: channel.name,
2516            user_is_admin: false,
2517            parent_id: None,
2518        });
2519    }
2520
2521    update
2522}
2523
2524fn build_initial_contacts_update(
2525    contacts: Vec<db::Contact>,
2526    pool: &ConnectionPool,
2527) -> proto::UpdateContacts {
2528    let mut update = proto::UpdateContacts::default();
2529
2530    for contact in contacts {
2531        match contact {
2532            db::Contact::Accepted {
2533                user_id,
2534                should_notify,
2535                busy,
2536            } => {
2537                update
2538                    .contacts
2539                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2540            }
2541            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2542            db::Contact::Incoming {
2543                user_id,
2544                should_notify,
2545            } => update
2546                .incoming_requests
2547                .push(proto::IncomingContactRequest {
2548                    requester_id: user_id.to_proto(),
2549                    should_notify,
2550                }),
2551        }
2552    }
2553
2554    update
2555}
2556
2557fn contact_for_user(
2558    user_id: UserId,
2559    should_notify: bool,
2560    busy: bool,
2561    pool: &ConnectionPool,
2562) -> proto::Contact {
2563    proto::Contact {
2564        user_id: user_id.to_proto(),
2565        online: pool.is_user_online(user_id),
2566        busy,
2567        should_notify,
2568    }
2569}
2570
2571fn room_updated(room: &proto::Room, peer: &Peer) {
2572    broadcast(
2573        None,
2574        room.participants
2575            .iter()
2576            .filter_map(|participant| Some(participant.peer_id?.into())),
2577        |peer_id| {
2578            peer.send(
2579                peer_id.into(),
2580                proto::RoomUpdated {
2581                    room: Some(room.clone()),
2582                },
2583            )
2584        },
2585    );
2586}
2587
2588fn channel_updated(
2589    channel_id: ChannelId,
2590    room: &proto::Room,
2591    channel_members: &[UserId],
2592    peer: &Peer,
2593    pool: &ConnectionPool,
2594) {
2595    let participants = room
2596        .participants
2597        .iter()
2598        .map(|p| p.user_id)
2599        .collect::<Vec<_>>();
2600
2601    broadcast(
2602        None,
2603        channel_members
2604            .iter()
2605            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2606        |peer_id| {
2607            peer.send(
2608                peer_id.into(),
2609                proto::UpdateChannels {
2610                    channel_participants: vec![proto::ChannelParticipants {
2611                        channel_id: channel_id.to_proto(),
2612                        participant_user_ids: participants.clone(),
2613                    }],
2614                    ..Default::default()
2615                },
2616            )
2617        },
2618    );
2619}
2620
2621async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2622    let db = session.db().await;
2623
2624    let contacts = db.get_contacts(user_id).await?;
2625    let busy = db.is_user_busy(user_id).await?;
2626
2627    let pool = session.connection_pool().await;
2628    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2629    for contact in contacts {
2630        if let db::Contact::Accepted {
2631            user_id: contact_user_id,
2632            ..
2633        } = contact
2634        {
2635            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2636                session
2637                    .peer
2638                    .send(
2639                        contact_conn_id,
2640                        proto::UpdateContacts {
2641                            contacts: vec![updated_contact.clone()],
2642                            remove_contacts: Default::default(),
2643                            incoming_requests: Default::default(),
2644                            remove_incoming_requests: Default::default(),
2645                            outgoing_requests: Default::default(),
2646                            remove_outgoing_requests: Default::default(),
2647                        },
2648                    )
2649                    .trace_err();
2650            }
2651        }
2652    }
2653    Ok(())
2654}
2655
2656async fn leave_room_for_session(session: &Session) -> Result<()> {
2657    let mut contacts_to_update = HashSet::default();
2658
2659    let room_id;
2660    let canceled_calls_to_user_ids;
2661    let live_kit_room;
2662    let delete_live_kit_room;
2663    let room;
2664    let channel_members;
2665    let channel_id;
2666
2667    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2668        contacts_to_update.insert(session.user_id);
2669
2670        for project in left_room.left_projects.values() {
2671            project_left(project, session);
2672        }
2673
2674        room_id = RoomId::from_proto(left_room.room.id);
2675        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2676        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2677        delete_live_kit_room = left_room.deleted;
2678        room = mem::take(&mut left_room.room);
2679        channel_members = mem::take(&mut left_room.channel_members);
2680        channel_id = left_room.channel_id;
2681
2682        room_updated(&room, &session.peer);
2683    } else {
2684        return Ok(());
2685    }
2686
2687    // TODO - do this while holding the room guard.
2688    if let Some(channel_id) = channel_id {
2689        channel_updated(
2690            channel_id,
2691            &room,
2692            &channel_members,
2693            &session.peer,
2694            &*session.connection_pool().await,
2695        );
2696    }
2697
2698    {
2699        let pool = session.connection_pool().await;
2700        for canceled_user_id in canceled_calls_to_user_ids {
2701            for connection_id in pool.user_connection_ids(canceled_user_id) {
2702                session
2703                    .peer
2704                    .send(
2705                        connection_id,
2706                        proto::CallCanceled {
2707                            room_id: room_id.to_proto(),
2708                        },
2709                    )
2710                    .trace_err();
2711            }
2712            contacts_to_update.insert(canceled_user_id);
2713        }
2714    }
2715
2716    for contact_user_id in contacts_to_update {
2717        update_user_contacts(contact_user_id, &session).await?;
2718    }
2719
2720    if let Some(live_kit) = session.live_kit_client.as_ref() {
2721        live_kit
2722            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2723            .await
2724            .trace_err();
2725
2726        if delete_live_kit_room {
2727            live_kit.delete_room(live_kit_room).await.trace_err();
2728        }
2729    }
2730
2731    Ok(())
2732}
2733
2734fn project_left(project: &db::LeftProject, session: &Session) {
2735    for connection_id in &project.connection_ids {
2736        if project.host_user_id == session.user_id {
2737            session
2738                .peer
2739                .send(
2740                    *connection_id,
2741                    proto::UnshareProject {
2742                        project_id: project.id.to_proto(),
2743                    },
2744                )
2745                .trace_err();
2746        } else {
2747            session
2748                .peer
2749                .send(
2750                    *connection_id,
2751                    proto::RemoveProjectCollaborator {
2752                        project_id: project.id.to_proto(),
2753                        peer_id: Some(session.connection_id.into()),
2754                    },
2755                )
2756                .trace_err();
2757        }
2758    }
2759}
2760
2761pub trait ResultExt {
2762    type Ok;
2763
2764    fn trace_err(self) -> Option<Self::Ok>;
2765}
2766
2767impl<T, E> ResultExt for Result<T, E>
2768where
2769    E: std::fmt::Debug,
2770{
2771    type Ok = T;
2772
2773    fn trace_err(self) -> Option<T> {
2774        match self {
2775            Ok(value) => Some(value),
2776            Err(error) => {
2777                tracing::error!("{:?}", error);
2778                None
2779            }
2780        }
2781    }
2782}