rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  39        RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(set_channel_member_admin)
 250            .add_request_handler(rename_channel)
 251            .add_request_handler(get_channel_members)
 252            .add_request_handler(respond_to_channel_invite)
 253            .add_request_handler(join_channel)
 254            .add_request_handler(follow)
 255            .add_message_handler(unfollow)
 256            .add_message_handler(update_followers)
 257            .add_message_handler(update_diff_base)
 258            .add_request_handler(get_private_user_info);
 259
 260        Arc::new(server)
 261    }
 262
 263    pub async fn start(&self) -> Result<()> {
 264        let server_id = *self.id.lock();
 265        let app_state = self.app_state.clone();
 266        let peer = self.peer.clone();
 267        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 268        let pool = self.connection_pool.clone();
 269        let live_kit_client = self.app_state.live_kit_client.clone();
 270
 271        let span = info_span!("start server");
 272        self.executor.spawn_detached(
 273            async move {
 274                tracing::info!("waiting for cleanup timeout");
 275                timeout.await;
 276                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 277                if let Some(room_ids) = app_state
 278                    .db
 279                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 280                    .await
 281                    .trace_err()
 282                {
 283                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 284                    for room_id in room_ids {
 285                        let mut contacts_to_update = HashSet::default();
 286                        let mut canceled_calls_to_user_ids = Vec::new();
 287                        let mut live_kit_room = String::new();
 288                        let mut delete_live_kit_room = false;
 289
 290                        if let Some(mut refreshed_room) = app_state
 291                            .db
 292                            .refresh_room(room_id, server_id)
 293                            .await
 294                            .trace_err()
 295                        {
 296                            tracing::info!(
 297                                room_id = room_id.0,
 298                                new_participant_count = refreshed_room.room.participants.len(),
 299                                "refreshed room"
 300                            );
 301                            room_updated(&refreshed_room.room, &peer);
 302                            if let Some(channel_id) = refreshed_room.channel_id {
 303                                channel_updated(
 304                                    channel_id,
 305                                    &refreshed_room.room,
 306                                    &refreshed_room.channel_members,
 307                                    &peer,
 308                                    &*pool.lock(),
 309                                );
 310                            }
 311                            contacts_to_update
 312                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 313                            contacts_to_update
 314                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 315                            canceled_calls_to_user_ids =
 316                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 317                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 318                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 319                        }
 320
 321                        {
 322                            let pool = pool.lock();
 323                            for canceled_user_id in canceled_calls_to_user_ids {
 324                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 325                                    peer.send(
 326                                        connection_id,
 327                                        proto::CallCanceled {
 328                                            room_id: room_id.to_proto(),
 329                                        },
 330                                    )
 331                                    .trace_err();
 332                                }
 333                            }
 334                        }
 335
 336                        for user_id in contacts_to_update {
 337                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 338                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 339                            if let Some((busy, contacts)) = busy.zip(contacts) {
 340                                let pool = pool.lock();
 341                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 342                                for contact in contacts {
 343                                    if let db::Contact::Accepted {
 344                                        user_id: contact_user_id,
 345                                        ..
 346                                    } = contact
 347                                    {
 348                                        for contact_conn_id in
 349                                            pool.user_connection_ids(contact_user_id)
 350                                        {
 351                                            peer.send(
 352                                                contact_conn_id,
 353                                                proto::UpdateContacts {
 354                                                    contacts: vec![updated_contact.clone()],
 355                                                    remove_contacts: Default::default(),
 356                                                    incoming_requests: Default::default(),
 357                                                    remove_incoming_requests: Default::default(),
 358                                                    outgoing_requests: Default::default(),
 359                                                    remove_outgoing_requests: Default::default(),
 360                                                },
 361                                            )
 362                                            .trace_err();
 363                                        }
 364                                    }
 365                                }
 366                            }
 367                        }
 368
 369                        if let Some(live_kit) = live_kit_client.as_ref() {
 370                            if delete_live_kit_room {
 371                                live_kit.delete_room(live_kit_room).await.trace_err();
 372                            }
 373                        }
 374                    }
 375                }
 376
 377                app_state
 378                    .db
 379                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 380                    .await
 381                    .trace_err();
 382            }
 383            .instrument(span),
 384        );
 385        Ok(())
 386    }
 387
 388    pub fn teardown(&self) {
 389        self.peer.teardown();
 390        self.connection_pool.lock().reset();
 391        let _ = self.teardown.send(());
 392    }
 393
 394    #[cfg(test)]
 395    pub fn reset(&self, id: ServerId) {
 396        self.teardown();
 397        *self.id.lock() = id;
 398        self.peer.reset(id.0 as u32);
 399    }
 400
 401    #[cfg(test)]
 402    pub fn id(&self) -> ServerId {
 403        *self.id.lock()
 404    }
 405
 406    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 407    where
 408        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 409        Fut: 'static + Send + Future<Output = Result<()>>,
 410        M: EnvelopedMessage,
 411    {
 412        let prev_handler = self.handlers.insert(
 413            TypeId::of::<M>(),
 414            Box::new(move |envelope, session| {
 415                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 416                let span = info_span!(
 417                    "handle message",
 418                    payload_type = envelope.payload_type_name()
 419                );
 420                span.in_scope(|| {
 421                    tracing::info!(
 422                        payload_type = envelope.payload_type_name(),
 423                        "message received"
 424                    );
 425                });
 426                let start_time = Instant::now();
 427                let future = (handler)(*envelope, session);
 428                async move {
 429                    let result = future.await;
 430                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 431                    match result {
 432                        Err(error) => {
 433                            tracing::error!(%error, ?duration_ms, "error handling message")
 434                        }
 435                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 436                    }
 437                }
 438                .instrument(span)
 439                .boxed()
 440            }),
 441        );
 442        if prev_handler.is_some() {
 443            panic!("registered a handler for the same message twice");
 444        }
 445        self
 446    }
 447
 448    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 449    where
 450        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 451        Fut: 'static + Send + Future<Output = Result<()>>,
 452        M: EnvelopedMessage,
 453    {
 454        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 455        self
 456    }
 457
 458    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 459    where
 460        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 461        Fut: Send + Future<Output = Result<()>>,
 462        M: RequestMessage,
 463    {
 464        let handler = Arc::new(handler);
 465        self.add_handler(move |envelope, session| {
 466            let receipt = envelope.receipt();
 467            let handler = handler.clone();
 468            async move {
 469                let peer = session.peer.clone();
 470                let responded = Arc::new(AtomicBool::default());
 471                let response = Response {
 472                    peer: peer.clone(),
 473                    responded: responded.clone(),
 474                    receipt,
 475                };
 476                match (handler)(envelope.payload, response, session).await {
 477                    Ok(()) => {
 478                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 479                            Ok(())
 480                        } else {
 481                            Err(anyhow!("handler did not send a response"))?
 482                        }
 483                    }
 484                    Err(error) => {
 485                        peer.respond_with_error(
 486                            receipt,
 487                            proto::Error {
 488                                message: error.to_string(),
 489                            },
 490                        )?;
 491                        Err(error)
 492                    }
 493                }
 494            }
 495        })
 496    }
 497
 498    pub fn handle_connection(
 499        self: &Arc<Self>,
 500        connection: Connection,
 501        address: String,
 502        user: User,
 503        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 504        executor: Executor,
 505    ) -> impl Future<Output = Result<()>> {
 506        let this = self.clone();
 507        let user_id = user.id;
 508        let login = user.github_login;
 509        let span = info_span!("handle connection", %user_id, %login, %address);
 510        let mut teardown = self.teardown.subscribe();
 511        async move {
 512            let (connection_id, handle_io, mut incoming_rx) = this
 513                .peer
 514                .add_connection(connection, {
 515                    let executor = executor.clone();
 516                    move |duration| executor.sleep(duration)
 517                });
 518
 519            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 520            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 521            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 522
 523            if let Some(send_connection_id) = send_connection_id.take() {
 524                let _ = send_connection_id.send(connection_id);
 525            }
 526
 527            if !user.connected_once {
 528                this.peer.send(connection_id, proto::ShowContacts {})?;
 529                this.app_state.db.set_user_connected_once(user_id, true).await?;
 530            }
 531
 532            let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
 533                this.app_state.db.get_contacts(user_id),
 534                this.app_state.db.get_invite_code_for_user(user_id),
 535                this.app_state.db.get_channels_for_user(user_id),
 536                this.app_state.db.get_channel_invites_for_user(user_id)
 537            ).await?;
 538
 539            {
 540                let mut pool = this.connection_pool.lock();
 541                pool.add_connection(connection_id, user_id, user.admin);
 542                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 543                this.peer.send(connection_id, build_initial_channels_update(
 544                    channels_for_user,
 545                    channel_invites
 546                ))?;
 547
 548                if let Some((code, count)) = invite_code {
 549                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 550                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 551                        count: count as u32,
 552                    })?;
 553                }
 554            }
 555
 556            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 557                this.peer.send(connection_id, incoming_call)?;
 558            }
 559
 560            let session = Session {
 561                user_id,
 562                connection_id,
 563                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 564                peer: this.peer.clone(),
 565                connection_pool: this.connection_pool.clone(),
 566                live_kit_client: this.app_state.live_kit_client.clone(),
 567                executor: executor.clone(),
 568            };
 569            update_user_contacts(user_id, &session).await?;
 570
 571            let handle_io = handle_io.fuse();
 572            futures::pin_mut!(handle_io);
 573
 574            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 575            // This prevents deadlocks when e.g., client A performs a request to client B and
 576            // client B performs a request to client A. If both clients stop processing further
 577            // messages until their respective request completes, they won't have a chance to
 578            // respond to the other client's request and cause a deadlock.
 579            //
 580            // This arrangement ensures we will attempt to process earlier messages first, but fall
 581            // back to processing messages arrived later in the spirit of making progress.
 582            let mut foreground_message_handlers = FuturesUnordered::new();
 583            let concurrent_handlers = Arc::new(Semaphore::new(256));
 584            loop {
 585                let next_message = async {
 586                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 587                    let message = incoming_rx.next().await;
 588                    (permit, message)
 589                }.fuse();
 590                futures::pin_mut!(next_message);
 591                futures::select_biased! {
 592                    _ = teardown.changed().fuse() => return Ok(()),
 593                    result = handle_io => {
 594                        if let Err(error) = result {
 595                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 596                        }
 597                        break;
 598                    }
 599                    _ = foreground_message_handlers.next() => {}
 600                    next_message = next_message => {
 601                        let (permit, message) = next_message;
 602                        if let Some(message) = message {
 603                            let type_name = message.payload_type_name();
 604                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 605                            let span_enter = span.enter();
 606                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 607                                let is_background = message.is_background();
 608                                let handle_message = (handler)(message, session.clone());
 609                                drop(span_enter);
 610
 611                                let handle_message = async move {
 612                                    handle_message.await;
 613                                    drop(permit);
 614                                }.instrument(span);
 615                                if is_background {
 616                                    executor.spawn_detached(handle_message);
 617                                } else {
 618                                    foreground_message_handlers.push(handle_message);
 619                                }
 620                            } else {
 621                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 622                            }
 623                        } else {
 624                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 625                            break;
 626                        }
 627                    }
 628                }
 629            }
 630
 631            drop(foreground_message_handlers);
 632            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 633            if let Err(error) = connection_lost(session, teardown, executor).await {
 634                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 635            }
 636
 637            Ok(())
 638        }.instrument(span)
 639    }
 640
 641    pub async fn invite_code_redeemed(
 642        self: &Arc<Self>,
 643        inviter_id: UserId,
 644        invitee_id: UserId,
 645    ) -> Result<()> {
 646        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 647            if let Some(code) = &user.invite_code {
 648                let pool = self.connection_pool.lock();
 649                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 650                for connection_id in pool.user_connection_ids(inviter_id) {
 651                    self.peer.send(
 652                        connection_id,
 653                        proto::UpdateContacts {
 654                            contacts: vec![invitee_contact.clone()],
 655                            ..Default::default()
 656                        },
 657                    )?;
 658                    self.peer.send(
 659                        connection_id,
 660                        proto::UpdateInviteInfo {
 661                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 662                            count: user.invite_count as u32,
 663                        },
 664                    )?;
 665                }
 666            }
 667        }
 668        Ok(())
 669    }
 670
 671    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 672        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 673            if let Some(invite_code) = &user.invite_code {
 674                let pool = self.connection_pool.lock();
 675                for connection_id in pool.user_connection_ids(user_id) {
 676                    self.peer.send(
 677                        connection_id,
 678                        proto::UpdateInviteInfo {
 679                            url: format!(
 680                                "{}{}",
 681                                self.app_state.config.invite_link_prefix, invite_code
 682                            ),
 683                            count: user.invite_count as u32,
 684                        },
 685                    )?;
 686                }
 687            }
 688        }
 689        Ok(())
 690    }
 691
 692    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 693        ServerSnapshot {
 694            connection_pool: ConnectionPoolGuard {
 695                guard: self.connection_pool.lock(),
 696                _not_send: PhantomData,
 697            },
 698            peer: &self.peer,
 699        }
 700    }
 701}
 702
 703impl<'a> Deref for ConnectionPoolGuard<'a> {
 704    type Target = ConnectionPool;
 705
 706    fn deref(&self) -> &Self::Target {
 707        &*self.guard
 708    }
 709}
 710
 711impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 712    fn deref_mut(&mut self) -> &mut Self::Target {
 713        &mut *self.guard
 714    }
 715}
 716
 717impl<'a> Drop for ConnectionPoolGuard<'a> {
 718    fn drop(&mut self) {
 719        #[cfg(test)]
 720        self.check_invariants();
 721    }
 722}
 723
 724fn broadcast<F>(
 725    sender_id: Option<ConnectionId>,
 726    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 727    mut f: F,
 728) where
 729    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 730{
 731    for receiver_id in receiver_ids {
 732        if Some(receiver_id) != sender_id {
 733            if let Err(error) = f(receiver_id) {
 734                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 735            }
 736        }
 737    }
 738}
 739
 740lazy_static! {
 741    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 742}
 743
 744pub struct ProtocolVersion(u32);
 745
 746impl Header for ProtocolVersion {
 747    fn name() -> &'static HeaderName {
 748        &ZED_PROTOCOL_VERSION
 749    }
 750
 751    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 752    where
 753        Self: Sized,
 754        I: Iterator<Item = &'i axum::http::HeaderValue>,
 755    {
 756        let version = values
 757            .next()
 758            .ok_or_else(axum::headers::Error::invalid)?
 759            .to_str()
 760            .map_err(|_| axum::headers::Error::invalid())?
 761            .parse()
 762            .map_err(|_| axum::headers::Error::invalid())?;
 763        Ok(Self(version))
 764    }
 765
 766    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 767        values.extend([self.0.to_string().parse().unwrap()]);
 768    }
 769}
 770
 771pub fn routes(server: Arc<Server>) -> Router<Body> {
 772    Router::new()
 773        .route("/rpc", get(handle_websocket_request))
 774        .layer(
 775            ServiceBuilder::new()
 776                .layer(Extension(server.app_state.clone()))
 777                .layer(middleware::from_fn(auth::validate_header)),
 778        )
 779        .route("/metrics", get(handle_metrics))
 780        .layer(Extension(server))
 781}
 782
 783pub async fn handle_websocket_request(
 784    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 785    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 786    Extension(server): Extension<Arc<Server>>,
 787    Extension(user): Extension<User>,
 788    ws: WebSocketUpgrade,
 789) -> axum::response::Response {
 790    if protocol_version != rpc::PROTOCOL_VERSION {
 791        return (
 792            StatusCode::UPGRADE_REQUIRED,
 793            "client must be upgraded".to_string(),
 794        )
 795            .into_response();
 796    }
 797    let socket_address = socket_address.to_string();
 798    ws.on_upgrade(move |socket| {
 799        use util::ResultExt;
 800        let socket = socket
 801            .map_ok(to_tungstenite_message)
 802            .err_into()
 803            .with(|message| async move { Ok(to_axum_message(message)) });
 804        let connection = Connection::new(Box::pin(socket));
 805        async move {
 806            server
 807                .handle_connection(connection, socket_address, user, None, Executor::Production)
 808                .await
 809                .log_err();
 810        }
 811    })
 812}
 813
 814pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 815    let connections = server
 816        .connection_pool
 817        .lock()
 818        .connections()
 819        .filter(|connection| !connection.admin)
 820        .count();
 821
 822    METRIC_CONNECTIONS.set(connections as _);
 823
 824    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 825    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 826
 827    let encoder = prometheus::TextEncoder::new();
 828    let metric_families = prometheus::gather();
 829    let encoded_metrics = encoder
 830        .encode_to_string(&metric_families)
 831        .map_err(|err| anyhow!("{}", err))?;
 832    Ok(encoded_metrics)
 833}
 834
 835#[instrument(err, skip(executor))]
 836async fn connection_lost(
 837    session: Session,
 838    mut teardown: watch::Receiver<()>,
 839    executor: Executor,
 840) -> Result<()> {
 841    session.peer.disconnect(session.connection_id);
 842    session
 843        .connection_pool()
 844        .await
 845        .remove_connection(session.connection_id)?;
 846
 847    session
 848        .db()
 849        .await
 850        .connection_lost(session.connection_id)
 851        .await
 852        .trace_err();
 853
 854    futures::select_biased! {
 855        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 856            leave_room_for_session(&session).await.trace_err();
 857
 858            if !session
 859                .connection_pool()
 860                .await
 861                .is_user_online(session.user_id)
 862            {
 863                let db = session.db().await;
 864                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 865                    room_updated(&room, &session.peer);
 866                }
 867            }
 868            update_user_contacts(session.user_id, &session).await?;
 869        }
 870        _ = teardown.changed().fuse() => {}
 871    }
 872
 873    Ok(())
 874}
 875
 876async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 877    response.send(proto::Ack {})?;
 878    Ok(())
 879}
 880
 881async fn create_room(
 882    _request: proto::CreateRoom,
 883    response: Response<proto::CreateRoom>,
 884    session: Session,
 885) -> Result<()> {
 886    let live_kit_room = nanoid::nanoid!(30);
 887
 888    let live_kit_connection_info = {
 889        let live_kit_room = live_kit_room.clone();
 890        let live_kit = session.live_kit_client.as_ref();
 891
 892        util::async_iife!({
 893            let live_kit = live_kit?;
 894
 895            live_kit
 896                .create_room(live_kit_room.clone())
 897                .await
 898                .trace_err()?;
 899
 900            let token = live_kit
 901                .room_token(&live_kit_room, &session.user_id.to_string())
 902                .trace_err()?;
 903
 904            Some(proto::LiveKitConnectionInfo {
 905                server_url: live_kit.url().into(),
 906                token,
 907            })
 908        })
 909    }
 910    .await;
 911
 912    let room = session
 913        .db()
 914        .await
 915        .create_room(session.user_id, session.connection_id, &live_kit_room)
 916        .await?;
 917
 918    response.send(proto::CreateRoomResponse {
 919        room: Some(room.clone()),
 920        live_kit_connection_info,
 921    })?;
 922
 923    update_user_contacts(session.user_id, &session).await?;
 924    Ok(())
 925}
 926
 927async fn join_room(
 928    request: proto::JoinRoom,
 929    response: Response<proto::JoinRoom>,
 930    session: Session,
 931) -> Result<()> {
 932    let room_id = RoomId::from_proto(request.id);
 933    let joined_room = {
 934        let room = session
 935            .db()
 936            .await
 937            .join_room(room_id, session.user_id, session.connection_id)
 938            .await?;
 939        room_updated(&room.room, &session.peer);
 940        room.into_inner()
 941    };
 942
 943    if let Some(channel_id) = joined_room.channel_id {
 944        channel_updated(
 945            channel_id,
 946            &joined_room.room,
 947            &joined_room.channel_members,
 948            &session.peer,
 949            &*session.connection_pool().await,
 950        )
 951    }
 952
 953    for connection_id in session
 954        .connection_pool()
 955        .await
 956        .user_connection_ids(session.user_id)
 957    {
 958        session
 959            .peer
 960            .send(
 961                connection_id,
 962                proto::CallCanceled {
 963                    room_id: room_id.to_proto(),
 964                },
 965            )
 966            .trace_err();
 967    }
 968
 969    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 970        if let Some(token) = live_kit
 971            .room_token(
 972                &joined_room.room.live_kit_room,
 973                &session.user_id.to_string(),
 974            )
 975            .trace_err()
 976        {
 977            Some(proto::LiveKitConnectionInfo {
 978                server_url: live_kit.url().into(),
 979                token,
 980            })
 981        } else {
 982            None
 983        }
 984    } else {
 985        None
 986    };
 987
 988    response.send(proto::JoinRoomResponse {
 989        room: Some(joined_room.room),
 990        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
 991        live_kit_connection_info,
 992    })?;
 993
 994    update_user_contacts(session.user_id, &session).await?;
 995    Ok(())
 996}
 997
 998async fn rejoin_room(
 999    request: proto::RejoinRoom,
1000    response: Response<proto::RejoinRoom>,
1001    session: Session,
1002) -> Result<()> {
1003    let room;
1004    let channel_id;
1005    let channel_members;
1006    {
1007        let mut rejoined_room = session
1008            .db()
1009            .await
1010            .rejoin_room(request, session.user_id, session.connection_id)
1011            .await?;
1012
1013        response.send(proto::RejoinRoomResponse {
1014            room: Some(rejoined_room.room.clone()),
1015            reshared_projects: rejoined_room
1016                .reshared_projects
1017                .iter()
1018                .map(|project| proto::ResharedProject {
1019                    id: project.id.to_proto(),
1020                    collaborators: project
1021                        .collaborators
1022                        .iter()
1023                        .map(|collaborator| collaborator.to_proto())
1024                        .collect(),
1025                })
1026                .collect(),
1027            rejoined_projects: rejoined_room
1028                .rejoined_projects
1029                .iter()
1030                .map(|rejoined_project| proto::RejoinedProject {
1031                    id: rejoined_project.id.to_proto(),
1032                    worktrees: rejoined_project
1033                        .worktrees
1034                        .iter()
1035                        .map(|worktree| proto::WorktreeMetadata {
1036                            id: worktree.id,
1037                            root_name: worktree.root_name.clone(),
1038                            visible: worktree.visible,
1039                            abs_path: worktree.abs_path.clone(),
1040                        })
1041                        .collect(),
1042                    collaborators: rejoined_project
1043                        .collaborators
1044                        .iter()
1045                        .map(|collaborator| collaborator.to_proto())
1046                        .collect(),
1047                    language_servers: rejoined_project.language_servers.clone(),
1048                })
1049                .collect(),
1050        })?;
1051        room_updated(&rejoined_room.room, &session.peer);
1052
1053        for project in &rejoined_room.reshared_projects {
1054            for collaborator in &project.collaborators {
1055                session
1056                    .peer
1057                    .send(
1058                        collaborator.connection_id,
1059                        proto::UpdateProjectCollaborator {
1060                            project_id: project.id.to_proto(),
1061                            old_peer_id: Some(project.old_connection_id.into()),
1062                            new_peer_id: Some(session.connection_id.into()),
1063                        },
1064                    )
1065                    .trace_err();
1066            }
1067
1068            broadcast(
1069                Some(session.connection_id),
1070                project
1071                    .collaborators
1072                    .iter()
1073                    .map(|collaborator| collaborator.connection_id),
1074                |connection_id| {
1075                    session.peer.forward_send(
1076                        session.connection_id,
1077                        connection_id,
1078                        proto::UpdateProject {
1079                            project_id: project.id.to_proto(),
1080                            worktrees: project.worktrees.clone(),
1081                        },
1082                    )
1083                },
1084            );
1085        }
1086
1087        for project in &rejoined_room.rejoined_projects {
1088            for collaborator in &project.collaborators {
1089                session
1090                    .peer
1091                    .send(
1092                        collaborator.connection_id,
1093                        proto::UpdateProjectCollaborator {
1094                            project_id: project.id.to_proto(),
1095                            old_peer_id: Some(project.old_connection_id.into()),
1096                            new_peer_id: Some(session.connection_id.into()),
1097                        },
1098                    )
1099                    .trace_err();
1100            }
1101        }
1102
1103        for project in &mut rejoined_room.rejoined_projects {
1104            for worktree in mem::take(&mut project.worktrees) {
1105                #[cfg(any(test, feature = "test-support"))]
1106                const MAX_CHUNK_SIZE: usize = 2;
1107                #[cfg(not(any(test, feature = "test-support")))]
1108                const MAX_CHUNK_SIZE: usize = 256;
1109
1110                // Stream this worktree's entries.
1111                let message = proto::UpdateWorktree {
1112                    project_id: project.id.to_proto(),
1113                    worktree_id: worktree.id,
1114                    abs_path: worktree.abs_path.clone(),
1115                    root_name: worktree.root_name,
1116                    updated_entries: worktree.updated_entries,
1117                    removed_entries: worktree.removed_entries,
1118                    scan_id: worktree.scan_id,
1119                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1120                    updated_repositories: worktree.updated_repositories,
1121                    removed_repositories: worktree.removed_repositories,
1122                };
1123                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1124                    session.peer.send(session.connection_id, update.clone())?;
1125                }
1126
1127                // Stream this worktree's diagnostics.
1128                for summary in worktree.diagnostic_summaries {
1129                    session.peer.send(
1130                        session.connection_id,
1131                        proto::UpdateDiagnosticSummary {
1132                            project_id: project.id.to_proto(),
1133                            worktree_id: worktree.id,
1134                            summary: Some(summary),
1135                        },
1136                    )?;
1137                }
1138
1139                for settings_file in worktree.settings_files {
1140                    session.peer.send(
1141                        session.connection_id,
1142                        proto::UpdateWorktreeSettings {
1143                            project_id: project.id.to_proto(),
1144                            worktree_id: worktree.id,
1145                            path: settings_file.path,
1146                            content: Some(settings_file.content),
1147                        },
1148                    )?;
1149                }
1150            }
1151
1152            for language_server in &project.language_servers {
1153                session.peer.send(
1154                    session.connection_id,
1155                    proto::UpdateLanguageServer {
1156                        project_id: project.id.to_proto(),
1157                        language_server_id: language_server.id,
1158                        variant: Some(
1159                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1160                                proto::LspDiskBasedDiagnosticsUpdated {},
1161                            ),
1162                        ),
1163                    },
1164                )?;
1165            }
1166        }
1167
1168        let rejoined_room = rejoined_room.into_inner();
1169
1170        room = rejoined_room.room;
1171        channel_id = rejoined_room.channel_id;
1172        channel_members = rejoined_room.channel_members;
1173    }
1174
1175    if let Some(channel_id) = channel_id {
1176        channel_updated(
1177            channel_id,
1178            &room,
1179            &channel_members,
1180            &session.peer,
1181            &*session.connection_pool().await,
1182        );
1183    }
1184
1185    update_user_contacts(session.user_id, &session).await?;
1186    Ok(())
1187}
1188
1189async fn leave_room(
1190    _: proto::LeaveRoom,
1191    response: Response<proto::LeaveRoom>,
1192    session: Session,
1193) -> Result<()> {
1194    leave_room_for_session(&session).await?;
1195    response.send(proto::Ack {})?;
1196    Ok(())
1197}
1198
1199async fn call(
1200    request: proto::Call,
1201    response: Response<proto::Call>,
1202    session: Session,
1203) -> Result<()> {
1204    let room_id = RoomId::from_proto(request.room_id);
1205    let calling_user_id = session.user_id;
1206    let calling_connection_id = session.connection_id;
1207    let called_user_id = UserId::from_proto(request.called_user_id);
1208    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1209    if !session
1210        .db()
1211        .await
1212        .has_contact(calling_user_id, called_user_id)
1213        .await?
1214    {
1215        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1216    }
1217
1218    let incoming_call = {
1219        let (room, incoming_call) = &mut *session
1220            .db()
1221            .await
1222            .call(
1223                room_id,
1224                calling_user_id,
1225                calling_connection_id,
1226                called_user_id,
1227                initial_project_id,
1228            )
1229            .await?;
1230        room_updated(&room, &session.peer);
1231        mem::take(incoming_call)
1232    };
1233    update_user_contacts(called_user_id, &session).await?;
1234
1235    let mut calls = session
1236        .connection_pool()
1237        .await
1238        .user_connection_ids(called_user_id)
1239        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1240        .collect::<FuturesUnordered<_>>();
1241
1242    while let Some(call_response) = calls.next().await {
1243        match call_response.as_ref() {
1244            Ok(_) => {
1245                response.send(proto::Ack {})?;
1246                return Ok(());
1247            }
1248            Err(_) => {
1249                call_response.trace_err();
1250            }
1251        }
1252    }
1253
1254    {
1255        let room = session
1256            .db()
1257            .await
1258            .call_failed(room_id, called_user_id)
1259            .await?;
1260        room_updated(&room, &session.peer);
1261    }
1262    update_user_contacts(called_user_id, &session).await?;
1263
1264    Err(anyhow!("failed to ring user"))?
1265}
1266
1267async fn cancel_call(
1268    request: proto::CancelCall,
1269    response: Response<proto::CancelCall>,
1270    session: Session,
1271) -> Result<()> {
1272    let called_user_id = UserId::from_proto(request.called_user_id);
1273    let room_id = RoomId::from_proto(request.room_id);
1274    {
1275        let room = session
1276            .db()
1277            .await
1278            .cancel_call(room_id, session.connection_id, called_user_id)
1279            .await?;
1280        room_updated(&room, &session.peer);
1281    }
1282
1283    for connection_id in session
1284        .connection_pool()
1285        .await
1286        .user_connection_ids(called_user_id)
1287    {
1288        session
1289            .peer
1290            .send(
1291                connection_id,
1292                proto::CallCanceled {
1293                    room_id: room_id.to_proto(),
1294                },
1295            )
1296            .trace_err();
1297    }
1298    response.send(proto::Ack {})?;
1299
1300    update_user_contacts(called_user_id, &session).await?;
1301    Ok(())
1302}
1303
1304async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1305    let room_id = RoomId::from_proto(message.room_id);
1306    {
1307        let room = session
1308            .db()
1309            .await
1310            .decline_call(Some(room_id), session.user_id)
1311            .await?
1312            .ok_or_else(|| anyhow!("failed to decline call"))?;
1313        room_updated(&room, &session.peer);
1314    }
1315
1316    for connection_id in session
1317        .connection_pool()
1318        .await
1319        .user_connection_ids(session.user_id)
1320    {
1321        session
1322            .peer
1323            .send(
1324                connection_id,
1325                proto::CallCanceled {
1326                    room_id: room_id.to_proto(),
1327                },
1328            )
1329            .trace_err();
1330    }
1331    update_user_contacts(session.user_id, &session).await?;
1332    Ok(())
1333}
1334
1335async fn update_participant_location(
1336    request: proto::UpdateParticipantLocation,
1337    response: Response<proto::UpdateParticipantLocation>,
1338    session: Session,
1339) -> Result<()> {
1340    let room_id = RoomId::from_proto(request.room_id);
1341    let location = request
1342        .location
1343        .ok_or_else(|| anyhow!("invalid location"))?;
1344
1345    let db = session.db().await;
1346    let room = db
1347        .update_room_participant_location(room_id, session.connection_id, location)
1348        .await?;
1349
1350    room_updated(&room, &session.peer);
1351    response.send(proto::Ack {})?;
1352    Ok(())
1353}
1354
1355async fn share_project(
1356    request: proto::ShareProject,
1357    response: Response<proto::ShareProject>,
1358    session: Session,
1359) -> Result<()> {
1360    let (project_id, room) = &*session
1361        .db()
1362        .await
1363        .share_project(
1364            RoomId::from_proto(request.room_id),
1365            session.connection_id,
1366            &request.worktrees,
1367        )
1368        .await?;
1369    response.send(proto::ShareProjectResponse {
1370        project_id: project_id.to_proto(),
1371    })?;
1372    room_updated(&room, &session.peer);
1373
1374    Ok(())
1375}
1376
1377async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1378    let project_id = ProjectId::from_proto(message.project_id);
1379
1380    let (room, guest_connection_ids) = &*session
1381        .db()
1382        .await
1383        .unshare_project(project_id, session.connection_id)
1384        .await?;
1385
1386    broadcast(
1387        Some(session.connection_id),
1388        guest_connection_ids.iter().copied(),
1389        |conn_id| session.peer.send(conn_id, message.clone()),
1390    );
1391    room_updated(&room, &session.peer);
1392
1393    Ok(())
1394}
1395
1396async fn join_project(
1397    request: proto::JoinProject,
1398    response: Response<proto::JoinProject>,
1399    session: Session,
1400) -> Result<()> {
1401    let project_id = ProjectId::from_proto(request.project_id);
1402    let guest_user_id = session.user_id;
1403
1404    tracing::info!(%project_id, "join project");
1405
1406    let (project, replica_id) = &mut *session
1407        .db()
1408        .await
1409        .join_project(project_id, session.connection_id)
1410        .await?;
1411
1412    let collaborators = project
1413        .collaborators
1414        .iter()
1415        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1416        .map(|collaborator| collaborator.to_proto())
1417        .collect::<Vec<_>>();
1418
1419    let worktrees = project
1420        .worktrees
1421        .iter()
1422        .map(|(id, worktree)| proto::WorktreeMetadata {
1423            id: *id,
1424            root_name: worktree.root_name.clone(),
1425            visible: worktree.visible,
1426            abs_path: worktree.abs_path.clone(),
1427        })
1428        .collect::<Vec<_>>();
1429
1430    for collaborator in &collaborators {
1431        session
1432            .peer
1433            .send(
1434                collaborator.peer_id.unwrap().into(),
1435                proto::AddProjectCollaborator {
1436                    project_id: project_id.to_proto(),
1437                    collaborator: Some(proto::Collaborator {
1438                        peer_id: Some(session.connection_id.into()),
1439                        replica_id: replica_id.0 as u32,
1440                        user_id: guest_user_id.to_proto(),
1441                    }),
1442                },
1443            )
1444            .trace_err();
1445    }
1446
1447    // First, we send the metadata associated with each worktree.
1448    response.send(proto::JoinProjectResponse {
1449        worktrees: worktrees.clone(),
1450        replica_id: replica_id.0 as u32,
1451        collaborators: collaborators.clone(),
1452        language_servers: project.language_servers.clone(),
1453    })?;
1454
1455    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1456        #[cfg(any(test, feature = "test-support"))]
1457        const MAX_CHUNK_SIZE: usize = 2;
1458        #[cfg(not(any(test, feature = "test-support")))]
1459        const MAX_CHUNK_SIZE: usize = 256;
1460
1461        // Stream this worktree's entries.
1462        let message = proto::UpdateWorktree {
1463            project_id: project_id.to_proto(),
1464            worktree_id,
1465            abs_path: worktree.abs_path.clone(),
1466            root_name: worktree.root_name,
1467            updated_entries: worktree.entries,
1468            removed_entries: Default::default(),
1469            scan_id: worktree.scan_id,
1470            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1471            updated_repositories: worktree.repository_entries.into_values().collect(),
1472            removed_repositories: Default::default(),
1473        };
1474        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1475            session.peer.send(session.connection_id, update.clone())?;
1476        }
1477
1478        // Stream this worktree's diagnostics.
1479        for summary in worktree.diagnostic_summaries {
1480            session.peer.send(
1481                session.connection_id,
1482                proto::UpdateDiagnosticSummary {
1483                    project_id: project_id.to_proto(),
1484                    worktree_id: worktree.id,
1485                    summary: Some(summary),
1486                },
1487            )?;
1488        }
1489
1490        for settings_file in worktree.settings_files {
1491            session.peer.send(
1492                session.connection_id,
1493                proto::UpdateWorktreeSettings {
1494                    project_id: project_id.to_proto(),
1495                    worktree_id: worktree.id,
1496                    path: settings_file.path,
1497                    content: Some(settings_file.content),
1498                },
1499            )?;
1500        }
1501    }
1502
1503    for language_server in &project.language_servers {
1504        session.peer.send(
1505            session.connection_id,
1506            proto::UpdateLanguageServer {
1507                project_id: project_id.to_proto(),
1508                language_server_id: language_server.id,
1509                variant: Some(
1510                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1511                        proto::LspDiskBasedDiagnosticsUpdated {},
1512                    ),
1513                ),
1514            },
1515        )?;
1516    }
1517
1518    Ok(())
1519}
1520
1521async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1522    let sender_id = session.connection_id;
1523    let project_id = ProjectId::from_proto(request.project_id);
1524
1525    let (room, project) = &*session
1526        .db()
1527        .await
1528        .leave_project(project_id, sender_id)
1529        .await?;
1530    tracing::info!(
1531        %project_id,
1532        host_user_id = %project.host_user_id,
1533        host_connection_id = %project.host_connection_id,
1534        "leave project"
1535    );
1536
1537    project_left(&project, &session);
1538    room_updated(&room, &session.peer);
1539
1540    Ok(())
1541}
1542
1543async fn update_project(
1544    request: proto::UpdateProject,
1545    response: Response<proto::UpdateProject>,
1546    session: Session,
1547) -> Result<()> {
1548    let project_id = ProjectId::from_proto(request.project_id);
1549    let (room, guest_connection_ids) = &*session
1550        .db()
1551        .await
1552        .update_project(project_id, session.connection_id, &request.worktrees)
1553        .await?;
1554    broadcast(
1555        Some(session.connection_id),
1556        guest_connection_ids.iter().copied(),
1557        |connection_id| {
1558            session
1559                .peer
1560                .forward_send(session.connection_id, connection_id, request.clone())
1561        },
1562    );
1563    room_updated(&room, &session.peer);
1564    response.send(proto::Ack {})?;
1565
1566    Ok(())
1567}
1568
1569async fn update_worktree(
1570    request: proto::UpdateWorktree,
1571    response: Response<proto::UpdateWorktree>,
1572    session: Session,
1573) -> Result<()> {
1574    let guest_connection_ids = session
1575        .db()
1576        .await
1577        .update_worktree(&request, session.connection_id)
1578        .await?;
1579
1580    broadcast(
1581        Some(session.connection_id),
1582        guest_connection_ids.iter().copied(),
1583        |connection_id| {
1584            session
1585                .peer
1586                .forward_send(session.connection_id, connection_id, request.clone())
1587        },
1588    );
1589    response.send(proto::Ack {})?;
1590    Ok(())
1591}
1592
1593async fn update_diagnostic_summary(
1594    message: proto::UpdateDiagnosticSummary,
1595    session: Session,
1596) -> Result<()> {
1597    let guest_connection_ids = session
1598        .db()
1599        .await
1600        .update_diagnostic_summary(&message, session.connection_id)
1601        .await?;
1602
1603    broadcast(
1604        Some(session.connection_id),
1605        guest_connection_ids.iter().copied(),
1606        |connection_id| {
1607            session
1608                .peer
1609                .forward_send(session.connection_id, connection_id, message.clone())
1610        },
1611    );
1612
1613    Ok(())
1614}
1615
1616async fn update_worktree_settings(
1617    message: proto::UpdateWorktreeSettings,
1618    session: Session,
1619) -> Result<()> {
1620    let guest_connection_ids = session
1621        .db()
1622        .await
1623        .update_worktree_settings(&message, session.connection_id)
1624        .await?;
1625
1626    broadcast(
1627        Some(session.connection_id),
1628        guest_connection_ids.iter().copied(),
1629        |connection_id| {
1630            session
1631                .peer
1632                .forward_send(session.connection_id, connection_id, message.clone())
1633        },
1634    );
1635
1636    Ok(())
1637}
1638
1639async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1640    broadcast_project_message(request.project_id, request, session).await
1641}
1642
1643async fn start_language_server(
1644    request: proto::StartLanguageServer,
1645    session: Session,
1646) -> Result<()> {
1647    let guest_connection_ids = session
1648        .db()
1649        .await
1650        .start_language_server(&request, session.connection_id)
1651        .await?;
1652
1653    broadcast(
1654        Some(session.connection_id),
1655        guest_connection_ids.iter().copied(),
1656        |connection_id| {
1657            session
1658                .peer
1659                .forward_send(session.connection_id, connection_id, request.clone())
1660        },
1661    );
1662    Ok(())
1663}
1664
1665async fn update_language_server(
1666    request: proto::UpdateLanguageServer,
1667    session: Session,
1668) -> Result<()> {
1669    session.executor.record_backtrace();
1670    let project_id = ProjectId::from_proto(request.project_id);
1671    let project_connection_ids = session
1672        .db()
1673        .await
1674        .project_connection_ids(project_id, session.connection_id)
1675        .await?;
1676    broadcast(
1677        Some(session.connection_id),
1678        project_connection_ids.iter().copied(),
1679        |connection_id| {
1680            session
1681                .peer
1682                .forward_send(session.connection_id, connection_id, request.clone())
1683        },
1684    );
1685    Ok(())
1686}
1687
1688async fn forward_project_request<T>(
1689    request: T,
1690    response: Response<T>,
1691    session: Session,
1692) -> Result<()>
1693where
1694    T: EntityMessage + RequestMessage,
1695{
1696    session.executor.record_backtrace();
1697    let project_id = ProjectId::from_proto(request.remote_entity_id());
1698    let host_connection_id = {
1699        let collaborators = session
1700            .db()
1701            .await
1702            .project_collaborators(project_id, session.connection_id)
1703            .await?;
1704        collaborators
1705            .iter()
1706            .find(|collaborator| collaborator.is_host)
1707            .ok_or_else(|| anyhow!("host not found"))?
1708            .connection_id
1709    };
1710
1711    let payload = session
1712        .peer
1713        .forward_request(session.connection_id, host_connection_id, request)
1714        .await?;
1715
1716    response.send(payload)?;
1717    Ok(())
1718}
1719
1720async fn create_buffer_for_peer(
1721    request: proto::CreateBufferForPeer,
1722    session: Session,
1723) -> Result<()> {
1724    session.executor.record_backtrace();
1725    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1726    session
1727        .peer
1728        .forward_send(session.connection_id, peer_id.into(), request)?;
1729    Ok(())
1730}
1731
1732async fn update_buffer(
1733    request: proto::UpdateBuffer,
1734    response: Response<proto::UpdateBuffer>,
1735    session: Session,
1736) -> Result<()> {
1737    session.executor.record_backtrace();
1738    let project_id = ProjectId::from_proto(request.project_id);
1739    let mut guest_connection_ids;
1740    let mut host_connection_id = None;
1741    {
1742        let collaborators = session
1743            .db()
1744            .await
1745            .project_collaborators(project_id, session.connection_id)
1746            .await?;
1747        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1748        for collaborator in collaborators.iter() {
1749            if collaborator.is_host {
1750                host_connection_id = Some(collaborator.connection_id);
1751            } else {
1752                guest_connection_ids.push(collaborator.connection_id);
1753            }
1754        }
1755    }
1756    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1757
1758    session.executor.record_backtrace();
1759    broadcast(
1760        Some(session.connection_id),
1761        guest_connection_ids,
1762        |connection_id| {
1763            session
1764                .peer
1765                .forward_send(session.connection_id, connection_id, request.clone())
1766        },
1767    );
1768    if host_connection_id != session.connection_id {
1769        session
1770            .peer
1771            .forward_request(session.connection_id, host_connection_id, request.clone())
1772            .await?;
1773    }
1774
1775    response.send(proto::Ack {})?;
1776    Ok(())
1777}
1778
1779async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1780    let project_id = ProjectId::from_proto(request.project_id);
1781    let project_connection_ids = session
1782        .db()
1783        .await
1784        .project_connection_ids(project_id, session.connection_id)
1785        .await?;
1786
1787    broadcast(
1788        Some(session.connection_id),
1789        project_connection_ids.iter().copied(),
1790        |connection_id| {
1791            session
1792                .peer
1793                .forward_send(session.connection_id, connection_id, request.clone())
1794        },
1795    );
1796    Ok(())
1797}
1798
1799async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1800    let project_id = ProjectId::from_proto(request.project_id);
1801    let project_connection_ids = session
1802        .db()
1803        .await
1804        .project_connection_ids(project_id, session.connection_id)
1805        .await?;
1806    broadcast(
1807        Some(session.connection_id),
1808        project_connection_ids.iter().copied(),
1809        |connection_id| {
1810            session
1811                .peer
1812                .forward_send(session.connection_id, connection_id, request.clone())
1813        },
1814    );
1815    Ok(())
1816}
1817
1818async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1819    broadcast_project_message(request.project_id, request, session).await
1820}
1821
1822async fn broadcast_project_message<T: EnvelopedMessage>(
1823    project_id: u64,
1824    request: T,
1825    session: Session,
1826) -> Result<()> {
1827    let project_id = ProjectId::from_proto(project_id);
1828    let project_connection_ids = session
1829        .db()
1830        .await
1831        .project_connection_ids(project_id, session.connection_id)
1832        .await?;
1833    broadcast(
1834        Some(session.connection_id),
1835        project_connection_ids.iter().copied(),
1836        |connection_id| {
1837            session
1838                .peer
1839                .forward_send(session.connection_id, connection_id, request.clone())
1840        },
1841    );
1842    Ok(())
1843}
1844
1845async fn follow(
1846    request: proto::Follow,
1847    response: Response<proto::Follow>,
1848    session: Session,
1849) -> Result<()> {
1850    let project_id = ProjectId::from_proto(request.project_id);
1851    let leader_id = request
1852        .leader_id
1853        .ok_or_else(|| anyhow!("invalid leader id"))?
1854        .into();
1855    let follower_id = session.connection_id;
1856
1857    {
1858        let project_connection_ids = session
1859            .db()
1860            .await
1861            .project_connection_ids(project_id, session.connection_id)
1862            .await?;
1863
1864        if !project_connection_ids.contains(&leader_id) {
1865            Err(anyhow!("no such peer"))?;
1866        }
1867    }
1868
1869    let mut response_payload = session
1870        .peer
1871        .forward_request(session.connection_id, leader_id, request)
1872        .await?;
1873    response_payload
1874        .views
1875        .retain(|view| view.leader_id != Some(follower_id.into()));
1876    response.send(response_payload)?;
1877
1878    let room = session
1879        .db()
1880        .await
1881        .follow(project_id, leader_id, follower_id)
1882        .await?;
1883    room_updated(&room, &session.peer);
1884
1885    Ok(())
1886}
1887
1888async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1889    let project_id = ProjectId::from_proto(request.project_id);
1890    let leader_id = request
1891        .leader_id
1892        .ok_or_else(|| anyhow!("invalid leader id"))?
1893        .into();
1894    let follower_id = session.connection_id;
1895
1896    if !session
1897        .db()
1898        .await
1899        .project_connection_ids(project_id, session.connection_id)
1900        .await?
1901        .contains(&leader_id)
1902    {
1903        Err(anyhow!("no such peer"))?;
1904    }
1905
1906    session
1907        .peer
1908        .forward_send(session.connection_id, leader_id, request)?;
1909
1910    let room = session
1911        .db()
1912        .await
1913        .unfollow(project_id, leader_id, follower_id)
1914        .await?;
1915    room_updated(&room, &session.peer);
1916
1917    Ok(())
1918}
1919
1920async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1921    let project_id = ProjectId::from_proto(request.project_id);
1922    let project_connection_ids = session
1923        .db
1924        .lock()
1925        .await
1926        .project_connection_ids(project_id, session.connection_id)
1927        .await?;
1928
1929    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1930        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1931        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1932        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1933    });
1934    for follower_peer_id in request.follower_ids.iter().copied() {
1935        let follower_connection_id = follower_peer_id.into();
1936        if project_connection_ids.contains(&follower_connection_id)
1937            && Some(follower_peer_id) != leader_id
1938        {
1939            session.peer.forward_send(
1940                session.connection_id,
1941                follower_connection_id,
1942                request.clone(),
1943            )?;
1944        }
1945    }
1946    Ok(())
1947}
1948
1949async fn get_users(
1950    request: proto::GetUsers,
1951    response: Response<proto::GetUsers>,
1952    session: Session,
1953) -> Result<()> {
1954    let user_ids = request
1955        .user_ids
1956        .into_iter()
1957        .map(UserId::from_proto)
1958        .collect();
1959    let users = session
1960        .db()
1961        .await
1962        .get_users_by_ids(user_ids)
1963        .await?
1964        .into_iter()
1965        .map(|user| proto::User {
1966            id: user.id.to_proto(),
1967            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1968            github_login: user.github_login,
1969        })
1970        .collect();
1971    response.send(proto::UsersResponse { users })?;
1972    Ok(())
1973}
1974
1975async fn fuzzy_search_users(
1976    request: proto::FuzzySearchUsers,
1977    response: Response<proto::FuzzySearchUsers>,
1978    session: Session,
1979) -> Result<()> {
1980    let query = request.query;
1981    let users = match query.len() {
1982        0 => vec![],
1983        1 | 2 => session
1984            .db()
1985            .await
1986            .get_user_by_github_login(&query)
1987            .await?
1988            .into_iter()
1989            .collect(),
1990        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1991    };
1992    let users = users
1993        .into_iter()
1994        .filter(|user| user.id != session.user_id)
1995        .map(|user| proto::User {
1996            id: user.id.to_proto(),
1997            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1998            github_login: user.github_login,
1999        })
2000        .collect();
2001    response.send(proto::UsersResponse { users })?;
2002    Ok(())
2003}
2004
2005async fn request_contact(
2006    request: proto::RequestContact,
2007    response: Response<proto::RequestContact>,
2008    session: Session,
2009) -> Result<()> {
2010    let requester_id = session.user_id;
2011    let responder_id = UserId::from_proto(request.responder_id);
2012    if requester_id == responder_id {
2013        return Err(anyhow!("cannot add yourself as a contact"))?;
2014    }
2015
2016    session
2017        .db()
2018        .await
2019        .send_contact_request(requester_id, responder_id)
2020        .await?;
2021
2022    // Update outgoing contact requests of requester
2023    let mut update = proto::UpdateContacts::default();
2024    update.outgoing_requests.push(responder_id.to_proto());
2025    for connection_id in session
2026        .connection_pool()
2027        .await
2028        .user_connection_ids(requester_id)
2029    {
2030        session.peer.send(connection_id, update.clone())?;
2031    }
2032
2033    // Update incoming contact requests of responder
2034    let mut update = proto::UpdateContacts::default();
2035    update
2036        .incoming_requests
2037        .push(proto::IncomingContactRequest {
2038            requester_id: requester_id.to_proto(),
2039            should_notify: true,
2040        });
2041    for connection_id in session
2042        .connection_pool()
2043        .await
2044        .user_connection_ids(responder_id)
2045    {
2046        session.peer.send(connection_id, update.clone())?;
2047    }
2048
2049    response.send(proto::Ack {})?;
2050    Ok(())
2051}
2052
2053async fn respond_to_contact_request(
2054    request: proto::RespondToContactRequest,
2055    response: Response<proto::RespondToContactRequest>,
2056    session: Session,
2057) -> Result<()> {
2058    let responder_id = session.user_id;
2059    let requester_id = UserId::from_proto(request.requester_id);
2060    let db = session.db().await;
2061    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2062        db.dismiss_contact_notification(responder_id, requester_id)
2063            .await?;
2064    } else {
2065        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2066
2067        db.respond_to_contact_request(responder_id, requester_id, accept)
2068            .await?;
2069        let requester_busy = db.is_user_busy(requester_id).await?;
2070        let responder_busy = db.is_user_busy(responder_id).await?;
2071
2072        let pool = session.connection_pool().await;
2073        // Update responder with new contact
2074        let mut update = proto::UpdateContacts::default();
2075        if accept {
2076            update
2077                .contacts
2078                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2079        }
2080        update
2081            .remove_incoming_requests
2082            .push(requester_id.to_proto());
2083        for connection_id in pool.user_connection_ids(responder_id) {
2084            session.peer.send(connection_id, update.clone())?;
2085        }
2086
2087        // Update requester with new contact
2088        let mut update = proto::UpdateContacts::default();
2089        if accept {
2090            update
2091                .contacts
2092                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2093        }
2094        update
2095            .remove_outgoing_requests
2096            .push(responder_id.to_proto());
2097        for connection_id in pool.user_connection_ids(requester_id) {
2098            session.peer.send(connection_id, update.clone())?;
2099        }
2100    }
2101
2102    response.send(proto::Ack {})?;
2103    Ok(())
2104}
2105
2106async fn remove_contact(
2107    request: proto::RemoveContact,
2108    response: Response<proto::RemoveContact>,
2109    session: Session,
2110) -> Result<()> {
2111    let requester_id = session.user_id;
2112    let responder_id = UserId::from_proto(request.user_id);
2113    let db = session.db().await;
2114    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2115
2116    let pool = session.connection_pool().await;
2117    // Update outgoing contact requests of requester
2118    let mut update = proto::UpdateContacts::default();
2119    if contact_accepted {
2120        update.remove_contacts.push(responder_id.to_proto());
2121    } else {
2122        update
2123            .remove_outgoing_requests
2124            .push(responder_id.to_proto());
2125    }
2126    for connection_id in pool.user_connection_ids(requester_id) {
2127        session.peer.send(connection_id, update.clone())?;
2128    }
2129
2130    // Update incoming contact requests of responder
2131    let mut update = proto::UpdateContacts::default();
2132    if contact_accepted {
2133        update.remove_contacts.push(requester_id.to_proto());
2134    } else {
2135        update
2136            .remove_incoming_requests
2137            .push(requester_id.to_proto());
2138    }
2139    for connection_id in pool.user_connection_ids(responder_id) {
2140        session.peer.send(connection_id, update.clone())?;
2141    }
2142
2143    response.send(proto::Ack {})?;
2144    Ok(())
2145}
2146
2147async fn create_channel(
2148    request: proto::CreateChannel,
2149    response: Response<proto::CreateChannel>,
2150    session: Session,
2151) -> Result<()> {
2152    let db = session.db().await;
2153    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2154
2155    if let Some(live_kit) = session.live_kit_client.as_ref() {
2156        live_kit.create_room(live_kit_room.clone()).await?;
2157    }
2158
2159    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2160    let id = db
2161        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2162        .await?;
2163
2164    let channel = proto::Channel {
2165        id: id.to_proto(),
2166        name: request.name,
2167        parent_id: request.parent_id,
2168    };
2169
2170    response.send(proto::ChannelResponse {
2171        channel: Some(channel.clone()),
2172    })?;
2173
2174    let mut update = proto::UpdateChannels::default();
2175    update.channels.push(channel);
2176
2177    let user_ids_to_notify = if let Some(parent_id) = parent_id {
2178        db.get_channel_members(parent_id).await?
2179    } else {
2180        vec![session.user_id]
2181    };
2182
2183    let connection_pool = session.connection_pool().await;
2184    for user_id in user_ids_to_notify {
2185        for connection_id in connection_pool.user_connection_ids(user_id) {
2186            let mut update = update.clone();
2187            if user_id == session.user_id {
2188                update.channel_permissions.push(proto::ChannelPermission {
2189                    channel_id: id.to_proto(),
2190                    is_admin: true,
2191                });
2192            }
2193            session.peer.send(connection_id, update)?;
2194        }
2195    }
2196
2197    Ok(())
2198}
2199
2200async fn remove_channel(
2201    request: proto::RemoveChannel,
2202    response: Response<proto::RemoveChannel>,
2203    session: Session,
2204) -> Result<()> {
2205    let db = session.db().await;
2206
2207    let channel_id = request.channel_id;
2208    let (removed_channels, member_ids) = db
2209        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2210        .await?;
2211    response.send(proto::Ack {})?;
2212
2213    // Notify members of removed channels
2214    let mut update = proto::UpdateChannels::default();
2215    update
2216        .remove_channels
2217        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2218
2219    let connection_pool = session.connection_pool().await;
2220    for member_id in member_ids {
2221        for connection_id in connection_pool.user_connection_ids(member_id) {
2222            session.peer.send(connection_id, update.clone())?;
2223        }
2224    }
2225
2226    Ok(())
2227}
2228
2229async fn invite_channel_member(
2230    request: proto::InviteChannelMember,
2231    response: Response<proto::InviteChannelMember>,
2232    session: Session,
2233) -> Result<()> {
2234    let db = session.db().await;
2235    let channel_id = ChannelId::from_proto(request.channel_id);
2236    let invitee_id = UserId::from_proto(request.user_id);
2237    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2238        .await?;
2239
2240    let (channel, _) = db
2241        .get_channel(channel_id, session.user_id)
2242        .await?
2243        .ok_or_else(|| anyhow!("channel not found"))?;
2244
2245    let mut update = proto::UpdateChannels::default();
2246    update.channel_invitations.push(proto::Channel {
2247        id: channel.id.to_proto(),
2248        name: channel.name,
2249        parent_id: None,
2250    });
2251    for connection_id in session
2252        .connection_pool()
2253        .await
2254        .user_connection_ids(invitee_id)
2255    {
2256        session.peer.send(connection_id, update.clone())?;
2257    }
2258
2259    response.send(proto::Ack {})?;
2260    Ok(())
2261}
2262
2263async fn remove_channel_member(
2264    request: proto::RemoveChannelMember,
2265    response: Response<proto::RemoveChannelMember>,
2266    session: Session,
2267) -> Result<()> {
2268    let db = session.db().await;
2269    let channel_id = ChannelId::from_proto(request.channel_id);
2270    let member_id = UserId::from_proto(request.user_id);
2271
2272    db.remove_channel_member(channel_id, member_id, session.user_id)
2273        .await?;
2274
2275    let mut update = proto::UpdateChannels::default();
2276    update.remove_channels.push(channel_id.to_proto());
2277
2278    for connection_id in session
2279        .connection_pool()
2280        .await
2281        .user_connection_ids(member_id)
2282    {
2283        session.peer.send(connection_id, update.clone())?;
2284    }
2285
2286    response.send(proto::Ack {})?;
2287    Ok(())
2288}
2289
2290async fn set_channel_member_admin(
2291    request: proto::SetChannelMemberAdmin,
2292    response: Response<proto::SetChannelMemberAdmin>,
2293    session: Session,
2294) -> Result<()> {
2295    let db = session.db().await;
2296    let channel_id = ChannelId::from_proto(request.channel_id);
2297    let member_id = UserId::from_proto(request.user_id);
2298    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2299        .await?;
2300
2301    let (channel, has_accepted) = db
2302        .get_channel(channel_id, member_id)
2303        .await?
2304        .ok_or_else(|| anyhow!("channel not found"))?;
2305
2306    let mut update = proto::UpdateChannels::default();
2307    if has_accepted {
2308        update.channel_permissions.push(proto::ChannelPermission {
2309            channel_id: channel.id.to_proto(),
2310            is_admin: request.admin,
2311        });
2312    }
2313
2314    for connection_id in session
2315        .connection_pool()
2316        .await
2317        .user_connection_ids(member_id)
2318    {
2319        session.peer.send(connection_id, update.clone())?;
2320    }
2321
2322    response.send(proto::Ack {})?;
2323    Ok(())
2324}
2325
2326async fn rename_channel(
2327    request: proto::RenameChannel,
2328    response: Response<proto::RenameChannel>,
2329    session: Session,
2330) -> Result<()> {
2331    let db = session.db().await;
2332    let channel_id = ChannelId::from_proto(request.channel_id);
2333    let new_name = db
2334        .rename_channel(channel_id, session.user_id, &request.name)
2335        .await?;
2336
2337    let channel = proto::Channel {
2338        id: request.channel_id,
2339        name: new_name,
2340        parent_id: None,
2341    };
2342    response.send(proto::ChannelResponse {
2343        channel: Some(channel.clone()),
2344    })?;
2345    let mut update = proto::UpdateChannels::default();
2346    update.channels.push(channel);
2347
2348    let member_ids = db.get_channel_members(channel_id).await?;
2349
2350    let connection_pool = session.connection_pool().await;
2351    for member_id in member_ids {
2352        for connection_id in connection_pool.user_connection_ids(member_id) {
2353            session.peer.send(connection_id, update.clone())?;
2354        }
2355    }
2356
2357    Ok(())
2358}
2359
2360async fn get_channel_members(
2361    request: proto::GetChannelMembers,
2362    response: Response<proto::GetChannelMembers>,
2363    session: Session,
2364) -> Result<()> {
2365    let db = session.db().await;
2366    let channel_id = ChannelId::from_proto(request.channel_id);
2367    let members = db
2368        .get_channel_member_details(channel_id, session.user_id)
2369        .await?;
2370    response.send(proto::GetChannelMembersResponse { members })?;
2371    Ok(())
2372}
2373
2374async fn respond_to_channel_invite(
2375    request: proto::RespondToChannelInvite,
2376    response: Response<proto::RespondToChannelInvite>,
2377    session: Session,
2378) -> Result<()> {
2379    let db = session.db().await;
2380    let channel_id = ChannelId::from_proto(request.channel_id);
2381    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2382        .await?;
2383
2384    let mut update = proto::UpdateChannels::default();
2385    update
2386        .remove_channel_invitations
2387        .push(channel_id.to_proto());
2388    if request.accept {
2389        let result = db.get_channels_for_user(session.user_id).await?;
2390        update
2391            .channels
2392            .extend(result.channels.into_iter().map(|channel| proto::Channel {
2393                id: channel.id.to_proto(),
2394                name: channel.name,
2395                parent_id: channel.parent_id.map(ChannelId::to_proto),
2396            }));
2397        update
2398            .channel_participants
2399            .extend(
2400                result
2401                    .channel_participants
2402                    .into_iter()
2403                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2404                        channel_id: channel_id.to_proto(),
2405                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2406                    }),
2407            );
2408        update
2409            .channel_permissions
2410            .extend(
2411                result
2412                    .channels_with_admin_privileges
2413                    .into_iter()
2414                    .map(|channel_id| proto::ChannelPermission {
2415                        channel_id: channel_id.to_proto(),
2416                        is_admin: true,
2417                    }),
2418            );
2419    }
2420    session.peer.send(session.connection_id, update)?;
2421    response.send(proto::Ack {})?;
2422
2423    Ok(())
2424}
2425
2426async fn join_channel(
2427    request: proto::JoinChannel,
2428    response: Response<proto::JoinChannel>,
2429    session: Session,
2430) -> Result<()> {
2431    let channel_id = ChannelId::from_proto(request.channel_id);
2432
2433    let joined_room = {
2434        leave_room_for_session(&session).await?;
2435        let db = session.db().await;
2436
2437        let room_id = db.room_id_for_channel(channel_id).await?;
2438
2439        let joined_room = db
2440            .join_room(room_id, session.user_id, session.connection_id)
2441            .await?;
2442
2443        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2444            let token = live_kit
2445                .room_token(
2446                    &joined_room.room.live_kit_room,
2447                    &session.user_id.to_string(),
2448                )
2449                .trace_err()?;
2450
2451            Some(LiveKitConnectionInfo {
2452                server_url: live_kit.url().into(),
2453                token,
2454            })
2455        });
2456
2457        response.send(proto::JoinRoomResponse {
2458            room: Some(joined_room.room.clone()),
2459            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2460            live_kit_connection_info,
2461        })?;
2462
2463        room_updated(&joined_room.room, &session.peer);
2464
2465        joined_room.into_inner()
2466    };
2467
2468    channel_updated(
2469        channel_id,
2470        &joined_room.room,
2471        &joined_room.channel_members,
2472        &session.peer,
2473        &*session.connection_pool().await,
2474    );
2475
2476    update_user_contacts(session.user_id, &session).await?;
2477
2478    Ok(())
2479}
2480
2481async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2482    let project_id = ProjectId::from_proto(request.project_id);
2483    let project_connection_ids = session
2484        .db()
2485        .await
2486        .project_connection_ids(project_id, session.connection_id)
2487        .await?;
2488    broadcast(
2489        Some(session.connection_id),
2490        project_connection_ids.iter().copied(),
2491        |connection_id| {
2492            session
2493                .peer
2494                .forward_send(session.connection_id, connection_id, request.clone())
2495        },
2496    );
2497    Ok(())
2498}
2499
2500async fn get_private_user_info(
2501    _request: proto::GetPrivateUserInfo,
2502    response: Response<proto::GetPrivateUserInfo>,
2503    session: Session,
2504) -> Result<()> {
2505    let metrics_id = session
2506        .db()
2507        .await
2508        .get_user_metrics_id(session.user_id)
2509        .await?;
2510    let user = session
2511        .db()
2512        .await
2513        .get_user_by_id(session.user_id)
2514        .await?
2515        .ok_or_else(|| anyhow!("user not found"))?;
2516    response.send(proto::GetPrivateUserInfoResponse {
2517        metrics_id,
2518        staff: user.admin,
2519    })?;
2520    Ok(())
2521}
2522
2523fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2524    match message {
2525        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2526        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2527        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2528        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2529        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2530            code: frame.code.into(),
2531            reason: frame.reason,
2532        })),
2533    }
2534}
2535
2536fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2537    match message {
2538        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2539        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2540        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2541        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2542        AxumMessage::Close(frame) => {
2543            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2544                code: frame.code.into(),
2545                reason: frame.reason,
2546            }))
2547        }
2548    }
2549}
2550
2551fn build_initial_channels_update(
2552    channels: ChannelsForUser,
2553    channel_invites: Vec<db::Channel>,
2554) -> proto::UpdateChannels {
2555    let mut update = proto::UpdateChannels::default();
2556
2557    for channel in channels.channels {
2558        update.channels.push(proto::Channel {
2559            id: channel.id.to_proto(),
2560            name: channel.name,
2561            parent_id: channel.parent_id.map(|id| id.to_proto()),
2562        });
2563    }
2564
2565    for (channel_id, participants) in channels.channel_participants {
2566        update
2567            .channel_participants
2568            .push(proto::ChannelParticipants {
2569                channel_id: channel_id.to_proto(),
2570                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2571            });
2572    }
2573
2574    update
2575        .channel_permissions
2576        .extend(
2577            channels
2578                .channels_with_admin_privileges
2579                .into_iter()
2580                .map(|id| proto::ChannelPermission {
2581                    channel_id: id.to_proto(),
2582                    is_admin: true,
2583                }),
2584        );
2585
2586    for channel in channel_invites {
2587        update.channel_invitations.push(proto::Channel {
2588            id: channel.id.to_proto(),
2589            name: channel.name,
2590            parent_id: None,
2591        });
2592    }
2593
2594    update
2595}
2596
2597fn build_initial_contacts_update(
2598    contacts: Vec<db::Contact>,
2599    pool: &ConnectionPool,
2600) -> proto::UpdateContacts {
2601    let mut update = proto::UpdateContacts::default();
2602
2603    for contact in contacts {
2604        match contact {
2605            db::Contact::Accepted {
2606                user_id,
2607                should_notify,
2608                busy,
2609            } => {
2610                update
2611                    .contacts
2612                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2613            }
2614            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2615            db::Contact::Incoming {
2616                user_id,
2617                should_notify,
2618            } => update
2619                .incoming_requests
2620                .push(proto::IncomingContactRequest {
2621                    requester_id: user_id.to_proto(),
2622                    should_notify,
2623                }),
2624        }
2625    }
2626
2627    update
2628}
2629
2630fn contact_for_user(
2631    user_id: UserId,
2632    should_notify: bool,
2633    busy: bool,
2634    pool: &ConnectionPool,
2635) -> proto::Contact {
2636    proto::Contact {
2637        user_id: user_id.to_proto(),
2638        online: pool.is_user_online(user_id),
2639        busy,
2640        should_notify,
2641    }
2642}
2643
2644fn room_updated(room: &proto::Room, peer: &Peer) {
2645    broadcast(
2646        None,
2647        room.participants
2648            .iter()
2649            .filter_map(|participant| Some(participant.peer_id?.into())),
2650        |peer_id| {
2651            peer.send(
2652                peer_id.into(),
2653                proto::RoomUpdated {
2654                    room: Some(room.clone()),
2655                },
2656            )
2657        },
2658    );
2659}
2660
2661fn channel_updated(
2662    channel_id: ChannelId,
2663    room: &proto::Room,
2664    channel_members: &[UserId],
2665    peer: &Peer,
2666    pool: &ConnectionPool,
2667) {
2668    let participants = room
2669        .participants
2670        .iter()
2671        .map(|p| p.user_id)
2672        .collect::<Vec<_>>();
2673
2674    broadcast(
2675        None,
2676        channel_members
2677            .iter()
2678            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2679        |peer_id| {
2680            peer.send(
2681                peer_id.into(),
2682                proto::UpdateChannels {
2683                    channel_participants: vec![proto::ChannelParticipants {
2684                        channel_id: channel_id.to_proto(),
2685                        participant_user_ids: participants.clone(),
2686                    }],
2687                    ..Default::default()
2688                },
2689            )
2690        },
2691    );
2692}
2693
2694async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2695    let db = session.db().await;
2696
2697    let contacts = db.get_contacts(user_id).await?;
2698    let busy = db.is_user_busy(user_id).await?;
2699
2700    let pool = session.connection_pool().await;
2701    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2702    for contact in contacts {
2703        if let db::Contact::Accepted {
2704            user_id: contact_user_id,
2705            ..
2706        } = contact
2707        {
2708            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2709                session
2710                    .peer
2711                    .send(
2712                        contact_conn_id,
2713                        proto::UpdateContacts {
2714                            contacts: vec![updated_contact.clone()],
2715                            remove_contacts: Default::default(),
2716                            incoming_requests: Default::default(),
2717                            remove_incoming_requests: Default::default(),
2718                            outgoing_requests: Default::default(),
2719                            remove_outgoing_requests: Default::default(),
2720                        },
2721                    )
2722                    .trace_err();
2723            }
2724        }
2725    }
2726    Ok(())
2727}
2728
2729async fn leave_room_for_session(session: &Session) -> Result<()> {
2730    let mut contacts_to_update = HashSet::default();
2731
2732    let room_id;
2733    let canceled_calls_to_user_ids;
2734    let live_kit_room;
2735    let delete_live_kit_room;
2736    let room;
2737    let channel_members;
2738    let channel_id;
2739
2740    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2741        contacts_to_update.insert(session.user_id);
2742
2743        for project in left_room.left_projects.values() {
2744            project_left(project, session);
2745        }
2746
2747        room_id = RoomId::from_proto(left_room.room.id);
2748        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2749        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2750        delete_live_kit_room = left_room.deleted;
2751        room = mem::take(&mut left_room.room);
2752        channel_members = mem::take(&mut left_room.channel_members);
2753        channel_id = left_room.channel_id;
2754
2755        room_updated(&room, &session.peer);
2756    } else {
2757        return Ok(());
2758    }
2759
2760    if let Some(channel_id) = channel_id {
2761        channel_updated(
2762            channel_id,
2763            &room,
2764            &channel_members,
2765            &session.peer,
2766            &*session.connection_pool().await,
2767        );
2768    }
2769
2770    {
2771        let pool = session.connection_pool().await;
2772        for canceled_user_id in canceled_calls_to_user_ids {
2773            for connection_id in pool.user_connection_ids(canceled_user_id) {
2774                session
2775                    .peer
2776                    .send(
2777                        connection_id,
2778                        proto::CallCanceled {
2779                            room_id: room_id.to_proto(),
2780                        },
2781                    )
2782                    .trace_err();
2783            }
2784            contacts_to_update.insert(canceled_user_id);
2785        }
2786    }
2787
2788    for contact_user_id in contacts_to_update {
2789        update_user_contacts(contact_user_id, &session).await?;
2790    }
2791
2792    if let Some(live_kit) = session.live_kit_client.as_ref() {
2793        live_kit
2794            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2795            .await
2796            .trace_err();
2797
2798        if delete_live_kit_room {
2799            live_kit.delete_room(live_kit_room).await.trace_err();
2800        }
2801    }
2802
2803    Ok(())
2804}
2805
2806fn project_left(project: &db::LeftProject, session: &Session) {
2807    for connection_id in &project.connection_ids {
2808        if project.host_user_id == session.user_id {
2809            session
2810                .peer
2811                .send(
2812                    *connection_id,
2813                    proto::UnshareProject {
2814                        project_id: project.id.to_proto(),
2815                    },
2816                )
2817                .trace_err();
2818        } else {
2819            session
2820                .peer
2821                .send(
2822                    *connection_id,
2823                    proto::RemoveProjectCollaborator {
2824                        project_id: project.id.to_proto(),
2825                        peer_id: Some(session.connection_id.into()),
2826                    },
2827                )
2828                .trace_err();
2829        }
2830    }
2831}
2832
2833pub trait ResultExt {
2834    type Ok;
2835
2836    fn trace_err(self) -> Option<Self::Ok>;
2837}
2838
2839impl<T, E> ResultExt for Result<T, E>
2840where
2841    E: std::fmt::Debug,
2842{
2843    type Ok = T;
2844
2845    fn trace_err(self) -> Option<T> {
2846        match self {
2847            Ok(value) => Some(value),
2848            Err(error) => {
2849                tracing::error!("{:?}", error);
2850                None
2851            }
2852        }
2853    }
2854}