rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  39        RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(get_channel_members)
 250            .add_request_handler(respond_to_channel_invite)
 251            .add_request_handler(join_channel)
 252            .add_request_handler(follow)
 253            .add_message_handler(unfollow)
 254            .add_message_handler(update_followers)
 255            .add_message_handler(update_diff_base)
 256            .add_request_handler(get_private_user_info);
 257
 258        Arc::new(server)
 259    }
 260
 261    pub async fn start(&self) -> Result<()> {
 262        let server_id = *self.id.lock();
 263        let app_state = self.app_state.clone();
 264        let peer = self.peer.clone();
 265        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 266        let pool = self.connection_pool.clone();
 267        let live_kit_client = self.app_state.live_kit_client.clone();
 268
 269        let span = info_span!("start server");
 270        self.executor.spawn_detached(
 271            async move {
 272                tracing::info!("waiting for cleanup timeout");
 273                timeout.await;
 274                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 275                if let Some(room_ids) = app_state
 276                    .db
 277                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 278                    .await
 279                    .trace_err()
 280                {
 281                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 282                    for room_id in room_ids {
 283                        let mut contacts_to_update = HashSet::default();
 284                        let mut canceled_calls_to_user_ids = Vec::new();
 285                        let mut live_kit_room = String::new();
 286                        let mut delete_live_kit_room = false;
 287
 288                        if let Some(mut refreshed_room) = app_state
 289                            .db
 290                            .refresh_room(room_id, server_id)
 291                            .await
 292                            .trace_err()
 293                        {
 294                            tracing::info!(
 295                                room_id = room_id.0,
 296                                new_participant_count = refreshed_room.room.participants.len(),
 297                                "refreshed room"
 298                            );
 299                            room_updated(&refreshed_room.room, &peer);
 300                            if let Some(channel_id) = refreshed_room.channel_id {
 301                                channel_updated(
 302                                    channel_id,
 303                                    &refreshed_room.room,
 304                                    &refreshed_room.channel_members,
 305                                    &peer,
 306                                    &*pool.lock(),
 307                                );
 308                            }
 309                            contacts_to_update
 310                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 311                            contacts_to_update
 312                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 313                            canceled_calls_to_user_ids =
 314                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 315                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 316                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 317                        }
 318
 319                        {
 320                            let pool = pool.lock();
 321                            for canceled_user_id in canceled_calls_to_user_ids {
 322                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 323                                    peer.send(
 324                                        connection_id,
 325                                        proto::CallCanceled {
 326                                            room_id: room_id.to_proto(),
 327                                        },
 328                                    )
 329                                    .trace_err();
 330                                }
 331                            }
 332                        }
 333
 334                        for user_id in contacts_to_update {
 335                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 336                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 337                            if let Some((busy, contacts)) = busy.zip(contacts) {
 338                                let pool = pool.lock();
 339                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 340                                for contact in contacts {
 341                                    if let db::Contact::Accepted {
 342                                        user_id: contact_user_id,
 343                                        ..
 344                                    } = contact
 345                                    {
 346                                        for contact_conn_id in
 347                                            pool.user_connection_ids(contact_user_id)
 348                                        {
 349                                            peer.send(
 350                                                contact_conn_id,
 351                                                proto::UpdateContacts {
 352                                                    contacts: vec![updated_contact.clone()],
 353                                                    remove_contacts: Default::default(),
 354                                                    incoming_requests: Default::default(),
 355                                                    remove_incoming_requests: Default::default(),
 356                                                    outgoing_requests: Default::default(),
 357                                                    remove_outgoing_requests: Default::default(),
 358                                                },
 359                                            )
 360                                            .trace_err();
 361                                        }
 362                                    }
 363                                }
 364                            }
 365                        }
 366
 367                        if let Some(live_kit) = live_kit_client.as_ref() {
 368                            if delete_live_kit_room {
 369                                live_kit.delete_room(live_kit_room).await.trace_err();
 370                            }
 371                        }
 372                    }
 373                }
 374
 375                app_state
 376                    .db
 377                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 378                    .await
 379                    .trace_err();
 380            }
 381            .instrument(span),
 382        );
 383        Ok(())
 384    }
 385
 386    pub fn teardown(&self) {
 387        self.peer.teardown();
 388        self.connection_pool.lock().reset();
 389        let _ = self.teardown.send(());
 390    }
 391
 392    #[cfg(test)]
 393    pub fn reset(&self, id: ServerId) {
 394        self.teardown();
 395        *self.id.lock() = id;
 396        self.peer.reset(id.0 as u32);
 397    }
 398
 399    #[cfg(test)]
 400    pub fn id(&self) -> ServerId {
 401        *self.id.lock()
 402    }
 403
 404    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 405    where
 406        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 407        Fut: 'static + Send + Future<Output = Result<()>>,
 408        M: EnvelopedMessage,
 409    {
 410        let prev_handler = self.handlers.insert(
 411            TypeId::of::<M>(),
 412            Box::new(move |envelope, session| {
 413                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 414                let span = info_span!(
 415                    "handle message",
 416                    payload_type = envelope.payload_type_name()
 417                );
 418                span.in_scope(|| {
 419                    tracing::info!(
 420                        payload_type = envelope.payload_type_name(),
 421                        "message received"
 422                    );
 423                });
 424                let start_time = Instant::now();
 425                let future = (handler)(*envelope, session);
 426                async move {
 427                    let result = future.await;
 428                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 429                    match result {
 430                        Err(error) => {
 431                            tracing::error!(%error, ?duration_ms, "error handling message")
 432                        }
 433                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 434                    }
 435                }
 436                .instrument(span)
 437                .boxed()
 438            }),
 439        );
 440        if prev_handler.is_some() {
 441            panic!("registered a handler for the same message twice");
 442        }
 443        self
 444    }
 445
 446    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 447    where
 448        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 449        Fut: 'static + Send + Future<Output = Result<()>>,
 450        M: EnvelopedMessage,
 451    {
 452        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 453        self
 454    }
 455
 456    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 457    where
 458        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 459        Fut: Send + Future<Output = Result<()>>,
 460        M: RequestMessage,
 461    {
 462        let handler = Arc::new(handler);
 463        self.add_handler(move |envelope, session| {
 464            let receipt = envelope.receipt();
 465            let handler = handler.clone();
 466            async move {
 467                let peer = session.peer.clone();
 468                let responded = Arc::new(AtomicBool::default());
 469                let response = Response {
 470                    peer: peer.clone(),
 471                    responded: responded.clone(),
 472                    receipt,
 473                };
 474                match (handler)(envelope.payload, response, session).await {
 475                    Ok(()) => {
 476                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 477                            Ok(())
 478                        } else {
 479                            Err(anyhow!("handler did not send a response"))?
 480                        }
 481                    }
 482                    Err(error) => {
 483                        peer.respond_with_error(
 484                            receipt,
 485                            proto::Error {
 486                                message: error.to_string(),
 487                            },
 488                        )?;
 489                        Err(error)
 490                    }
 491                }
 492            }
 493        })
 494    }
 495
 496    pub fn handle_connection(
 497        self: &Arc<Self>,
 498        connection: Connection,
 499        address: String,
 500        user: User,
 501        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 502        executor: Executor,
 503    ) -> impl Future<Output = Result<()>> {
 504        let this = self.clone();
 505        let user_id = user.id;
 506        let login = user.github_login;
 507        let span = info_span!("handle connection", %user_id, %login, %address);
 508        let mut teardown = self.teardown.subscribe();
 509        async move {
 510            let (connection_id, handle_io, mut incoming_rx) = this
 511                .peer
 512                .add_connection(connection, {
 513                    let executor = executor.clone();
 514                    move |duration| executor.sleep(duration)
 515                });
 516
 517            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 518            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 519            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 520
 521            if let Some(send_connection_id) = send_connection_id.take() {
 522                let _ = send_connection_id.send(connection_id);
 523            }
 524
 525            if !user.connected_once {
 526                this.peer.send(connection_id, proto::ShowContacts {})?;
 527                this.app_state.db.set_user_connected_once(user_id, true).await?;
 528            }
 529
 530            let (contacts, invite_code, (channels, channel_participants), channel_invites) = future::try_join4(
 531                this.app_state.db.get_contacts(user_id),
 532                this.app_state.db.get_invite_code_for_user(user_id),
 533                this.app_state.db.get_channels_for_user(user_id),
 534                this.app_state.db.get_channel_invites_for_user(user_id)
 535            ).await?;
 536
 537            {
 538                let mut pool = this.connection_pool.lock();
 539                pool.add_connection(connection_id, user_id, user.admin);
 540                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 541                this.peer.send(connection_id, build_initial_channels_update(channels, channel_participants, channel_invites))?;
 542
 543                if let Some((code, count)) = invite_code {
 544                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 545                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 546                        count: count as u32,
 547                    })?;
 548                }
 549            }
 550
 551            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 552                this.peer.send(connection_id, incoming_call)?;
 553            }
 554
 555            let session = Session {
 556                user_id,
 557                connection_id,
 558                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 559                peer: this.peer.clone(),
 560                connection_pool: this.connection_pool.clone(),
 561                live_kit_client: this.app_state.live_kit_client.clone(),
 562                executor: executor.clone(),
 563            };
 564            update_user_contacts(user_id, &session).await?;
 565
 566            let handle_io = handle_io.fuse();
 567            futures::pin_mut!(handle_io);
 568
 569            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 570            // This prevents deadlocks when e.g., client A performs a request to client B and
 571            // client B performs a request to client A. If both clients stop processing further
 572            // messages until their respective request completes, they won't have a chance to
 573            // respond to the other client's request and cause a deadlock.
 574            //
 575            // This arrangement ensures we will attempt to process earlier messages first, but fall
 576            // back to processing messages arrived later in the spirit of making progress.
 577            let mut foreground_message_handlers = FuturesUnordered::new();
 578            let concurrent_handlers = Arc::new(Semaphore::new(256));
 579            loop {
 580                let next_message = async {
 581                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 582                    let message = incoming_rx.next().await;
 583                    (permit, message)
 584                }.fuse();
 585                futures::pin_mut!(next_message);
 586                futures::select_biased! {
 587                    _ = teardown.changed().fuse() => return Ok(()),
 588                    result = handle_io => {
 589                        if let Err(error) = result {
 590                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 591                        }
 592                        break;
 593                    }
 594                    _ = foreground_message_handlers.next() => {}
 595                    next_message = next_message => {
 596                        let (permit, message) = next_message;
 597                        if let Some(message) = message {
 598                            let type_name = message.payload_type_name();
 599                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 600                            let span_enter = span.enter();
 601                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 602                                let is_background = message.is_background();
 603                                let handle_message = (handler)(message, session.clone());
 604                                drop(span_enter);
 605
 606                                let handle_message = async move {
 607                                    handle_message.await;
 608                                    drop(permit);
 609                                }.instrument(span);
 610                                if is_background {
 611                                    executor.spawn_detached(handle_message);
 612                                } else {
 613                                    foreground_message_handlers.push(handle_message);
 614                                }
 615                            } else {
 616                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 617                            }
 618                        } else {
 619                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 620                            break;
 621                        }
 622                    }
 623                }
 624            }
 625
 626            drop(foreground_message_handlers);
 627            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 628            if let Err(error) = connection_lost(session, teardown, executor).await {
 629                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 630            }
 631
 632            Ok(())
 633        }.instrument(span)
 634    }
 635
 636    pub async fn invite_code_redeemed(
 637        self: &Arc<Self>,
 638        inviter_id: UserId,
 639        invitee_id: UserId,
 640    ) -> Result<()> {
 641        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 642            if let Some(code) = &user.invite_code {
 643                let pool = self.connection_pool.lock();
 644                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 645                for connection_id in pool.user_connection_ids(inviter_id) {
 646                    self.peer.send(
 647                        connection_id,
 648                        proto::UpdateContacts {
 649                            contacts: vec![invitee_contact.clone()],
 650                            ..Default::default()
 651                        },
 652                    )?;
 653                    self.peer.send(
 654                        connection_id,
 655                        proto::UpdateInviteInfo {
 656                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 657                            count: user.invite_count as u32,
 658                        },
 659                    )?;
 660                }
 661            }
 662        }
 663        Ok(())
 664    }
 665
 666    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 667        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 668            if let Some(invite_code) = &user.invite_code {
 669                let pool = self.connection_pool.lock();
 670                for connection_id in pool.user_connection_ids(user_id) {
 671                    self.peer.send(
 672                        connection_id,
 673                        proto::UpdateInviteInfo {
 674                            url: format!(
 675                                "{}{}",
 676                                self.app_state.config.invite_link_prefix, invite_code
 677                            ),
 678                            count: user.invite_count as u32,
 679                        },
 680                    )?;
 681                }
 682            }
 683        }
 684        Ok(())
 685    }
 686
 687    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 688        ServerSnapshot {
 689            connection_pool: ConnectionPoolGuard {
 690                guard: self.connection_pool.lock(),
 691                _not_send: PhantomData,
 692            },
 693            peer: &self.peer,
 694        }
 695    }
 696}
 697
 698impl<'a> Deref for ConnectionPoolGuard<'a> {
 699    type Target = ConnectionPool;
 700
 701    fn deref(&self) -> &Self::Target {
 702        &*self.guard
 703    }
 704}
 705
 706impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 707    fn deref_mut(&mut self) -> &mut Self::Target {
 708        &mut *self.guard
 709    }
 710}
 711
 712impl<'a> Drop for ConnectionPoolGuard<'a> {
 713    fn drop(&mut self) {
 714        #[cfg(test)]
 715        self.check_invariants();
 716    }
 717}
 718
 719fn broadcast<F>(
 720    sender_id: Option<ConnectionId>,
 721    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 722    mut f: F,
 723) where
 724    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 725{
 726    for receiver_id in receiver_ids {
 727        if Some(receiver_id) != sender_id {
 728            if let Err(error) = f(receiver_id) {
 729                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 730            }
 731        }
 732    }
 733}
 734
 735lazy_static! {
 736    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 737}
 738
 739pub struct ProtocolVersion(u32);
 740
 741impl Header for ProtocolVersion {
 742    fn name() -> &'static HeaderName {
 743        &ZED_PROTOCOL_VERSION
 744    }
 745
 746    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 747    where
 748        Self: Sized,
 749        I: Iterator<Item = &'i axum::http::HeaderValue>,
 750    {
 751        let version = values
 752            .next()
 753            .ok_or_else(axum::headers::Error::invalid)?
 754            .to_str()
 755            .map_err(|_| axum::headers::Error::invalid())?
 756            .parse()
 757            .map_err(|_| axum::headers::Error::invalid())?;
 758        Ok(Self(version))
 759    }
 760
 761    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 762        values.extend([self.0.to_string().parse().unwrap()]);
 763    }
 764}
 765
 766pub fn routes(server: Arc<Server>) -> Router<Body> {
 767    Router::new()
 768        .route("/rpc", get(handle_websocket_request))
 769        .layer(
 770            ServiceBuilder::new()
 771                .layer(Extension(server.app_state.clone()))
 772                .layer(middleware::from_fn(auth::validate_header)),
 773        )
 774        .route("/metrics", get(handle_metrics))
 775        .layer(Extension(server))
 776}
 777
 778pub async fn handle_websocket_request(
 779    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 780    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 781    Extension(server): Extension<Arc<Server>>,
 782    Extension(user): Extension<User>,
 783    ws: WebSocketUpgrade,
 784) -> axum::response::Response {
 785    if protocol_version != rpc::PROTOCOL_VERSION {
 786        return (
 787            StatusCode::UPGRADE_REQUIRED,
 788            "client must be upgraded".to_string(),
 789        )
 790            .into_response();
 791    }
 792    let socket_address = socket_address.to_string();
 793    ws.on_upgrade(move |socket| {
 794        use util::ResultExt;
 795        let socket = socket
 796            .map_ok(to_tungstenite_message)
 797            .err_into()
 798            .with(|message| async move { Ok(to_axum_message(message)) });
 799        let connection = Connection::new(Box::pin(socket));
 800        async move {
 801            server
 802                .handle_connection(connection, socket_address, user, None, Executor::Production)
 803                .await
 804                .log_err();
 805        }
 806    })
 807}
 808
 809pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 810    let connections = server
 811        .connection_pool
 812        .lock()
 813        .connections()
 814        .filter(|connection| !connection.admin)
 815        .count();
 816
 817    METRIC_CONNECTIONS.set(connections as _);
 818
 819    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 820    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 821
 822    let encoder = prometheus::TextEncoder::new();
 823    let metric_families = prometheus::gather();
 824    let encoded_metrics = encoder
 825        .encode_to_string(&metric_families)
 826        .map_err(|err| anyhow!("{}", err))?;
 827    Ok(encoded_metrics)
 828}
 829
 830#[instrument(err, skip(executor))]
 831async fn connection_lost(
 832    session: Session,
 833    mut teardown: watch::Receiver<()>,
 834    executor: Executor,
 835) -> Result<()> {
 836    session.peer.disconnect(session.connection_id);
 837    session
 838        .connection_pool()
 839        .await
 840        .remove_connection(session.connection_id)?;
 841
 842    session
 843        .db()
 844        .await
 845        .connection_lost(session.connection_id)
 846        .await
 847        .trace_err();
 848
 849    futures::select_biased! {
 850        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 851            leave_room_for_session(&session).await.trace_err();
 852
 853            if !session
 854                .connection_pool()
 855                .await
 856                .is_user_online(session.user_id)
 857            {
 858                let db = session.db().await;
 859                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 860                    room_updated(&room, &session.peer);
 861                }
 862            }
 863            update_user_contacts(session.user_id, &session).await?;
 864        }
 865        _ = teardown.changed().fuse() => {}
 866    }
 867
 868    Ok(())
 869}
 870
 871async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 872    response.send(proto::Ack {})?;
 873    Ok(())
 874}
 875
 876async fn create_room(
 877    _request: proto::CreateRoom,
 878    response: Response<proto::CreateRoom>,
 879    session: Session,
 880) -> Result<()> {
 881    let live_kit_room = nanoid::nanoid!(30);
 882
 883    let live_kit_connection_info = {
 884        let live_kit_room = live_kit_room.clone();
 885        let live_kit = session.live_kit_client.as_ref();
 886
 887        util::async_iife!({
 888            let live_kit = live_kit?;
 889
 890            live_kit
 891                .create_room(live_kit_room.clone())
 892                .await
 893                .trace_err()?;
 894
 895            let token = live_kit
 896                .room_token(&live_kit_room, &session.user_id.to_string())
 897                .trace_err()?;
 898
 899            Some(proto::LiveKitConnectionInfo {
 900                server_url: live_kit.url().into(),
 901                token,
 902            })
 903        })
 904    }
 905    .await;
 906
 907    let room = session
 908        .db()
 909        .await
 910        .create_room(session.user_id, session.connection_id, &live_kit_room)
 911        .await?;
 912
 913    response.send(proto::CreateRoomResponse {
 914        room: Some(room.clone()),
 915        live_kit_connection_info,
 916    })?;
 917
 918    update_user_contacts(session.user_id, &session).await?;
 919    Ok(())
 920}
 921
 922async fn join_room(
 923    request: proto::JoinRoom,
 924    response: Response<proto::JoinRoom>,
 925    session: Session,
 926) -> Result<()> {
 927    let room_id = RoomId::from_proto(request.id);
 928    let room = {
 929        let room = session
 930            .db()
 931            .await
 932            .join_room(room_id, session.user_id, None, session.connection_id)
 933            .await?;
 934        room_updated(&room.room, &session.peer);
 935        room.room.clone()
 936    };
 937
 938    for connection_id in session
 939        .connection_pool()
 940        .await
 941        .user_connection_ids(session.user_id)
 942    {
 943        session
 944            .peer
 945            .send(
 946                connection_id,
 947                proto::CallCanceled {
 948                    room_id: room_id.to_proto(),
 949                },
 950            )
 951            .trace_err();
 952    }
 953
 954    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 955        if let Some(token) = live_kit
 956            .room_token(&room.live_kit_room, &session.user_id.to_string())
 957            .trace_err()
 958        {
 959            Some(proto::LiveKitConnectionInfo {
 960                server_url: live_kit.url().into(),
 961                token,
 962            })
 963        } else {
 964            None
 965        }
 966    } else {
 967        None
 968    };
 969
 970    response.send(proto::JoinRoomResponse {
 971        room: Some(room),
 972        live_kit_connection_info,
 973    })?;
 974
 975    update_user_contacts(session.user_id, &session).await?;
 976    Ok(())
 977}
 978
 979async fn rejoin_room(
 980    request: proto::RejoinRoom,
 981    response: Response<proto::RejoinRoom>,
 982    session: Session,
 983) -> Result<()> {
 984    let room;
 985    let channel_id;
 986    let channel_members;
 987    {
 988        let mut rejoined_room = session
 989            .db()
 990            .await
 991            .rejoin_room(request, session.user_id, session.connection_id)
 992            .await?;
 993
 994        response.send(proto::RejoinRoomResponse {
 995            room: Some(rejoined_room.room.clone()),
 996            reshared_projects: rejoined_room
 997                .reshared_projects
 998                .iter()
 999                .map(|project| proto::ResharedProject {
1000                    id: project.id.to_proto(),
1001                    collaborators: project
1002                        .collaborators
1003                        .iter()
1004                        .map(|collaborator| collaborator.to_proto())
1005                        .collect(),
1006                })
1007                .collect(),
1008            rejoined_projects: rejoined_room
1009                .rejoined_projects
1010                .iter()
1011                .map(|rejoined_project| proto::RejoinedProject {
1012                    id: rejoined_project.id.to_proto(),
1013                    worktrees: rejoined_project
1014                        .worktrees
1015                        .iter()
1016                        .map(|worktree| proto::WorktreeMetadata {
1017                            id: worktree.id,
1018                            root_name: worktree.root_name.clone(),
1019                            visible: worktree.visible,
1020                            abs_path: worktree.abs_path.clone(),
1021                        })
1022                        .collect(),
1023                    collaborators: rejoined_project
1024                        .collaborators
1025                        .iter()
1026                        .map(|collaborator| collaborator.to_proto())
1027                        .collect(),
1028                    language_servers: rejoined_project.language_servers.clone(),
1029                })
1030                .collect(),
1031        })?;
1032        room_updated(&rejoined_room.room, &session.peer);
1033
1034        for project in &rejoined_room.reshared_projects {
1035            for collaborator in &project.collaborators {
1036                session
1037                    .peer
1038                    .send(
1039                        collaborator.connection_id,
1040                        proto::UpdateProjectCollaborator {
1041                            project_id: project.id.to_proto(),
1042                            old_peer_id: Some(project.old_connection_id.into()),
1043                            new_peer_id: Some(session.connection_id.into()),
1044                        },
1045                    )
1046                    .trace_err();
1047            }
1048
1049            broadcast(
1050                Some(session.connection_id),
1051                project
1052                    .collaborators
1053                    .iter()
1054                    .map(|collaborator| collaborator.connection_id),
1055                |connection_id| {
1056                    session.peer.forward_send(
1057                        session.connection_id,
1058                        connection_id,
1059                        proto::UpdateProject {
1060                            project_id: project.id.to_proto(),
1061                            worktrees: project.worktrees.clone(),
1062                        },
1063                    )
1064                },
1065            );
1066        }
1067
1068        for project in &rejoined_room.rejoined_projects {
1069            for collaborator in &project.collaborators {
1070                session
1071                    .peer
1072                    .send(
1073                        collaborator.connection_id,
1074                        proto::UpdateProjectCollaborator {
1075                            project_id: project.id.to_proto(),
1076                            old_peer_id: Some(project.old_connection_id.into()),
1077                            new_peer_id: Some(session.connection_id.into()),
1078                        },
1079                    )
1080                    .trace_err();
1081            }
1082        }
1083
1084        for project in &mut rejoined_room.rejoined_projects {
1085            for worktree in mem::take(&mut project.worktrees) {
1086                #[cfg(any(test, feature = "test-support"))]
1087                const MAX_CHUNK_SIZE: usize = 2;
1088                #[cfg(not(any(test, feature = "test-support")))]
1089                const MAX_CHUNK_SIZE: usize = 256;
1090
1091                // Stream this worktree's entries.
1092                let message = proto::UpdateWorktree {
1093                    project_id: project.id.to_proto(),
1094                    worktree_id: worktree.id,
1095                    abs_path: worktree.abs_path.clone(),
1096                    root_name: worktree.root_name,
1097                    updated_entries: worktree.updated_entries,
1098                    removed_entries: worktree.removed_entries,
1099                    scan_id: worktree.scan_id,
1100                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1101                    updated_repositories: worktree.updated_repositories,
1102                    removed_repositories: worktree.removed_repositories,
1103                };
1104                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1105                    session.peer.send(session.connection_id, update.clone())?;
1106                }
1107
1108                // Stream this worktree's diagnostics.
1109                for summary in worktree.diagnostic_summaries {
1110                    session.peer.send(
1111                        session.connection_id,
1112                        proto::UpdateDiagnosticSummary {
1113                            project_id: project.id.to_proto(),
1114                            worktree_id: worktree.id,
1115                            summary: Some(summary),
1116                        },
1117                    )?;
1118                }
1119
1120                for settings_file in worktree.settings_files {
1121                    session.peer.send(
1122                        session.connection_id,
1123                        proto::UpdateWorktreeSettings {
1124                            project_id: project.id.to_proto(),
1125                            worktree_id: worktree.id,
1126                            path: settings_file.path,
1127                            content: Some(settings_file.content),
1128                        },
1129                    )?;
1130                }
1131            }
1132
1133            for language_server in &project.language_servers {
1134                session.peer.send(
1135                    session.connection_id,
1136                    proto::UpdateLanguageServer {
1137                        project_id: project.id.to_proto(),
1138                        language_server_id: language_server.id,
1139                        variant: Some(
1140                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1141                                proto::LspDiskBasedDiagnosticsUpdated {},
1142                            ),
1143                        ),
1144                    },
1145                )?;
1146            }
1147        }
1148
1149        room = mem::take(&mut rejoined_room.room);
1150        channel_id = rejoined_room.channel_id;
1151        channel_members = mem::take(&mut rejoined_room.channel_members);
1152    }
1153
1154    //TODO: move this into the room guard
1155    if let Some(channel_id) = channel_id {
1156        channel_updated(
1157            channel_id,
1158            &room,
1159            &channel_members,
1160            &session.peer,
1161            &*session.connection_pool().await,
1162        );
1163    }
1164
1165    update_user_contacts(session.user_id, &session).await?;
1166    Ok(())
1167}
1168
1169async fn leave_room(
1170    _: proto::LeaveRoom,
1171    response: Response<proto::LeaveRoom>,
1172    session: Session,
1173) -> Result<()> {
1174    leave_room_for_session(&session).await?;
1175    response.send(proto::Ack {})?;
1176    Ok(())
1177}
1178
1179async fn call(
1180    request: proto::Call,
1181    response: Response<proto::Call>,
1182    session: Session,
1183) -> Result<()> {
1184    let room_id = RoomId::from_proto(request.room_id);
1185    let calling_user_id = session.user_id;
1186    let calling_connection_id = session.connection_id;
1187    let called_user_id = UserId::from_proto(request.called_user_id);
1188    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1189    if !session
1190        .db()
1191        .await
1192        .has_contact(calling_user_id, called_user_id)
1193        .await?
1194    {
1195        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1196    }
1197
1198    let incoming_call = {
1199        let (room, incoming_call) = &mut *session
1200            .db()
1201            .await
1202            .call(
1203                room_id,
1204                calling_user_id,
1205                calling_connection_id,
1206                called_user_id,
1207                initial_project_id,
1208            )
1209            .await?;
1210        room_updated(&room, &session.peer);
1211        mem::take(incoming_call)
1212    };
1213    update_user_contacts(called_user_id, &session).await?;
1214
1215    let mut calls = session
1216        .connection_pool()
1217        .await
1218        .user_connection_ids(called_user_id)
1219        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1220        .collect::<FuturesUnordered<_>>();
1221
1222    while let Some(call_response) = calls.next().await {
1223        match call_response.as_ref() {
1224            Ok(_) => {
1225                response.send(proto::Ack {})?;
1226                return Ok(());
1227            }
1228            Err(_) => {
1229                call_response.trace_err();
1230            }
1231        }
1232    }
1233
1234    {
1235        let room = session
1236            .db()
1237            .await
1238            .call_failed(room_id, called_user_id)
1239            .await?;
1240        room_updated(&room, &session.peer);
1241    }
1242    update_user_contacts(called_user_id, &session).await?;
1243
1244    Err(anyhow!("failed to ring user"))?
1245}
1246
1247async fn cancel_call(
1248    request: proto::CancelCall,
1249    response: Response<proto::CancelCall>,
1250    session: Session,
1251) -> Result<()> {
1252    let called_user_id = UserId::from_proto(request.called_user_id);
1253    let room_id = RoomId::from_proto(request.room_id);
1254    {
1255        let room = session
1256            .db()
1257            .await
1258            .cancel_call(room_id, session.connection_id, called_user_id)
1259            .await?;
1260        room_updated(&room, &session.peer);
1261    }
1262
1263    for connection_id in session
1264        .connection_pool()
1265        .await
1266        .user_connection_ids(called_user_id)
1267    {
1268        session
1269            .peer
1270            .send(
1271                connection_id,
1272                proto::CallCanceled {
1273                    room_id: room_id.to_proto(),
1274                },
1275            )
1276            .trace_err();
1277    }
1278    response.send(proto::Ack {})?;
1279
1280    update_user_contacts(called_user_id, &session).await?;
1281    Ok(())
1282}
1283
1284async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1285    let room_id = RoomId::from_proto(message.room_id);
1286    {
1287        let room = session
1288            .db()
1289            .await
1290            .decline_call(Some(room_id), session.user_id)
1291            .await?
1292            .ok_or_else(|| anyhow!("failed to decline call"))?;
1293        room_updated(&room, &session.peer);
1294    }
1295
1296    for connection_id in session
1297        .connection_pool()
1298        .await
1299        .user_connection_ids(session.user_id)
1300    {
1301        session
1302            .peer
1303            .send(
1304                connection_id,
1305                proto::CallCanceled {
1306                    room_id: room_id.to_proto(),
1307                },
1308            )
1309            .trace_err();
1310    }
1311    update_user_contacts(session.user_id, &session).await?;
1312    Ok(())
1313}
1314
1315async fn update_participant_location(
1316    request: proto::UpdateParticipantLocation,
1317    response: Response<proto::UpdateParticipantLocation>,
1318    session: Session,
1319) -> Result<()> {
1320    let room_id = RoomId::from_proto(request.room_id);
1321    let location = request
1322        .location
1323        .ok_or_else(|| anyhow!("invalid location"))?;
1324
1325    let db = session.db().await;
1326    let room = db
1327        .update_room_participant_location(room_id, session.connection_id, location)
1328        .await?;
1329
1330    room_updated(&room, &session.peer);
1331    response.send(proto::Ack {})?;
1332    Ok(())
1333}
1334
1335async fn share_project(
1336    request: proto::ShareProject,
1337    response: Response<proto::ShareProject>,
1338    session: Session,
1339) -> Result<()> {
1340    let (project_id, room) = &*session
1341        .db()
1342        .await
1343        .share_project(
1344            RoomId::from_proto(request.room_id),
1345            session.connection_id,
1346            &request.worktrees,
1347        )
1348        .await?;
1349    response.send(proto::ShareProjectResponse {
1350        project_id: project_id.to_proto(),
1351    })?;
1352    room_updated(&room, &session.peer);
1353
1354    Ok(())
1355}
1356
1357async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1358    let project_id = ProjectId::from_proto(message.project_id);
1359
1360    let (room, guest_connection_ids) = &*session
1361        .db()
1362        .await
1363        .unshare_project(project_id, session.connection_id)
1364        .await?;
1365
1366    broadcast(
1367        Some(session.connection_id),
1368        guest_connection_ids.iter().copied(),
1369        |conn_id| session.peer.send(conn_id, message.clone()),
1370    );
1371    room_updated(&room, &session.peer);
1372
1373    Ok(())
1374}
1375
1376async fn join_project(
1377    request: proto::JoinProject,
1378    response: Response<proto::JoinProject>,
1379    session: Session,
1380) -> Result<()> {
1381    let project_id = ProjectId::from_proto(request.project_id);
1382    let guest_user_id = session.user_id;
1383
1384    tracing::info!(%project_id, "join project");
1385
1386    let (project, replica_id) = &mut *session
1387        .db()
1388        .await
1389        .join_project(project_id, session.connection_id)
1390        .await?;
1391
1392    let collaborators = project
1393        .collaborators
1394        .iter()
1395        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1396        .map(|collaborator| collaborator.to_proto())
1397        .collect::<Vec<_>>();
1398
1399    let worktrees = project
1400        .worktrees
1401        .iter()
1402        .map(|(id, worktree)| proto::WorktreeMetadata {
1403            id: *id,
1404            root_name: worktree.root_name.clone(),
1405            visible: worktree.visible,
1406            abs_path: worktree.abs_path.clone(),
1407        })
1408        .collect::<Vec<_>>();
1409
1410    for collaborator in &collaborators {
1411        session
1412            .peer
1413            .send(
1414                collaborator.peer_id.unwrap().into(),
1415                proto::AddProjectCollaborator {
1416                    project_id: project_id.to_proto(),
1417                    collaborator: Some(proto::Collaborator {
1418                        peer_id: Some(session.connection_id.into()),
1419                        replica_id: replica_id.0 as u32,
1420                        user_id: guest_user_id.to_proto(),
1421                    }),
1422                },
1423            )
1424            .trace_err();
1425    }
1426
1427    // First, we send the metadata associated with each worktree.
1428    response.send(proto::JoinProjectResponse {
1429        worktrees: worktrees.clone(),
1430        replica_id: replica_id.0 as u32,
1431        collaborators: collaborators.clone(),
1432        language_servers: project.language_servers.clone(),
1433    })?;
1434
1435    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1436        #[cfg(any(test, feature = "test-support"))]
1437        const MAX_CHUNK_SIZE: usize = 2;
1438        #[cfg(not(any(test, feature = "test-support")))]
1439        const MAX_CHUNK_SIZE: usize = 256;
1440
1441        // Stream this worktree's entries.
1442        let message = proto::UpdateWorktree {
1443            project_id: project_id.to_proto(),
1444            worktree_id,
1445            abs_path: worktree.abs_path.clone(),
1446            root_name: worktree.root_name,
1447            updated_entries: worktree.entries,
1448            removed_entries: Default::default(),
1449            scan_id: worktree.scan_id,
1450            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1451            updated_repositories: worktree.repository_entries.into_values().collect(),
1452            removed_repositories: Default::default(),
1453        };
1454        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1455            session.peer.send(session.connection_id, update.clone())?;
1456        }
1457
1458        // Stream this worktree's diagnostics.
1459        for summary in worktree.diagnostic_summaries {
1460            session.peer.send(
1461                session.connection_id,
1462                proto::UpdateDiagnosticSummary {
1463                    project_id: project_id.to_proto(),
1464                    worktree_id: worktree.id,
1465                    summary: Some(summary),
1466                },
1467            )?;
1468        }
1469
1470        for settings_file in worktree.settings_files {
1471            session.peer.send(
1472                session.connection_id,
1473                proto::UpdateWorktreeSettings {
1474                    project_id: project_id.to_proto(),
1475                    worktree_id: worktree.id,
1476                    path: settings_file.path,
1477                    content: Some(settings_file.content),
1478                },
1479            )?;
1480        }
1481    }
1482
1483    for language_server in &project.language_servers {
1484        session.peer.send(
1485            session.connection_id,
1486            proto::UpdateLanguageServer {
1487                project_id: project_id.to_proto(),
1488                language_server_id: language_server.id,
1489                variant: Some(
1490                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1491                        proto::LspDiskBasedDiagnosticsUpdated {},
1492                    ),
1493                ),
1494            },
1495        )?;
1496    }
1497
1498    Ok(())
1499}
1500
1501async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1502    let sender_id = session.connection_id;
1503    let project_id = ProjectId::from_proto(request.project_id);
1504
1505    let (room, project) = &*session
1506        .db()
1507        .await
1508        .leave_project(project_id, sender_id)
1509        .await?;
1510    tracing::info!(
1511        %project_id,
1512        host_user_id = %project.host_user_id,
1513        host_connection_id = %project.host_connection_id,
1514        "leave project"
1515    );
1516
1517    project_left(&project, &session);
1518    room_updated(&room, &session.peer);
1519
1520    Ok(())
1521}
1522
1523async fn update_project(
1524    request: proto::UpdateProject,
1525    response: Response<proto::UpdateProject>,
1526    session: Session,
1527) -> Result<()> {
1528    let project_id = ProjectId::from_proto(request.project_id);
1529    let (room, guest_connection_ids) = &*session
1530        .db()
1531        .await
1532        .update_project(project_id, session.connection_id, &request.worktrees)
1533        .await?;
1534    broadcast(
1535        Some(session.connection_id),
1536        guest_connection_ids.iter().copied(),
1537        |connection_id| {
1538            session
1539                .peer
1540                .forward_send(session.connection_id, connection_id, request.clone())
1541        },
1542    );
1543    room_updated(&room, &session.peer);
1544    response.send(proto::Ack {})?;
1545
1546    Ok(())
1547}
1548
1549async fn update_worktree(
1550    request: proto::UpdateWorktree,
1551    response: Response<proto::UpdateWorktree>,
1552    session: Session,
1553) -> Result<()> {
1554    let guest_connection_ids = session
1555        .db()
1556        .await
1557        .update_worktree(&request, session.connection_id)
1558        .await?;
1559
1560    broadcast(
1561        Some(session.connection_id),
1562        guest_connection_ids.iter().copied(),
1563        |connection_id| {
1564            session
1565                .peer
1566                .forward_send(session.connection_id, connection_id, request.clone())
1567        },
1568    );
1569    response.send(proto::Ack {})?;
1570    Ok(())
1571}
1572
1573async fn update_diagnostic_summary(
1574    message: proto::UpdateDiagnosticSummary,
1575    session: Session,
1576) -> Result<()> {
1577    let guest_connection_ids = session
1578        .db()
1579        .await
1580        .update_diagnostic_summary(&message, session.connection_id)
1581        .await?;
1582
1583    broadcast(
1584        Some(session.connection_id),
1585        guest_connection_ids.iter().copied(),
1586        |connection_id| {
1587            session
1588                .peer
1589                .forward_send(session.connection_id, connection_id, message.clone())
1590        },
1591    );
1592
1593    Ok(())
1594}
1595
1596async fn update_worktree_settings(
1597    message: proto::UpdateWorktreeSettings,
1598    session: Session,
1599) -> Result<()> {
1600    let guest_connection_ids = session
1601        .db()
1602        .await
1603        .update_worktree_settings(&message, session.connection_id)
1604        .await?;
1605
1606    broadcast(
1607        Some(session.connection_id),
1608        guest_connection_ids.iter().copied(),
1609        |connection_id| {
1610            session
1611                .peer
1612                .forward_send(session.connection_id, connection_id, message.clone())
1613        },
1614    );
1615
1616    Ok(())
1617}
1618
1619async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1620    broadcast_project_message(request.project_id, request, session).await
1621}
1622
1623async fn start_language_server(
1624    request: proto::StartLanguageServer,
1625    session: Session,
1626) -> Result<()> {
1627    let guest_connection_ids = session
1628        .db()
1629        .await
1630        .start_language_server(&request, session.connection_id)
1631        .await?;
1632
1633    broadcast(
1634        Some(session.connection_id),
1635        guest_connection_ids.iter().copied(),
1636        |connection_id| {
1637            session
1638                .peer
1639                .forward_send(session.connection_id, connection_id, request.clone())
1640        },
1641    );
1642    Ok(())
1643}
1644
1645async fn update_language_server(
1646    request: proto::UpdateLanguageServer,
1647    session: Session,
1648) -> Result<()> {
1649    session.executor.record_backtrace();
1650    let project_id = ProjectId::from_proto(request.project_id);
1651    let project_connection_ids = session
1652        .db()
1653        .await
1654        .project_connection_ids(project_id, session.connection_id)
1655        .await?;
1656    broadcast(
1657        Some(session.connection_id),
1658        project_connection_ids.iter().copied(),
1659        |connection_id| {
1660            session
1661                .peer
1662                .forward_send(session.connection_id, connection_id, request.clone())
1663        },
1664    );
1665    Ok(())
1666}
1667
1668async fn forward_project_request<T>(
1669    request: T,
1670    response: Response<T>,
1671    session: Session,
1672) -> Result<()>
1673where
1674    T: EntityMessage + RequestMessage,
1675{
1676    session.executor.record_backtrace();
1677    let project_id = ProjectId::from_proto(request.remote_entity_id());
1678    let host_connection_id = {
1679        let collaborators = session
1680            .db()
1681            .await
1682            .project_collaborators(project_id, session.connection_id)
1683            .await?;
1684        collaborators
1685            .iter()
1686            .find(|collaborator| collaborator.is_host)
1687            .ok_or_else(|| anyhow!("host not found"))?
1688            .connection_id
1689    };
1690
1691    let payload = session
1692        .peer
1693        .forward_request(session.connection_id, host_connection_id, request)
1694        .await?;
1695
1696    response.send(payload)?;
1697    Ok(())
1698}
1699
1700async fn create_buffer_for_peer(
1701    request: proto::CreateBufferForPeer,
1702    session: Session,
1703) -> Result<()> {
1704    session.executor.record_backtrace();
1705    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1706    session
1707        .peer
1708        .forward_send(session.connection_id, peer_id.into(), request)?;
1709    Ok(())
1710}
1711
1712async fn update_buffer(
1713    request: proto::UpdateBuffer,
1714    response: Response<proto::UpdateBuffer>,
1715    session: Session,
1716) -> Result<()> {
1717    session.executor.record_backtrace();
1718    let project_id = ProjectId::from_proto(request.project_id);
1719    let mut guest_connection_ids;
1720    let mut host_connection_id = None;
1721    {
1722        let collaborators = session
1723            .db()
1724            .await
1725            .project_collaborators(project_id, session.connection_id)
1726            .await?;
1727        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1728        for collaborator in collaborators.iter() {
1729            if collaborator.is_host {
1730                host_connection_id = Some(collaborator.connection_id);
1731            } else {
1732                guest_connection_ids.push(collaborator.connection_id);
1733            }
1734        }
1735    }
1736    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1737
1738    session.executor.record_backtrace();
1739    broadcast(
1740        Some(session.connection_id),
1741        guest_connection_ids,
1742        |connection_id| {
1743            session
1744                .peer
1745                .forward_send(session.connection_id, connection_id, request.clone())
1746        },
1747    );
1748    if host_connection_id != session.connection_id {
1749        session
1750            .peer
1751            .forward_request(session.connection_id, host_connection_id, request.clone())
1752            .await?;
1753    }
1754
1755    response.send(proto::Ack {})?;
1756    Ok(())
1757}
1758
1759async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1760    let project_id = ProjectId::from_proto(request.project_id);
1761    let project_connection_ids = session
1762        .db()
1763        .await
1764        .project_connection_ids(project_id, session.connection_id)
1765        .await?;
1766
1767    broadcast(
1768        Some(session.connection_id),
1769        project_connection_ids.iter().copied(),
1770        |connection_id| {
1771            session
1772                .peer
1773                .forward_send(session.connection_id, connection_id, request.clone())
1774        },
1775    );
1776    Ok(())
1777}
1778
1779async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1780    let project_id = ProjectId::from_proto(request.project_id);
1781    let project_connection_ids = session
1782        .db()
1783        .await
1784        .project_connection_ids(project_id, session.connection_id)
1785        .await?;
1786    broadcast(
1787        Some(session.connection_id),
1788        project_connection_ids.iter().copied(),
1789        |connection_id| {
1790            session
1791                .peer
1792                .forward_send(session.connection_id, connection_id, request.clone())
1793        },
1794    );
1795    Ok(())
1796}
1797
1798async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1799    broadcast_project_message(request.project_id, request, session).await
1800}
1801
1802async fn broadcast_project_message<T: EnvelopedMessage>(
1803    project_id: u64,
1804    request: T,
1805    session: Session,
1806) -> Result<()> {
1807    let project_id = ProjectId::from_proto(project_id);
1808    let project_connection_ids = session
1809        .db()
1810        .await
1811        .project_connection_ids(project_id, session.connection_id)
1812        .await?;
1813    broadcast(
1814        Some(session.connection_id),
1815        project_connection_ids.iter().copied(),
1816        |connection_id| {
1817            session
1818                .peer
1819                .forward_send(session.connection_id, connection_id, request.clone())
1820        },
1821    );
1822    Ok(())
1823}
1824
1825async fn follow(
1826    request: proto::Follow,
1827    response: Response<proto::Follow>,
1828    session: Session,
1829) -> Result<()> {
1830    let project_id = ProjectId::from_proto(request.project_id);
1831    let leader_id = request
1832        .leader_id
1833        .ok_or_else(|| anyhow!("invalid leader id"))?
1834        .into();
1835    let follower_id = session.connection_id;
1836
1837    {
1838        let project_connection_ids = session
1839            .db()
1840            .await
1841            .project_connection_ids(project_id, session.connection_id)
1842            .await?;
1843
1844        if !project_connection_ids.contains(&leader_id) {
1845            Err(anyhow!("no such peer"))?;
1846        }
1847    }
1848
1849    let mut response_payload = session
1850        .peer
1851        .forward_request(session.connection_id, leader_id, request)
1852        .await?;
1853    response_payload
1854        .views
1855        .retain(|view| view.leader_id != Some(follower_id.into()));
1856    response.send(response_payload)?;
1857
1858    let room = session
1859        .db()
1860        .await
1861        .follow(project_id, leader_id, follower_id)
1862        .await?;
1863    room_updated(&room, &session.peer);
1864
1865    Ok(())
1866}
1867
1868async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1869    let project_id = ProjectId::from_proto(request.project_id);
1870    let leader_id = request
1871        .leader_id
1872        .ok_or_else(|| anyhow!("invalid leader id"))?
1873        .into();
1874    let follower_id = session.connection_id;
1875
1876    if !session
1877        .db()
1878        .await
1879        .project_connection_ids(project_id, session.connection_id)
1880        .await?
1881        .contains(&leader_id)
1882    {
1883        Err(anyhow!("no such peer"))?;
1884    }
1885
1886    session
1887        .peer
1888        .forward_send(session.connection_id, leader_id, request)?;
1889
1890    let room = session
1891        .db()
1892        .await
1893        .unfollow(project_id, leader_id, follower_id)
1894        .await?;
1895    room_updated(&room, &session.peer);
1896
1897    Ok(())
1898}
1899
1900async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1901    let project_id = ProjectId::from_proto(request.project_id);
1902    let project_connection_ids = session
1903        .db
1904        .lock()
1905        .await
1906        .project_connection_ids(project_id, session.connection_id)
1907        .await?;
1908
1909    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1910        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1911        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1912        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1913    });
1914    for follower_peer_id in request.follower_ids.iter().copied() {
1915        let follower_connection_id = follower_peer_id.into();
1916        if project_connection_ids.contains(&follower_connection_id)
1917            && Some(follower_peer_id) != leader_id
1918        {
1919            session.peer.forward_send(
1920                session.connection_id,
1921                follower_connection_id,
1922                request.clone(),
1923            )?;
1924        }
1925    }
1926    Ok(())
1927}
1928
1929async fn get_users(
1930    request: proto::GetUsers,
1931    response: Response<proto::GetUsers>,
1932    session: Session,
1933) -> Result<()> {
1934    let user_ids = request
1935        .user_ids
1936        .into_iter()
1937        .map(UserId::from_proto)
1938        .collect();
1939    let users = session
1940        .db()
1941        .await
1942        .get_users_by_ids(user_ids)
1943        .await?
1944        .into_iter()
1945        .map(|user| proto::User {
1946            id: user.id.to_proto(),
1947            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1948            github_login: user.github_login,
1949        })
1950        .collect();
1951    response.send(proto::UsersResponse { users })?;
1952    Ok(())
1953}
1954
1955async fn fuzzy_search_users(
1956    request: proto::FuzzySearchUsers,
1957    response: Response<proto::FuzzySearchUsers>,
1958    session: Session,
1959) -> Result<()> {
1960    let query = request.query;
1961    let users = match query.len() {
1962        0 => vec![],
1963        1 | 2 => session
1964            .db()
1965            .await
1966            .get_user_by_github_login(&query)
1967            .await?
1968            .into_iter()
1969            .collect(),
1970        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1971    };
1972    let users = users
1973        .into_iter()
1974        .filter(|user| user.id != session.user_id)
1975        .map(|user| proto::User {
1976            id: user.id.to_proto(),
1977            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1978            github_login: user.github_login,
1979        })
1980        .collect();
1981    response.send(proto::UsersResponse { users })?;
1982    Ok(())
1983}
1984
1985async fn request_contact(
1986    request: proto::RequestContact,
1987    response: Response<proto::RequestContact>,
1988    session: Session,
1989) -> Result<()> {
1990    let requester_id = session.user_id;
1991    let responder_id = UserId::from_proto(request.responder_id);
1992    if requester_id == responder_id {
1993        return Err(anyhow!("cannot add yourself as a contact"))?;
1994    }
1995
1996    session
1997        .db()
1998        .await
1999        .send_contact_request(requester_id, responder_id)
2000        .await?;
2001
2002    // Update outgoing contact requests of requester
2003    let mut update = proto::UpdateContacts::default();
2004    update.outgoing_requests.push(responder_id.to_proto());
2005    for connection_id in session
2006        .connection_pool()
2007        .await
2008        .user_connection_ids(requester_id)
2009    {
2010        session.peer.send(connection_id, update.clone())?;
2011    }
2012
2013    // Update incoming contact requests of responder
2014    let mut update = proto::UpdateContacts::default();
2015    update
2016        .incoming_requests
2017        .push(proto::IncomingContactRequest {
2018            requester_id: requester_id.to_proto(),
2019            should_notify: true,
2020        });
2021    for connection_id in session
2022        .connection_pool()
2023        .await
2024        .user_connection_ids(responder_id)
2025    {
2026        session.peer.send(connection_id, update.clone())?;
2027    }
2028
2029    response.send(proto::Ack {})?;
2030    Ok(())
2031}
2032
2033async fn respond_to_contact_request(
2034    request: proto::RespondToContactRequest,
2035    response: Response<proto::RespondToContactRequest>,
2036    session: Session,
2037) -> Result<()> {
2038    let responder_id = session.user_id;
2039    let requester_id = UserId::from_proto(request.requester_id);
2040    let db = session.db().await;
2041    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2042        db.dismiss_contact_notification(responder_id, requester_id)
2043            .await?;
2044    } else {
2045        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2046
2047        db.respond_to_contact_request(responder_id, requester_id, accept)
2048            .await?;
2049        let requester_busy = db.is_user_busy(requester_id).await?;
2050        let responder_busy = db.is_user_busy(responder_id).await?;
2051
2052        let pool = session.connection_pool().await;
2053        // Update responder with new contact
2054        let mut update = proto::UpdateContacts::default();
2055        if accept {
2056            update
2057                .contacts
2058                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2059        }
2060        update
2061            .remove_incoming_requests
2062            .push(requester_id.to_proto());
2063        for connection_id in pool.user_connection_ids(responder_id) {
2064            session.peer.send(connection_id, update.clone())?;
2065        }
2066
2067        // Update requester with new contact
2068        let mut update = proto::UpdateContacts::default();
2069        if accept {
2070            update
2071                .contacts
2072                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2073        }
2074        update
2075            .remove_outgoing_requests
2076            .push(responder_id.to_proto());
2077        for connection_id in pool.user_connection_ids(requester_id) {
2078            session.peer.send(connection_id, update.clone())?;
2079        }
2080    }
2081
2082    response.send(proto::Ack {})?;
2083    Ok(())
2084}
2085
2086async fn remove_contact(
2087    request: proto::RemoveContact,
2088    response: Response<proto::RemoveContact>,
2089    session: Session,
2090) -> Result<()> {
2091    let requester_id = session.user_id;
2092    let responder_id = UserId::from_proto(request.user_id);
2093    let db = session.db().await;
2094    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2095
2096    let pool = session.connection_pool().await;
2097    // Update outgoing contact requests of requester
2098    let mut update = proto::UpdateContacts::default();
2099    if contact_accepted {
2100        update.remove_contacts.push(responder_id.to_proto());
2101    } else {
2102        update
2103            .remove_outgoing_requests
2104            .push(responder_id.to_proto());
2105    }
2106    for connection_id in pool.user_connection_ids(requester_id) {
2107        session.peer.send(connection_id, update.clone())?;
2108    }
2109
2110    // Update incoming contact requests of responder
2111    let mut update = proto::UpdateContacts::default();
2112    if contact_accepted {
2113        update.remove_contacts.push(requester_id.to_proto());
2114    } else {
2115        update
2116            .remove_incoming_requests
2117            .push(requester_id.to_proto());
2118    }
2119    for connection_id in pool.user_connection_ids(responder_id) {
2120        session.peer.send(connection_id, update.clone())?;
2121    }
2122
2123    response.send(proto::Ack {})?;
2124    Ok(())
2125}
2126
2127async fn create_channel(
2128    request: proto::CreateChannel,
2129    response: Response<proto::CreateChannel>,
2130    session: Session,
2131) -> Result<()> {
2132    let db = session.db().await;
2133    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2134
2135    if let Some(live_kit) = session.live_kit_client.as_ref() {
2136        live_kit.create_room(live_kit_room.clone()).await?;
2137    }
2138
2139    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2140    let id = db
2141        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2142        .await?;
2143
2144    response.send(proto::CreateChannelResponse {
2145        channel_id: id.to_proto(),
2146    })?;
2147
2148    let mut update = proto::UpdateChannels::default();
2149    update.channels.push(proto::Channel {
2150        id: id.to_proto(),
2151        name: request.name,
2152        parent_id: request.parent_id,
2153        user_is_admin: true,
2154    });
2155
2156    if let Some(parent_id) = parent_id {
2157        let member_ids = db.get_channel_members(parent_id).await?;
2158        let connection_pool = session.connection_pool().await;
2159        for member_id in member_ids {
2160            for connection_id in connection_pool.user_connection_ids(member_id) {
2161                session.peer.send(connection_id, update.clone())?;
2162            }
2163        }
2164    } else {
2165        session.peer.send(session.connection_id, update)?;
2166    }
2167
2168    Ok(())
2169}
2170
2171async fn remove_channel(
2172    request: proto::RemoveChannel,
2173    response: Response<proto::RemoveChannel>,
2174    session: Session,
2175) -> Result<()> {
2176    let db = session.db().await;
2177
2178    let channel_id = request.channel_id;
2179    let (removed_channels, member_ids) = db
2180        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2181        .await?;
2182    response.send(proto::Ack {})?;
2183
2184    // Notify members of removed channels
2185    let mut update = proto::UpdateChannels::default();
2186    update
2187        .remove_channels
2188        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2189
2190    let connection_pool = session.connection_pool().await;
2191    for member_id in member_ids {
2192        for connection_id in connection_pool.user_connection_ids(member_id) {
2193            session.peer.send(connection_id, update.clone())?;
2194        }
2195    }
2196
2197    Ok(())
2198}
2199
2200async fn invite_channel_member(
2201    request: proto::InviteChannelMember,
2202    response: Response<proto::InviteChannelMember>,
2203    session: Session,
2204) -> Result<()> {
2205    let db = session.db().await;
2206    let channel_id = ChannelId::from_proto(request.channel_id);
2207    let channel = db
2208        .get_channel(channel_id, session.user_id)
2209        .await?
2210        .ok_or_else(|| anyhow!("channel not found"))?;
2211    let invitee_id = UserId::from_proto(request.user_id);
2212    db.invite_channel_member(channel_id, invitee_id, session.user_id, false)
2213        .await?;
2214
2215    let mut update = proto::UpdateChannels::default();
2216    update.channel_invitations.push(proto::Channel {
2217        id: channel.id.to_proto(),
2218        name: channel.name,
2219        parent_id: None,
2220        user_is_admin: false,
2221    });
2222    for connection_id in session
2223        .connection_pool()
2224        .await
2225        .user_connection_ids(invitee_id)
2226    {
2227        session.peer.send(connection_id, update.clone())?;
2228    }
2229
2230    response.send(proto::Ack {})?;
2231    Ok(())
2232}
2233
2234async fn remove_channel_member(
2235    request: proto::RemoveChannelMember,
2236    response: Response<proto::RemoveChannelMember>,
2237    session: Session,
2238) -> Result<()> {
2239    let db = session.db().await;
2240    let channel_id = ChannelId::from_proto(request.channel_id);
2241    let member_id = UserId::from_proto(request.user_id);
2242    db.remove_channel_member(channel_id, member_id, session.user_id)
2243        .await?;
2244    response.send(proto::Ack {})?;
2245    Ok(())
2246}
2247
2248async fn get_channel_members(
2249    request: proto::GetChannelMembers,
2250    response: Response<proto::GetChannelMembers>,
2251    session: Session,
2252) -> Result<()> {
2253    let db = session.db().await;
2254    let channel_id = ChannelId::from_proto(request.channel_id);
2255    let members = db
2256        .get_channel_member_details(channel_id, session.user_id)
2257        .await?;
2258    response.send(proto::GetChannelMembersResponse { members })?;
2259    Ok(())
2260}
2261
2262async fn respond_to_channel_invite(
2263    request: proto::RespondToChannelInvite,
2264    response: Response<proto::RespondToChannelInvite>,
2265    session: Session,
2266) -> Result<()> {
2267    let db = session.db().await;
2268    let channel_id = ChannelId::from_proto(request.channel_id);
2269    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2270        .await?;
2271    let channel = db
2272        .get_channel(channel_id, session.user_id)
2273        .await?
2274        .ok_or_else(|| anyhow!("no such channel"))?;
2275
2276    let mut update = proto::UpdateChannels::default();
2277    update
2278        .remove_channel_invitations
2279        .push(channel_id.to_proto());
2280    if request.accept {
2281        update.channels.push(proto::Channel {
2282            id: channel.id.to_proto(),
2283            name: channel.name,
2284            user_is_admin: channel.user_is_admin,
2285            parent_id: None,
2286        });
2287    }
2288    session.peer.send(session.connection_id, update)?;
2289    response.send(proto::Ack {})?;
2290
2291    Ok(())
2292}
2293
2294async fn join_channel(
2295    request: proto::JoinChannel,
2296    response: Response<proto::JoinChannel>,
2297    session: Session,
2298) -> Result<()> {
2299    let channel_id = ChannelId::from_proto(request.channel_id);
2300
2301    let joined_room = {
2302        let db = session.db().await;
2303
2304        let room_id = db.room_id_for_channel(channel_id).await?;
2305
2306        let joined_room = db
2307            .join_room(
2308                room_id,
2309                session.user_id,
2310                Some(channel_id),
2311                session.connection_id,
2312            )
2313            .await?;
2314
2315        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2316            let token = live_kit
2317                .room_token(
2318                    &joined_room.room.live_kit_room,
2319                    &session.user_id.to_string(),
2320                )
2321                .trace_err()?;
2322
2323            Some(LiveKitConnectionInfo {
2324                server_url: live_kit.url().into(),
2325                token,
2326            })
2327        });
2328
2329        response.send(proto::JoinRoomResponse {
2330            room: Some(joined_room.room.clone()),
2331            live_kit_connection_info,
2332        })?;
2333
2334        room_updated(&joined_room.room, &session.peer);
2335
2336        joined_room.clone()
2337    };
2338
2339    // TODO - do this while still holding the room guard,
2340    // currently there's a possible race condition if someone joins the channel
2341    // after we've dropped the lock but before we finish sending these updates
2342    channel_updated(
2343        channel_id,
2344        &joined_room.room,
2345        &joined_room.channel_members,
2346        &session.peer,
2347        &*session.connection_pool().await,
2348    );
2349
2350    update_user_contacts(session.user_id, &session).await?;
2351
2352    Ok(())
2353}
2354
2355async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2356    let project_id = ProjectId::from_proto(request.project_id);
2357    let project_connection_ids = session
2358        .db()
2359        .await
2360        .project_connection_ids(project_id, session.connection_id)
2361        .await?;
2362    broadcast(
2363        Some(session.connection_id),
2364        project_connection_ids.iter().copied(),
2365        |connection_id| {
2366            session
2367                .peer
2368                .forward_send(session.connection_id, connection_id, request.clone())
2369        },
2370    );
2371    Ok(())
2372}
2373
2374async fn get_private_user_info(
2375    _request: proto::GetPrivateUserInfo,
2376    response: Response<proto::GetPrivateUserInfo>,
2377    session: Session,
2378) -> Result<()> {
2379    let metrics_id = session
2380        .db()
2381        .await
2382        .get_user_metrics_id(session.user_id)
2383        .await?;
2384    let user = session
2385        .db()
2386        .await
2387        .get_user_by_id(session.user_id)
2388        .await?
2389        .ok_or_else(|| anyhow!("user not found"))?;
2390    response.send(proto::GetPrivateUserInfoResponse {
2391        metrics_id,
2392        staff: user.admin,
2393    })?;
2394    Ok(())
2395}
2396
2397fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2398    match message {
2399        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2400        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2401        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2402        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2403        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2404            code: frame.code.into(),
2405            reason: frame.reason,
2406        })),
2407    }
2408}
2409
2410fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2411    match message {
2412        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2413        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2414        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2415        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2416        AxumMessage::Close(frame) => {
2417            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2418                code: frame.code.into(),
2419                reason: frame.reason,
2420            }))
2421        }
2422    }
2423}
2424
2425fn build_initial_channels_update(
2426    channels: Vec<db::Channel>,
2427    channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2428    channel_invites: Vec<db::Channel>,
2429) -> proto::UpdateChannels {
2430    let mut update = proto::UpdateChannels::default();
2431
2432    for channel in channels {
2433        update.channels.push(proto::Channel {
2434            id: channel.id.to_proto(),
2435            name: channel.name,
2436            user_is_admin: channel.user_is_admin,
2437            parent_id: channel.parent_id.map(|id| id.to_proto()),
2438        });
2439    }
2440
2441    for (channel_id, participants) in channel_participants {
2442        update
2443            .channel_participants
2444            .push(proto::ChannelParticipants {
2445                channel_id: channel_id.to_proto(),
2446                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2447            });
2448    }
2449
2450    for channel in channel_invites {
2451        update.channel_invitations.push(proto::Channel {
2452            id: channel.id.to_proto(),
2453            name: channel.name,
2454            user_is_admin: false,
2455            parent_id: None,
2456        });
2457    }
2458
2459    update
2460}
2461
2462fn build_initial_contacts_update(
2463    contacts: Vec<db::Contact>,
2464    pool: &ConnectionPool,
2465) -> proto::UpdateContacts {
2466    let mut update = proto::UpdateContacts::default();
2467
2468    for contact in contacts {
2469        match contact {
2470            db::Contact::Accepted {
2471                user_id,
2472                should_notify,
2473                busy,
2474            } => {
2475                update
2476                    .contacts
2477                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2478            }
2479            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2480            db::Contact::Incoming {
2481                user_id,
2482                should_notify,
2483            } => update
2484                .incoming_requests
2485                .push(proto::IncomingContactRequest {
2486                    requester_id: user_id.to_proto(),
2487                    should_notify,
2488                }),
2489        }
2490    }
2491
2492    update
2493}
2494
2495fn contact_for_user(
2496    user_id: UserId,
2497    should_notify: bool,
2498    busy: bool,
2499    pool: &ConnectionPool,
2500) -> proto::Contact {
2501    proto::Contact {
2502        user_id: user_id.to_proto(),
2503        online: pool.is_user_online(user_id),
2504        busy,
2505        should_notify,
2506    }
2507}
2508
2509fn room_updated(room: &proto::Room, peer: &Peer) {
2510    broadcast(
2511        None,
2512        room.participants
2513            .iter()
2514            .filter_map(|participant| Some(participant.peer_id?.into())),
2515        |peer_id| {
2516            peer.send(
2517                peer_id.into(),
2518                proto::RoomUpdated {
2519                    room: Some(room.clone()),
2520                },
2521            )
2522        },
2523    );
2524}
2525
2526fn channel_updated(
2527    channel_id: ChannelId,
2528    room: &proto::Room,
2529    channel_members: &[UserId],
2530    peer: &Peer,
2531    pool: &ConnectionPool,
2532) {
2533    let participants = room
2534        .participants
2535        .iter()
2536        .map(|p| p.user_id)
2537        .collect::<Vec<_>>();
2538
2539    broadcast(
2540        None,
2541        channel_members
2542            .iter()
2543            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2544        |peer_id| {
2545            peer.send(
2546                peer_id.into(),
2547                proto::UpdateChannels {
2548                    channel_participants: vec![proto::ChannelParticipants {
2549                        channel_id: channel_id.to_proto(),
2550                        participant_user_ids: participants.clone(),
2551                    }],
2552                    ..Default::default()
2553                },
2554            )
2555        },
2556    );
2557}
2558
2559async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2560    let db = session.db().await;
2561
2562    let contacts = db.get_contacts(user_id).await?;
2563    let busy = db.is_user_busy(user_id).await?;
2564
2565    let pool = session.connection_pool().await;
2566    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2567    for contact in contacts {
2568        if let db::Contact::Accepted {
2569            user_id: contact_user_id,
2570            ..
2571        } = contact
2572        {
2573            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2574                session
2575                    .peer
2576                    .send(
2577                        contact_conn_id,
2578                        proto::UpdateContacts {
2579                            contacts: vec![updated_contact.clone()],
2580                            remove_contacts: Default::default(),
2581                            incoming_requests: Default::default(),
2582                            remove_incoming_requests: Default::default(),
2583                            outgoing_requests: Default::default(),
2584                            remove_outgoing_requests: Default::default(),
2585                        },
2586                    )
2587                    .trace_err();
2588            }
2589        }
2590    }
2591    Ok(())
2592}
2593
2594async fn leave_room_for_session(session: &Session) -> Result<()> {
2595    let mut contacts_to_update = HashSet::default();
2596
2597    let room_id;
2598    let canceled_calls_to_user_ids;
2599    let live_kit_room;
2600    let delete_live_kit_room;
2601    let room;
2602    let channel_members;
2603    let channel_id;
2604
2605    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2606        contacts_to_update.insert(session.user_id);
2607
2608        for project in left_room.left_projects.values() {
2609            project_left(project, session);
2610        }
2611
2612        room_id = RoomId::from_proto(left_room.room.id);
2613        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2614        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2615        delete_live_kit_room = left_room.deleted;
2616        room = mem::take(&mut left_room.room);
2617        channel_members = mem::take(&mut left_room.channel_members);
2618        channel_id = left_room.channel_id;
2619
2620        room_updated(&room, &session.peer);
2621    } else {
2622        return Ok(());
2623    }
2624
2625    // TODO - do this while holding the room guard.
2626    if let Some(channel_id) = channel_id {
2627        channel_updated(
2628            channel_id,
2629            &room,
2630            &channel_members,
2631            &session.peer,
2632            &*session.connection_pool().await,
2633        );
2634    }
2635
2636    {
2637        let pool = session.connection_pool().await;
2638        for canceled_user_id in canceled_calls_to_user_ids {
2639            for connection_id in pool.user_connection_ids(canceled_user_id) {
2640                session
2641                    .peer
2642                    .send(
2643                        connection_id,
2644                        proto::CallCanceled {
2645                            room_id: room_id.to_proto(),
2646                        },
2647                    )
2648                    .trace_err();
2649            }
2650            contacts_to_update.insert(canceled_user_id);
2651        }
2652    }
2653
2654    for contact_user_id in contacts_to_update {
2655        update_user_contacts(contact_user_id, &session).await?;
2656    }
2657
2658    if let Some(live_kit) = session.live_kit_client.as_ref() {
2659        live_kit
2660            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2661            .await
2662            .trace_err();
2663
2664        if delete_live_kit_room {
2665            live_kit.delete_room(live_kit_room).await.trace_err();
2666        }
2667    }
2668
2669    Ok(())
2670}
2671
2672fn project_left(project: &db::LeftProject, session: &Session) {
2673    for connection_id in &project.connection_ids {
2674        if project.host_user_id == session.user_id {
2675            session
2676                .peer
2677                .send(
2678                    *connection_id,
2679                    proto::UnshareProject {
2680                        project_id: project.id.to_proto(),
2681                    },
2682                )
2683                .trace_err();
2684        } else {
2685            session
2686                .peer
2687                .send(
2688                    *connection_id,
2689                    proto::RemoveProjectCollaborator {
2690                        project_id: project.id.to_proto(),
2691                        peer_id: Some(session.connection_id.into()),
2692                    },
2693                )
2694                .trace_err();
2695        }
2696    }
2697}
2698
2699pub trait ResultExt {
2700    type Ok;
2701
2702    fn trace_err(self) -> Option<Self::Ok>;
2703}
2704
2705impl<T, E> ResultExt for Result<T, E>
2706where
2707    E: std::fmt::Debug,
2708{
2709    type Ok = T;
2710
2711    fn trace_err(self) -> Option<T> {
2712        match self {
2713            Ok(value) => Some(value),
2714            Err(error) => {
2715                tracing::error!("{:?}", error);
2716                None
2717            }
2718        }
2719    }
2720}