rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
  39        LiveKitConnectionInfo, RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(set_channel_member_admin)
 250            .add_request_handler(rename_channel)
 251            .add_request_handler(join_channel_buffer)
 252            .add_request_handler(leave_channel_buffer)
 253            .add_message_handler(update_channel_buffer)
 254            .add_request_handler(get_channel_members)
 255            .add_request_handler(respond_to_channel_invite)
 256            .add_request_handler(join_channel)
 257            .add_request_handler(follow)
 258            .add_message_handler(unfollow)
 259            .add_message_handler(update_followers)
 260            .add_message_handler(update_diff_base)
 261            .add_request_handler(get_private_user_info);
 262
 263        Arc::new(server)
 264    }
 265
 266    pub async fn start(&self) -> Result<()> {
 267        let server_id = *self.id.lock();
 268        let app_state = self.app_state.clone();
 269        let peer = self.peer.clone();
 270        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 271        let pool = self.connection_pool.clone();
 272        let live_kit_client = self.app_state.live_kit_client.clone();
 273
 274        let span = info_span!("start server");
 275        self.executor.spawn_detached(
 276            async move {
 277                tracing::info!("waiting for cleanup timeout");
 278                timeout.await;
 279                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 280                if let Some(room_ids) = app_state
 281                    .db
 282                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 283                    .await
 284                    .trace_err()
 285                {
 286                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 287                    for room_id in room_ids {
 288                        let mut contacts_to_update = HashSet::default();
 289                        let mut canceled_calls_to_user_ids = Vec::new();
 290                        let mut live_kit_room = String::new();
 291                        let mut delete_live_kit_room = false;
 292
 293                        if let Some(mut refreshed_room) = app_state
 294                            .db
 295                            .refresh_room(room_id, server_id)
 296                            .await
 297                            .trace_err()
 298                        {
 299                            tracing::info!(
 300                                room_id = room_id.0,
 301                                new_participant_count = refreshed_room.room.participants.len(),
 302                                "refreshed room"
 303                            );
 304                            room_updated(&refreshed_room.room, &peer);
 305                            if let Some(channel_id) = refreshed_room.channel_id {
 306                                channel_updated(
 307                                    channel_id,
 308                                    &refreshed_room.room,
 309                                    &refreshed_room.channel_members,
 310                                    &peer,
 311                                    &*pool.lock(),
 312                                );
 313                            }
 314                            contacts_to_update
 315                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 316                            contacts_to_update
 317                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 318                            canceled_calls_to_user_ids =
 319                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 320                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 321                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 322                        }
 323
 324                        {
 325                            let pool = pool.lock();
 326                            for canceled_user_id in canceled_calls_to_user_ids {
 327                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 328                                    peer.send(
 329                                        connection_id,
 330                                        proto::CallCanceled {
 331                                            room_id: room_id.to_proto(),
 332                                        },
 333                                    )
 334                                    .trace_err();
 335                                }
 336                            }
 337                        }
 338
 339                        for user_id in contacts_to_update {
 340                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 341                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 342                            if let Some((busy, contacts)) = busy.zip(contacts) {
 343                                let pool = pool.lock();
 344                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 345                                for contact in contacts {
 346                                    if let db::Contact::Accepted {
 347                                        user_id: contact_user_id,
 348                                        ..
 349                                    } = contact
 350                                    {
 351                                        for contact_conn_id in
 352                                            pool.user_connection_ids(contact_user_id)
 353                                        {
 354                                            peer.send(
 355                                                contact_conn_id,
 356                                                proto::UpdateContacts {
 357                                                    contacts: vec![updated_contact.clone()],
 358                                                    remove_contacts: Default::default(),
 359                                                    incoming_requests: Default::default(),
 360                                                    remove_incoming_requests: Default::default(),
 361                                                    outgoing_requests: Default::default(),
 362                                                    remove_outgoing_requests: Default::default(),
 363                                                },
 364                                            )
 365                                            .trace_err();
 366                                        }
 367                                    }
 368                                }
 369                            }
 370                        }
 371
 372                        if let Some(live_kit) = live_kit_client.as_ref() {
 373                            if delete_live_kit_room {
 374                                live_kit.delete_room(live_kit_room).await.trace_err();
 375                            }
 376                        }
 377                    }
 378                }
 379
 380                app_state
 381                    .db
 382                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 383                    .await
 384                    .trace_err();
 385            }
 386            .instrument(span),
 387        );
 388        Ok(())
 389    }
 390
 391    pub fn teardown(&self) {
 392        self.peer.teardown();
 393        self.connection_pool.lock().reset();
 394        let _ = self.teardown.send(());
 395    }
 396
 397    #[cfg(test)]
 398    pub fn reset(&self, id: ServerId) {
 399        self.teardown();
 400        *self.id.lock() = id;
 401        self.peer.reset(id.0 as u32);
 402    }
 403
 404    #[cfg(test)]
 405    pub fn id(&self) -> ServerId {
 406        *self.id.lock()
 407    }
 408
 409    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 410    where
 411        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 412        Fut: 'static + Send + Future<Output = Result<()>>,
 413        M: EnvelopedMessage,
 414    {
 415        let prev_handler = self.handlers.insert(
 416            TypeId::of::<M>(),
 417            Box::new(move |envelope, session| {
 418                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 419                let span = info_span!(
 420                    "handle message",
 421                    payload_type = envelope.payload_type_name()
 422                );
 423                span.in_scope(|| {
 424                    tracing::info!(
 425                        payload_type = envelope.payload_type_name(),
 426                        "message received"
 427                    );
 428                });
 429                let start_time = Instant::now();
 430                let future = (handler)(*envelope, session);
 431                async move {
 432                    let result = future.await;
 433                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 434                    match result {
 435                        Err(error) => {
 436                            tracing::error!(%error, ?duration_ms, "error handling message")
 437                        }
 438                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 439                    }
 440                }
 441                .instrument(span)
 442                .boxed()
 443            }),
 444        );
 445        if prev_handler.is_some() {
 446            panic!("registered a handler for the same message twice");
 447        }
 448        self
 449    }
 450
 451    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 452    where
 453        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 454        Fut: 'static + Send + Future<Output = Result<()>>,
 455        M: EnvelopedMessage,
 456    {
 457        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 458        self
 459    }
 460
 461    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 462    where
 463        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 464        Fut: Send + Future<Output = Result<()>>,
 465        M: RequestMessage,
 466    {
 467        let handler = Arc::new(handler);
 468        self.add_handler(move |envelope, session| {
 469            let receipt = envelope.receipt();
 470            let handler = handler.clone();
 471            async move {
 472                let peer = session.peer.clone();
 473                let responded = Arc::new(AtomicBool::default());
 474                let response = Response {
 475                    peer: peer.clone(),
 476                    responded: responded.clone(),
 477                    receipt,
 478                };
 479                match (handler)(envelope.payload, response, session).await {
 480                    Ok(()) => {
 481                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 482                            Ok(())
 483                        } else {
 484                            Err(anyhow!("handler did not send a response"))?
 485                        }
 486                    }
 487                    Err(error) => {
 488                        peer.respond_with_error(
 489                            receipt,
 490                            proto::Error {
 491                                message: error.to_string(),
 492                            },
 493                        )?;
 494                        Err(error)
 495                    }
 496                }
 497            }
 498        })
 499    }
 500
 501    pub fn handle_connection(
 502        self: &Arc<Self>,
 503        connection: Connection,
 504        address: String,
 505        user: User,
 506        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 507        executor: Executor,
 508    ) -> impl Future<Output = Result<()>> {
 509        let this = self.clone();
 510        let user_id = user.id;
 511        let login = user.github_login;
 512        let span = info_span!("handle connection", %user_id, %login, %address);
 513        let mut teardown = self.teardown.subscribe();
 514        async move {
 515            let (connection_id, handle_io, mut incoming_rx) = this
 516                .peer
 517                .add_connection(connection, {
 518                    let executor = executor.clone();
 519                    move |duration| executor.sleep(duration)
 520                });
 521
 522            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 523            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 524            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 525
 526            if let Some(send_connection_id) = send_connection_id.take() {
 527                let _ = send_connection_id.send(connection_id);
 528            }
 529
 530            if !user.connected_once {
 531                this.peer.send(connection_id, proto::ShowContacts {})?;
 532                this.app_state.db.set_user_connected_once(user_id, true).await?;
 533            }
 534
 535            let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
 536                this.app_state.db.get_contacts(user_id),
 537                this.app_state.db.get_invite_code_for_user(user_id),
 538                this.app_state.db.get_channels_for_user(user_id),
 539                this.app_state.db.get_channel_invites_for_user(user_id)
 540            ).await?;
 541
 542            {
 543                let mut pool = this.connection_pool.lock();
 544                pool.add_connection(connection_id, user_id, user.admin);
 545                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 546                this.peer.send(connection_id, build_initial_channels_update(
 547                    channels_for_user,
 548                    channel_invites
 549                ))?;
 550
 551                if let Some((code, count)) = invite_code {
 552                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 553                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 554                        count: count as u32,
 555                    })?;
 556                }
 557            }
 558
 559            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 560                this.peer.send(connection_id, incoming_call)?;
 561            }
 562
 563            let session = Session {
 564                user_id,
 565                connection_id,
 566                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 567                peer: this.peer.clone(),
 568                connection_pool: this.connection_pool.clone(),
 569                live_kit_client: this.app_state.live_kit_client.clone(),
 570                executor: executor.clone(),
 571            };
 572            update_user_contacts(user_id, &session).await?;
 573
 574            let handle_io = handle_io.fuse();
 575            futures::pin_mut!(handle_io);
 576
 577            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 578            // This prevents deadlocks when e.g., client A performs a request to client B and
 579            // client B performs a request to client A. If both clients stop processing further
 580            // messages until their respective request completes, they won't have a chance to
 581            // respond to the other client's request and cause a deadlock.
 582            //
 583            // This arrangement ensures we will attempt to process earlier messages first, but fall
 584            // back to processing messages arrived later in the spirit of making progress.
 585            let mut foreground_message_handlers = FuturesUnordered::new();
 586            let concurrent_handlers = Arc::new(Semaphore::new(256));
 587            loop {
 588                let next_message = async {
 589                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 590                    let message = incoming_rx.next().await;
 591                    (permit, message)
 592                }.fuse();
 593                futures::pin_mut!(next_message);
 594                futures::select_biased! {
 595                    _ = teardown.changed().fuse() => return Ok(()),
 596                    result = handle_io => {
 597                        if let Err(error) = result {
 598                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 599                        }
 600                        break;
 601                    }
 602                    _ = foreground_message_handlers.next() => {}
 603                    next_message = next_message => {
 604                        let (permit, message) = next_message;
 605                        if let Some(message) = message {
 606                            let type_name = message.payload_type_name();
 607                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 608                            let span_enter = span.enter();
 609                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 610                                let is_background = message.is_background();
 611                                let handle_message = (handler)(message, session.clone());
 612                                drop(span_enter);
 613
 614                                let handle_message = async move {
 615                                    handle_message.await;
 616                                    drop(permit);
 617                                }.instrument(span);
 618                                if is_background {
 619                                    executor.spawn_detached(handle_message);
 620                                } else {
 621                                    foreground_message_handlers.push(handle_message);
 622                                }
 623                            } else {
 624                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 625                            }
 626                        } else {
 627                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 628                            break;
 629                        }
 630                    }
 631                }
 632            }
 633
 634            drop(foreground_message_handlers);
 635            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 636            if let Err(error) = connection_lost(session, teardown, executor).await {
 637                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 638            }
 639
 640            Ok(())
 641        }.instrument(span)
 642    }
 643
 644    pub async fn invite_code_redeemed(
 645        self: &Arc<Self>,
 646        inviter_id: UserId,
 647        invitee_id: UserId,
 648    ) -> Result<()> {
 649        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 650            if let Some(code) = &user.invite_code {
 651                let pool = self.connection_pool.lock();
 652                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 653                for connection_id in pool.user_connection_ids(inviter_id) {
 654                    self.peer.send(
 655                        connection_id,
 656                        proto::UpdateContacts {
 657                            contacts: vec![invitee_contact.clone()],
 658                            ..Default::default()
 659                        },
 660                    )?;
 661                    self.peer.send(
 662                        connection_id,
 663                        proto::UpdateInviteInfo {
 664                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 665                            count: user.invite_count as u32,
 666                        },
 667                    )?;
 668                }
 669            }
 670        }
 671        Ok(())
 672    }
 673
 674    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 675        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 676            if let Some(invite_code) = &user.invite_code {
 677                let pool = self.connection_pool.lock();
 678                for connection_id in pool.user_connection_ids(user_id) {
 679                    self.peer.send(
 680                        connection_id,
 681                        proto::UpdateInviteInfo {
 682                            url: format!(
 683                                "{}{}",
 684                                self.app_state.config.invite_link_prefix, invite_code
 685                            ),
 686                            count: user.invite_count as u32,
 687                        },
 688                    )?;
 689                }
 690            }
 691        }
 692        Ok(())
 693    }
 694
 695    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 696        ServerSnapshot {
 697            connection_pool: ConnectionPoolGuard {
 698                guard: self.connection_pool.lock(),
 699                _not_send: PhantomData,
 700            },
 701            peer: &self.peer,
 702        }
 703    }
 704}
 705
 706impl<'a> Deref for ConnectionPoolGuard<'a> {
 707    type Target = ConnectionPool;
 708
 709    fn deref(&self) -> &Self::Target {
 710        &*self.guard
 711    }
 712}
 713
 714impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 715    fn deref_mut(&mut self) -> &mut Self::Target {
 716        &mut *self.guard
 717    }
 718}
 719
 720impl<'a> Drop for ConnectionPoolGuard<'a> {
 721    fn drop(&mut self) {
 722        #[cfg(test)]
 723        self.check_invariants();
 724    }
 725}
 726
 727fn broadcast<F>(
 728    sender_id: Option<ConnectionId>,
 729    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 730    mut f: F,
 731) where
 732    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 733{
 734    for receiver_id in receiver_ids {
 735        if Some(receiver_id) != sender_id {
 736            if let Err(error) = f(receiver_id) {
 737                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 738            }
 739        }
 740    }
 741}
 742
 743lazy_static! {
 744    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 745}
 746
 747pub struct ProtocolVersion(u32);
 748
 749impl Header for ProtocolVersion {
 750    fn name() -> &'static HeaderName {
 751        &ZED_PROTOCOL_VERSION
 752    }
 753
 754    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 755    where
 756        Self: Sized,
 757        I: Iterator<Item = &'i axum::http::HeaderValue>,
 758    {
 759        let version = values
 760            .next()
 761            .ok_or_else(axum::headers::Error::invalid)?
 762            .to_str()
 763            .map_err(|_| axum::headers::Error::invalid())?
 764            .parse()
 765            .map_err(|_| axum::headers::Error::invalid())?;
 766        Ok(Self(version))
 767    }
 768
 769    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 770        values.extend([self.0.to_string().parse().unwrap()]);
 771    }
 772}
 773
 774pub fn routes(server: Arc<Server>) -> Router<Body> {
 775    Router::new()
 776        .route("/rpc", get(handle_websocket_request))
 777        .layer(
 778            ServiceBuilder::new()
 779                .layer(Extension(server.app_state.clone()))
 780                .layer(middleware::from_fn(auth::validate_header)),
 781        )
 782        .route("/metrics", get(handle_metrics))
 783        .layer(Extension(server))
 784}
 785
 786pub async fn handle_websocket_request(
 787    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 788    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 789    Extension(server): Extension<Arc<Server>>,
 790    Extension(user): Extension<User>,
 791    ws: WebSocketUpgrade,
 792) -> axum::response::Response {
 793    if protocol_version != rpc::PROTOCOL_VERSION {
 794        return (
 795            StatusCode::UPGRADE_REQUIRED,
 796            "client must be upgraded".to_string(),
 797        )
 798            .into_response();
 799    }
 800    let socket_address = socket_address.to_string();
 801    ws.on_upgrade(move |socket| {
 802        use util::ResultExt;
 803        let socket = socket
 804            .map_ok(to_tungstenite_message)
 805            .err_into()
 806            .with(|message| async move { Ok(to_axum_message(message)) });
 807        let connection = Connection::new(Box::pin(socket));
 808        async move {
 809            server
 810                .handle_connection(connection, socket_address, user, None, Executor::Production)
 811                .await
 812                .log_err();
 813        }
 814    })
 815}
 816
 817pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 818    let connections = server
 819        .connection_pool
 820        .lock()
 821        .connections()
 822        .filter(|connection| !connection.admin)
 823        .count();
 824
 825    METRIC_CONNECTIONS.set(connections as _);
 826
 827    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 828    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 829
 830    let encoder = prometheus::TextEncoder::new();
 831    let metric_families = prometheus::gather();
 832    let encoded_metrics = encoder
 833        .encode_to_string(&metric_families)
 834        .map_err(|err| anyhow!("{}", err))?;
 835    Ok(encoded_metrics)
 836}
 837
 838#[instrument(err, skip(executor))]
 839async fn connection_lost(
 840    session: Session,
 841    mut teardown: watch::Receiver<()>,
 842    executor: Executor,
 843) -> Result<()> {
 844    session.peer.disconnect(session.connection_id);
 845    session
 846        .connection_pool()
 847        .await
 848        .remove_connection(session.connection_id)?;
 849
 850    session
 851        .db()
 852        .await
 853        .connection_lost(session.connection_id)
 854        .await
 855        .trace_err();
 856
 857    futures::select_biased! {
 858        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 859            leave_room_for_session(&session).await.trace_err();
 860            leave_channel_buffers_for_session(&session).await.trace_err();
 861
 862            if !session
 863                .connection_pool()
 864                .await
 865                .is_user_online(session.user_id)
 866            {
 867                let db = session.db().await;
 868                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 869                    room_updated(&room, &session.peer);
 870                }
 871            }
 872            update_user_contacts(session.user_id, &session).await?;
 873
 874
 875        }
 876        _ = teardown.changed().fuse() => {}
 877    }
 878
 879    Ok(())
 880}
 881
 882async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 883    response.send(proto::Ack {})?;
 884    Ok(())
 885}
 886
 887async fn create_room(
 888    _request: proto::CreateRoom,
 889    response: Response<proto::CreateRoom>,
 890    session: Session,
 891) -> Result<()> {
 892    let live_kit_room = nanoid::nanoid!(30);
 893
 894    let live_kit_connection_info = {
 895        let live_kit_room = live_kit_room.clone();
 896        let live_kit = session.live_kit_client.as_ref();
 897
 898        util::async_iife!({
 899            let live_kit = live_kit?;
 900
 901            live_kit
 902                .create_room(live_kit_room.clone())
 903                .await
 904                .trace_err()?;
 905
 906            let token = live_kit
 907                .room_token(&live_kit_room, &session.user_id.to_string())
 908                .trace_err()?;
 909
 910            Some(proto::LiveKitConnectionInfo {
 911                server_url: live_kit.url().into(),
 912                token,
 913            })
 914        })
 915    }
 916    .await;
 917
 918    let room = session
 919        .db()
 920        .await
 921        .create_room(session.user_id, session.connection_id, &live_kit_room)
 922        .await?;
 923
 924    response.send(proto::CreateRoomResponse {
 925        room: Some(room.clone()),
 926        live_kit_connection_info,
 927    })?;
 928
 929    update_user_contacts(session.user_id, &session).await?;
 930    Ok(())
 931}
 932
 933async fn join_room(
 934    request: proto::JoinRoom,
 935    response: Response<proto::JoinRoom>,
 936    session: Session,
 937) -> Result<()> {
 938    let room_id = RoomId::from_proto(request.id);
 939    let joined_room = {
 940        let room = session
 941            .db()
 942            .await
 943            .join_room(room_id, session.user_id, session.connection_id)
 944            .await?;
 945        room_updated(&room.room, &session.peer);
 946        room.into_inner()
 947    };
 948
 949    if let Some(channel_id) = joined_room.channel_id {
 950        channel_updated(
 951            channel_id,
 952            &joined_room.room,
 953            &joined_room.channel_members,
 954            &session.peer,
 955            &*session.connection_pool().await,
 956        )
 957    }
 958
 959    for connection_id in session
 960        .connection_pool()
 961        .await
 962        .user_connection_ids(session.user_id)
 963    {
 964        session
 965            .peer
 966            .send(
 967                connection_id,
 968                proto::CallCanceled {
 969                    room_id: room_id.to_proto(),
 970                },
 971            )
 972            .trace_err();
 973    }
 974
 975    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 976        if let Some(token) = live_kit
 977            .room_token(
 978                &joined_room.room.live_kit_room,
 979                &session.user_id.to_string(),
 980            )
 981            .trace_err()
 982        {
 983            Some(proto::LiveKitConnectionInfo {
 984                server_url: live_kit.url().into(),
 985                token,
 986            })
 987        } else {
 988            None
 989        }
 990    } else {
 991        None
 992    };
 993
 994    response.send(proto::JoinRoomResponse {
 995        room: Some(joined_room.room),
 996        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
 997        live_kit_connection_info,
 998    })?;
 999
1000    update_user_contacts(session.user_id, &session).await?;
1001    Ok(())
1002}
1003
1004async fn rejoin_room(
1005    request: proto::RejoinRoom,
1006    response: Response<proto::RejoinRoom>,
1007    session: Session,
1008) -> Result<()> {
1009    let room;
1010    let channel_id;
1011    let channel_members;
1012    {
1013        let mut rejoined_room = session
1014            .db()
1015            .await
1016            .rejoin_room(request, session.user_id, session.connection_id)
1017            .await?;
1018
1019        response.send(proto::RejoinRoomResponse {
1020            room: Some(rejoined_room.room.clone()),
1021            reshared_projects: rejoined_room
1022                .reshared_projects
1023                .iter()
1024                .map(|project| proto::ResharedProject {
1025                    id: project.id.to_proto(),
1026                    collaborators: project
1027                        .collaborators
1028                        .iter()
1029                        .map(|collaborator| collaborator.to_proto())
1030                        .collect(),
1031                })
1032                .collect(),
1033            rejoined_projects: rejoined_room
1034                .rejoined_projects
1035                .iter()
1036                .map(|rejoined_project| proto::RejoinedProject {
1037                    id: rejoined_project.id.to_proto(),
1038                    worktrees: rejoined_project
1039                        .worktrees
1040                        .iter()
1041                        .map(|worktree| proto::WorktreeMetadata {
1042                            id: worktree.id,
1043                            root_name: worktree.root_name.clone(),
1044                            visible: worktree.visible,
1045                            abs_path: worktree.abs_path.clone(),
1046                        })
1047                        .collect(),
1048                    collaborators: rejoined_project
1049                        .collaborators
1050                        .iter()
1051                        .map(|collaborator| collaborator.to_proto())
1052                        .collect(),
1053                    language_servers: rejoined_project.language_servers.clone(),
1054                })
1055                .collect(),
1056        })?;
1057        room_updated(&rejoined_room.room, &session.peer);
1058
1059        for project in &rejoined_room.reshared_projects {
1060            for collaborator in &project.collaborators {
1061                session
1062                    .peer
1063                    .send(
1064                        collaborator.connection_id,
1065                        proto::UpdateProjectCollaborator {
1066                            project_id: project.id.to_proto(),
1067                            old_peer_id: Some(project.old_connection_id.into()),
1068                            new_peer_id: Some(session.connection_id.into()),
1069                        },
1070                    )
1071                    .trace_err();
1072            }
1073
1074            broadcast(
1075                Some(session.connection_id),
1076                project
1077                    .collaborators
1078                    .iter()
1079                    .map(|collaborator| collaborator.connection_id),
1080                |connection_id| {
1081                    session.peer.forward_send(
1082                        session.connection_id,
1083                        connection_id,
1084                        proto::UpdateProject {
1085                            project_id: project.id.to_proto(),
1086                            worktrees: project.worktrees.clone(),
1087                        },
1088                    )
1089                },
1090            );
1091        }
1092
1093        for project in &rejoined_room.rejoined_projects {
1094            for collaborator in &project.collaborators {
1095                session
1096                    .peer
1097                    .send(
1098                        collaborator.connection_id,
1099                        proto::UpdateProjectCollaborator {
1100                            project_id: project.id.to_proto(),
1101                            old_peer_id: Some(project.old_connection_id.into()),
1102                            new_peer_id: Some(session.connection_id.into()),
1103                        },
1104                    )
1105                    .trace_err();
1106            }
1107        }
1108
1109        for project in &mut rejoined_room.rejoined_projects {
1110            for worktree in mem::take(&mut project.worktrees) {
1111                #[cfg(any(test, feature = "test-support"))]
1112                const MAX_CHUNK_SIZE: usize = 2;
1113                #[cfg(not(any(test, feature = "test-support")))]
1114                const MAX_CHUNK_SIZE: usize = 256;
1115
1116                // Stream this worktree's entries.
1117                let message = proto::UpdateWorktree {
1118                    project_id: project.id.to_proto(),
1119                    worktree_id: worktree.id,
1120                    abs_path: worktree.abs_path.clone(),
1121                    root_name: worktree.root_name,
1122                    updated_entries: worktree.updated_entries,
1123                    removed_entries: worktree.removed_entries,
1124                    scan_id: worktree.scan_id,
1125                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1126                    updated_repositories: worktree.updated_repositories,
1127                    removed_repositories: worktree.removed_repositories,
1128                };
1129                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1130                    session.peer.send(session.connection_id, update.clone())?;
1131                }
1132
1133                // Stream this worktree's diagnostics.
1134                for summary in worktree.diagnostic_summaries {
1135                    session.peer.send(
1136                        session.connection_id,
1137                        proto::UpdateDiagnosticSummary {
1138                            project_id: project.id.to_proto(),
1139                            worktree_id: worktree.id,
1140                            summary: Some(summary),
1141                        },
1142                    )?;
1143                }
1144
1145                for settings_file in worktree.settings_files {
1146                    session.peer.send(
1147                        session.connection_id,
1148                        proto::UpdateWorktreeSettings {
1149                            project_id: project.id.to_proto(),
1150                            worktree_id: worktree.id,
1151                            path: settings_file.path,
1152                            content: Some(settings_file.content),
1153                        },
1154                    )?;
1155                }
1156            }
1157
1158            for language_server in &project.language_servers {
1159                session.peer.send(
1160                    session.connection_id,
1161                    proto::UpdateLanguageServer {
1162                        project_id: project.id.to_proto(),
1163                        language_server_id: language_server.id,
1164                        variant: Some(
1165                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1166                                proto::LspDiskBasedDiagnosticsUpdated {},
1167                            ),
1168                        ),
1169                    },
1170                )?;
1171            }
1172        }
1173
1174        let rejoined_room = rejoined_room.into_inner();
1175
1176        room = rejoined_room.room;
1177        channel_id = rejoined_room.channel_id;
1178        channel_members = rejoined_room.channel_members;
1179    }
1180
1181    if let Some(channel_id) = channel_id {
1182        channel_updated(
1183            channel_id,
1184            &room,
1185            &channel_members,
1186            &session.peer,
1187            &*session.connection_pool().await,
1188        );
1189    }
1190
1191    update_user_contacts(session.user_id, &session).await?;
1192    Ok(())
1193}
1194
1195async fn leave_room(
1196    _: proto::LeaveRoom,
1197    response: Response<proto::LeaveRoom>,
1198    session: Session,
1199) -> Result<()> {
1200    leave_room_for_session(&session).await?;
1201    response.send(proto::Ack {})?;
1202    Ok(())
1203}
1204
1205async fn call(
1206    request: proto::Call,
1207    response: Response<proto::Call>,
1208    session: Session,
1209) -> Result<()> {
1210    let room_id = RoomId::from_proto(request.room_id);
1211    let calling_user_id = session.user_id;
1212    let calling_connection_id = session.connection_id;
1213    let called_user_id = UserId::from_proto(request.called_user_id);
1214    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1215    if !session
1216        .db()
1217        .await
1218        .has_contact(calling_user_id, called_user_id)
1219        .await?
1220    {
1221        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1222    }
1223
1224    let incoming_call = {
1225        let (room, incoming_call) = &mut *session
1226            .db()
1227            .await
1228            .call(
1229                room_id,
1230                calling_user_id,
1231                calling_connection_id,
1232                called_user_id,
1233                initial_project_id,
1234            )
1235            .await?;
1236        room_updated(&room, &session.peer);
1237        mem::take(incoming_call)
1238    };
1239    update_user_contacts(called_user_id, &session).await?;
1240
1241    let mut calls = session
1242        .connection_pool()
1243        .await
1244        .user_connection_ids(called_user_id)
1245        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1246        .collect::<FuturesUnordered<_>>();
1247
1248    while let Some(call_response) = calls.next().await {
1249        match call_response.as_ref() {
1250            Ok(_) => {
1251                response.send(proto::Ack {})?;
1252                return Ok(());
1253            }
1254            Err(_) => {
1255                call_response.trace_err();
1256            }
1257        }
1258    }
1259
1260    {
1261        let room = session
1262            .db()
1263            .await
1264            .call_failed(room_id, called_user_id)
1265            .await?;
1266        room_updated(&room, &session.peer);
1267    }
1268    update_user_contacts(called_user_id, &session).await?;
1269
1270    Err(anyhow!("failed to ring user"))?
1271}
1272
1273async fn cancel_call(
1274    request: proto::CancelCall,
1275    response: Response<proto::CancelCall>,
1276    session: Session,
1277) -> Result<()> {
1278    let called_user_id = UserId::from_proto(request.called_user_id);
1279    let room_id = RoomId::from_proto(request.room_id);
1280    {
1281        let room = session
1282            .db()
1283            .await
1284            .cancel_call(room_id, session.connection_id, called_user_id)
1285            .await?;
1286        room_updated(&room, &session.peer);
1287    }
1288
1289    for connection_id in session
1290        .connection_pool()
1291        .await
1292        .user_connection_ids(called_user_id)
1293    {
1294        session
1295            .peer
1296            .send(
1297                connection_id,
1298                proto::CallCanceled {
1299                    room_id: room_id.to_proto(),
1300                },
1301            )
1302            .trace_err();
1303    }
1304    response.send(proto::Ack {})?;
1305
1306    update_user_contacts(called_user_id, &session).await?;
1307    Ok(())
1308}
1309
1310async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1311    let room_id = RoomId::from_proto(message.room_id);
1312    {
1313        let room = session
1314            .db()
1315            .await
1316            .decline_call(Some(room_id), session.user_id)
1317            .await?
1318            .ok_or_else(|| anyhow!("failed to decline call"))?;
1319        room_updated(&room, &session.peer);
1320    }
1321
1322    for connection_id in session
1323        .connection_pool()
1324        .await
1325        .user_connection_ids(session.user_id)
1326    {
1327        session
1328            .peer
1329            .send(
1330                connection_id,
1331                proto::CallCanceled {
1332                    room_id: room_id.to_proto(),
1333                },
1334            )
1335            .trace_err();
1336    }
1337    update_user_contacts(session.user_id, &session).await?;
1338    Ok(())
1339}
1340
1341async fn update_participant_location(
1342    request: proto::UpdateParticipantLocation,
1343    response: Response<proto::UpdateParticipantLocation>,
1344    session: Session,
1345) -> Result<()> {
1346    let room_id = RoomId::from_proto(request.room_id);
1347    let location = request
1348        .location
1349        .ok_or_else(|| anyhow!("invalid location"))?;
1350
1351    let db = session.db().await;
1352    let room = db
1353        .update_room_participant_location(room_id, session.connection_id, location)
1354        .await?;
1355
1356    room_updated(&room, &session.peer);
1357    response.send(proto::Ack {})?;
1358    Ok(())
1359}
1360
1361async fn share_project(
1362    request: proto::ShareProject,
1363    response: Response<proto::ShareProject>,
1364    session: Session,
1365) -> Result<()> {
1366    let (project_id, room) = &*session
1367        .db()
1368        .await
1369        .share_project(
1370            RoomId::from_proto(request.room_id),
1371            session.connection_id,
1372            &request.worktrees,
1373        )
1374        .await?;
1375    response.send(proto::ShareProjectResponse {
1376        project_id: project_id.to_proto(),
1377    })?;
1378    room_updated(&room, &session.peer);
1379
1380    Ok(())
1381}
1382
1383async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1384    let project_id = ProjectId::from_proto(message.project_id);
1385
1386    let (room, guest_connection_ids) = &*session
1387        .db()
1388        .await
1389        .unshare_project(project_id, session.connection_id)
1390        .await?;
1391
1392    broadcast(
1393        Some(session.connection_id),
1394        guest_connection_ids.iter().copied(),
1395        |conn_id| session.peer.send(conn_id, message.clone()),
1396    );
1397    room_updated(&room, &session.peer);
1398
1399    Ok(())
1400}
1401
1402async fn join_project(
1403    request: proto::JoinProject,
1404    response: Response<proto::JoinProject>,
1405    session: Session,
1406) -> Result<()> {
1407    let project_id = ProjectId::from_proto(request.project_id);
1408    let guest_user_id = session.user_id;
1409
1410    tracing::info!(%project_id, "join project");
1411
1412    let (project, replica_id) = &mut *session
1413        .db()
1414        .await
1415        .join_project(project_id, session.connection_id)
1416        .await?;
1417
1418    let collaborators = project
1419        .collaborators
1420        .iter()
1421        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1422        .map(|collaborator| collaborator.to_proto())
1423        .collect::<Vec<_>>();
1424
1425    let worktrees = project
1426        .worktrees
1427        .iter()
1428        .map(|(id, worktree)| proto::WorktreeMetadata {
1429            id: *id,
1430            root_name: worktree.root_name.clone(),
1431            visible: worktree.visible,
1432            abs_path: worktree.abs_path.clone(),
1433        })
1434        .collect::<Vec<_>>();
1435
1436    for collaborator in &collaborators {
1437        session
1438            .peer
1439            .send(
1440                collaborator.peer_id.unwrap().into(),
1441                proto::AddProjectCollaborator {
1442                    project_id: project_id.to_proto(),
1443                    collaborator: Some(proto::Collaborator {
1444                        peer_id: Some(session.connection_id.into()),
1445                        replica_id: replica_id.0 as u32,
1446                        user_id: guest_user_id.to_proto(),
1447                    }),
1448                },
1449            )
1450            .trace_err();
1451    }
1452
1453    // First, we send the metadata associated with each worktree.
1454    response.send(proto::JoinProjectResponse {
1455        worktrees: worktrees.clone(),
1456        replica_id: replica_id.0 as u32,
1457        collaborators: collaborators.clone(),
1458        language_servers: project.language_servers.clone(),
1459    })?;
1460
1461    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1462        #[cfg(any(test, feature = "test-support"))]
1463        const MAX_CHUNK_SIZE: usize = 2;
1464        #[cfg(not(any(test, feature = "test-support")))]
1465        const MAX_CHUNK_SIZE: usize = 256;
1466
1467        // Stream this worktree's entries.
1468        let message = proto::UpdateWorktree {
1469            project_id: project_id.to_proto(),
1470            worktree_id,
1471            abs_path: worktree.abs_path.clone(),
1472            root_name: worktree.root_name,
1473            updated_entries: worktree.entries,
1474            removed_entries: Default::default(),
1475            scan_id: worktree.scan_id,
1476            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1477            updated_repositories: worktree.repository_entries.into_values().collect(),
1478            removed_repositories: Default::default(),
1479        };
1480        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1481            session.peer.send(session.connection_id, update.clone())?;
1482        }
1483
1484        // Stream this worktree's diagnostics.
1485        for summary in worktree.diagnostic_summaries {
1486            session.peer.send(
1487                session.connection_id,
1488                proto::UpdateDiagnosticSummary {
1489                    project_id: project_id.to_proto(),
1490                    worktree_id: worktree.id,
1491                    summary: Some(summary),
1492                },
1493            )?;
1494        }
1495
1496        for settings_file in worktree.settings_files {
1497            session.peer.send(
1498                session.connection_id,
1499                proto::UpdateWorktreeSettings {
1500                    project_id: project_id.to_proto(),
1501                    worktree_id: worktree.id,
1502                    path: settings_file.path,
1503                    content: Some(settings_file.content),
1504                },
1505            )?;
1506        }
1507    }
1508
1509    for language_server in &project.language_servers {
1510        session.peer.send(
1511            session.connection_id,
1512            proto::UpdateLanguageServer {
1513                project_id: project_id.to_proto(),
1514                language_server_id: language_server.id,
1515                variant: Some(
1516                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1517                        proto::LspDiskBasedDiagnosticsUpdated {},
1518                    ),
1519                ),
1520            },
1521        )?;
1522    }
1523
1524    Ok(())
1525}
1526
1527async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1528    let sender_id = session.connection_id;
1529    let project_id = ProjectId::from_proto(request.project_id);
1530
1531    let (room, project) = &*session
1532        .db()
1533        .await
1534        .leave_project(project_id, sender_id)
1535        .await?;
1536    tracing::info!(
1537        %project_id,
1538        host_user_id = %project.host_user_id,
1539        host_connection_id = %project.host_connection_id,
1540        "leave project"
1541    );
1542
1543    project_left(&project, &session);
1544    room_updated(&room, &session.peer);
1545
1546    Ok(())
1547}
1548
1549async fn update_project(
1550    request: proto::UpdateProject,
1551    response: Response<proto::UpdateProject>,
1552    session: Session,
1553) -> Result<()> {
1554    let project_id = ProjectId::from_proto(request.project_id);
1555    let (room, guest_connection_ids) = &*session
1556        .db()
1557        .await
1558        .update_project(project_id, session.connection_id, &request.worktrees)
1559        .await?;
1560    broadcast(
1561        Some(session.connection_id),
1562        guest_connection_ids.iter().copied(),
1563        |connection_id| {
1564            session
1565                .peer
1566                .forward_send(session.connection_id, connection_id, request.clone())
1567        },
1568    );
1569    room_updated(&room, &session.peer);
1570    response.send(proto::Ack {})?;
1571
1572    Ok(())
1573}
1574
1575async fn update_worktree(
1576    request: proto::UpdateWorktree,
1577    response: Response<proto::UpdateWorktree>,
1578    session: Session,
1579) -> Result<()> {
1580    let guest_connection_ids = session
1581        .db()
1582        .await
1583        .update_worktree(&request, session.connection_id)
1584        .await?;
1585
1586    broadcast(
1587        Some(session.connection_id),
1588        guest_connection_ids.iter().copied(),
1589        |connection_id| {
1590            session
1591                .peer
1592                .forward_send(session.connection_id, connection_id, request.clone())
1593        },
1594    );
1595    response.send(proto::Ack {})?;
1596    Ok(())
1597}
1598
1599async fn update_diagnostic_summary(
1600    message: proto::UpdateDiagnosticSummary,
1601    session: Session,
1602) -> Result<()> {
1603    let guest_connection_ids = session
1604        .db()
1605        .await
1606        .update_diagnostic_summary(&message, session.connection_id)
1607        .await?;
1608
1609    broadcast(
1610        Some(session.connection_id),
1611        guest_connection_ids.iter().copied(),
1612        |connection_id| {
1613            session
1614                .peer
1615                .forward_send(session.connection_id, connection_id, message.clone())
1616        },
1617    );
1618
1619    Ok(())
1620}
1621
1622async fn update_worktree_settings(
1623    message: proto::UpdateWorktreeSettings,
1624    session: Session,
1625) -> Result<()> {
1626    let guest_connection_ids = session
1627        .db()
1628        .await
1629        .update_worktree_settings(&message, session.connection_id)
1630        .await?;
1631
1632    broadcast(
1633        Some(session.connection_id),
1634        guest_connection_ids.iter().copied(),
1635        |connection_id| {
1636            session
1637                .peer
1638                .forward_send(session.connection_id, connection_id, message.clone())
1639        },
1640    );
1641
1642    Ok(())
1643}
1644
1645async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1646    broadcast_project_message(request.project_id, request, session).await
1647}
1648
1649async fn start_language_server(
1650    request: proto::StartLanguageServer,
1651    session: Session,
1652) -> Result<()> {
1653    let guest_connection_ids = session
1654        .db()
1655        .await
1656        .start_language_server(&request, session.connection_id)
1657        .await?;
1658
1659    broadcast(
1660        Some(session.connection_id),
1661        guest_connection_ids.iter().copied(),
1662        |connection_id| {
1663            session
1664                .peer
1665                .forward_send(session.connection_id, connection_id, request.clone())
1666        },
1667    );
1668    Ok(())
1669}
1670
1671async fn update_language_server(
1672    request: proto::UpdateLanguageServer,
1673    session: Session,
1674) -> Result<()> {
1675    session.executor.record_backtrace();
1676    let project_id = ProjectId::from_proto(request.project_id);
1677    let project_connection_ids = session
1678        .db()
1679        .await
1680        .project_connection_ids(project_id, session.connection_id)
1681        .await?;
1682    broadcast(
1683        Some(session.connection_id),
1684        project_connection_ids.iter().copied(),
1685        |connection_id| {
1686            session
1687                .peer
1688                .forward_send(session.connection_id, connection_id, request.clone())
1689        },
1690    );
1691    Ok(())
1692}
1693
1694async fn forward_project_request<T>(
1695    request: T,
1696    response: Response<T>,
1697    session: Session,
1698) -> Result<()>
1699where
1700    T: EntityMessage + RequestMessage,
1701{
1702    session.executor.record_backtrace();
1703    let project_id = ProjectId::from_proto(request.remote_entity_id());
1704    let host_connection_id = {
1705        let collaborators = session
1706            .db()
1707            .await
1708            .project_collaborators(project_id, session.connection_id)
1709            .await?;
1710        collaborators
1711            .iter()
1712            .find(|collaborator| collaborator.is_host)
1713            .ok_or_else(|| anyhow!("host not found"))?
1714            .connection_id
1715    };
1716
1717    let payload = session
1718        .peer
1719        .forward_request(session.connection_id, host_connection_id, request)
1720        .await?;
1721
1722    response.send(payload)?;
1723    Ok(())
1724}
1725
1726async fn create_buffer_for_peer(
1727    request: proto::CreateBufferForPeer,
1728    session: Session,
1729) -> Result<()> {
1730    session.executor.record_backtrace();
1731    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1732    session
1733        .peer
1734        .forward_send(session.connection_id, peer_id.into(), request)?;
1735    Ok(())
1736}
1737
1738async fn update_buffer(
1739    request: proto::UpdateBuffer,
1740    response: Response<proto::UpdateBuffer>,
1741    session: Session,
1742) -> Result<()> {
1743    session.executor.record_backtrace();
1744    let project_id = ProjectId::from_proto(request.project_id);
1745    let mut guest_connection_ids;
1746    let mut host_connection_id = None;
1747    {
1748        let collaborators = session
1749            .db()
1750            .await
1751            .project_collaborators(project_id, session.connection_id)
1752            .await?;
1753        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1754        for collaborator in collaborators.iter() {
1755            if collaborator.is_host {
1756                host_connection_id = Some(collaborator.connection_id);
1757            } else {
1758                guest_connection_ids.push(collaborator.connection_id);
1759            }
1760        }
1761    }
1762    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1763
1764    session.executor.record_backtrace();
1765    broadcast(
1766        Some(session.connection_id),
1767        guest_connection_ids,
1768        |connection_id| {
1769            session
1770                .peer
1771                .forward_send(session.connection_id, connection_id, request.clone())
1772        },
1773    );
1774    if host_connection_id != session.connection_id {
1775        session
1776            .peer
1777            .forward_request(session.connection_id, host_connection_id, request.clone())
1778            .await?;
1779    }
1780
1781    response.send(proto::Ack {})?;
1782    Ok(())
1783}
1784
1785async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1786    let project_id = ProjectId::from_proto(request.project_id);
1787    let project_connection_ids = session
1788        .db()
1789        .await
1790        .project_connection_ids(project_id, session.connection_id)
1791        .await?;
1792
1793    broadcast(
1794        Some(session.connection_id),
1795        project_connection_ids.iter().copied(),
1796        |connection_id| {
1797            session
1798                .peer
1799                .forward_send(session.connection_id, connection_id, request.clone())
1800        },
1801    );
1802    Ok(())
1803}
1804
1805async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1806    let project_id = ProjectId::from_proto(request.project_id);
1807    let project_connection_ids = session
1808        .db()
1809        .await
1810        .project_connection_ids(project_id, session.connection_id)
1811        .await?;
1812    broadcast(
1813        Some(session.connection_id),
1814        project_connection_ids.iter().copied(),
1815        |connection_id| {
1816            session
1817                .peer
1818                .forward_send(session.connection_id, connection_id, request.clone())
1819        },
1820    );
1821    Ok(())
1822}
1823
1824async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1825    broadcast_project_message(request.project_id, request, session).await
1826}
1827
1828async fn broadcast_project_message<T: EnvelopedMessage>(
1829    project_id: u64,
1830    request: T,
1831    session: Session,
1832) -> Result<()> {
1833    let project_id = ProjectId::from_proto(project_id);
1834    let project_connection_ids = session
1835        .db()
1836        .await
1837        .project_connection_ids(project_id, session.connection_id)
1838        .await?;
1839    broadcast(
1840        Some(session.connection_id),
1841        project_connection_ids.iter().copied(),
1842        |connection_id| {
1843            session
1844                .peer
1845                .forward_send(session.connection_id, connection_id, request.clone())
1846        },
1847    );
1848    Ok(())
1849}
1850
1851async fn follow(
1852    request: proto::Follow,
1853    response: Response<proto::Follow>,
1854    session: Session,
1855) -> Result<()> {
1856    let project_id = ProjectId::from_proto(request.project_id);
1857    let leader_id = request
1858        .leader_id
1859        .ok_or_else(|| anyhow!("invalid leader id"))?
1860        .into();
1861    let follower_id = session.connection_id;
1862
1863    {
1864        let project_connection_ids = session
1865            .db()
1866            .await
1867            .project_connection_ids(project_id, session.connection_id)
1868            .await?;
1869
1870        if !project_connection_ids.contains(&leader_id) {
1871            Err(anyhow!("no such peer"))?;
1872        }
1873    }
1874
1875    let mut response_payload = session
1876        .peer
1877        .forward_request(session.connection_id, leader_id, request)
1878        .await?;
1879    response_payload
1880        .views
1881        .retain(|view| view.leader_id != Some(follower_id.into()));
1882    response.send(response_payload)?;
1883
1884    let room = session
1885        .db()
1886        .await
1887        .follow(project_id, leader_id, follower_id)
1888        .await?;
1889    room_updated(&room, &session.peer);
1890
1891    Ok(())
1892}
1893
1894async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1895    let project_id = ProjectId::from_proto(request.project_id);
1896    let leader_id = request
1897        .leader_id
1898        .ok_or_else(|| anyhow!("invalid leader id"))?
1899        .into();
1900    let follower_id = session.connection_id;
1901
1902    if !session
1903        .db()
1904        .await
1905        .project_connection_ids(project_id, session.connection_id)
1906        .await?
1907        .contains(&leader_id)
1908    {
1909        Err(anyhow!("no such peer"))?;
1910    }
1911
1912    session
1913        .peer
1914        .forward_send(session.connection_id, leader_id, request)?;
1915
1916    let room = session
1917        .db()
1918        .await
1919        .unfollow(project_id, leader_id, follower_id)
1920        .await?;
1921    room_updated(&room, &session.peer);
1922
1923    Ok(())
1924}
1925
1926async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1927    let project_id = ProjectId::from_proto(request.project_id);
1928    let project_connection_ids = session
1929        .db
1930        .lock()
1931        .await
1932        .project_connection_ids(project_id, session.connection_id)
1933        .await?;
1934
1935    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1936        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1937        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1938        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1939    });
1940    for follower_peer_id in request.follower_ids.iter().copied() {
1941        let follower_connection_id = follower_peer_id.into();
1942        if project_connection_ids.contains(&follower_connection_id)
1943            && Some(follower_peer_id) != leader_id
1944        {
1945            session.peer.forward_send(
1946                session.connection_id,
1947                follower_connection_id,
1948                request.clone(),
1949            )?;
1950        }
1951    }
1952    Ok(())
1953}
1954
1955async fn get_users(
1956    request: proto::GetUsers,
1957    response: Response<proto::GetUsers>,
1958    session: Session,
1959) -> Result<()> {
1960    let user_ids = request
1961        .user_ids
1962        .into_iter()
1963        .map(UserId::from_proto)
1964        .collect();
1965    let users = session
1966        .db()
1967        .await
1968        .get_users_by_ids(user_ids)
1969        .await?
1970        .into_iter()
1971        .map(|user| proto::User {
1972            id: user.id.to_proto(),
1973            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1974            github_login: user.github_login,
1975        })
1976        .collect();
1977    response.send(proto::UsersResponse { users })?;
1978    Ok(())
1979}
1980
1981async fn fuzzy_search_users(
1982    request: proto::FuzzySearchUsers,
1983    response: Response<proto::FuzzySearchUsers>,
1984    session: Session,
1985) -> Result<()> {
1986    let query = request.query;
1987    let users = match query.len() {
1988        0 => vec![],
1989        1 | 2 => session
1990            .db()
1991            .await
1992            .get_user_by_github_login(&query)
1993            .await?
1994            .into_iter()
1995            .collect(),
1996        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1997    };
1998    let users = users
1999        .into_iter()
2000        .filter(|user| user.id != session.user_id)
2001        .map(|user| proto::User {
2002            id: user.id.to_proto(),
2003            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2004            github_login: user.github_login,
2005        })
2006        .collect();
2007    response.send(proto::UsersResponse { users })?;
2008    Ok(())
2009}
2010
2011async fn request_contact(
2012    request: proto::RequestContact,
2013    response: Response<proto::RequestContact>,
2014    session: Session,
2015) -> Result<()> {
2016    let requester_id = session.user_id;
2017    let responder_id = UserId::from_proto(request.responder_id);
2018    if requester_id == responder_id {
2019        return Err(anyhow!("cannot add yourself as a contact"))?;
2020    }
2021
2022    session
2023        .db()
2024        .await
2025        .send_contact_request(requester_id, responder_id)
2026        .await?;
2027
2028    // Update outgoing contact requests of requester
2029    let mut update = proto::UpdateContacts::default();
2030    update.outgoing_requests.push(responder_id.to_proto());
2031    for connection_id in session
2032        .connection_pool()
2033        .await
2034        .user_connection_ids(requester_id)
2035    {
2036        session.peer.send(connection_id, update.clone())?;
2037    }
2038
2039    // Update incoming contact requests of responder
2040    let mut update = proto::UpdateContacts::default();
2041    update
2042        .incoming_requests
2043        .push(proto::IncomingContactRequest {
2044            requester_id: requester_id.to_proto(),
2045            should_notify: true,
2046        });
2047    for connection_id in session
2048        .connection_pool()
2049        .await
2050        .user_connection_ids(responder_id)
2051    {
2052        session.peer.send(connection_id, update.clone())?;
2053    }
2054
2055    response.send(proto::Ack {})?;
2056    Ok(())
2057}
2058
2059async fn respond_to_contact_request(
2060    request: proto::RespondToContactRequest,
2061    response: Response<proto::RespondToContactRequest>,
2062    session: Session,
2063) -> Result<()> {
2064    let responder_id = session.user_id;
2065    let requester_id = UserId::from_proto(request.requester_id);
2066    let db = session.db().await;
2067    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2068        db.dismiss_contact_notification(responder_id, requester_id)
2069            .await?;
2070    } else {
2071        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2072
2073        db.respond_to_contact_request(responder_id, requester_id, accept)
2074            .await?;
2075        let requester_busy = db.is_user_busy(requester_id).await?;
2076        let responder_busy = db.is_user_busy(responder_id).await?;
2077
2078        let pool = session.connection_pool().await;
2079        // Update responder with new contact
2080        let mut update = proto::UpdateContacts::default();
2081        if accept {
2082            update
2083                .contacts
2084                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2085        }
2086        update
2087            .remove_incoming_requests
2088            .push(requester_id.to_proto());
2089        for connection_id in pool.user_connection_ids(responder_id) {
2090            session.peer.send(connection_id, update.clone())?;
2091        }
2092
2093        // Update requester with new contact
2094        let mut update = proto::UpdateContacts::default();
2095        if accept {
2096            update
2097                .contacts
2098                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2099        }
2100        update
2101            .remove_outgoing_requests
2102            .push(responder_id.to_proto());
2103        for connection_id in pool.user_connection_ids(requester_id) {
2104            session.peer.send(connection_id, update.clone())?;
2105        }
2106    }
2107
2108    response.send(proto::Ack {})?;
2109    Ok(())
2110}
2111
2112async fn remove_contact(
2113    request: proto::RemoveContact,
2114    response: Response<proto::RemoveContact>,
2115    session: Session,
2116) -> Result<()> {
2117    let requester_id = session.user_id;
2118    let responder_id = UserId::from_proto(request.user_id);
2119    let db = session.db().await;
2120    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2121
2122    let pool = session.connection_pool().await;
2123    // Update outgoing contact requests of requester
2124    let mut update = proto::UpdateContacts::default();
2125    if contact_accepted {
2126        update.remove_contacts.push(responder_id.to_proto());
2127    } else {
2128        update
2129            .remove_outgoing_requests
2130            .push(responder_id.to_proto());
2131    }
2132    for connection_id in pool.user_connection_ids(requester_id) {
2133        session.peer.send(connection_id, update.clone())?;
2134    }
2135
2136    // Update incoming contact requests of responder
2137    let mut update = proto::UpdateContacts::default();
2138    if contact_accepted {
2139        update.remove_contacts.push(requester_id.to_proto());
2140    } else {
2141        update
2142            .remove_incoming_requests
2143            .push(requester_id.to_proto());
2144    }
2145    for connection_id in pool.user_connection_ids(responder_id) {
2146        session.peer.send(connection_id, update.clone())?;
2147    }
2148
2149    response.send(proto::Ack {})?;
2150    Ok(())
2151}
2152
2153async fn create_channel(
2154    request: proto::CreateChannel,
2155    response: Response<proto::CreateChannel>,
2156    session: Session,
2157) -> Result<()> {
2158    let db = session.db().await;
2159    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2160
2161    if let Some(live_kit) = session.live_kit_client.as_ref() {
2162        live_kit.create_room(live_kit_room.clone()).await?;
2163    }
2164
2165    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2166    let id = db
2167        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2168        .await?;
2169
2170    let channel = proto::Channel {
2171        id: id.to_proto(),
2172        name: request.name,
2173        parent_id: request.parent_id,
2174    };
2175
2176    response.send(proto::ChannelResponse {
2177        channel: Some(channel.clone()),
2178    })?;
2179
2180    let mut update = proto::UpdateChannels::default();
2181    update.channels.push(channel);
2182
2183    let user_ids_to_notify = if let Some(parent_id) = parent_id {
2184        db.get_channel_members(parent_id).await?
2185    } else {
2186        vec![session.user_id]
2187    };
2188
2189    let connection_pool = session.connection_pool().await;
2190    for user_id in user_ids_to_notify {
2191        for connection_id in connection_pool.user_connection_ids(user_id) {
2192            let mut update = update.clone();
2193            if user_id == session.user_id {
2194                update.channel_permissions.push(proto::ChannelPermission {
2195                    channel_id: id.to_proto(),
2196                    is_admin: true,
2197                });
2198            }
2199            session.peer.send(connection_id, update)?;
2200        }
2201    }
2202
2203    Ok(())
2204}
2205
2206async fn remove_channel(
2207    request: proto::RemoveChannel,
2208    response: Response<proto::RemoveChannel>,
2209    session: Session,
2210) -> Result<()> {
2211    let db = session.db().await;
2212
2213    let channel_id = request.channel_id;
2214    let (removed_channels, member_ids) = db
2215        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2216        .await?;
2217    response.send(proto::Ack {})?;
2218
2219    // Notify members of removed channels
2220    let mut update = proto::UpdateChannels::default();
2221    update
2222        .remove_channels
2223        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2224
2225    let connection_pool = session.connection_pool().await;
2226    for member_id in member_ids {
2227        for connection_id in connection_pool.user_connection_ids(member_id) {
2228            session.peer.send(connection_id, update.clone())?;
2229        }
2230    }
2231
2232    Ok(())
2233}
2234
2235async fn invite_channel_member(
2236    request: proto::InviteChannelMember,
2237    response: Response<proto::InviteChannelMember>,
2238    session: Session,
2239) -> Result<()> {
2240    let db = session.db().await;
2241    let channel_id = ChannelId::from_proto(request.channel_id);
2242    let invitee_id = UserId::from_proto(request.user_id);
2243    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2244        .await?;
2245
2246    let (channel, _) = db
2247        .get_channel(channel_id, session.user_id)
2248        .await?
2249        .ok_or_else(|| anyhow!("channel not found"))?;
2250
2251    let mut update = proto::UpdateChannels::default();
2252    update.channel_invitations.push(proto::Channel {
2253        id: channel.id.to_proto(),
2254        name: channel.name,
2255        parent_id: None,
2256    });
2257    for connection_id in session
2258        .connection_pool()
2259        .await
2260        .user_connection_ids(invitee_id)
2261    {
2262        session.peer.send(connection_id, update.clone())?;
2263    }
2264
2265    response.send(proto::Ack {})?;
2266    Ok(())
2267}
2268
2269async fn remove_channel_member(
2270    request: proto::RemoveChannelMember,
2271    response: Response<proto::RemoveChannelMember>,
2272    session: Session,
2273) -> Result<()> {
2274    let db = session.db().await;
2275    let channel_id = ChannelId::from_proto(request.channel_id);
2276    let member_id = UserId::from_proto(request.user_id);
2277
2278    db.remove_channel_member(channel_id, member_id, session.user_id)
2279        .await?;
2280
2281    let mut update = proto::UpdateChannels::default();
2282    update.remove_channels.push(channel_id.to_proto());
2283
2284    for connection_id in session
2285        .connection_pool()
2286        .await
2287        .user_connection_ids(member_id)
2288    {
2289        session.peer.send(connection_id, update.clone())?;
2290    }
2291
2292    response.send(proto::Ack {})?;
2293    Ok(())
2294}
2295
2296async fn set_channel_member_admin(
2297    request: proto::SetChannelMemberAdmin,
2298    response: Response<proto::SetChannelMemberAdmin>,
2299    session: Session,
2300) -> Result<()> {
2301    let db = session.db().await;
2302    let channel_id = ChannelId::from_proto(request.channel_id);
2303    let member_id = UserId::from_proto(request.user_id);
2304    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2305        .await?;
2306
2307    let (channel, has_accepted) = db
2308        .get_channel(channel_id, member_id)
2309        .await?
2310        .ok_or_else(|| anyhow!("channel not found"))?;
2311
2312    let mut update = proto::UpdateChannels::default();
2313    if has_accepted {
2314        update.channel_permissions.push(proto::ChannelPermission {
2315            channel_id: channel.id.to_proto(),
2316            is_admin: request.admin,
2317        });
2318    }
2319
2320    for connection_id in session
2321        .connection_pool()
2322        .await
2323        .user_connection_ids(member_id)
2324    {
2325        session.peer.send(connection_id, update.clone())?;
2326    }
2327
2328    response.send(proto::Ack {})?;
2329    Ok(())
2330}
2331
2332async fn rename_channel(
2333    request: proto::RenameChannel,
2334    response: Response<proto::RenameChannel>,
2335    session: Session,
2336) -> Result<()> {
2337    let db = session.db().await;
2338    let channel_id = ChannelId::from_proto(request.channel_id);
2339    let new_name = db
2340        .rename_channel(channel_id, session.user_id, &request.name)
2341        .await?;
2342
2343    let channel = proto::Channel {
2344        id: request.channel_id,
2345        name: new_name,
2346        parent_id: None,
2347    };
2348    response.send(proto::ChannelResponse {
2349        channel: Some(channel.clone()),
2350    })?;
2351    let mut update = proto::UpdateChannels::default();
2352    update.channels.push(channel);
2353
2354    let member_ids = db.get_channel_members(channel_id).await?;
2355
2356    let connection_pool = session.connection_pool().await;
2357    for member_id in member_ids {
2358        for connection_id in connection_pool.user_connection_ids(member_id) {
2359            session.peer.send(connection_id, update.clone())?;
2360        }
2361    }
2362
2363    Ok(())
2364}
2365
2366async fn get_channel_members(
2367    request: proto::GetChannelMembers,
2368    response: Response<proto::GetChannelMembers>,
2369    session: Session,
2370) -> Result<()> {
2371    let db = session.db().await;
2372    let channel_id = ChannelId::from_proto(request.channel_id);
2373    let members = db
2374        .get_channel_member_details(channel_id, session.user_id)
2375        .await?;
2376    response.send(proto::GetChannelMembersResponse { members })?;
2377    Ok(())
2378}
2379
2380async fn respond_to_channel_invite(
2381    request: proto::RespondToChannelInvite,
2382    response: Response<proto::RespondToChannelInvite>,
2383    session: Session,
2384) -> Result<()> {
2385    let db = session.db().await;
2386    let channel_id = ChannelId::from_proto(request.channel_id);
2387    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2388        .await?;
2389
2390    let mut update = proto::UpdateChannels::default();
2391    update
2392        .remove_channel_invitations
2393        .push(channel_id.to_proto());
2394    if request.accept {
2395        let result = db.get_channels_for_user(session.user_id).await?;
2396        update
2397            .channels
2398            .extend(result.channels.into_iter().map(|channel| proto::Channel {
2399                id: channel.id.to_proto(),
2400                name: channel.name,
2401                parent_id: channel.parent_id.map(ChannelId::to_proto),
2402            }));
2403        update
2404            .channel_participants
2405            .extend(
2406                result
2407                    .channel_participants
2408                    .into_iter()
2409                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2410                        channel_id: channel_id.to_proto(),
2411                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2412                    }),
2413            );
2414        update
2415            .channel_permissions
2416            .extend(
2417                result
2418                    .channels_with_admin_privileges
2419                    .into_iter()
2420                    .map(|channel_id| proto::ChannelPermission {
2421                        channel_id: channel_id.to_proto(),
2422                        is_admin: true,
2423                    }),
2424            );
2425    }
2426    session.peer.send(session.connection_id, update)?;
2427    response.send(proto::Ack {})?;
2428
2429    Ok(())
2430}
2431
2432async fn join_channel(
2433    request: proto::JoinChannel,
2434    response: Response<proto::JoinChannel>,
2435    session: Session,
2436) -> Result<()> {
2437    let channel_id = ChannelId::from_proto(request.channel_id);
2438
2439    let joined_room = {
2440        leave_room_for_session(&session).await?;
2441        let db = session.db().await;
2442
2443        let room_id = db.room_id_for_channel(channel_id).await?;
2444
2445        let joined_room = db
2446            .join_room(room_id, session.user_id, session.connection_id)
2447            .await?;
2448
2449        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2450            let token = live_kit
2451                .room_token(
2452                    &joined_room.room.live_kit_room,
2453                    &session.user_id.to_string(),
2454                )
2455                .trace_err()?;
2456
2457            Some(LiveKitConnectionInfo {
2458                server_url: live_kit.url().into(),
2459                token,
2460            })
2461        });
2462
2463        response.send(proto::JoinRoomResponse {
2464            room: Some(joined_room.room.clone()),
2465            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2466            live_kit_connection_info,
2467        })?;
2468
2469        room_updated(&joined_room.room, &session.peer);
2470
2471        joined_room.into_inner()
2472    };
2473
2474    channel_updated(
2475        channel_id,
2476        &joined_room.room,
2477        &joined_room.channel_members,
2478        &session.peer,
2479        &*session.connection_pool().await,
2480    );
2481
2482    update_user_contacts(session.user_id, &session).await?;
2483
2484    Ok(())
2485}
2486
2487async fn join_channel_buffer(
2488    request: proto::JoinChannelBuffer,
2489    response: Response<proto::JoinChannelBuffer>,
2490    session: Session,
2491) -> Result<()> {
2492    let db = session.db().await;
2493    let channel_id = ChannelId::from_proto(request.channel_id);
2494
2495    let open_response = db
2496        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2497        .await?;
2498
2499    let replica_id = open_response.replica_id;
2500    let collaborators = open_response.collaborators.clone();
2501
2502    response.send(open_response)?;
2503
2504    let update = AddChannelBufferCollaborator {
2505        channel_id: channel_id.to_proto(),
2506        collaborator: Some(proto::Collaborator {
2507            user_id: session.user_id.to_proto(),
2508            peer_id: Some(session.connection_id.into()),
2509            replica_id,
2510        }),
2511    };
2512    channel_buffer_updated(
2513        session.connection_id,
2514        collaborators
2515            .iter()
2516            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2517        &update,
2518        &session.peer,
2519    );
2520
2521    Ok(())
2522}
2523
2524async fn update_channel_buffer(
2525    request: proto::UpdateChannelBuffer,
2526    session: Session,
2527) -> Result<()> {
2528    let db = session.db().await;
2529    let channel_id = ChannelId::from_proto(request.channel_id);
2530
2531    let collaborators = db
2532        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2533        .await?;
2534
2535    channel_buffer_updated(
2536        session.connection_id,
2537        collaborators,
2538        &proto::UpdateChannelBuffer {
2539            channel_id: channel_id.to_proto(),
2540            operations: request.operations,
2541        },
2542        &session.peer,
2543    );
2544    Ok(())
2545}
2546
2547async fn leave_channel_buffer(
2548    request: proto::LeaveChannelBuffer,
2549    response: Response<proto::LeaveChannelBuffer>,
2550    session: Session,
2551) -> Result<()> {
2552    let db = session.db().await;
2553    let channel_id = ChannelId::from_proto(request.channel_id);
2554
2555    let collaborators_to_notify = db
2556        .leave_channel_buffer(channel_id, session.connection_id)
2557        .await?;
2558
2559    response.send(Ack {})?;
2560
2561    channel_buffer_updated(
2562        session.connection_id,
2563        collaborators_to_notify,
2564        &proto::RemoveChannelBufferCollaborator {
2565            channel_id: channel_id.to_proto(),
2566            peer_id: Some(session.connection_id.into()),
2567        },
2568        &session.peer,
2569    );
2570
2571    Ok(())
2572}
2573
2574fn channel_buffer_updated<T: EnvelopedMessage>(
2575    sender_id: ConnectionId,
2576    collaborators: impl IntoIterator<Item = ConnectionId>,
2577    message: &T,
2578    peer: &Peer,
2579) {
2580    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2581        peer.send(peer_id.into(), message.clone())
2582    });
2583}
2584
2585async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2586    let project_id = ProjectId::from_proto(request.project_id);
2587    let project_connection_ids = session
2588        .db()
2589        .await
2590        .project_connection_ids(project_id, session.connection_id)
2591        .await?;
2592    broadcast(
2593        Some(session.connection_id),
2594        project_connection_ids.iter().copied(),
2595        |connection_id| {
2596            session
2597                .peer
2598                .forward_send(session.connection_id, connection_id, request.clone())
2599        },
2600    );
2601    Ok(())
2602}
2603
2604async fn get_private_user_info(
2605    _request: proto::GetPrivateUserInfo,
2606    response: Response<proto::GetPrivateUserInfo>,
2607    session: Session,
2608) -> Result<()> {
2609    let metrics_id = session
2610        .db()
2611        .await
2612        .get_user_metrics_id(session.user_id)
2613        .await?;
2614    let user = session
2615        .db()
2616        .await
2617        .get_user_by_id(session.user_id)
2618        .await?
2619        .ok_or_else(|| anyhow!("user not found"))?;
2620    response.send(proto::GetPrivateUserInfoResponse {
2621        metrics_id,
2622        staff: user.admin,
2623    })?;
2624    Ok(())
2625}
2626
2627fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2628    match message {
2629        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2630        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2631        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2632        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2633        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2634            code: frame.code.into(),
2635            reason: frame.reason,
2636        })),
2637    }
2638}
2639
2640fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2641    match message {
2642        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2643        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2644        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2645        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2646        AxumMessage::Close(frame) => {
2647            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2648                code: frame.code.into(),
2649                reason: frame.reason,
2650            }))
2651        }
2652    }
2653}
2654
2655fn build_initial_channels_update(
2656    channels: ChannelsForUser,
2657    channel_invites: Vec<db::Channel>,
2658) -> proto::UpdateChannels {
2659    let mut update = proto::UpdateChannels::default();
2660
2661    for channel in channels.channels {
2662        update.channels.push(proto::Channel {
2663            id: channel.id.to_proto(),
2664            name: channel.name,
2665            parent_id: channel.parent_id.map(|id| id.to_proto()),
2666        });
2667    }
2668
2669    for (channel_id, participants) in channels.channel_participants {
2670        update
2671            .channel_participants
2672            .push(proto::ChannelParticipants {
2673                channel_id: channel_id.to_proto(),
2674                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2675            });
2676    }
2677
2678    update
2679        .channel_permissions
2680        .extend(
2681            channels
2682                .channels_with_admin_privileges
2683                .into_iter()
2684                .map(|id| proto::ChannelPermission {
2685                    channel_id: id.to_proto(),
2686                    is_admin: true,
2687                }),
2688        );
2689
2690    for channel in channel_invites {
2691        update.channel_invitations.push(proto::Channel {
2692            id: channel.id.to_proto(),
2693            name: channel.name,
2694            parent_id: None,
2695        });
2696    }
2697
2698    update
2699}
2700
2701fn build_initial_contacts_update(
2702    contacts: Vec<db::Contact>,
2703    pool: &ConnectionPool,
2704) -> proto::UpdateContacts {
2705    let mut update = proto::UpdateContacts::default();
2706
2707    for contact in contacts {
2708        match contact {
2709            db::Contact::Accepted {
2710                user_id,
2711                should_notify,
2712                busy,
2713            } => {
2714                update
2715                    .contacts
2716                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2717            }
2718            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2719            db::Contact::Incoming {
2720                user_id,
2721                should_notify,
2722            } => update
2723                .incoming_requests
2724                .push(proto::IncomingContactRequest {
2725                    requester_id: user_id.to_proto(),
2726                    should_notify,
2727                }),
2728        }
2729    }
2730
2731    update
2732}
2733
2734fn contact_for_user(
2735    user_id: UserId,
2736    should_notify: bool,
2737    busy: bool,
2738    pool: &ConnectionPool,
2739) -> proto::Contact {
2740    proto::Contact {
2741        user_id: user_id.to_proto(),
2742        online: pool.is_user_online(user_id),
2743        busy,
2744        should_notify,
2745    }
2746}
2747
2748fn room_updated(room: &proto::Room, peer: &Peer) {
2749    broadcast(
2750        None,
2751        room.participants
2752            .iter()
2753            .filter_map(|participant| Some(participant.peer_id?.into())),
2754        |peer_id| {
2755            peer.send(
2756                peer_id.into(),
2757                proto::RoomUpdated {
2758                    room: Some(room.clone()),
2759                },
2760            )
2761        },
2762    );
2763}
2764
2765fn channel_updated(
2766    channel_id: ChannelId,
2767    room: &proto::Room,
2768    channel_members: &[UserId],
2769    peer: &Peer,
2770    pool: &ConnectionPool,
2771) {
2772    let participants = room
2773        .participants
2774        .iter()
2775        .map(|p| p.user_id)
2776        .collect::<Vec<_>>();
2777
2778    broadcast(
2779        None,
2780        channel_members
2781            .iter()
2782            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2783        |peer_id| {
2784            peer.send(
2785                peer_id.into(),
2786                proto::UpdateChannels {
2787                    channel_participants: vec![proto::ChannelParticipants {
2788                        channel_id: channel_id.to_proto(),
2789                        participant_user_ids: participants.clone(),
2790                    }],
2791                    ..Default::default()
2792                },
2793            )
2794        },
2795    );
2796}
2797
2798async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2799    let db = session.db().await;
2800
2801    let contacts = db.get_contacts(user_id).await?;
2802    let busy = db.is_user_busy(user_id).await?;
2803
2804    let pool = session.connection_pool().await;
2805    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2806    for contact in contacts {
2807        if let db::Contact::Accepted {
2808            user_id: contact_user_id,
2809            ..
2810        } = contact
2811        {
2812            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2813                session
2814                    .peer
2815                    .send(
2816                        contact_conn_id,
2817                        proto::UpdateContacts {
2818                            contacts: vec![updated_contact.clone()],
2819                            remove_contacts: Default::default(),
2820                            incoming_requests: Default::default(),
2821                            remove_incoming_requests: Default::default(),
2822                            outgoing_requests: Default::default(),
2823                            remove_outgoing_requests: Default::default(),
2824                        },
2825                    )
2826                    .trace_err();
2827            }
2828        }
2829    }
2830    Ok(())
2831}
2832
2833async fn leave_room_for_session(session: &Session) -> Result<()> {
2834    let mut contacts_to_update = HashSet::default();
2835
2836    let room_id;
2837    let canceled_calls_to_user_ids;
2838    let live_kit_room;
2839    let delete_live_kit_room;
2840    let room;
2841    let channel_members;
2842    let channel_id;
2843
2844    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2845        contacts_to_update.insert(session.user_id);
2846
2847        for project in left_room.left_projects.values() {
2848            project_left(project, session);
2849        }
2850
2851        room_id = RoomId::from_proto(left_room.room.id);
2852        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2853        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2854        delete_live_kit_room = left_room.deleted;
2855        room = mem::take(&mut left_room.room);
2856        channel_members = mem::take(&mut left_room.channel_members);
2857        channel_id = left_room.channel_id;
2858
2859        room_updated(&room, &session.peer);
2860    } else {
2861        return Ok(());
2862    }
2863
2864    if let Some(channel_id) = channel_id {
2865        channel_updated(
2866            channel_id,
2867            &room,
2868            &channel_members,
2869            &session.peer,
2870            &*session.connection_pool().await,
2871        );
2872    }
2873
2874    {
2875        let pool = session.connection_pool().await;
2876        for canceled_user_id in canceled_calls_to_user_ids {
2877            for connection_id in pool.user_connection_ids(canceled_user_id) {
2878                session
2879                    .peer
2880                    .send(
2881                        connection_id,
2882                        proto::CallCanceled {
2883                            room_id: room_id.to_proto(),
2884                        },
2885                    )
2886                    .trace_err();
2887            }
2888            contacts_to_update.insert(canceled_user_id);
2889        }
2890    }
2891
2892    for contact_user_id in contacts_to_update {
2893        update_user_contacts(contact_user_id, &session).await?;
2894    }
2895
2896    if let Some(live_kit) = session.live_kit_client.as_ref() {
2897        live_kit
2898            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2899            .await
2900            .trace_err();
2901
2902        if delete_live_kit_room {
2903            live_kit.delete_room(live_kit_room).await.trace_err();
2904        }
2905    }
2906
2907    Ok(())
2908}
2909
2910async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
2911    let left_channel_buffers = session
2912        .db()
2913        .await
2914        .leave_channel_buffers(session.connection_id)
2915        .await?;
2916
2917    for (channel_id, connections) in left_channel_buffers {
2918        channel_buffer_updated(
2919            session.connection_id,
2920            connections,
2921            &proto::RemoveChannelBufferCollaborator {
2922                channel_id: channel_id.to_proto(),
2923                peer_id: Some(session.connection_id.into()),
2924            },
2925            &session.peer,
2926        );
2927    }
2928
2929    Ok(())
2930}
2931
2932fn project_left(project: &db::LeftProject, session: &Session) {
2933    for connection_id in &project.connection_ids {
2934        if project.host_user_id == session.user_id {
2935            session
2936                .peer
2937                .send(
2938                    *connection_id,
2939                    proto::UnshareProject {
2940                        project_id: project.id.to_proto(),
2941                    },
2942                )
2943                .trace_err();
2944        } else {
2945            session
2946                .peer
2947                .send(
2948                    *connection_id,
2949                    proto::RemoveProjectCollaborator {
2950                        project_id: project.id.to_proto(),
2951                        peer_id: Some(session.connection_id.into()),
2952                    },
2953                )
2954                .trace_err();
2955        }
2956    }
2957}
2958
2959pub trait ResultExt {
2960    type Ok;
2961
2962    fn trace_err(self) -> Option<Self::Ok>;
2963}
2964
2965impl<T, E> ResultExt for Result<T, E>
2966where
2967    E: std::fmt::Debug,
2968{
2969    type Ok = T;
2970
2971    fn trace_err(self) -> Option<T> {
2972        match self {
2973            Ok(value) => Some(value),
2974            Err(error) => {
2975                tracing::error!("{:?}", error);
2976                None
2977            }
2978        }
2979    }
2980}