rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  39        RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(set_channel_member_admin)
 250            .add_request_handler(rename_channel)
 251            .add_request_handler(get_channel_members)
 252            .add_request_handler(respond_to_channel_invite)
 253            .add_request_handler(join_channel)
 254            .add_request_handler(follow)
 255            .add_message_handler(unfollow)
 256            .add_message_handler(update_followers)
 257            .add_message_handler(update_diff_base)
 258            .add_request_handler(get_private_user_info);
 259
 260        Arc::new(server)
 261    }
 262
 263    pub async fn start(&self) -> Result<()> {
 264        let server_id = *self.id.lock();
 265        let app_state = self.app_state.clone();
 266        let peer = self.peer.clone();
 267        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 268        let pool = self.connection_pool.clone();
 269        let live_kit_client = self.app_state.live_kit_client.clone();
 270
 271        let span = info_span!("start server");
 272        self.executor.spawn_detached(
 273            async move {
 274                tracing::info!("waiting for cleanup timeout");
 275                timeout.await;
 276                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 277                if let Some(room_ids) = app_state
 278                    .db
 279                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 280                    .await
 281                    .trace_err()
 282                {
 283                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 284                    for room_id in room_ids {
 285                        let mut contacts_to_update = HashSet::default();
 286                        let mut canceled_calls_to_user_ids = Vec::new();
 287                        let mut live_kit_room = String::new();
 288                        let mut delete_live_kit_room = false;
 289
 290                        if let Some(mut refreshed_room) = app_state
 291                            .db
 292                            .refresh_room(room_id, server_id)
 293                            .await
 294                            .trace_err()
 295                        {
 296                            tracing::info!(
 297                                room_id = room_id.0,
 298                                new_participant_count = refreshed_room.room.participants.len(),
 299                                "refreshed room"
 300                            );
 301                            room_updated(&refreshed_room.room, &peer);
 302                            if let Some(channel_id) = refreshed_room.channel_id {
 303                                channel_updated(
 304                                    channel_id,
 305                                    &refreshed_room.room,
 306                                    &refreshed_room.channel_members,
 307                                    &peer,
 308                                    &*pool.lock(),
 309                                );
 310                            }
 311                            contacts_to_update
 312                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 313                            contacts_to_update
 314                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 315                            canceled_calls_to_user_ids =
 316                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 317                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 318                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 319                        }
 320
 321                        {
 322                            let pool = pool.lock();
 323                            for canceled_user_id in canceled_calls_to_user_ids {
 324                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 325                                    peer.send(
 326                                        connection_id,
 327                                        proto::CallCanceled {
 328                                            room_id: room_id.to_proto(),
 329                                        },
 330                                    )
 331                                    .trace_err();
 332                                }
 333                            }
 334                        }
 335
 336                        for user_id in contacts_to_update {
 337                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 338                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 339                            if let Some((busy, contacts)) = busy.zip(contacts) {
 340                                let pool = pool.lock();
 341                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 342                                for contact in contacts {
 343                                    if let db::Contact::Accepted {
 344                                        user_id: contact_user_id,
 345                                        ..
 346                                    } = contact
 347                                    {
 348                                        for contact_conn_id in
 349                                            pool.user_connection_ids(contact_user_id)
 350                                        {
 351                                            peer.send(
 352                                                contact_conn_id,
 353                                                proto::UpdateContacts {
 354                                                    contacts: vec![updated_contact.clone()],
 355                                                    remove_contacts: Default::default(),
 356                                                    incoming_requests: Default::default(),
 357                                                    remove_incoming_requests: Default::default(),
 358                                                    outgoing_requests: Default::default(),
 359                                                    remove_outgoing_requests: Default::default(),
 360                                                },
 361                                            )
 362                                            .trace_err();
 363                                        }
 364                                    }
 365                                }
 366                            }
 367                        }
 368
 369                        if let Some(live_kit) = live_kit_client.as_ref() {
 370                            if delete_live_kit_room {
 371                                live_kit.delete_room(live_kit_room).await.trace_err();
 372                            }
 373                        }
 374                    }
 375                }
 376
 377                app_state
 378                    .db
 379                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 380                    .await
 381                    .trace_err();
 382            }
 383            .instrument(span),
 384        );
 385        Ok(())
 386    }
 387
 388    pub fn teardown(&self) {
 389        self.peer.teardown();
 390        self.connection_pool.lock().reset();
 391        let _ = self.teardown.send(());
 392    }
 393
 394    #[cfg(test)]
 395    pub fn reset(&self, id: ServerId) {
 396        self.teardown();
 397        *self.id.lock() = id;
 398        self.peer.reset(id.0 as u32);
 399    }
 400
 401    #[cfg(test)]
 402    pub fn id(&self) -> ServerId {
 403        *self.id.lock()
 404    }
 405
 406    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 407    where
 408        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 409        Fut: 'static + Send + Future<Output = Result<()>>,
 410        M: EnvelopedMessage,
 411    {
 412        let prev_handler = self.handlers.insert(
 413            TypeId::of::<M>(),
 414            Box::new(move |envelope, session| {
 415                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 416                let span = info_span!(
 417                    "handle message",
 418                    payload_type = envelope.payload_type_name()
 419                );
 420                span.in_scope(|| {
 421                    tracing::info!(
 422                        payload_type = envelope.payload_type_name(),
 423                        "message received"
 424                    );
 425                });
 426                let start_time = Instant::now();
 427                let future = (handler)(*envelope, session);
 428                async move {
 429                    let result = future.await;
 430                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 431                    match result {
 432                        Err(error) => {
 433                            tracing::error!(%error, ?duration_ms, "error handling message")
 434                        }
 435                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 436                    }
 437                }
 438                .instrument(span)
 439                .boxed()
 440            }),
 441        );
 442        if prev_handler.is_some() {
 443            panic!("registered a handler for the same message twice");
 444        }
 445        self
 446    }
 447
 448    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 449    where
 450        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 451        Fut: 'static + Send + Future<Output = Result<()>>,
 452        M: EnvelopedMessage,
 453    {
 454        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 455        self
 456    }
 457
 458    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 459    where
 460        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 461        Fut: Send + Future<Output = Result<()>>,
 462        M: RequestMessage,
 463    {
 464        let handler = Arc::new(handler);
 465        self.add_handler(move |envelope, session| {
 466            let receipt = envelope.receipt();
 467            let handler = handler.clone();
 468            async move {
 469                let peer = session.peer.clone();
 470                let responded = Arc::new(AtomicBool::default());
 471                let response = Response {
 472                    peer: peer.clone(),
 473                    responded: responded.clone(),
 474                    receipt,
 475                };
 476                match (handler)(envelope.payload, response, session).await {
 477                    Ok(()) => {
 478                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 479                            Ok(())
 480                        } else {
 481                            Err(anyhow!("handler did not send a response"))?
 482                        }
 483                    }
 484                    Err(error) => {
 485                        peer.respond_with_error(
 486                            receipt,
 487                            proto::Error {
 488                                message: error.to_string(),
 489                            },
 490                        )?;
 491                        Err(error)
 492                    }
 493                }
 494            }
 495        })
 496    }
 497
 498    pub fn handle_connection(
 499        self: &Arc<Self>,
 500        connection: Connection,
 501        address: String,
 502        user: User,
 503        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 504        executor: Executor,
 505    ) -> impl Future<Output = Result<()>> {
 506        let this = self.clone();
 507        let user_id = user.id;
 508        let login = user.github_login;
 509        let span = info_span!("handle connection", %user_id, %login, %address);
 510        let mut teardown = self.teardown.subscribe();
 511        async move {
 512            let (connection_id, handle_io, mut incoming_rx) = this
 513                .peer
 514                .add_connection(connection, {
 515                    let executor = executor.clone();
 516                    move |duration| executor.sleep(duration)
 517                });
 518
 519            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 520            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 521            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 522
 523            if let Some(send_connection_id) = send_connection_id.take() {
 524                let _ = send_connection_id.send(connection_id);
 525            }
 526
 527            if !user.connected_once {
 528                this.peer.send(connection_id, proto::ShowContacts {})?;
 529                this.app_state.db.set_user_connected_once(user_id, true).await?;
 530            }
 531
 532            let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
 533                this.app_state.db.get_contacts(user_id),
 534                this.app_state.db.get_invite_code_for_user(user_id),
 535                this.app_state.db.get_channels_for_user(user_id),
 536                this.app_state.db.get_channel_invites_for_user(user_id)
 537            ).await?;
 538
 539            {
 540                let mut pool = this.connection_pool.lock();
 541                pool.add_connection(connection_id, user_id, user.admin);
 542                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 543                this.peer.send(connection_id, build_initial_channels_update(
 544                    channels_for_user.channels,
 545                    channels_for_user.channel_participants,
 546                    channel_invites
 547                ))?;
 548
 549                if let Some((code, count)) = invite_code {
 550                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 551                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 552                        count: count as u32,
 553                    })?;
 554                }
 555            }
 556
 557            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 558                this.peer.send(connection_id, incoming_call)?;
 559            }
 560
 561            let session = Session {
 562                user_id,
 563                connection_id,
 564                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 565                peer: this.peer.clone(),
 566                connection_pool: this.connection_pool.clone(),
 567                live_kit_client: this.app_state.live_kit_client.clone(),
 568                executor: executor.clone(),
 569            };
 570            update_user_contacts(user_id, &session).await?;
 571
 572            let handle_io = handle_io.fuse();
 573            futures::pin_mut!(handle_io);
 574
 575            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 576            // This prevents deadlocks when e.g., client A performs a request to client B and
 577            // client B performs a request to client A. If both clients stop processing further
 578            // messages until their respective request completes, they won't have a chance to
 579            // respond to the other client's request and cause a deadlock.
 580            //
 581            // This arrangement ensures we will attempt to process earlier messages first, but fall
 582            // back to processing messages arrived later in the spirit of making progress.
 583            let mut foreground_message_handlers = FuturesUnordered::new();
 584            let concurrent_handlers = Arc::new(Semaphore::new(256));
 585            loop {
 586                let next_message = async {
 587                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 588                    let message = incoming_rx.next().await;
 589                    (permit, message)
 590                }.fuse();
 591                futures::pin_mut!(next_message);
 592                futures::select_biased! {
 593                    _ = teardown.changed().fuse() => return Ok(()),
 594                    result = handle_io => {
 595                        if let Err(error) = result {
 596                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 597                        }
 598                        break;
 599                    }
 600                    _ = foreground_message_handlers.next() => {}
 601                    next_message = next_message => {
 602                        let (permit, message) = next_message;
 603                        if let Some(message) = message {
 604                            let type_name = message.payload_type_name();
 605                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 606                            let span_enter = span.enter();
 607                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 608                                let is_background = message.is_background();
 609                                let handle_message = (handler)(message, session.clone());
 610                                drop(span_enter);
 611
 612                                let handle_message = async move {
 613                                    handle_message.await;
 614                                    drop(permit);
 615                                }.instrument(span);
 616                                if is_background {
 617                                    executor.spawn_detached(handle_message);
 618                                } else {
 619                                    foreground_message_handlers.push(handle_message);
 620                                }
 621                            } else {
 622                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 623                            }
 624                        } else {
 625                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 626                            break;
 627                        }
 628                    }
 629                }
 630            }
 631
 632            drop(foreground_message_handlers);
 633            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 634            if let Err(error) = connection_lost(session, teardown, executor).await {
 635                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 636            }
 637
 638            Ok(())
 639        }.instrument(span)
 640    }
 641
 642    pub async fn invite_code_redeemed(
 643        self: &Arc<Self>,
 644        inviter_id: UserId,
 645        invitee_id: UserId,
 646    ) -> Result<()> {
 647        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 648            if let Some(code) = &user.invite_code {
 649                let pool = self.connection_pool.lock();
 650                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 651                for connection_id in pool.user_connection_ids(inviter_id) {
 652                    self.peer.send(
 653                        connection_id,
 654                        proto::UpdateContacts {
 655                            contacts: vec![invitee_contact.clone()],
 656                            ..Default::default()
 657                        },
 658                    )?;
 659                    self.peer.send(
 660                        connection_id,
 661                        proto::UpdateInviteInfo {
 662                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 663                            count: user.invite_count as u32,
 664                        },
 665                    )?;
 666                }
 667            }
 668        }
 669        Ok(())
 670    }
 671
 672    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 673        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 674            if let Some(invite_code) = &user.invite_code {
 675                let pool = self.connection_pool.lock();
 676                for connection_id in pool.user_connection_ids(user_id) {
 677                    self.peer.send(
 678                        connection_id,
 679                        proto::UpdateInviteInfo {
 680                            url: format!(
 681                                "{}{}",
 682                                self.app_state.config.invite_link_prefix, invite_code
 683                            ),
 684                            count: user.invite_count as u32,
 685                        },
 686                    )?;
 687                }
 688            }
 689        }
 690        Ok(())
 691    }
 692
 693    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 694        ServerSnapshot {
 695            connection_pool: ConnectionPoolGuard {
 696                guard: self.connection_pool.lock(),
 697                _not_send: PhantomData,
 698            },
 699            peer: &self.peer,
 700        }
 701    }
 702}
 703
 704impl<'a> Deref for ConnectionPoolGuard<'a> {
 705    type Target = ConnectionPool;
 706
 707    fn deref(&self) -> &Self::Target {
 708        &*self.guard
 709    }
 710}
 711
 712impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 713    fn deref_mut(&mut self) -> &mut Self::Target {
 714        &mut *self.guard
 715    }
 716}
 717
 718impl<'a> Drop for ConnectionPoolGuard<'a> {
 719    fn drop(&mut self) {
 720        #[cfg(test)]
 721        self.check_invariants();
 722    }
 723}
 724
 725fn broadcast<F>(
 726    sender_id: Option<ConnectionId>,
 727    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 728    mut f: F,
 729) where
 730    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 731{
 732    for receiver_id in receiver_ids {
 733        if Some(receiver_id) != sender_id {
 734            if let Err(error) = f(receiver_id) {
 735                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 736            }
 737        }
 738    }
 739}
 740
 741lazy_static! {
 742    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 743}
 744
 745pub struct ProtocolVersion(u32);
 746
 747impl Header for ProtocolVersion {
 748    fn name() -> &'static HeaderName {
 749        &ZED_PROTOCOL_VERSION
 750    }
 751
 752    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 753    where
 754        Self: Sized,
 755        I: Iterator<Item = &'i axum::http::HeaderValue>,
 756    {
 757        let version = values
 758            .next()
 759            .ok_or_else(axum::headers::Error::invalid)?
 760            .to_str()
 761            .map_err(|_| axum::headers::Error::invalid())?
 762            .parse()
 763            .map_err(|_| axum::headers::Error::invalid())?;
 764        Ok(Self(version))
 765    }
 766
 767    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 768        values.extend([self.0.to_string().parse().unwrap()]);
 769    }
 770}
 771
 772pub fn routes(server: Arc<Server>) -> Router<Body> {
 773    Router::new()
 774        .route("/rpc", get(handle_websocket_request))
 775        .layer(
 776            ServiceBuilder::new()
 777                .layer(Extension(server.app_state.clone()))
 778                .layer(middleware::from_fn(auth::validate_header)),
 779        )
 780        .route("/metrics", get(handle_metrics))
 781        .layer(Extension(server))
 782}
 783
 784pub async fn handle_websocket_request(
 785    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 786    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 787    Extension(server): Extension<Arc<Server>>,
 788    Extension(user): Extension<User>,
 789    ws: WebSocketUpgrade,
 790) -> axum::response::Response {
 791    if protocol_version != rpc::PROTOCOL_VERSION {
 792        return (
 793            StatusCode::UPGRADE_REQUIRED,
 794            "client must be upgraded".to_string(),
 795        )
 796            .into_response();
 797    }
 798    let socket_address = socket_address.to_string();
 799    ws.on_upgrade(move |socket| {
 800        use util::ResultExt;
 801        let socket = socket
 802            .map_ok(to_tungstenite_message)
 803            .err_into()
 804            .with(|message| async move { Ok(to_axum_message(message)) });
 805        let connection = Connection::new(Box::pin(socket));
 806        async move {
 807            server
 808                .handle_connection(connection, socket_address, user, None, Executor::Production)
 809                .await
 810                .log_err();
 811        }
 812    })
 813}
 814
 815pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 816    let connections = server
 817        .connection_pool
 818        .lock()
 819        .connections()
 820        .filter(|connection| !connection.admin)
 821        .count();
 822
 823    METRIC_CONNECTIONS.set(connections as _);
 824
 825    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 826    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 827
 828    let encoder = prometheus::TextEncoder::new();
 829    let metric_families = prometheus::gather();
 830    let encoded_metrics = encoder
 831        .encode_to_string(&metric_families)
 832        .map_err(|err| anyhow!("{}", err))?;
 833    Ok(encoded_metrics)
 834}
 835
 836#[instrument(err, skip(executor))]
 837async fn connection_lost(
 838    session: Session,
 839    mut teardown: watch::Receiver<()>,
 840    executor: Executor,
 841) -> Result<()> {
 842    session.peer.disconnect(session.connection_id);
 843    session
 844        .connection_pool()
 845        .await
 846        .remove_connection(session.connection_id)?;
 847
 848    session
 849        .db()
 850        .await
 851        .connection_lost(session.connection_id)
 852        .await
 853        .trace_err();
 854
 855    futures::select_biased! {
 856        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 857            leave_room_for_session(&session).await.trace_err();
 858
 859            if !session
 860                .connection_pool()
 861                .await
 862                .is_user_online(session.user_id)
 863            {
 864                let db = session.db().await;
 865                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 866                    room_updated(&room, &session.peer);
 867                }
 868            }
 869            update_user_contacts(session.user_id, &session).await?;
 870        }
 871        _ = teardown.changed().fuse() => {}
 872    }
 873
 874    Ok(())
 875}
 876
 877async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 878    response.send(proto::Ack {})?;
 879    Ok(())
 880}
 881
 882async fn create_room(
 883    _request: proto::CreateRoom,
 884    response: Response<proto::CreateRoom>,
 885    session: Session,
 886) -> Result<()> {
 887    let live_kit_room = nanoid::nanoid!(30);
 888
 889    let live_kit_connection_info = {
 890        let live_kit_room = live_kit_room.clone();
 891        let live_kit = session.live_kit_client.as_ref();
 892
 893        util::async_iife!({
 894            let live_kit = live_kit?;
 895
 896            live_kit
 897                .create_room(live_kit_room.clone())
 898                .await
 899                .trace_err()?;
 900
 901            let token = live_kit
 902                .room_token(&live_kit_room, &session.user_id.to_string())
 903                .trace_err()?;
 904
 905            Some(proto::LiveKitConnectionInfo {
 906                server_url: live_kit.url().into(),
 907                token,
 908            })
 909        })
 910    }
 911    .await;
 912
 913    let room = session
 914        .db()
 915        .await
 916        .create_room(session.user_id, session.connection_id, &live_kit_room)
 917        .await?;
 918
 919    response.send(proto::CreateRoomResponse {
 920        room: Some(room.clone()),
 921        live_kit_connection_info,
 922    })?;
 923
 924    update_user_contacts(session.user_id, &session).await?;
 925    Ok(())
 926}
 927
 928async fn join_room(
 929    request: proto::JoinRoom,
 930    response: Response<proto::JoinRoom>,
 931    session: Session,
 932) -> Result<()> {
 933    let room_id = RoomId::from_proto(request.id);
 934    let room = {
 935        let room = session
 936            .db()
 937            .await
 938            .join_room(room_id, session.user_id, None, session.connection_id)
 939            .await?;
 940        room_updated(&room.room, &session.peer);
 941        room.room.clone()
 942    };
 943
 944    for connection_id in session
 945        .connection_pool()
 946        .await
 947        .user_connection_ids(session.user_id)
 948    {
 949        session
 950            .peer
 951            .send(
 952                connection_id,
 953                proto::CallCanceled {
 954                    room_id: room_id.to_proto(),
 955                },
 956            )
 957            .trace_err();
 958    }
 959
 960    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 961        if let Some(token) = live_kit
 962            .room_token(&room.live_kit_room, &session.user_id.to_string())
 963            .trace_err()
 964        {
 965            Some(proto::LiveKitConnectionInfo {
 966                server_url: live_kit.url().into(),
 967                token,
 968            })
 969        } else {
 970            None
 971        }
 972    } else {
 973        None
 974    };
 975
 976    response.send(proto::JoinRoomResponse {
 977        room: Some(room),
 978        live_kit_connection_info,
 979    })?;
 980
 981    update_user_contacts(session.user_id, &session).await?;
 982    Ok(())
 983}
 984
 985async fn rejoin_room(
 986    request: proto::RejoinRoom,
 987    response: Response<proto::RejoinRoom>,
 988    session: Session,
 989) -> Result<()> {
 990    let room;
 991    let channel_id;
 992    let channel_members;
 993    {
 994        let mut rejoined_room = session
 995            .db()
 996            .await
 997            .rejoin_room(request, session.user_id, session.connection_id)
 998            .await?;
 999
1000        response.send(proto::RejoinRoomResponse {
1001            room: Some(rejoined_room.room.clone()),
1002            reshared_projects: rejoined_room
1003                .reshared_projects
1004                .iter()
1005                .map(|project| proto::ResharedProject {
1006                    id: project.id.to_proto(),
1007                    collaborators: project
1008                        .collaborators
1009                        .iter()
1010                        .map(|collaborator| collaborator.to_proto())
1011                        .collect(),
1012                })
1013                .collect(),
1014            rejoined_projects: rejoined_room
1015                .rejoined_projects
1016                .iter()
1017                .map(|rejoined_project| proto::RejoinedProject {
1018                    id: rejoined_project.id.to_proto(),
1019                    worktrees: rejoined_project
1020                        .worktrees
1021                        .iter()
1022                        .map(|worktree| proto::WorktreeMetadata {
1023                            id: worktree.id,
1024                            root_name: worktree.root_name.clone(),
1025                            visible: worktree.visible,
1026                            abs_path: worktree.abs_path.clone(),
1027                        })
1028                        .collect(),
1029                    collaborators: rejoined_project
1030                        .collaborators
1031                        .iter()
1032                        .map(|collaborator| collaborator.to_proto())
1033                        .collect(),
1034                    language_servers: rejoined_project.language_servers.clone(),
1035                })
1036                .collect(),
1037        })?;
1038        room_updated(&rejoined_room.room, &session.peer);
1039
1040        for project in &rejoined_room.reshared_projects {
1041            for collaborator in &project.collaborators {
1042                session
1043                    .peer
1044                    .send(
1045                        collaborator.connection_id,
1046                        proto::UpdateProjectCollaborator {
1047                            project_id: project.id.to_proto(),
1048                            old_peer_id: Some(project.old_connection_id.into()),
1049                            new_peer_id: Some(session.connection_id.into()),
1050                        },
1051                    )
1052                    .trace_err();
1053            }
1054
1055            broadcast(
1056                Some(session.connection_id),
1057                project
1058                    .collaborators
1059                    .iter()
1060                    .map(|collaborator| collaborator.connection_id),
1061                |connection_id| {
1062                    session.peer.forward_send(
1063                        session.connection_id,
1064                        connection_id,
1065                        proto::UpdateProject {
1066                            project_id: project.id.to_proto(),
1067                            worktrees: project.worktrees.clone(),
1068                        },
1069                    )
1070                },
1071            );
1072        }
1073
1074        for project in &rejoined_room.rejoined_projects {
1075            for collaborator in &project.collaborators {
1076                session
1077                    .peer
1078                    .send(
1079                        collaborator.connection_id,
1080                        proto::UpdateProjectCollaborator {
1081                            project_id: project.id.to_proto(),
1082                            old_peer_id: Some(project.old_connection_id.into()),
1083                            new_peer_id: Some(session.connection_id.into()),
1084                        },
1085                    )
1086                    .trace_err();
1087            }
1088        }
1089
1090        for project in &mut rejoined_room.rejoined_projects {
1091            for worktree in mem::take(&mut project.worktrees) {
1092                #[cfg(any(test, feature = "test-support"))]
1093                const MAX_CHUNK_SIZE: usize = 2;
1094                #[cfg(not(any(test, feature = "test-support")))]
1095                const MAX_CHUNK_SIZE: usize = 256;
1096
1097                // Stream this worktree's entries.
1098                let message = proto::UpdateWorktree {
1099                    project_id: project.id.to_proto(),
1100                    worktree_id: worktree.id,
1101                    abs_path: worktree.abs_path.clone(),
1102                    root_name: worktree.root_name,
1103                    updated_entries: worktree.updated_entries,
1104                    removed_entries: worktree.removed_entries,
1105                    scan_id: worktree.scan_id,
1106                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1107                    updated_repositories: worktree.updated_repositories,
1108                    removed_repositories: worktree.removed_repositories,
1109                };
1110                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1111                    session.peer.send(session.connection_id, update.clone())?;
1112                }
1113
1114                // Stream this worktree's diagnostics.
1115                for summary in worktree.diagnostic_summaries {
1116                    session.peer.send(
1117                        session.connection_id,
1118                        proto::UpdateDiagnosticSummary {
1119                            project_id: project.id.to_proto(),
1120                            worktree_id: worktree.id,
1121                            summary: Some(summary),
1122                        },
1123                    )?;
1124                }
1125
1126                for settings_file in worktree.settings_files {
1127                    session.peer.send(
1128                        session.connection_id,
1129                        proto::UpdateWorktreeSettings {
1130                            project_id: project.id.to_proto(),
1131                            worktree_id: worktree.id,
1132                            path: settings_file.path,
1133                            content: Some(settings_file.content),
1134                        },
1135                    )?;
1136                }
1137            }
1138
1139            for language_server in &project.language_servers {
1140                session.peer.send(
1141                    session.connection_id,
1142                    proto::UpdateLanguageServer {
1143                        project_id: project.id.to_proto(),
1144                        language_server_id: language_server.id,
1145                        variant: Some(
1146                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1147                                proto::LspDiskBasedDiagnosticsUpdated {},
1148                            ),
1149                        ),
1150                    },
1151                )?;
1152            }
1153        }
1154
1155        room = mem::take(&mut rejoined_room.room);
1156        channel_id = rejoined_room.channel_id;
1157        channel_members = mem::take(&mut rejoined_room.channel_members);
1158    }
1159
1160    //TODO: move this into the room guard
1161    if let Some(channel_id) = channel_id {
1162        channel_updated(
1163            channel_id,
1164            &room,
1165            &channel_members,
1166            &session.peer,
1167            &*session.connection_pool().await,
1168        );
1169    }
1170
1171    update_user_contacts(session.user_id, &session).await?;
1172    Ok(())
1173}
1174
1175async fn leave_room(
1176    _: proto::LeaveRoom,
1177    response: Response<proto::LeaveRoom>,
1178    session: Session,
1179) -> Result<()> {
1180    leave_room_for_session(&session).await?;
1181    response.send(proto::Ack {})?;
1182    Ok(())
1183}
1184
1185async fn call(
1186    request: proto::Call,
1187    response: Response<proto::Call>,
1188    session: Session,
1189) -> Result<()> {
1190    let room_id = RoomId::from_proto(request.room_id);
1191    let calling_user_id = session.user_id;
1192    let calling_connection_id = session.connection_id;
1193    let called_user_id = UserId::from_proto(request.called_user_id);
1194    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1195    if !session
1196        .db()
1197        .await
1198        .has_contact(calling_user_id, called_user_id)
1199        .await?
1200    {
1201        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1202    }
1203
1204    let incoming_call = {
1205        let (room, incoming_call) = &mut *session
1206            .db()
1207            .await
1208            .call(
1209                room_id,
1210                calling_user_id,
1211                calling_connection_id,
1212                called_user_id,
1213                initial_project_id,
1214            )
1215            .await?;
1216        room_updated(&room, &session.peer);
1217        mem::take(incoming_call)
1218    };
1219    update_user_contacts(called_user_id, &session).await?;
1220
1221    let mut calls = session
1222        .connection_pool()
1223        .await
1224        .user_connection_ids(called_user_id)
1225        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1226        .collect::<FuturesUnordered<_>>();
1227
1228    while let Some(call_response) = calls.next().await {
1229        match call_response.as_ref() {
1230            Ok(_) => {
1231                response.send(proto::Ack {})?;
1232                return Ok(());
1233            }
1234            Err(_) => {
1235                call_response.trace_err();
1236            }
1237        }
1238    }
1239
1240    {
1241        let room = session
1242            .db()
1243            .await
1244            .call_failed(room_id, called_user_id)
1245            .await?;
1246        room_updated(&room, &session.peer);
1247    }
1248    update_user_contacts(called_user_id, &session).await?;
1249
1250    Err(anyhow!("failed to ring user"))?
1251}
1252
1253async fn cancel_call(
1254    request: proto::CancelCall,
1255    response: Response<proto::CancelCall>,
1256    session: Session,
1257) -> Result<()> {
1258    let called_user_id = UserId::from_proto(request.called_user_id);
1259    let room_id = RoomId::from_proto(request.room_id);
1260    {
1261        let room = session
1262            .db()
1263            .await
1264            .cancel_call(room_id, session.connection_id, called_user_id)
1265            .await?;
1266        room_updated(&room, &session.peer);
1267    }
1268
1269    for connection_id in session
1270        .connection_pool()
1271        .await
1272        .user_connection_ids(called_user_id)
1273    {
1274        session
1275            .peer
1276            .send(
1277                connection_id,
1278                proto::CallCanceled {
1279                    room_id: room_id.to_proto(),
1280                },
1281            )
1282            .trace_err();
1283    }
1284    response.send(proto::Ack {})?;
1285
1286    update_user_contacts(called_user_id, &session).await?;
1287    Ok(())
1288}
1289
1290async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1291    let room_id = RoomId::from_proto(message.room_id);
1292    {
1293        let room = session
1294            .db()
1295            .await
1296            .decline_call(Some(room_id), session.user_id)
1297            .await?
1298            .ok_or_else(|| anyhow!("failed to decline call"))?;
1299        room_updated(&room, &session.peer);
1300    }
1301
1302    for connection_id in session
1303        .connection_pool()
1304        .await
1305        .user_connection_ids(session.user_id)
1306    {
1307        session
1308            .peer
1309            .send(
1310                connection_id,
1311                proto::CallCanceled {
1312                    room_id: room_id.to_proto(),
1313                },
1314            )
1315            .trace_err();
1316    }
1317    update_user_contacts(session.user_id, &session).await?;
1318    Ok(())
1319}
1320
1321async fn update_participant_location(
1322    request: proto::UpdateParticipantLocation,
1323    response: Response<proto::UpdateParticipantLocation>,
1324    session: Session,
1325) -> Result<()> {
1326    let room_id = RoomId::from_proto(request.room_id);
1327    let location = request
1328        .location
1329        .ok_or_else(|| anyhow!("invalid location"))?;
1330
1331    let db = session.db().await;
1332    let room = db
1333        .update_room_participant_location(room_id, session.connection_id, location)
1334        .await?;
1335
1336    room_updated(&room, &session.peer);
1337    response.send(proto::Ack {})?;
1338    Ok(())
1339}
1340
1341async fn share_project(
1342    request: proto::ShareProject,
1343    response: Response<proto::ShareProject>,
1344    session: Session,
1345) -> Result<()> {
1346    let (project_id, room) = &*session
1347        .db()
1348        .await
1349        .share_project(
1350            RoomId::from_proto(request.room_id),
1351            session.connection_id,
1352            &request.worktrees,
1353        )
1354        .await?;
1355    response.send(proto::ShareProjectResponse {
1356        project_id: project_id.to_proto(),
1357    })?;
1358    room_updated(&room, &session.peer);
1359
1360    Ok(())
1361}
1362
1363async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1364    let project_id = ProjectId::from_proto(message.project_id);
1365
1366    let (room, guest_connection_ids) = &*session
1367        .db()
1368        .await
1369        .unshare_project(project_id, session.connection_id)
1370        .await?;
1371
1372    broadcast(
1373        Some(session.connection_id),
1374        guest_connection_ids.iter().copied(),
1375        |conn_id| session.peer.send(conn_id, message.clone()),
1376    );
1377    room_updated(&room, &session.peer);
1378
1379    Ok(())
1380}
1381
1382async fn join_project(
1383    request: proto::JoinProject,
1384    response: Response<proto::JoinProject>,
1385    session: Session,
1386) -> Result<()> {
1387    let project_id = ProjectId::from_proto(request.project_id);
1388    let guest_user_id = session.user_id;
1389
1390    tracing::info!(%project_id, "join project");
1391
1392    let (project, replica_id) = &mut *session
1393        .db()
1394        .await
1395        .join_project(project_id, session.connection_id)
1396        .await?;
1397
1398    let collaborators = project
1399        .collaborators
1400        .iter()
1401        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1402        .map(|collaborator| collaborator.to_proto())
1403        .collect::<Vec<_>>();
1404
1405    let worktrees = project
1406        .worktrees
1407        .iter()
1408        .map(|(id, worktree)| proto::WorktreeMetadata {
1409            id: *id,
1410            root_name: worktree.root_name.clone(),
1411            visible: worktree.visible,
1412            abs_path: worktree.abs_path.clone(),
1413        })
1414        .collect::<Vec<_>>();
1415
1416    for collaborator in &collaborators {
1417        session
1418            .peer
1419            .send(
1420                collaborator.peer_id.unwrap().into(),
1421                proto::AddProjectCollaborator {
1422                    project_id: project_id.to_proto(),
1423                    collaborator: Some(proto::Collaborator {
1424                        peer_id: Some(session.connection_id.into()),
1425                        replica_id: replica_id.0 as u32,
1426                        user_id: guest_user_id.to_proto(),
1427                    }),
1428                },
1429            )
1430            .trace_err();
1431    }
1432
1433    // First, we send the metadata associated with each worktree.
1434    response.send(proto::JoinProjectResponse {
1435        worktrees: worktrees.clone(),
1436        replica_id: replica_id.0 as u32,
1437        collaborators: collaborators.clone(),
1438        language_servers: project.language_servers.clone(),
1439    })?;
1440
1441    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1442        #[cfg(any(test, feature = "test-support"))]
1443        const MAX_CHUNK_SIZE: usize = 2;
1444        #[cfg(not(any(test, feature = "test-support")))]
1445        const MAX_CHUNK_SIZE: usize = 256;
1446
1447        // Stream this worktree's entries.
1448        let message = proto::UpdateWorktree {
1449            project_id: project_id.to_proto(),
1450            worktree_id,
1451            abs_path: worktree.abs_path.clone(),
1452            root_name: worktree.root_name,
1453            updated_entries: worktree.entries,
1454            removed_entries: Default::default(),
1455            scan_id: worktree.scan_id,
1456            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1457            updated_repositories: worktree.repository_entries.into_values().collect(),
1458            removed_repositories: Default::default(),
1459        };
1460        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1461            session.peer.send(session.connection_id, update.clone())?;
1462        }
1463
1464        // Stream this worktree's diagnostics.
1465        for summary in worktree.diagnostic_summaries {
1466            session.peer.send(
1467                session.connection_id,
1468                proto::UpdateDiagnosticSummary {
1469                    project_id: project_id.to_proto(),
1470                    worktree_id: worktree.id,
1471                    summary: Some(summary),
1472                },
1473            )?;
1474        }
1475
1476        for settings_file in worktree.settings_files {
1477            session.peer.send(
1478                session.connection_id,
1479                proto::UpdateWorktreeSettings {
1480                    project_id: project_id.to_proto(),
1481                    worktree_id: worktree.id,
1482                    path: settings_file.path,
1483                    content: Some(settings_file.content),
1484                },
1485            )?;
1486        }
1487    }
1488
1489    for language_server in &project.language_servers {
1490        session.peer.send(
1491            session.connection_id,
1492            proto::UpdateLanguageServer {
1493                project_id: project_id.to_proto(),
1494                language_server_id: language_server.id,
1495                variant: Some(
1496                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1497                        proto::LspDiskBasedDiagnosticsUpdated {},
1498                    ),
1499                ),
1500            },
1501        )?;
1502    }
1503
1504    Ok(())
1505}
1506
1507async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1508    let sender_id = session.connection_id;
1509    let project_id = ProjectId::from_proto(request.project_id);
1510
1511    let (room, project) = &*session
1512        .db()
1513        .await
1514        .leave_project(project_id, sender_id)
1515        .await?;
1516    tracing::info!(
1517        %project_id,
1518        host_user_id = %project.host_user_id,
1519        host_connection_id = %project.host_connection_id,
1520        "leave project"
1521    );
1522
1523    project_left(&project, &session);
1524    room_updated(&room, &session.peer);
1525
1526    Ok(())
1527}
1528
1529async fn update_project(
1530    request: proto::UpdateProject,
1531    response: Response<proto::UpdateProject>,
1532    session: Session,
1533) -> Result<()> {
1534    let project_id = ProjectId::from_proto(request.project_id);
1535    let (room, guest_connection_ids) = &*session
1536        .db()
1537        .await
1538        .update_project(project_id, session.connection_id, &request.worktrees)
1539        .await?;
1540    broadcast(
1541        Some(session.connection_id),
1542        guest_connection_ids.iter().copied(),
1543        |connection_id| {
1544            session
1545                .peer
1546                .forward_send(session.connection_id, connection_id, request.clone())
1547        },
1548    );
1549    room_updated(&room, &session.peer);
1550    response.send(proto::Ack {})?;
1551
1552    Ok(())
1553}
1554
1555async fn update_worktree(
1556    request: proto::UpdateWorktree,
1557    response: Response<proto::UpdateWorktree>,
1558    session: Session,
1559) -> Result<()> {
1560    let guest_connection_ids = session
1561        .db()
1562        .await
1563        .update_worktree(&request, session.connection_id)
1564        .await?;
1565
1566    broadcast(
1567        Some(session.connection_id),
1568        guest_connection_ids.iter().copied(),
1569        |connection_id| {
1570            session
1571                .peer
1572                .forward_send(session.connection_id, connection_id, request.clone())
1573        },
1574    );
1575    response.send(proto::Ack {})?;
1576    Ok(())
1577}
1578
1579async fn update_diagnostic_summary(
1580    message: proto::UpdateDiagnosticSummary,
1581    session: Session,
1582) -> Result<()> {
1583    let guest_connection_ids = session
1584        .db()
1585        .await
1586        .update_diagnostic_summary(&message, session.connection_id)
1587        .await?;
1588
1589    broadcast(
1590        Some(session.connection_id),
1591        guest_connection_ids.iter().copied(),
1592        |connection_id| {
1593            session
1594                .peer
1595                .forward_send(session.connection_id, connection_id, message.clone())
1596        },
1597    );
1598
1599    Ok(())
1600}
1601
1602async fn update_worktree_settings(
1603    message: proto::UpdateWorktreeSettings,
1604    session: Session,
1605) -> Result<()> {
1606    let guest_connection_ids = session
1607        .db()
1608        .await
1609        .update_worktree_settings(&message, session.connection_id)
1610        .await?;
1611
1612    broadcast(
1613        Some(session.connection_id),
1614        guest_connection_ids.iter().copied(),
1615        |connection_id| {
1616            session
1617                .peer
1618                .forward_send(session.connection_id, connection_id, message.clone())
1619        },
1620    );
1621
1622    Ok(())
1623}
1624
1625async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1626    broadcast_project_message(request.project_id, request, session).await
1627}
1628
1629async fn start_language_server(
1630    request: proto::StartLanguageServer,
1631    session: Session,
1632) -> Result<()> {
1633    let guest_connection_ids = session
1634        .db()
1635        .await
1636        .start_language_server(&request, session.connection_id)
1637        .await?;
1638
1639    broadcast(
1640        Some(session.connection_id),
1641        guest_connection_ids.iter().copied(),
1642        |connection_id| {
1643            session
1644                .peer
1645                .forward_send(session.connection_id, connection_id, request.clone())
1646        },
1647    );
1648    Ok(())
1649}
1650
1651async fn update_language_server(
1652    request: proto::UpdateLanguageServer,
1653    session: Session,
1654) -> Result<()> {
1655    session.executor.record_backtrace();
1656    let project_id = ProjectId::from_proto(request.project_id);
1657    let project_connection_ids = session
1658        .db()
1659        .await
1660        .project_connection_ids(project_id, session.connection_id)
1661        .await?;
1662    broadcast(
1663        Some(session.connection_id),
1664        project_connection_ids.iter().copied(),
1665        |connection_id| {
1666            session
1667                .peer
1668                .forward_send(session.connection_id, connection_id, request.clone())
1669        },
1670    );
1671    Ok(())
1672}
1673
1674async fn forward_project_request<T>(
1675    request: T,
1676    response: Response<T>,
1677    session: Session,
1678) -> Result<()>
1679where
1680    T: EntityMessage + RequestMessage,
1681{
1682    session.executor.record_backtrace();
1683    let project_id = ProjectId::from_proto(request.remote_entity_id());
1684    let host_connection_id = {
1685        let collaborators = session
1686            .db()
1687            .await
1688            .project_collaborators(project_id, session.connection_id)
1689            .await?;
1690        collaborators
1691            .iter()
1692            .find(|collaborator| collaborator.is_host)
1693            .ok_or_else(|| anyhow!("host not found"))?
1694            .connection_id
1695    };
1696
1697    let payload = session
1698        .peer
1699        .forward_request(session.connection_id, host_connection_id, request)
1700        .await?;
1701
1702    response.send(payload)?;
1703    Ok(())
1704}
1705
1706async fn create_buffer_for_peer(
1707    request: proto::CreateBufferForPeer,
1708    session: Session,
1709) -> Result<()> {
1710    session.executor.record_backtrace();
1711    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1712    session
1713        .peer
1714        .forward_send(session.connection_id, peer_id.into(), request)?;
1715    Ok(())
1716}
1717
1718async fn update_buffer(
1719    request: proto::UpdateBuffer,
1720    response: Response<proto::UpdateBuffer>,
1721    session: Session,
1722) -> Result<()> {
1723    session.executor.record_backtrace();
1724    let project_id = ProjectId::from_proto(request.project_id);
1725    let mut guest_connection_ids;
1726    let mut host_connection_id = None;
1727    {
1728        let collaborators = session
1729            .db()
1730            .await
1731            .project_collaborators(project_id, session.connection_id)
1732            .await?;
1733        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1734        for collaborator in collaborators.iter() {
1735            if collaborator.is_host {
1736                host_connection_id = Some(collaborator.connection_id);
1737            } else {
1738                guest_connection_ids.push(collaborator.connection_id);
1739            }
1740        }
1741    }
1742    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1743
1744    session.executor.record_backtrace();
1745    broadcast(
1746        Some(session.connection_id),
1747        guest_connection_ids,
1748        |connection_id| {
1749            session
1750                .peer
1751                .forward_send(session.connection_id, connection_id, request.clone())
1752        },
1753    );
1754    if host_connection_id != session.connection_id {
1755        session
1756            .peer
1757            .forward_request(session.connection_id, host_connection_id, request.clone())
1758            .await?;
1759    }
1760
1761    response.send(proto::Ack {})?;
1762    Ok(())
1763}
1764
1765async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1766    let project_id = ProjectId::from_proto(request.project_id);
1767    let project_connection_ids = session
1768        .db()
1769        .await
1770        .project_connection_ids(project_id, session.connection_id)
1771        .await?;
1772
1773    broadcast(
1774        Some(session.connection_id),
1775        project_connection_ids.iter().copied(),
1776        |connection_id| {
1777            session
1778                .peer
1779                .forward_send(session.connection_id, connection_id, request.clone())
1780        },
1781    );
1782    Ok(())
1783}
1784
1785async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1786    let project_id = ProjectId::from_proto(request.project_id);
1787    let project_connection_ids = session
1788        .db()
1789        .await
1790        .project_connection_ids(project_id, session.connection_id)
1791        .await?;
1792    broadcast(
1793        Some(session.connection_id),
1794        project_connection_ids.iter().copied(),
1795        |connection_id| {
1796            session
1797                .peer
1798                .forward_send(session.connection_id, connection_id, request.clone())
1799        },
1800    );
1801    Ok(())
1802}
1803
1804async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1805    broadcast_project_message(request.project_id, request, session).await
1806}
1807
1808async fn broadcast_project_message<T: EnvelopedMessage>(
1809    project_id: u64,
1810    request: T,
1811    session: Session,
1812) -> Result<()> {
1813    let project_id = ProjectId::from_proto(project_id);
1814    let project_connection_ids = session
1815        .db()
1816        .await
1817        .project_connection_ids(project_id, session.connection_id)
1818        .await?;
1819    broadcast(
1820        Some(session.connection_id),
1821        project_connection_ids.iter().copied(),
1822        |connection_id| {
1823            session
1824                .peer
1825                .forward_send(session.connection_id, connection_id, request.clone())
1826        },
1827    );
1828    Ok(())
1829}
1830
1831async fn follow(
1832    request: proto::Follow,
1833    response: Response<proto::Follow>,
1834    session: Session,
1835) -> Result<()> {
1836    let project_id = ProjectId::from_proto(request.project_id);
1837    let leader_id = request
1838        .leader_id
1839        .ok_or_else(|| anyhow!("invalid leader id"))?
1840        .into();
1841    let follower_id = session.connection_id;
1842
1843    {
1844        let project_connection_ids = session
1845            .db()
1846            .await
1847            .project_connection_ids(project_id, session.connection_id)
1848            .await?;
1849
1850        if !project_connection_ids.contains(&leader_id) {
1851            Err(anyhow!("no such peer"))?;
1852        }
1853    }
1854
1855    let mut response_payload = session
1856        .peer
1857        .forward_request(session.connection_id, leader_id, request)
1858        .await?;
1859    response_payload
1860        .views
1861        .retain(|view| view.leader_id != Some(follower_id.into()));
1862    response.send(response_payload)?;
1863
1864    let room = session
1865        .db()
1866        .await
1867        .follow(project_id, leader_id, follower_id)
1868        .await?;
1869    room_updated(&room, &session.peer);
1870
1871    Ok(())
1872}
1873
1874async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1875    let project_id = ProjectId::from_proto(request.project_id);
1876    let leader_id = request
1877        .leader_id
1878        .ok_or_else(|| anyhow!("invalid leader id"))?
1879        .into();
1880    let follower_id = session.connection_id;
1881
1882    if !session
1883        .db()
1884        .await
1885        .project_connection_ids(project_id, session.connection_id)
1886        .await?
1887        .contains(&leader_id)
1888    {
1889        Err(anyhow!("no such peer"))?;
1890    }
1891
1892    session
1893        .peer
1894        .forward_send(session.connection_id, leader_id, request)?;
1895
1896    let room = session
1897        .db()
1898        .await
1899        .unfollow(project_id, leader_id, follower_id)
1900        .await?;
1901    room_updated(&room, &session.peer);
1902
1903    Ok(())
1904}
1905
1906async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1907    let project_id = ProjectId::from_proto(request.project_id);
1908    let project_connection_ids = session
1909        .db
1910        .lock()
1911        .await
1912        .project_connection_ids(project_id, session.connection_id)
1913        .await?;
1914
1915    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1916        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1917        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1918        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1919    });
1920    for follower_peer_id in request.follower_ids.iter().copied() {
1921        let follower_connection_id = follower_peer_id.into();
1922        if project_connection_ids.contains(&follower_connection_id)
1923            && Some(follower_peer_id) != leader_id
1924        {
1925            session.peer.forward_send(
1926                session.connection_id,
1927                follower_connection_id,
1928                request.clone(),
1929            )?;
1930        }
1931    }
1932    Ok(())
1933}
1934
1935async fn get_users(
1936    request: proto::GetUsers,
1937    response: Response<proto::GetUsers>,
1938    session: Session,
1939) -> Result<()> {
1940    let user_ids = request
1941        .user_ids
1942        .into_iter()
1943        .map(UserId::from_proto)
1944        .collect();
1945    let users = session
1946        .db()
1947        .await
1948        .get_users_by_ids(user_ids)
1949        .await?
1950        .into_iter()
1951        .map(|user| proto::User {
1952            id: user.id.to_proto(),
1953            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1954            github_login: user.github_login,
1955        })
1956        .collect();
1957    response.send(proto::UsersResponse { users })?;
1958    Ok(())
1959}
1960
1961async fn fuzzy_search_users(
1962    request: proto::FuzzySearchUsers,
1963    response: Response<proto::FuzzySearchUsers>,
1964    session: Session,
1965) -> Result<()> {
1966    let query = request.query;
1967    let users = match query.len() {
1968        0 => vec![],
1969        1 | 2 => session
1970            .db()
1971            .await
1972            .get_user_by_github_login(&query)
1973            .await?
1974            .into_iter()
1975            .collect(),
1976        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1977    };
1978    let users = users
1979        .into_iter()
1980        .filter(|user| user.id != session.user_id)
1981        .map(|user| proto::User {
1982            id: user.id.to_proto(),
1983            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1984            github_login: user.github_login,
1985        })
1986        .collect();
1987    response.send(proto::UsersResponse { users })?;
1988    Ok(())
1989}
1990
1991async fn request_contact(
1992    request: proto::RequestContact,
1993    response: Response<proto::RequestContact>,
1994    session: Session,
1995) -> Result<()> {
1996    let requester_id = session.user_id;
1997    let responder_id = UserId::from_proto(request.responder_id);
1998    if requester_id == responder_id {
1999        return Err(anyhow!("cannot add yourself as a contact"))?;
2000    }
2001
2002    session
2003        .db()
2004        .await
2005        .send_contact_request(requester_id, responder_id)
2006        .await?;
2007
2008    // Update outgoing contact requests of requester
2009    let mut update = proto::UpdateContacts::default();
2010    update.outgoing_requests.push(responder_id.to_proto());
2011    for connection_id in session
2012        .connection_pool()
2013        .await
2014        .user_connection_ids(requester_id)
2015    {
2016        session.peer.send(connection_id, update.clone())?;
2017    }
2018
2019    // Update incoming contact requests of responder
2020    let mut update = proto::UpdateContacts::default();
2021    update
2022        .incoming_requests
2023        .push(proto::IncomingContactRequest {
2024            requester_id: requester_id.to_proto(),
2025            should_notify: true,
2026        });
2027    for connection_id in session
2028        .connection_pool()
2029        .await
2030        .user_connection_ids(responder_id)
2031    {
2032        session.peer.send(connection_id, update.clone())?;
2033    }
2034
2035    response.send(proto::Ack {})?;
2036    Ok(())
2037}
2038
2039async fn respond_to_contact_request(
2040    request: proto::RespondToContactRequest,
2041    response: Response<proto::RespondToContactRequest>,
2042    session: Session,
2043) -> Result<()> {
2044    let responder_id = session.user_id;
2045    let requester_id = UserId::from_proto(request.requester_id);
2046    let db = session.db().await;
2047    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2048        db.dismiss_contact_notification(responder_id, requester_id)
2049            .await?;
2050    } else {
2051        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2052
2053        db.respond_to_contact_request(responder_id, requester_id, accept)
2054            .await?;
2055        let requester_busy = db.is_user_busy(requester_id).await?;
2056        let responder_busy = db.is_user_busy(responder_id).await?;
2057
2058        let pool = session.connection_pool().await;
2059        // Update responder with new contact
2060        let mut update = proto::UpdateContacts::default();
2061        if accept {
2062            update
2063                .contacts
2064                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2065        }
2066        update
2067            .remove_incoming_requests
2068            .push(requester_id.to_proto());
2069        for connection_id in pool.user_connection_ids(responder_id) {
2070            session.peer.send(connection_id, update.clone())?;
2071        }
2072
2073        // Update requester with new contact
2074        let mut update = proto::UpdateContacts::default();
2075        if accept {
2076            update
2077                .contacts
2078                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2079        }
2080        update
2081            .remove_outgoing_requests
2082            .push(responder_id.to_proto());
2083        for connection_id in pool.user_connection_ids(requester_id) {
2084            session.peer.send(connection_id, update.clone())?;
2085        }
2086    }
2087
2088    response.send(proto::Ack {})?;
2089    Ok(())
2090}
2091
2092async fn remove_contact(
2093    request: proto::RemoveContact,
2094    response: Response<proto::RemoveContact>,
2095    session: Session,
2096) -> Result<()> {
2097    let requester_id = session.user_id;
2098    let responder_id = UserId::from_proto(request.user_id);
2099    let db = session.db().await;
2100    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2101
2102    let pool = session.connection_pool().await;
2103    // Update outgoing contact requests of requester
2104    let mut update = proto::UpdateContacts::default();
2105    if contact_accepted {
2106        update.remove_contacts.push(responder_id.to_proto());
2107    } else {
2108        update
2109            .remove_outgoing_requests
2110            .push(responder_id.to_proto());
2111    }
2112    for connection_id in pool.user_connection_ids(requester_id) {
2113        session.peer.send(connection_id, update.clone())?;
2114    }
2115
2116    // Update incoming contact requests of responder
2117    let mut update = proto::UpdateContacts::default();
2118    if contact_accepted {
2119        update.remove_contacts.push(requester_id.to_proto());
2120    } else {
2121        update
2122            .remove_incoming_requests
2123            .push(requester_id.to_proto());
2124    }
2125    for connection_id in pool.user_connection_ids(responder_id) {
2126        session.peer.send(connection_id, update.clone())?;
2127    }
2128
2129    response.send(proto::Ack {})?;
2130    Ok(())
2131}
2132
2133async fn create_channel(
2134    request: proto::CreateChannel,
2135    response: Response<proto::CreateChannel>,
2136    session: Session,
2137) -> Result<()> {
2138    let db = session.db().await;
2139    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2140
2141    if let Some(live_kit) = session.live_kit_client.as_ref() {
2142        live_kit.create_room(live_kit_room.clone()).await?;
2143    }
2144
2145    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2146    let id = db
2147        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2148        .await?;
2149
2150    response.send(proto::CreateChannelResponse {
2151        channel_id: id.to_proto(),
2152    })?;
2153
2154    let mut update = proto::UpdateChannels::default();
2155    update.channels.push(proto::Channel {
2156        id: id.to_proto(),
2157        name: request.name,
2158        parent_id: request.parent_id,
2159    });
2160
2161    let user_ids_to_notify = if let Some(parent_id) = parent_id {
2162        db.get_channel_members(parent_id).await?
2163    } else {
2164        vec![session.user_id]
2165    };
2166
2167    let connection_pool = session.connection_pool().await;
2168    for user_id in user_ids_to_notify {
2169        for connection_id in connection_pool.user_connection_ids(user_id) {
2170            let mut update = update.clone();
2171            if user_id == session.user_id {
2172                update.channel_permissions.push(proto::ChannelPermission {
2173                    channel_id: id.to_proto(),
2174                    is_admin: true,
2175                });
2176            }
2177            session.peer.send(connection_id, update)?;
2178        }
2179    }
2180
2181    Ok(())
2182}
2183
2184async fn remove_channel(
2185    request: proto::RemoveChannel,
2186    response: Response<proto::RemoveChannel>,
2187    session: Session,
2188) -> Result<()> {
2189    let db = session.db().await;
2190
2191    let channel_id = request.channel_id;
2192    let (removed_channels, member_ids) = db
2193        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2194        .await?;
2195    response.send(proto::Ack {})?;
2196
2197    // Notify members of removed channels
2198    let mut update = proto::UpdateChannels::default();
2199    update
2200        .remove_channels
2201        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2202
2203    let connection_pool = session.connection_pool().await;
2204    for member_id in member_ids {
2205        for connection_id in connection_pool.user_connection_ids(member_id) {
2206            session.peer.send(connection_id, update.clone())?;
2207        }
2208    }
2209
2210    Ok(())
2211}
2212
2213async fn invite_channel_member(
2214    request: proto::InviteChannelMember,
2215    response: Response<proto::InviteChannelMember>,
2216    session: Session,
2217) -> Result<()> {
2218    let db = session.db().await;
2219    let channel_id = ChannelId::from_proto(request.channel_id);
2220    let invitee_id = UserId::from_proto(request.user_id);
2221    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2222        .await?;
2223
2224    let (channel, _) = db
2225        .get_channel(channel_id, session.user_id)
2226        .await?
2227        .ok_or_else(|| anyhow!("channel not found"))?;
2228
2229    let mut update = proto::UpdateChannels::default();
2230    update.channel_invitations.push(proto::Channel {
2231        id: channel.id.to_proto(),
2232        name: channel.name,
2233        parent_id: None,
2234    });
2235    for connection_id in session
2236        .connection_pool()
2237        .await
2238        .user_connection_ids(invitee_id)
2239    {
2240        session.peer.send(connection_id, update.clone())?;
2241    }
2242
2243    response.send(proto::Ack {})?;
2244    Ok(())
2245}
2246
2247async fn remove_channel_member(
2248    request: proto::RemoveChannelMember,
2249    response: Response<proto::RemoveChannelMember>,
2250    session: Session,
2251) -> Result<()> {
2252    let db = session.db().await;
2253    let channel_id = ChannelId::from_proto(request.channel_id);
2254    let member_id = UserId::from_proto(request.user_id);
2255
2256    db.remove_channel_member(channel_id, member_id, session.user_id)
2257        .await?;
2258
2259    let mut update = proto::UpdateChannels::default();
2260    update.remove_channels.push(channel_id.to_proto());
2261
2262    for connection_id in session
2263        .connection_pool()
2264        .await
2265        .user_connection_ids(member_id)
2266    {
2267        session.peer.send(connection_id, update.clone())?;
2268    }
2269
2270    response.send(proto::Ack {})?;
2271    Ok(())
2272}
2273
2274async fn set_channel_member_admin(
2275    request: proto::SetChannelMemberAdmin,
2276    response: Response<proto::SetChannelMemberAdmin>,
2277    session: Session,
2278) -> Result<()> {
2279    let db = session.db().await;
2280    let channel_id = ChannelId::from_proto(request.channel_id);
2281    let member_id = UserId::from_proto(request.user_id);
2282    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2283        .await?;
2284
2285    let (channel, has_accepted) = db
2286        .get_channel(channel_id, member_id)
2287        .await?
2288        .ok_or_else(|| anyhow!("channel not found"))?;
2289
2290    let mut update = proto::UpdateChannels::default();
2291    if has_accepted {
2292        update.channel_permissions.push(proto::ChannelPermission {
2293            channel_id: channel.id.to_proto(),
2294            is_admin: request.admin,
2295        });
2296    }
2297
2298    for connection_id in session
2299        .connection_pool()
2300        .await
2301        .user_connection_ids(member_id)
2302    {
2303        session.peer.send(connection_id, update.clone())?;
2304    }
2305
2306    response.send(proto::Ack {})?;
2307    Ok(())
2308}
2309
2310async fn rename_channel(
2311    request: proto::RenameChannel,
2312    response: Response<proto::RenameChannel>,
2313    session: Session,
2314) -> Result<()> {
2315    let db = session.db().await;
2316    let channel_id = ChannelId::from_proto(request.channel_id);
2317    let new_name = db
2318        .rename_channel(channel_id, session.user_id, &request.name)
2319        .await?;
2320
2321    response.send(proto::Ack {})?;
2322
2323    let mut update = proto::UpdateChannels::default();
2324    update.channels.push(proto::Channel {
2325        id: request.channel_id,
2326        name: new_name,
2327        parent_id: None,
2328    });
2329
2330    let member_ids = db.get_channel_members(channel_id).await?;
2331
2332    let connection_pool = session.connection_pool().await;
2333    for member_id in member_ids {
2334        for connection_id in connection_pool.user_connection_ids(member_id) {
2335            session.peer.send(connection_id, update.clone())?;
2336        }
2337    }
2338
2339    Ok(())
2340}
2341
2342async fn get_channel_members(
2343    request: proto::GetChannelMembers,
2344    response: Response<proto::GetChannelMembers>,
2345    session: Session,
2346) -> Result<()> {
2347    let db = session.db().await;
2348    let channel_id = ChannelId::from_proto(request.channel_id);
2349    let members = db
2350        .get_channel_member_details(channel_id, session.user_id)
2351        .await?;
2352    response.send(proto::GetChannelMembersResponse { members })?;
2353    Ok(())
2354}
2355
2356async fn respond_to_channel_invite(
2357    request: proto::RespondToChannelInvite,
2358    response: Response<proto::RespondToChannelInvite>,
2359    session: Session,
2360) -> Result<()> {
2361    let db = session.db().await;
2362    let channel_id = ChannelId::from_proto(request.channel_id);
2363    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2364        .await?;
2365
2366    let mut update = proto::UpdateChannels::default();
2367    update
2368        .remove_channel_invitations
2369        .push(channel_id.to_proto());
2370    if request.accept {
2371        let result = db.get_channels_for_user(session.user_id).await?;
2372        update
2373            .channels
2374            .extend(result.channels.into_iter().map(|channel| proto::Channel {
2375                id: channel.id.to_proto(),
2376                name: channel.name,
2377                parent_id: channel.parent_id.map(ChannelId::to_proto),
2378            }));
2379        update
2380            .channel_participants
2381            .extend(
2382                result
2383                    .channel_participants
2384                    .into_iter()
2385                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2386                        channel_id: channel_id.to_proto(),
2387                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2388                    }),
2389            );
2390        update
2391            .channel_permissions
2392            .extend(
2393                result
2394                    .channels_with_admin_privileges
2395                    .into_iter()
2396                    .map(|channel_id| proto::ChannelPermission {
2397                        channel_id: channel_id.to_proto(),
2398                        is_admin: true,
2399                    }),
2400            );
2401    }
2402    session.peer.send(session.connection_id, update)?;
2403    response.send(proto::Ack {})?;
2404
2405    Ok(())
2406}
2407
2408async fn join_channel(
2409    request: proto::JoinChannel,
2410    response: Response<proto::JoinChannel>,
2411    session: Session,
2412) -> Result<()> {
2413    let channel_id = ChannelId::from_proto(request.channel_id);
2414
2415    let joined_room = {
2416        let db = session.db().await;
2417
2418        let room_id = db.room_id_for_channel(channel_id).await?;
2419
2420        let joined_room = db
2421            .join_room(
2422                room_id,
2423                session.user_id,
2424                Some(channel_id),
2425                session.connection_id,
2426            )
2427            .await?;
2428
2429        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2430            let token = live_kit
2431                .room_token(
2432                    &joined_room.room.live_kit_room,
2433                    &session.user_id.to_string(),
2434                )
2435                .trace_err()?;
2436
2437            Some(LiveKitConnectionInfo {
2438                server_url: live_kit.url().into(),
2439                token,
2440            })
2441        });
2442
2443        response.send(proto::JoinRoomResponse {
2444            room: Some(joined_room.room.clone()),
2445            live_kit_connection_info,
2446        })?;
2447
2448        room_updated(&joined_room.room, &session.peer);
2449
2450        joined_room.clone()
2451    };
2452
2453    // TODO - do this while still holding the room guard,
2454    // currently there's a possible race condition if someone joins the channel
2455    // after we've dropped the lock but before we finish sending these updates
2456    channel_updated(
2457        channel_id,
2458        &joined_room.room,
2459        &joined_room.channel_members,
2460        &session.peer,
2461        &*session.connection_pool().await,
2462    );
2463
2464    update_user_contacts(session.user_id, &session).await?;
2465
2466    Ok(())
2467}
2468
2469async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2470    let project_id = ProjectId::from_proto(request.project_id);
2471    let project_connection_ids = session
2472        .db()
2473        .await
2474        .project_connection_ids(project_id, session.connection_id)
2475        .await?;
2476    broadcast(
2477        Some(session.connection_id),
2478        project_connection_ids.iter().copied(),
2479        |connection_id| {
2480            session
2481                .peer
2482                .forward_send(session.connection_id, connection_id, request.clone())
2483        },
2484    );
2485    Ok(())
2486}
2487
2488async fn get_private_user_info(
2489    _request: proto::GetPrivateUserInfo,
2490    response: Response<proto::GetPrivateUserInfo>,
2491    session: Session,
2492) -> Result<()> {
2493    let metrics_id = session
2494        .db()
2495        .await
2496        .get_user_metrics_id(session.user_id)
2497        .await?;
2498    let user = session
2499        .db()
2500        .await
2501        .get_user_by_id(session.user_id)
2502        .await?
2503        .ok_or_else(|| anyhow!("user not found"))?;
2504    response.send(proto::GetPrivateUserInfoResponse {
2505        metrics_id,
2506        staff: user.admin,
2507    })?;
2508    Ok(())
2509}
2510
2511fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2512    match message {
2513        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2514        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2515        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2516        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2517        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2518            code: frame.code.into(),
2519            reason: frame.reason,
2520        })),
2521    }
2522}
2523
2524fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2525    match message {
2526        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2527        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2528        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2529        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2530        AxumMessage::Close(frame) => {
2531            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2532                code: frame.code.into(),
2533                reason: frame.reason,
2534            }))
2535        }
2536    }
2537}
2538
2539fn build_initial_channels_update(
2540    channels: Vec<db::Channel>,
2541    channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2542    channel_invites: Vec<db::Channel>,
2543) -> proto::UpdateChannels {
2544    let mut update = proto::UpdateChannels::default();
2545
2546    for channel in channels {
2547        update.channels.push(proto::Channel {
2548            id: channel.id.to_proto(),
2549            name: channel.name,
2550            parent_id: channel.parent_id.map(|id| id.to_proto()),
2551        });
2552    }
2553
2554    for (channel_id, participants) in channel_participants {
2555        update
2556            .channel_participants
2557            .push(proto::ChannelParticipants {
2558                channel_id: channel_id.to_proto(),
2559                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2560            });
2561    }
2562
2563    for channel in channel_invites {
2564        update.channel_invitations.push(proto::Channel {
2565            id: channel.id.to_proto(),
2566            name: channel.name,
2567            parent_id: None,
2568        });
2569    }
2570
2571    update
2572}
2573
2574fn build_initial_contacts_update(
2575    contacts: Vec<db::Contact>,
2576    pool: &ConnectionPool,
2577) -> proto::UpdateContacts {
2578    let mut update = proto::UpdateContacts::default();
2579
2580    for contact in contacts {
2581        match contact {
2582            db::Contact::Accepted {
2583                user_id,
2584                should_notify,
2585                busy,
2586            } => {
2587                update
2588                    .contacts
2589                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2590            }
2591            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2592            db::Contact::Incoming {
2593                user_id,
2594                should_notify,
2595            } => update
2596                .incoming_requests
2597                .push(proto::IncomingContactRequest {
2598                    requester_id: user_id.to_proto(),
2599                    should_notify,
2600                }),
2601        }
2602    }
2603
2604    update
2605}
2606
2607fn contact_for_user(
2608    user_id: UserId,
2609    should_notify: bool,
2610    busy: bool,
2611    pool: &ConnectionPool,
2612) -> proto::Contact {
2613    proto::Contact {
2614        user_id: user_id.to_proto(),
2615        online: pool.is_user_online(user_id),
2616        busy,
2617        should_notify,
2618    }
2619}
2620
2621fn room_updated(room: &proto::Room, peer: &Peer) {
2622    broadcast(
2623        None,
2624        room.participants
2625            .iter()
2626            .filter_map(|participant| Some(participant.peer_id?.into())),
2627        |peer_id| {
2628            peer.send(
2629                peer_id.into(),
2630                proto::RoomUpdated {
2631                    room: Some(room.clone()),
2632                },
2633            )
2634        },
2635    );
2636}
2637
2638fn channel_updated(
2639    channel_id: ChannelId,
2640    room: &proto::Room,
2641    channel_members: &[UserId],
2642    peer: &Peer,
2643    pool: &ConnectionPool,
2644) {
2645    let participants = room
2646        .participants
2647        .iter()
2648        .map(|p| p.user_id)
2649        .collect::<Vec<_>>();
2650
2651    broadcast(
2652        None,
2653        channel_members
2654            .iter()
2655            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2656        |peer_id| {
2657            peer.send(
2658                peer_id.into(),
2659                proto::UpdateChannels {
2660                    channel_participants: vec![proto::ChannelParticipants {
2661                        channel_id: channel_id.to_proto(),
2662                        participant_user_ids: participants.clone(),
2663                    }],
2664                    ..Default::default()
2665                },
2666            )
2667        },
2668    );
2669}
2670
2671async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2672    let db = session.db().await;
2673
2674    let contacts = db.get_contacts(user_id).await?;
2675    let busy = db.is_user_busy(user_id).await?;
2676
2677    let pool = session.connection_pool().await;
2678    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2679    for contact in contacts {
2680        if let db::Contact::Accepted {
2681            user_id: contact_user_id,
2682            ..
2683        } = contact
2684        {
2685            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2686                session
2687                    .peer
2688                    .send(
2689                        contact_conn_id,
2690                        proto::UpdateContacts {
2691                            contacts: vec![updated_contact.clone()],
2692                            remove_contacts: Default::default(),
2693                            incoming_requests: Default::default(),
2694                            remove_incoming_requests: Default::default(),
2695                            outgoing_requests: Default::default(),
2696                            remove_outgoing_requests: Default::default(),
2697                        },
2698                    )
2699                    .trace_err();
2700            }
2701        }
2702    }
2703    Ok(())
2704}
2705
2706async fn leave_room_for_session(session: &Session) -> Result<()> {
2707    let mut contacts_to_update = HashSet::default();
2708
2709    let room_id;
2710    let canceled_calls_to_user_ids;
2711    let live_kit_room;
2712    let delete_live_kit_room;
2713    let room;
2714    let channel_members;
2715    let channel_id;
2716
2717    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2718        contacts_to_update.insert(session.user_id);
2719
2720        for project in left_room.left_projects.values() {
2721            project_left(project, session);
2722        }
2723
2724        room_id = RoomId::from_proto(left_room.room.id);
2725        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2726        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2727        delete_live_kit_room = left_room.deleted;
2728        room = mem::take(&mut left_room.room);
2729        channel_members = mem::take(&mut left_room.channel_members);
2730        channel_id = left_room.channel_id;
2731
2732        room_updated(&room, &session.peer);
2733    } else {
2734        return Ok(());
2735    }
2736
2737    // TODO - do this while holding the room guard.
2738    if let Some(channel_id) = channel_id {
2739        channel_updated(
2740            channel_id,
2741            &room,
2742            &channel_members,
2743            &session.peer,
2744            &*session.connection_pool().await,
2745        );
2746    }
2747
2748    {
2749        let pool = session.connection_pool().await;
2750        for canceled_user_id in canceled_calls_to_user_ids {
2751            for connection_id in pool.user_connection_ids(canceled_user_id) {
2752                session
2753                    .peer
2754                    .send(
2755                        connection_id,
2756                        proto::CallCanceled {
2757                            room_id: room_id.to_proto(),
2758                        },
2759                    )
2760                    .trace_err();
2761            }
2762            contacts_to_update.insert(canceled_user_id);
2763        }
2764    }
2765
2766    for contact_user_id in contacts_to_update {
2767        update_user_contacts(contact_user_id, &session).await?;
2768    }
2769
2770    if let Some(live_kit) = session.live_kit_client.as_ref() {
2771        live_kit
2772            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2773            .await
2774            .trace_err();
2775
2776        if delete_live_kit_room {
2777            live_kit.delete_room(live_kit_room).await.trace_err();
2778        }
2779    }
2780
2781    Ok(())
2782}
2783
2784fn project_left(project: &db::LeftProject, session: &Session) {
2785    for connection_id in &project.connection_ids {
2786        if project.host_user_id == session.user_id {
2787            session
2788                .peer
2789                .send(
2790                    *connection_id,
2791                    proto::UnshareProject {
2792                        project_id: project.id.to_proto(),
2793                    },
2794                )
2795                .trace_err();
2796        } else {
2797            session
2798                .peer
2799                .send(
2800                    *connection_id,
2801                    proto::RemoveProjectCollaborator {
2802                        project_id: project.id.to_proto(),
2803                        peer_id: Some(session.connection_id.into()),
2804                    },
2805                )
2806                .trace_err();
2807        }
2808    }
2809}
2810
2811pub trait ResultExt {
2812    type Ok;
2813
2814    fn trace_err(self) -> Option<Self::Ok>;
2815}
2816
2817impl<T, E> ResultExt for Result<T, E>
2818where
2819    E: std::fmt::Debug,
2820{
2821    type Ok = T;
2822
2823    fn trace_err(self) -> Option<T> {
2824        match self {
2825            Ok(value) => Some(value),
2826            Err(error) => {
2827                tracing::error!("{:?}", error);
2828                None
2829            }
2830        }
2831    }
2832}