rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
  39        LiveKitConnectionInfo, RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(set_channel_member_admin)
 250            .add_request_handler(rename_channel)
 251            .add_request_handler(join_channel_buffer)
 252            .add_request_handler(leave_channel_buffer)
 253            .add_message_handler(update_channel_buffer)
 254            .add_request_handler(get_channel_members)
 255            .add_request_handler(respond_to_channel_invite)
 256            .add_request_handler(join_channel)
 257            .add_request_handler(follow)
 258            .add_message_handler(unfollow)
 259            .add_message_handler(update_followers)
 260            .add_message_handler(update_diff_base)
 261            .add_request_handler(get_private_user_info);
 262
 263        Arc::new(server)
 264    }
 265
 266    pub async fn start(&self) -> Result<()> {
 267        let server_id = *self.id.lock();
 268        let app_state = self.app_state.clone();
 269        let peer = self.peer.clone();
 270        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 271        let pool = self.connection_pool.clone();
 272        let live_kit_client = self.app_state.live_kit_client.clone();
 273
 274        let span = info_span!("start server");
 275        self.executor.spawn_detached(
 276            async move {
 277                tracing::info!("waiting for cleanup timeout");
 278                timeout.await;
 279                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 280                if let Some(room_ids) = app_state
 281                    .db
 282                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 283                    .await
 284                    .trace_err()
 285                {
 286                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 287                    for room_id in room_ids {
 288                        let mut contacts_to_update = HashSet::default();
 289                        let mut canceled_calls_to_user_ids = Vec::new();
 290                        let mut live_kit_room = String::new();
 291                        let mut delete_live_kit_room = false;
 292
 293                        if let Some(mut refreshed_room) = app_state
 294                            .db
 295                            .refresh_room(room_id, server_id)
 296                            .await
 297                            .trace_err()
 298                        {
 299                            tracing::info!(
 300                                room_id = room_id.0,
 301                                new_participant_count = refreshed_room.room.participants.len(),
 302                                "refreshed room"
 303                            );
 304                            room_updated(&refreshed_room.room, &peer);
 305                            if let Some(channel_id) = refreshed_room.channel_id {
 306                                channel_updated(
 307                                    channel_id,
 308                                    &refreshed_room.room,
 309                                    &refreshed_room.channel_members,
 310                                    &peer,
 311                                    &*pool.lock(),
 312                                );
 313                            }
 314                            contacts_to_update
 315                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 316                            contacts_to_update
 317                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 318                            canceled_calls_to_user_ids =
 319                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 320                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 321                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 322                        }
 323
 324                        {
 325                            let pool = pool.lock();
 326                            for canceled_user_id in canceled_calls_to_user_ids {
 327                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 328                                    peer.send(
 329                                        connection_id,
 330                                        proto::CallCanceled {
 331                                            room_id: room_id.to_proto(),
 332                                        },
 333                                    )
 334                                    .trace_err();
 335                                }
 336                            }
 337                        }
 338
 339                        for user_id in contacts_to_update {
 340                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 341                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 342                            if let Some((busy, contacts)) = busy.zip(contacts) {
 343                                let pool = pool.lock();
 344                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 345                                for contact in contacts {
 346                                    if let db::Contact::Accepted {
 347                                        user_id: contact_user_id,
 348                                        ..
 349                                    } = contact
 350                                    {
 351                                        for contact_conn_id in
 352                                            pool.user_connection_ids(contact_user_id)
 353                                        {
 354                                            peer.send(
 355                                                contact_conn_id,
 356                                                proto::UpdateContacts {
 357                                                    contacts: vec![updated_contact.clone()],
 358                                                    remove_contacts: Default::default(),
 359                                                    incoming_requests: Default::default(),
 360                                                    remove_incoming_requests: Default::default(),
 361                                                    outgoing_requests: Default::default(),
 362                                                    remove_outgoing_requests: Default::default(),
 363                                                },
 364                                            )
 365                                            .trace_err();
 366                                        }
 367                                    }
 368                                }
 369                            }
 370                        }
 371
 372                        if let Some(live_kit) = live_kit_client.as_ref() {
 373                            if delete_live_kit_room {
 374                                live_kit.delete_room(live_kit_room).await.trace_err();
 375                            }
 376                        }
 377                    }
 378                }
 379
 380                app_state
 381                    .db
 382                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 383                    .await
 384                    .trace_err();
 385            }
 386            .instrument(span),
 387        );
 388        Ok(())
 389    }
 390
 391    pub fn teardown(&self) {
 392        self.peer.teardown();
 393        self.connection_pool.lock().reset();
 394        let _ = self.teardown.send(());
 395    }
 396
 397    #[cfg(test)]
 398    pub fn reset(&self, id: ServerId) {
 399        self.teardown();
 400        *self.id.lock() = id;
 401        self.peer.reset(id.0 as u32);
 402    }
 403
 404    #[cfg(test)]
 405    pub fn id(&self) -> ServerId {
 406        *self.id.lock()
 407    }
 408
 409    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 410    where
 411        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 412        Fut: 'static + Send + Future<Output = Result<()>>,
 413        M: EnvelopedMessage,
 414    {
 415        let prev_handler = self.handlers.insert(
 416            TypeId::of::<M>(),
 417            Box::new(move |envelope, session| {
 418                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 419                let span = info_span!(
 420                    "handle message",
 421                    payload_type = envelope.payload_type_name()
 422                );
 423                span.in_scope(|| {
 424                    tracing::info!(
 425                        payload_type = envelope.payload_type_name(),
 426                        "message received"
 427                    );
 428                });
 429                let start_time = Instant::now();
 430                let future = (handler)(*envelope, session);
 431                async move {
 432                    let result = future.await;
 433                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 434                    match result {
 435                        Err(error) => {
 436                            tracing::error!(%error, ?duration_ms, "error handling message")
 437                        }
 438                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 439                    }
 440                }
 441                .instrument(span)
 442                .boxed()
 443            }),
 444        );
 445        if prev_handler.is_some() {
 446            panic!("registered a handler for the same message twice");
 447        }
 448        self
 449    }
 450
 451    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 452    where
 453        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 454        Fut: 'static + Send + Future<Output = Result<()>>,
 455        M: EnvelopedMessage,
 456    {
 457        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 458        self
 459    }
 460
 461    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 462    where
 463        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 464        Fut: Send + Future<Output = Result<()>>,
 465        M: RequestMessage,
 466    {
 467        let handler = Arc::new(handler);
 468        self.add_handler(move |envelope, session| {
 469            let receipt = envelope.receipt();
 470            let handler = handler.clone();
 471            async move {
 472                let peer = session.peer.clone();
 473                let responded = Arc::new(AtomicBool::default());
 474                let response = Response {
 475                    peer: peer.clone(),
 476                    responded: responded.clone(),
 477                    receipt,
 478                };
 479                match (handler)(envelope.payload, response, session).await {
 480                    Ok(()) => {
 481                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 482                            Ok(())
 483                        } else {
 484                            Err(anyhow!("handler did not send a response"))?
 485                        }
 486                    }
 487                    Err(error) => {
 488                        peer.respond_with_error(
 489                            receipt,
 490                            proto::Error {
 491                                message: error.to_string(),
 492                            },
 493                        )?;
 494                        Err(error)
 495                    }
 496                }
 497            }
 498        })
 499    }
 500
 501    pub fn handle_connection(
 502        self: &Arc<Self>,
 503        connection: Connection,
 504        address: String,
 505        user: User,
 506        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 507        executor: Executor,
 508    ) -> impl Future<Output = Result<()>> {
 509        let this = self.clone();
 510        let user_id = user.id;
 511        let login = user.github_login;
 512        let span = info_span!("handle connection", %user_id, %login, %address);
 513        let mut teardown = self.teardown.subscribe();
 514        async move {
 515            let (connection_id, handle_io, mut incoming_rx) = this
 516                .peer
 517                .add_connection(connection, {
 518                    let executor = executor.clone();
 519                    move |duration| executor.sleep(duration)
 520                });
 521
 522            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 523            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 524            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 525
 526            if let Some(send_connection_id) = send_connection_id.take() {
 527                let _ = send_connection_id.send(connection_id);
 528            }
 529
 530            if !user.connected_once {
 531                this.peer.send(connection_id, proto::ShowContacts {})?;
 532                this.app_state.db.set_user_connected_once(user_id, true).await?;
 533            }
 534
 535            let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
 536                this.app_state.db.get_contacts(user_id),
 537                this.app_state.db.get_invite_code_for_user(user_id),
 538                this.app_state.db.get_channels_for_user(user_id),
 539                this.app_state.db.get_channel_invites_for_user(user_id)
 540            ).await?;
 541
 542            {
 543                let mut pool = this.connection_pool.lock();
 544                pool.add_connection(connection_id, user_id, user.admin);
 545                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 546                this.peer.send(connection_id, build_initial_channels_update(
 547                    channels_for_user,
 548                    channel_invites
 549                ))?;
 550
 551                if let Some((code, count)) = invite_code {
 552                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 553                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 554                        count: count as u32,
 555                    })?;
 556                }
 557            }
 558
 559            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 560                this.peer.send(connection_id, incoming_call)?;
 561            }
 562
 563            let session = Session {
 564                user_id,
 565                connection_id,
 566                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 567                peer: this.peer.clone(),
 568                connection_pool: this.connection_pool.clone(),
 569                live_kit_client: this.app_state.live_kit_client.clone(),
 570                executor: executor.clone(),
 571            };
 572            update_user_contacts(user_id, &session).await?;
 573
 574            let handle_io = handle_io.fuse();
 575            futures::pin_mut!(handle_io);
 576
 577            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 578            // This prevents deadlocks when e.g., client A performs a request to client B and
 579            // client B performs a request to client A. If both clients stop processing further
 580            // messages until their respective request completes, they won't have a chance to
 581            // respond to the other client's request and cause a deadlock.
 582            //
 583            // This arrangement ensures we will attempt to process earlier messages first, but fall
 584            // back to processing messages arrived later in the spirit of making progress.
 585            let mut foreground_message_handlers = FuturesUnordered::new();
 586            let concurrent_handlers = Arc::new(Semaphore::new(256));
 587            loop {
 588                let next_message = async {
 589                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 590                    let message = incoming_rx.next().await;
 591                    (permit, message)
 592                }.fuse();
 593                futures::pin_mut!(next_message);
 594                futures::select_biased! {
 595                    _ = teardown.changed().fuse() => return Ok(()),
 596                    result = handle_io => {
 597                        if let Err(error) = result {
 598                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 599                        }
 600                        break;
 601                    }
 602                    _ = foreground_message_handlers.next() => {}
 603                    next_message = next_message => {
 604                        let (permit, message) = next_message;
 605                        if let Some(message) = message {
 606                            let type_name = message.payload_type_name();
 607                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 608                            let span_enter = span.enter();
 609                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 610                                let is_background = message.is_background();
 611                                let handle_message = (handler)(message, session.clone());
 612                                drop(span_enter);
 613
 614                                let handle_message = async move {
 615                                    handle_message.await;
 616                                    drop(permit);
 617                                }.instrument(span);
 618                                if is_background {
 619                                    executor.spawn_detached(handle_message);
 620                                } else {
 621                                    foreground_message_handlers.push(handle_message);
 622                                }
 623                            } else {
 624                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 625                            }
 626                        } else {
 627                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 628                            break;
 629                        }
 630                    }
 631                }
 632            }
 633
 634            drop(foreground_message_handlers);
 635            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 636            if let Err(error) = connection_lost(session, teardown, executor).await {
 637                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 638            }
 639
 640            Ok(())
 641        }.instrument(span)
 642    }
 643
 644    pub async fn invite_code_redeemed(
 645        self: &Arc<Self>,
 646        inviter_id: UserId,
 647        invitee_id: UserId,
 648    ) -> Result<()> {
 649        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 650            if let Some(code) = &user.invite_code {
 651                let pool = self.connection_pool.lock();
 652                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 653                for connection_id in pool.user_connection_ids(inviter_id) {
 654                    self.peer.send(
 655                        connection_id,
 656                        proto::UpdateContacts {
 657                            contacts: vec![invitee_contact.clone()],
 658                            ..Default::default()
 659                        },
 660                    )?;
 661                    self.peer.send(
 662                        connection_id,
 663                        proto::UpdateInviteInfo {
 664                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 665                            count: user.invite_count as u32,
 666                        },
 667                    )?;
 668                }
 669            }
 670        }
 671        Ok(())
 672    }
 673
 674    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 675        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 676            if let Some(invite_code) = &user.invite_code {
 677                let pool = self.connection_pool.lock();
 678                for connection_id in pool.user_connection_ids(user_id) {
 679                    self.peer.send(
 680                        connection_id,
 681                        proto::UpdateInviteInfo {
 682                            url: format!(
 683                                "{}{}",
 684                                self.app_state.config.invite_link_prefix, invite_code
 685                            ),
 686                            count: user.invite_count as u32,
 687                        },
 688                    )?;
 689                }
 690            }
 691        }
 692        Ok(())
 693    }
 694
 695    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 696        ServerSnapshot {
 697            connection_pool: ConnectionPoolGuard {
 698                guard: self.connection_pool.lock(),
 699                _not_send: PhantomData,
 700            },
 701            peer: &self.peer,
 702        }
 703    }
 704}
 705
 706impl<'a> Deref for ConnectionPoolGuard<'a> {
 707    type Target = ConnectionPool;
 708
 709    fn deref(&self) -> &Self::Target {
 710        &*self.guard
 711    }
 712}
 713
 714impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 715    fn deref_mut(&mut self) -> &mut Self::Target {
 716        &mut *self.guard
 717    }
 718}
 719
 720impl<'a> Drop for ConnectionPoolGuard<'a> {
 721    fn drop(&mut self) {
 722        #[cfg(test)]
 723        self.check_invariants();
 724    }
 725}
 726
 727fn broadcast<F>(
 728    sender_id: Option<ConnectionId>,
 729    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 730    mut f: F,
 731) where
 732    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 733{
 734    for receiver_id in receiver_ids {
 735        if Some(receiver_id) != sender_id {
 736            if let Err(error) = f(receiver_id) {
 737                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 738            }
 739        }
 740    }
 741}
 742
 743lazy_static! {
 744    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 745}
 746
 747pub struct ProtocolVersion(u32);
 748
 749impl Header for ProtocolVersion {
 750    fn name() -> &'static HeaderName {
 751        &ZED_PROTOCOL_VERSION
 752    }
 753
 754    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 755    where
 756        Self: Sized,
 757        I: Iterator<Item = &'i axum::http::HeaderValue>,
 758    {
 759        let version = values
 760            .next()
 761            .ok_or_else(axum::headers::Error::invalid)?
 762            .to_str()
 763            .map_err(|_| axum::headers::Error::invalid())?
 764            .parse()
 765            .map_err(|_| axum::headers::Error::invalid())?;
 766        Ok(Self(version))
 767    }
 768
 769    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 770        values.extend([self.0.to_string().parse().unwrap()]);
 771    }
 772}
 773
 774pub fn routes(server: Arc<Server>) -> Router<Body> {
 775    Router::new()
 776        .route("/rpc", get(handle_websocket_request))
 777        .layer(
 778            ServiceBuilder::new()
 779                .layer(Extension(server.app_state.clone()))
 780                .layer(middleware::from_fn(auth::validate_header)),
 781        )
 782        .route("/metrics", get(handle_metrics))
 783        .layer(Extension(server))
 784}
 785
 786pub async fn handle_websocket_request(
 787    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 788    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 789    Extension(server): Extension<Arc<Server>>,
 790    Extension(user): Extension<User>,
 791    ws: WebSocketUpgrade,
 792) -> axum::response::Response {
 793    if protocol_version != rpc::PROTOCOL_VERSION {
 794        return (
 795            StatusCode::UPGRADE_REQUIRED,
 796            "client must be upgraded".to_string(),
 797        )
 798            .into_response();
 799    }
 800    let socket_address = socket_address.to_string();
 801    ws.on_upgrade(move |socket| {
 802        use util::ResultExt;
 803        let socket = socket
 804            .map_ok(to_tungstenite_message)
 805            .err_into()
 806            .with(|message| async move { Ok(to_axum_message(message)) });
 807        let connection = Connection::new(Box::pin(socket));
 808        async move {
 809            server
 810                .handle_connection(connection, socket_address, user, None, Executor::Production)
 811                .await
 812                .log_err();
 813        }
 814    })
 815}
 816
 817pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 818    let connections = server
 819        .connection_pool
 820        .lock()
 821        .connections()
 822        .filter(|connection| !connection.admin)
 823        .count();
 824
 825    METRIC_CONNECTIONS.set(connections as _);
 826
 827    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 828    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 829
 830    let encoder = prometheus::TextEncoder::new();
 831    let metric_families = prometheus::gather();
 832    let encoded_metrics = encoder
 833        .encode_to_string(&metric_families)
 834        .map_err(|err| anyhow!("{}", err))?;
 835    Ok(encoded_metrics)
 836}
 837
 838#[instrument(err, skip(executor))]
 839async fn connection_lost(
 840    session: Session,
 841    mut teardown: watch::Receiver<()>,
 842    executor: Executor,
 843) -> Result<()> {
 844    session.peer.disconnect(session.connection_id);
 845    session
 846        .connection_pool()
 847        .await
 848        .remove_connection(session.connection_id)?;
 849
 850    session
 851        .db()
 852        .await
 853        .connection_lost(session.connection_id)
 854        .await
 855        .trace_err();
 856
 857    leave_channel_buffers_for_session(&session)
 858        .await
 859        .trace_err();
 860
 861    futures::select_biased! {
 862        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 863            leave_room_for_session(&session).await.trace_err();
 864
 865            if !session
 866                .connection_pool()
 867                .await
 868                .is_user_online(session.user_id)
 869            {
 870                let db = session.db().await;
 871                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 872                    room_updated(&room, &session.peer);
 873                }
 874            }
 875            update_user_contacts(session.user_id, &session).await?;
 876
 877
 878        }
 879        _ = teardown.changed().fuse() => {}
 880    }
 881
 882    Ok(())
 883}
 884
 885async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 886    response.send(proto::Ack {})?;
 887    Ok(())
 888}
 889
 890async fn create_room(
 891    _request: proto::CreateRoom,
 892    response: Response<proto::CreateRoom>,
 893    session: Session,
 894) -> Result<()> {
 895    let live_kit_room = nanoid::nanoid!(30);
 896
 897    let live_kit_connection_info = {
 898        let live_kit_room = live_kit_room.clone();
 899        let live_kit = session.live_kit_client.as_ref();
 900
 901        util::async_iife!({
 902            let live_kit = live_kit?;
 903
 904            live_kit
 905                .create_room(live_kit_room.clone())
 906                .await
 907                .trace_err()?;
 908
 909            let token = live_kit
 910                .room_token(&live_kit_room, &session.user_id.to_string())
 911                .trace_err()?;
 912
 913            Some(proto::LiveKitConnectionInfo {
 914                server_url: live_kit.url().into(),
 915                token,
 916            })
 917        })
 918    }
 919    .await;
 920
 921    let room = session
 922        .db()
 923        .await
 924        .create_room(session.user_id, session.connection_id, &live_kit_room)
 925        .await?;
 926
 927    response.send(proto::CreateRoomResponse {
 928        room: Some(room.clone()),
 929        live_kit_connection_info,
 930    })?;
 931
 932    update_user_contacts(session.user_id, &session).await?;
 933    Ok(())
 934}
 935
 936async fn join_room(
 937    request: proto::JoinRoom,
 938    response: Response<proto::JoinRoom>,
 939    session: Session,
 940) -> Result<()> {
 941    let room_id = RoomId::from_proto(request.id);
 942    let joined_room = {
 943        let room = session
 944            .db()
 945            .await
 946            .join_room(room_id, session.user_id, session.connection_id)
 947            .await?;
 948        room_updated(&room.room, &session.peer);
 949        room.into_inner()
 950    };
 951
 952    if let Some(channel_id) = joined_room.channel_id {
 953        channel_updated(
 954            channel_id,
 955            &joined_room.room,
 956            &joined_room.channel_members,
 957            &session.peer,
 958            &*session.connection_pool().await,
 959        )
 960    }
 961
 962    for connection_id in session
 963        .connection_pool()
 964        .await
 965        .user_connection_ids(session.user_id)
 966    {
 967        session
 968            .peer
 969            .send(
 970                connection_id,
 971                proto::CallCanceled {
 972                    room_id: room_id.to_proto(),
 973                },
 974            )
 975            .trace_err();
 976    }
 977
 978    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 979        if let Some(token) = live_kit
 980            .room_token(
 981                &joined_room.room.live_kit_room,
 982                &session.user_id.to_string(),
 983            )
 984            .trace_err()
 985        {
 986            Some(proto::LiveKitConnectionInfo {
 987                server_url: live_kit.url().into(),
 988                token,
 989            })
 990        } else {
 991            None
 992        }
 993    } else {
 994        None
 995    };
 996
 997    response.send(proto::JoinRoomResponse {
 998        room: Some(joined_room.room),
 999        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1000        live_kit_connection_info,
1001    })?;
1002
1003    update_user_contacts(session.user_id, &session).await?;
1004    Ok(())
1005}
1006
1007async fn rejoin_room(
1008    request: proto::RejoinRoom,
1009    response: Response<proto::RejoinRoom>,
1010    session: Session,
1011) -> Result<()> {
1012    let room;
1013    let channel_id;
1014    let channel_members;
1015    {
1016        let mut rejoined_room = session
1017            .db()
1018            .await
1019            .rejoin_room(request, session.user_id, session.connection_id)
1020            .await?;
1021
1022        response.send(proto::RejoinRoomResponse {
1023            room: Some(rejoined_room.room.clone()),
1024            reshared_projects: rejoined_room
1025                .reshared_projects
1026                .iter()
1027                .map(|project| proto::ResharedProject {
1028                    id: project.id.to_proto(),
1029                    collaborators: project
1030                        .collaborators
1031                        .iter()
1032                        .map(|collaborator| collaborator.to_proto())
1033                        .collect(),
1034                })
1035                .collect(),
1036            rejoined_projects: rejoined_room
1037                .rejoined_projects
1038                .iter()
1039                .map(|rejoined_project| proto::RejoinedProject {
1040                    id: rejoined_project.id.to_proto(),
1041                    worktrees: rejoined_project
1042                        .worktrees
1043                        .iter()
1044                        .map(|worktree| proto::WorktreeMetadata {
1045                            id: worktree.id,
1046                            root_name: worktree.root_name.clone(),
1047                            visible: worktree.visible,
1048                            abs_path: worktree.abs_path.clone(),
1049                        })
1050                        .collect(),
1051                    collaborators: rejoined_project
1052                        .collaborators
1053                        .iter()
1054                        .map(|collaborator| collaborator.to_proto())
1055                        .collect(),
1056                    language_servers: rejoined_project.language_servers.clone(),
1057                })
1058                .collect(),
1059        })?;
1060        room_updated(&rejoined_room.room, &session.peer);
1061
1062        for project in &rejoined_room.reshared_projects {
1063            for collaborator in &project.collaborators {
1064                session
1065                    .peer
1066                    .send(
1067                        collaborator.connection_id,
1068                        proto::UpdateProjectCollaborator {
1069                            project_id: project.id.to_proto(),
1070                            old_peer_id: Some(project.old_connection_id.into()),
1071                            new_peer_id: Some(session.connection_id.into()),
1072                        },
1073                    )
1074                    .trace_err();
1075            }
1076
1077            broadcast(
1078                Some(session.connection_id),
1079                project
1080                    .collaborators
1081                    .iter()
1082                    .map(|collaborator| collaborator.connection_id),
1083                |connection_id| {
1084                    session.peer.forward_send(
1085                        session.connection_id,
1086                        connection_id,
1087                        proto::UpdateProject {
1088                            project_id: project.id.to_proto(),
1089                            worktrees: project.worktrees.clone(),
1090                        },
1091                    )
1092                },
1093            );
1094        }
1095
1096        for project in &rejoined_room.rejoined_projects {
1097            for collaborator in &project.collaborators {
1098                session
1099                    .peer
1100                    .send(
1101                        collaborator.connection_id,
1102                        proto::UpdateProjectCollaborator {
1103                            project_id: project.id.to_proto(),
1104                            old_peer_id: Some(project.old_connection_id.into()),
1105                            new_peer_id: Some(session.connection_id.into()),
1106                        },
1107                    )
1108                    .trace_err();
1109            }
1110        }
1111
1112        for project in &mut rejoined_room.rejoined_projects {
1113            for worktree in mem::take(&mut project.worktrees) {
1114                #[cfg(any(test, feature = "test-support"))]
1115                const MAX_CHUNK_SIZE: usize = 2;
1116                #[cfg(not(any(test, feature = "test-support")))]
1117                const MAX_CHUNK_SIZE: usize = 256;
1118
1119                // Stream this worktree's entries.
1120                let message = proto::UpdateWorktree {
1121                    project_id: project.id.to_proto(),
1122                    worktree_id: worktree.id,
1123                    abs_path: worktree.abs_path.clone(),
1124                    root_name: worktree.root_name,
1125                    updated_entries: worktree.updated_entries,
1126                    removed_entries: worktree.removed_entries,
1127                    scan_id: worktree.scan_id,
1128                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1129                    updated_repositories: worktree.updated_repositories,
1130                    removed_repositories: worktree.removed_repositories,
1131                };
1132                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1133                    session.peer.send(session.connection_id, update.clone())?;
1134                }
1135
1136                // Stream this worktree's diagnostics.
1137                for summary in worktree.diagnostic_summaries {
1138                    session.peer.send(
1139                        session.connection_id,
1140                        proto::UpdateDiagnosticSummary {
1141                            project_id: project.id.to_proto(),
1142                            worktree_id: worktree.id,
1143                            summary: Some(summary),
1144                        },
1145                    )?;
1146                }
1147
1148                for settings_file in worktree.settings_files {
1149                    session.peer.send(
1150                        session.connection_id,
1151                        proto::UpdateWorktreeSettings {
1152                            project_id: project.id.to_proto(),
1153                            worktree_id: worktree.id,
1154                            path: settings_file.path,
1155                            content: Some(settings_file.content),
1156                        },
1157                    )?;
1158                }
1159            }
1160
1161            for language_server in &project.language_servers {
1162                session.peer.send(
1163                    session.connection_id,
1164                    proto::UpdateLanguageServer {
1165                        project_id: project.id.to_proto(),
1166                        language_server_id: language_server.id,
1167                        variant: Some(
1168                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1169                                proto::LspDiskBasedDiagnosticsUpdated {},
1170                            ),
1171                        ),
1172                    },
1173                )?;
1174            }
1175        }
1176
1177        let rejoined_room = rejoined_room.into_inner();
1178
1179        room = rejoined_room.room;
1180        channel_id = rejoined_room.channel_id;
1181        channel_members = rejoined_room.channel_members;
1182    }
1183
1184    if let Some(channel_id) = channel_id {
1185        channel_updated(
1186            channel_id,
1187            &room,
1188            &channel_members,
1189            &session.peer,
1190            &*session.connection_pool().await,
1191        );
1192    }
1193
1194    update_user_contacts(session.user_id, &session).await?;
1195    Ok(())
1196}
1197
1198async fn leave_room(
1199    _: proto::LeaveRoom,
1200    response: Response<proto::LeaveRoom>,
1201    session: Session,
1202) -> Result<()> {
1203    leave_room_for_session(&session).await?;
1204    response.send(proto::Ack {})?;
1205    Ok(())
1206}
1207
1208async fn call(
1209    request: proto::Call,
1210    response: Response<proto::Call>,
1211    session: Session,
1212) -> Result<()> {
1213    let room_id = RoomId::from_proto(request.room_id);
1214    let calling_user_id = session.user_id;
1215    let calling_connection_id = session.connection_id;
1216    let called_user_id = UserId::from_proto(request.called_user_id);
1217    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1218    if !session
1219        .db()
1220        .await
1221        .has_contact(calling_user_id, called_user_id)
1222        .await?
1223    {
1224        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1225    }
1226
1227    let incoming_call = {
1228        let (room, incoming_call) = &mut *session
1229            .db()
1230            .await
1231            .call(
1232                room_id,
1233                calling_user_id,
1234                calling_connection_id,
1235                called_user_id,
1236                initial_project_id,
1237            )
1238            .await?;
1239        room_updated(&room, &session.peer);
1240        mem::take(incoming_call)
1241    };
1242    update_user_contacts(called_user_id, &session).await?;
1243
1244    let mut calls = session
1245        .connection_pool()
1246        .await
1247        .user_connection_ids(called_user_id)
1248        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1249        .collect::<FuturesUnordered<_>>();
1250
1251    while let Some(call_response) = calls.next().await {
1252        match call_response.as_ref() {
1253            Ok(_) => {
1254                response.send(proto::Ack {})?;
1255                return Ok(());
1256            }
1257            Err(_) => {
1258                call_response.trace_err();
1259            }
1260        }
1261    }
1262
1263    {
1264        let room = session
1265            .db()
1266            .await
1267            .call_failed(room_id, called_user_id)
1268            .await?;
1269        room_updated(&room, &session.peer);
1270    }
1271    update_user_contacts(called_user_id, &session).await?;
1272
1273    Err(anyhow!("failed to ring user"))?
1274}
1275
1276async fn cancel_call(
1277    request: proto::CancelCall,
1278    response: Response<proto::CancelCall>,
1279    session: Session,
1280) -> Result<()> {
1281    let called_user_id = UserId::from_proto(request.called_user_id);
1282    let room_id = RoomId::from_proto(request.room_id);
1283    {
1284        let room = session
1285            .db()
1286            .await
1287            .cancel_call(room_id, session.connection_id, called_user_id)
1288            .await?;
1289        room_updated(&room, &session.peer);
1290    }
1291
1292    for connection_id in session
1293        .connection_pool()
1294        .await
1295        .user_connection_ids(called_user_id)
1296    {
1297        session
1298            .peer
1299            .send(
1300                connection_id,
1301                proto::CallCanceled {
1302                    room_id: room_id.to_proto(),
1303                },
1304            )
1305            .trace_err();
1306    }
1307    response.send(proto::Ack {})?;
1308
1309    update_user_contacts(called_user_id, &session).await?;
1310    Ok(())
1311}
1312
1313async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1314    let room_id = RoomId::from_proto(message.room_id);
1315    {
1316        let room = session
1317            .db()
1318            .await
1319            .decline_call(Some(room_id), session.user_id)
1320            .await?
1321            .ok_or_else(|| anyhow!("failed to decline call"))?;
1322        room_updated(&room, &session.peer);
1323    }
1324
1325    for connection_id in session
1326        .connection_pool()
1327        .await
1328        .user_connection_ids(session.user_id)
1329    {
1330        session
1331            .peer
1332            .send(
1333                connection_id,
1334                proto::CallCanceled {
1335                    room_id: room_id.to_proto(),
1336                },
1337            )
1338            .trace_err();
1339    }
1340    update_user_contacts(session.user_id, &session).await?;
1341    Ok(())
1342}
1343
1344async fn update_participant_location(
1345    request: proto::UpdateParticipantLocation,
1346    response: Response<proto::UpdateParticipantLocation>,
1347    session: Session,
1348) -> Result<()> {
1349    let room_id = RoomId::from_proto(request.room_id);
1350    let location = request
1351        .location
1352        .ok_or_else(|| anyhow!("invalid location"))?;
1353
1354    let db = session.db().await;
1355    let room = db
1356        .update_room_participant_location(room_id, session.connection_id, location)
1357        .await?;
1358
1359    room_updated(&room, &session.peer);
1360    response.send(proto::Ack {})?;
1361    Ok(())
1362}
1363
1364async fn share_project(
1365    request: proto::ShareProject,
1366    response: Response<proto::ShareProject>,
1367    session: Session,
1368) -> Result<()> {
1369    let (project_id, room) = &*session
1370        .db()
1371        .await
1372        .share_project(
1373            RoomId::from_proto(request.room_id),
1374            session.connection_id,
1375            &request.worktrees,
1376        )
1377        .await?;
1378    response.send(proto::ShareProjectResponse {
1379        project_id: project_id.to_proto(),
1380    })?;
1381    room_updated(&room, &session.peer);
1382
1383    Ok(())
1384}
1385
1386async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1387    let project_id = ProjectId::from_proto(message.project_id);
1388
1389    let (room, guest_connection_ids) = &*session
1390        .db()
1391        .await
1392        .unshare_project(project_id, session.connection_id)
1393        .await?;
1394
1395    broadcast(
1396        Some(session.connection_id),
1397        guest_connection_ids.iter().copied(),
1398        |conn_id| session.peer.send(conn_id, message.clone()),
1399    );
1400    room_updated(&room, &session.peer);
1401
1402    Ok(())
1403}
1404
1405async fn join_project(
1406    request: proto::JoinProject,
1407    response: Response<proto::JoinProject>,
1408    session: Session,
1409) -> Result<()> {
1410    let project_id = ProjectId::from_proto(request.project_id);
1411    let guest_user_id = session.user_id;
1412
1413    tracing::info!(%project_id, "join project");
1414
1415    let (project, replica_id) = &mut *session
1416        .db()
1417        .await
1418        .join_project(project_id, session.connection_id)
1419        .await?;
1420
1421    let collaborators = project
1422        .collaborators
1423        .iter()
1424        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1425        .map(|collaborator| collaborator.to_proto())
1426        .collect::<Vec<_>>();
1427
1428    let worktrees = project
1429        .worktrees
1430        .iter()
1431        .map(|(id, worktree)| proto::WorktreeMetadata {
1432            id: *id,
1433            root_name: worktree.root_name.clone(),
1434            visible: worktree.visible,
1435            abs_path: worktree.abs_path.clone(),
1436        })
1437        .collect::<Vec<_>>();
1438
1439    for collaborator in &collaborators {
1440        session
1441            .peer
1442            .send(
1443                collaborator.peer_id.unwrap().into(),
1444                proto::AddProjectCollaborator {
1445                    project_id: project_id.to_proto(),
1446                    collaborator: Some(proto::Collaborator {
1447                        peer_id: Some(session.connection_id.into()),
1448                        replica_id: replica_id.0 as u32,
1449                        user_id: guest_user_id.to_proto(),
1450                    }),
1451                },
1452            )
1453            .trace_err();
1454    }
1455
1456    // First, we send the metadata associated with each worktree.
1457    response.send(proto::JoinProjectResponse {
1458        worktrees: worktrees.clone(),
1459        replica_id: replica_id.0 as u32,
1460        collaborators: collaborators.clone(),
1461        language_servers: project.language_servers.clone(),
1462    })?;
1463
1464    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1465        #[cfg(any(test, feature = "test-support"))]
1466        const MAX_CHUNK_SIZE: usize = 2;
1467        #[cfg(not(any(test, feature = "test-support")))]
1468        const MAX_CHUNK_SIZE: usize = 256;
1469
1470        // Stream this worktree's entries.
1471        let message = proto::UpdateWorktree {
1472            project_id: project_id.to_proto(),
1473            worktree_id,
1474            abs_path: worktree.abs_path.clone(),
1475            root_name: worktree.root_name,
1476            updated_entries: worktree.entries,
1477            removed_entries: Default::default(),
1478            scan_id: worktree.scan_id,
1479            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1480            updated_repositories: worktree.repository_entries.into_values().collect(),
1481            removed_repositories: Default::default(),
1482        };
1483        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1484            session.peer.send(session.connection_id, update.clone())?;
1485        }
1486
1487        // Stream this worktree's diagnostics.
1488        for summary in worktree.diagnostic_summaries {
1489            session.peer.send(
1490                session.connection_id,
1491                proto::UpdateDiagnosticSummary {
1492                    project_id: project_id.to_proto(),
1493                    worktree_id: worktree.id,
1494                    summary: Some(summary),
1495                },
1496            )?;
1497        }
1498
1499        for settings_file in worktree.settings_files {
1500            session.peer.send(
1501                session.connection_id,
1502                proto::UpdateWorktreeSettings {
1503                    project_id: project_id.to_proto(),
1504                    worktree_id: worktree.id,
1505                    path: settings_file.path,
1506                    content: Some(settings_file.content),
1507                },
1508            )?;
1509        }
1510    }
1511
1512    for language_server in &project.language_servers {
1513        session.peer.send(
1514            session.connection_id,
1515            proto::UpdateLanguageServer {
1516                project_id: project_id.to_proto(),
1517                language_server_id: language_server.id,
1518                variant: Some(
1519                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1520                        proto::LspDiskBasedDiagnosticsUpdated {},
1521                    ),
1522                ),
1523            },
1524        )?;
1525    }
1526
1527    Ok(())
1528}
1529
1530async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1531    let sender_id = session.connection_id;
1532    let project_id = ProjectId::from_proto(request.project_id);
1533
1534    let (room, project) = &*session
1535        .db()
1536        .await
1537        .leave_project(project_id, sender_id)
1538        .await?;
1539    tracing::info!(
1540        %project_id,
1541        host_user_id = %project.host_user_id,
1542        host_connection_id = %project.host_connection_id,
1543        "leave project"
1544    );
1545
1546    project_left(&project, &session);
1547    room_updated(&room, &session.peer);
1548
1549    Ok(())
1550}
1551
1552async fn update_project(
1553    request: proto::UpdateProject,
1554    response: Response<proto::UpdateProject>,
1555    session: Session,
1556) -> Result<()> {
1557    let project_id = ProjectId::from_proto(request.project_id);
1558    let (room, guest_connection_ids) = &*session
1559        .db()
1560        .await
1561        .update_project(project_id, session.connection_id, &request.worktrees)
1562        .await?;
1563    broadcast(
1564        Some(session.connection_id),
1565        guest_connection_ids.iter().copied(),
1566        |connection_id| {
1567            session
1568                .peer
1569                .forward_send(session.connection_id, connection_id, request.clone())
1570        },
1571    );
1572    room_updated(&room, &session.peer);
1573    response.send(proto::Ack {})?;
1574
1575    Ok(())
1576}
1577
1578async fn update_worktree(
1579    request: proto::UpdateWorktree,
1580    response: Response<proto::UpdateWorktree>,
1581    session: Session,
1582) -> Result<()> {
1583    let guest_connection_ids = session
1584        .db()
1585        .await
1586        .update_worktree(&request, session.connection_id)
1587        .await?;
1588
1589    broadcast(
1590        Some(session.connection_id),
1591        guest_connection_ids.iter().copied(),
1592        |connection_id| {
1593            session
1594                .peer
1595                .forward_send(session.connection_id, connection_id, request.clone())
1596        },
1597    );
1598    response.send(proto::Ack {})?;
1599    Ok(())
1600}
1601
1602async fn update_diagnostic_summary(
1603    message: proto::UpdateDiagnosticSummary,
1604    session: Session,
1605) -> Result<()> {
1606    let guest_connection_ids = session
1607        .db()
1608        .await
1609        .update_diagnostic_summary(&message, session.connection_id)
1610        .await?;
1611
1612    broadcast(
1613        Some(session.connection_id),
1614        guest_connection_ids.iter().copied(),
1615        |connection_id| {
1616            session
1617                .peer
1618                .forward_send(session.connection_id, connection_id, message.clone())
1619        },
1620    );
1621
1622    Ok(())
1623}
1624
1625async fn update_worktree_settings(
1626    message: proto::UpdateWorktreeSettings,
1627    session: Session,
1628) -> Result<()> {
1629    let guest_connection_ids = session
1630        .db()
1631        .await
1632        .update_worktree_settings(&message, session.connection_id)
1633        .await?;
1634
1635    broadcast(
1636        Some(session.connection_id),
1637        guest_connection_ids.iter().copied(),
1638        |connection_id| {
1639            session
1640                .peer
1641                .forward_send(session.connection_id, connection_id, message.clone())
1642        },
1643    );
1644
1645    Ok(())
1646}
1647
1648async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1649    broadcast_project_message(request.project_id, request, session).await
1650}
1651
1652async fn start_language_server(
1653    request: proto::StartLanguageServer,
1654    session: Session,
1655) -> Result<()> {
1656    let guest_connection_ids = session
1657        .db()
1658        .await
1659        .start_language_server(&request, session.connection_id)
1660        .await?;
1661
1662    broadcast(
1663        Some(session.connection_id),
1664        guest_connection_ids.iter().copied(),
1665        |connection_id| {
1666            session
1667                .peer
1668                .forward_send(session.connection_id, connection_id, request.clone())
1669        },
1670    );
1671    Ok(())
1672}
1673
1674async fn update_language_server(
1675    request: proto::UpdateLanguageServer,
1676    session: Session,
1677) -> Result<()> {
1678    session.executor.record_backtrace();
1679    let project_id = ProjectId::from_proto(request.project_id);
1680    let project_connection_ids = session
1681        .db()
1682        .await
1683        .project_connection_ids(project_id, session.connection_id)
1684        .await?;
1685    broadcast(
1686        Some(session.connection_id),
1687        project_connection_ids.iter().copied(),
1688        |connection_id| {
1689            session
1690                .peer
1691                .forward_send(session.connection_id, connection_id, request.clone())
1692        },
1693    );
1694    Ok(())
1695}
1696
1697async fn forward_project_request<T>(
1698    request: T,
1699    response: Response<T>,
1700    session: Session,
1701) -> Result<()>
1702where
1703    T: EntityMessage + RequestMessage,
1704{
1705    session.executor.record_backtrace();
1706    let project_id = ProjectId::from_proto(request.remote_entity_id());
1707    let host_connection_id = {
1708        let collaborators = session
1709            .db()
1710            .await
1711            .project_collaborators(project_id, session.connection_id)
1712            .await?;
1713        collaborators
1714            .iter()
1715            .find(|collaborator| collaborator.is_host)
1716            .ok_or_else(|| anyhow!("host not found"))?
1717            .connection_id
1718    };
1719
1720    let payload = session
1721        .peer
1722        .forward_request(session.connection_id, host_connection_id, request)
1723        .await?;
1724
1725    response.send(payload)?;
1726    Ok(())
1727}
1728
1729async fn create_buffer_for_peer(
1730    request: proto::CreateBufferForPeer,
1731    session: Session,
1732) -> Result<()> {
1733    session.executor.record_backtrace();
1734    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1735    session
1736        .peer
1737        .forward_send(session.connection_id, peer_id.into(), request)?;
1738    Ok(())
1739}
1740
1741async fn update_buffer(
1742    request: proto::UpdateBuffer,
1743    response: Response<proto::UpdateBuffer>,
1744    session: Session,
1745) -> Result<()> {
1746    session.executor.record_backtrace();
1747    let project_id = ProjectId::from_proto(request.project_id);
1748    let mut guest_connection_ids;
1749    let mut host_connection_id = None;
1750    {
1751        let collaborators = session
1752            .db()
1753            .await
1754            .project_collaborators(project_id, session.connection_id)
1755            .await?;
1756        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1757        for collaborator in collaborators.iter() {
1758            if collaborator.is_host {
1759                host_connection_id = Some(collaborator.connection_id);
1760            } else {
1761                guest_connection_ids.push(collaborator.connection_id);
1762            }
1763        }
1764    }
1765    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1766
1767    session.executor.record_backtrace();
1768    broadcast(
1769        Some(session.connection_id),
1770        guest_connection_ids,
1771        |connection_id| {
1772            session
1773                .peer
1774                .forward_send(session.connection_id, connection_id, request.clone())
1775        },
1776    );
1777    if host_connection_id != session.connection_id {
1778        session
1779            .peer
1780            .forward_request(session.connection_id, host_connection_id, request.clone())
1781            .await?;
1782    }
1783
1784    response.send(proto::Ack {})?;
1785    Ok(())
1786}
1787
1788async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1789    let project_id = ProjectId::from_proto(request.project_id);
1790    let project_connection_ids = session
1791        .db()
1792        .await
1793        .project_connection_ids(project_id, session.connection_id)
1794        .await?;
1795
1796    broadcast(
1797        Some(session.connection_id),
1798        project_connection_ids.iter().copied(),
1799        |connection_id| {
1800            session
1801                .peer
1802                .forward_send(session.connection_id, connection_id, request.clone())
1803        },
1804    );
1805    Ok(())
1806}
1807
1808async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1809    let project_id = ProjectId::from_proto(request.project_id);
1810    let project_connection_ids = session
1811        .db()
1812        .await
1813        .project_connection_ids(project_id, session.connection_id)
1814        .await?;
1815    broadcast(
1816        Some(session.connection_id),
1817        project_connection_ids.iter().copied(),
1818        |connection_id| {
1819            session
1820                .peer
1821                .forward_send(session.connection_id, connection_id, request.clone())
1822        },
1823    );
1824    Ok(())
1825}
1826
1827async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1828    broadcast_project_message(request.project_id, request, session).await
1829}
1830
1831async fn broadcast_project_message<T: EnvelopedMessage>(
1832    project_id: u64,
1833    request: T,
1834    session: Session,
1835) -> Result<()> {
1836    let project_id = ProjectId::from_proto(project_id);
1837    let project_connection_ids = session
1838        .db()
1839        .await
1840        .project_connection_ids(project_id, session.connection_id)
1841        .await?;
1842    broadcast(
1843        Some(session.connection_id),
1844        project_connection_ids.iter().copied(),
1845        |connection_id| {
1846            session
1847                .peer
1848                .forward_send(session.connection_id, connection_id, request.clone())
1849        },
1850    );
1851    Ok(())
1852}
1853
1854async fn follow(
1855    request: proto::Follow,
1856    response: Response<proto::Follow>,
1857    session: Session,
1858) -> Result<()> {
1859    let project_id = ProjectId::from_proto(request.project_id);
1860    let leader_id = request
1861        .leader_id
1862        .ok_or_else(|| anyhow!("invalid leader id"))?
1863        .into();
1864    let follower_id = session.connection_id;
1865
1866    {
1867        let project_connection_ids = session
1868            .db()
1869            .await
1870            .project_connection_ids(project_id, session.connection_id)
1871            .await?;
1872
1873        if !project_connection_ids.contains(&leader_id) {
1874            Err(anyhow!("no such peer"))?;
1875        }
1876    }
1877
1878    let mut response_payload = session
1879        .peer
1880        .forward_request(session.connection_id, leader_id, request)
1881        .await?;
1882    response_payload
1883        .views
1884        .retain(|view| view.leader_id != Some(follower_id.into()));
1885    response.send(response_payload)?;
1886
1887    let room = session
1888        .db()
1889        .await
1890        .follow(project_id, leader_id, follower_id)
1891        .await?;
1892    room_updated(&room, &session.peer);
1893
1894    Ok(())
1895}
1896
1897async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1898    let project_id = ProjectId::from_proto(request.project_id);
1899    let leader_id = request
1900        .leader_id
1901        .ok_or_else(|| anyhow!("invalid leader id"))?
1902        .into();
1903    let follower_id = session.connection_id;
1904
1905    if !session
1906        .db()
1907        .await
1908        .project_connection_ids(project_id, session.connection_id)
1909        .await?
1910        .contains(&leader_id)
1911    {
1912        Err(anyhow!("no such peer"))?;
1913    }
1914
1915    session
1916        .peer
1917        .forward_send(session.connection_id, leader_id, request)?;
1918
1919    let room = session
1920        .db()
1921        .await
1922        .unfollow(project_id, leader_id, follower_id)
1923        .await?;
1924    room_updated(&room, &session.peer);
1925
1926    Ok(())
1927}
1928
1929async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1930    let project_id = ProjectId::from_proto(request.project_id);
1931    let project_connection_ids = session
1932        .db
1933        .lock()
1934        .await
1935        .project_connection_ids(project_id, session.connection_id)
1936        .await?;
1937
1938    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1939        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1940        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1941        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1942    });
1943    for follower_peer_id in request.follower_ids.iter().copied() {
1944        let follower_connection_id = follower_peer_id.into();
1945        if project_connection_ids.contains(&follower_connection_id)
1946            && Some(follower_peer_id) != leader_id
1947        {
1948            session.peer.forward_send(
1949                session.connection_id,
1950                follower_connection_id,
1951                request.clone(),
1952            )?;
1953        }
1954    }
1955    Ok(())
1956}
1957
1958async fn get_users(
1959    request: proto::GetUsers,
1960    response: Response<proto::GetUsers>,
1961    session: Session,
1962) -> Result<()> {
1963    let user_ids = request
1964        .user_ids
1965        .into_iter()
1966        .map(UserId::from_proto)
1967        .collect();
1968    let users = session
1969        .db()
1970        .await
1971        .get_users_by_ids(user_ids)
1972        .await?
1973        .into_iter()
1974        .map(|user| proto::User {
1975            id: user.id.to_proto(),
1976            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1977            github_login: user.github_login,
1978        })
1979        .collect();
1980    response.send(proto::UsersResponse { users })?;
1981    Ok(())
1982}
1983
1984async fn fuzzy_search_users(
1985    request: proto::FuzzySearchUsers,
1986    response: Response<proto::FuzzySearchUsers>,
1987    session: Session,
1988) -> Result<()> {
1989    let query = request.query;
1990    let users = match query.len() {
1991        0 => vec![],
1992        1 | 2 => session
1993            .db()
1994            .await
1995            .get_user_by_github_login(&query)
1996            .await?
1997            .into_iter()
1998            .collect(),
1999        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2000    };
2001    let users = users
2002        .into_iter()
2003        .filter(|user| user.id != session.user_id)
2004        .map(|user| proto::User {
2005            id: user.id.to_proto(),
2006            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2007            github_login: user.github_login,
2008        })
2009        .collect();
2010    response.send(proto::UsersResponse { users })?;
2011    Ok(())
2012}
2013
2014async fn request_contact(
2015    request: proto::RequestContact,
2016    response: Response<proto::RequestContact>,
2017    session: Session,
2018) -> Result<()> {
2019    let requester_id = session.user_id;
2020    let responder_id = UserId::from_proto(request.responder_id);
2021    if requester_id == responder_id {
2022        return Err(anyhow!("cannot add yourself as a contact"))?;
2023    }
2024
2025    session
2026        .db()
2027        .await
2028        .send_contact_request(requester_id, responder_id)
2029        .await?;
2030
2031    // Update outgoing contact requests of requester
2032    let mut update = proto::UpdateContacts::default();
2033    update.outgoing_requests.push(responder_id.to_proto());
2034    for connection_id in session
2035        .connection_pool()
2036        .await
2037        .user_connection_ids(requester_id)
2038    {
2039        session.peer.send(connection_id, update.clone())?;
2040    }
2041
2042    // Update incoming contact requests of responder
2043    let mut update = proto::UpdateContacts::default();
2044    update
2045        .incoming_requests
2046        .push(proto::IncomingContactRequest {
2047            requester_id: requester_id.to_proto(),
2048            should_notify: true,
2049        });
2050    for connection_id in session
2051        .connection_pool()
2052        .await
2053        .user_connection_ids(responder_id)
2054    {
2055        session.peer.send(connection_id, update.clone())?;
2056    }
2057
2058    response.send(proto::Ack {})?;
2059    Ok(())
2060}
2061
2062async fn respond_to_contact_request(
2063    request: proto::RespondToContactRequest,
2064    response: Response<proto::RespondToContactRequest>,
2065    session: Session,
2066) -> Result<()> {
2067    let responder_id = session.user_id;
2068    let requester_id = UserId::from_proto(request.requester_id);
2069    let db = session.db().await;
2070    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2071        db.dismiss_contact_notification(responder_id, requester_id)
2072            .await?;
2073    } else {
2074        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2075
2076        db.respond_to_contact_request(responder_id, requester_id, accept)
2077            .await?;
2078        let requester_busy = db.is_user_busy(requester_id).await?;
2079        let responder_busy = db.is_user_busy(responder_id).await?;
2080
2081        let pool = session.connection_pool().await;
2082        // Update responder with new contact
2083        let mut update = proto::UpdateContacts::default();
2084        if accept {
2085            update
2086                .contacts
2087                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2088        }
2089        update
2090            .remove_incoming_requests
2091            .push(requester_id.to_proto());
2092        for connection_id in pool.user_connection_ids(responder_id) {
2093            session.peer.send(connection_id, update.clone())?;
2094        }
2095
2096        // Update requester with new contact
2097        let mut update = proto::UpdateContacts::default();
2098        if accept {
2099            update
2100                .contacts
2101                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2102        }
2103        update
2104            .remove_outgoing_requests
2105            .push(responder_id.to_proto());
2106        for connection_id in pool.user_connection_ids(requester_id) {
2107            session.peer.send(connection_id, update.clone())?;
2108        }
2109    }
2110
2111    response.send(proto::Ack {})?;
2112    Ok(())
2113}
2114
2115async fn remove_contact(
2116    request: proto::RemoveContact,
2117    response: Response<proto::RemoveContact>,
2118    session: Session,
2119) -> Result<()> {
2120    let requester_id = session.user_id;
2121    let responder_id = UserId::from_proto(request.user_id);
2122    let db = session.db().await;
2123    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2124
2125    let pool = session.connection_pool().await;
2126    // Update outgoing contact requests of requester
2127    let mut update = proto::UpdateContacts::default();
2128    if contact_accepted {
2129        update.remove_contacts.push(responder_id.to_proto());
2130    } else {
2131        update
2132            .remove_outgoing_requests
2133            .push(responder_id.to_proto());
2134    }
2135    for connection_id in pool.user_connection_ids(requester_id) {
2136        session.peer.send(connection_id, update.clone())?;
2137    }
2138
2139    // Update incoming contact requests of responder
2140    let mut update = proto::UpdateContacts::default();
2141    if contact_accepted {
2142        update.remove_contacts.push(requester_id.to_proto());
2143    } else {
2144        update
2145            .remove_incoming_requests
2146            .push(requester_id.to_proto());
2147    }
2148    for connection_id in pool.user_connection_ids(responder_id) {
2149        session.peer.send(connection_id, update.clone())?;
2150    }
2151
2152    response.send(proto::Ack {})?;
2153    Ok(())
2154}
2155
2156async fn create_channel(
2157    request: proto::CreateChannel,
2158    response: Response<proto::CreateChannel>,
2159    session: Session,
2160) -> Result<()> {
2161    let db = session.db().await;
2162    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2163
2164    if let Some(live_kit) = session.live_kit_client.as_ref() {
2165        live_kit.create_room(live_kit_room.clone()).await?;
2166    }
2167
2168    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2169    let id = db
2170        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2171        .await?;
2172
2173    let channel = proto::Channel {
2174        id: id.to_proto(),
2175        name: request.name,
2176        parent_id: request.parent_id,
2177    };
2178
2179    response.send(proto::ChannelResponse {
2180        channel: Some(channel.clone()),
2181    })?;
2182
2183    let mut update = proto::UpdateChannels::default();
2184    update.channels.push(channel);
2185
2186    let user_ids_to_notify = if let Some(parent_id) = parent_id {
2187        db.get_channel_members(parent_id).await?
2188    } else {
2189        vec![session.user_id]
2190    };
2191
2192    let connection_pool = session.connection_pool().await;
2193    for user_id in user_ids_to_notify {
2194        for connection_id in connection_pool.user_connection_ids(user_id) {
2195            let mut update = update.clone();
2196            if user_id == session.user_id {
2197                update.channel_permissions.push(proto::ChannelPermission {
2198                    channel_id: id.to_proto(),
2199                    is_admin: true,
2200                });
2201            }
2202            session.peer.send(connection_id, update)?;
2203        }
2204    }
2205
2206    Ok(())
2207}
2208
2209async fn remove_channel(
2210    request: proto::RemoveChannel,
2211    response: Response<proto::RemoveChannel>,
2212    session: Session,
2213) -> Result<()> {
2214    let db = session.db().await;
2215
2216    let channel_id = request.channel_id;
2217    let (removed_channels, member_ids) = db
2218        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2219        .await?;
2220    response.send(proto::Ack {})?;
2221
2222    // Notify members of removed channels
2223    let mut update = proto::UpdateChannels::default();
2224    update
2225        .remove_channels
2226        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2227
2228    let connection_pool = session.connection_pool().await;
2229    for member_id in member_ids {
2230        for connection_id in connection_pool.user_connection_ids(member_id) {
2231            session.peer.send(connection_id, update.clone())?;
2232        }
2233    }
2234
2235    Ok(())
2236}
2237
2238async fn invite_channel_member(
2239    request: proto::InviteChannelMember,
2240    response: Response<proto::InviteChannelMember>,
2241    session: Session,
2242) -> Result<()> {
2243    let db = session.db().await;
2244    let channel_id = ChannelId::from_proto(request.channel_id);
2245    let invitee_id = UserId::from_proto(request.user_id);
2246    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2247        .await?;
2248
2249    let (channel, _) = db
2250        .get_channel(channel_id, session.user_id)
2251        .await?
2252        .ok_or_else(|| anyhow!("channel not found"))?;
2253
2254    let mut update = proto::UpdateChannels::default();
2255    update.channel_invitations.push(proto::Channel {
2256        id: channel.id.to_proto(),
2257        name: channel.name,
2258        parent_id: None,
2259    });
2260    for connection_id in session
2261        .connection_pool()
2262        .await
2263        .user_connection_ids(invitee_id)
2264    {
2265        session.peer.send(connection_id, update.clone())?;
2266    }
2267
2268    response.send(proto::Ack {})?;
2269    Ok(())
2270}
2271
2272async fn remove_channel_member(
2273    request: proto::RemoveChannelMember,
2274    response: Response<proto::RemoveChannelMember>,
2275    session: Session,
2276) -> Result<()> {
2277    let db = session.db().await;
2278    let channel_id = ChannelId::from_proto(request.channel_id);
2279    let member_id = UserId::from_proto(request.user_id);
2280
2281    db.remove_channel_member(channel_id, member_id, session.user_id)
2282        .await?;
2283
2284    let mut update = proto::UpdateChannels::default();
2285    update.remove_channels.push(channel_id.to_proto());
2286
2287    for connection_id in session
2288        .connection_pool()
2289        .await
2290        .user_connection_ids(member_id)
2291    {
2292        session.peer.send(connection_id, update.clone())?;
2293    }
2294
2295    response.send(proto::Ack {})?;
2296    Ok(())
2297}
2298
2299async fn set_channel_member_admin(
2300    request: proto::SetChannelMemberAdmin,
2301    response: Response<proto::SetChannelMemberAdmin>,
2302    session: Session,
2303) -> Result<()> {
2304    let db = session.db().await;
2305    let channel_id = ChannelId::from_proto(request.channel_id);
2306    let member_id = UserId::from_proto(request.user_id);
2307    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2308        .await?;
2309
2310    let (channel, has_accepted) = db
2311        .get_channel(channel_id, member_id)
2312        .await?
2313        .ok_or_else(|| anyhow!("channel not found"))?;
2314
2315    let mut update = proto::UpdateChannels::default();
2316    if has_accepted {
2317        update.channel_permissions.push(proto::ChannelPermission {
2318            channel_id: channel.id.to_proto(),
2319            is_admin: request.admin,
2320        });
2321    }
2322
2323    for connection_id in session
2324        .connection_pool()
2325        .await
2326        .user_connection_ids(member_id)
2327    {
2328        session.peer.send(connection_id, update.clone())?;
2329    }
2330
2331    response.send(proto::Ack {})?;
2332    Ok(())
2333}
2334
2335async fn rename_channel(
2336    request: proto::RenameChannel,
2337    response: Response<proto::RenameChannel>,
2338    session: Session,
2339) -> Result<()> {
2340    let db = session.db().await;
2341    let channel_id = ChannelId::from_proto(request.channel_id);
2342    let new_name = db
2343        .rename_channel(channel_id, session.user_id, &request.name)
2344        .await?;
2345
2346    let channel = proto::Channel {
2347        id: request.channel_id,
2348        name: new_name,
2349        parent_id: None,
2350    };
2351    response.send(proto::ChannelResponse {
2352        channel: Some(channel.clone()),
2353    })?;
2354    let mut update = proto::UpdateChannels::default();
2355    update.channels.push(channel);
2356
2357    let member_ids = db.get_channel_members(channel_id).await?;
2358
2359    let connection_pool = session.connection_pool().await;
2360    for member_id in member_ids {
2361        for connection_id in connection_pool.user_connection_ids(member_id) {
2362            session.peer.send(connection_id, update.clone())?;
2363        }
2364    }
2365
2366    Ok(())
2367}
2368
2369async fn get_channel_members(
2370    request: proto::GetChannelMembers,
2371    response: Response<proto::GetChannelMembers>,
2372    session: Session,
2373) -> Result<()> {
2374    let db = session.db().await;
2375    let channel_id = ChannelId::from_proto(request.channel_id);
2376    let members = db
2377        .get_channel_member_details(channel_id, session.user_id)
2378        .await?;
2379    response.send(proto::GetChannelMembersResponse { members })?;
2380    Ok(())
2381}
2382
2383async fn respond_to_channel_invite(
2384    request: proto::RespondToChannelInvite,
2385    response: Response<proto::RespondToChannelInvite>,
2386    session: Session,
2387) -> Result<()> {
2388    let db = session.db().await;
2389    let channel_id = ChannelId::from_proto(request.channel_id);
2390    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2391        .await?;
2392
2393    let mut update = proto::UpdateChannels::default();
2394    update
2395        .remove_channel_invitations
2396        .push(channel_id.to_proto());
2397    if request.accept {
2398        let result = db.get_channels_for_user(session.user_id).await?;
2399        update
2400            .channels
2401            .extend(result.channels.into_iter().map(|channel| proto::Channel {
2402                id: channel.id.to_proto(),
2403                name: channel.name,
2404                parent_id: channel.parent_id.map(ChannelId::to_proto),
2405            }));
2406        update
2407            .channel_participants
2408            .extend(
2409                result
2410                    .channel_participants
2411                    .into_iter()
2412                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2413                        channel_id: channel_id.to_proto(),
2414                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2415                    }),
2416            );
2417        update
2418            .channel_permissions
2419            .extend(
2420                result
2421                    .channels_with_admin_privileges
2422                    .into_iter()
2423                    .map(|channel_id| proto::ChannelPermission {
2424                        channel_id: channel_id.to_proto(),
2425                        is_admin: true,
2426                    }),
2427            );
2428    }
2429    session.peer.send(session.connection_id, update)?;
2430    response.send(proto::Ack {})?;
2431
2432    Ok(())
2433}
2434
2435async fn join_channel(
2436    request: proto::JoinChannel,
2437    response: Response<proto::JoinChannel>,
2438    session: Session,
2439) -> Result<()> {
2440    let channel_id = ChannelId::from_proto(request.channel_id);
2441
2442    let joined_room = {
2443        leave_room_for_session(&session).await?;
2444        let db = session.db().await;
2445
2446        let room_id = db.room_id_for_channel(channel_id).await?;
2447
2448        let joined_room = db
2449            .join_room(room_id, session.user_id, session.connection_id)
2450            .await?;
2451
2452        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2453            let token = live_kit
2454                .room_token(
2455                    &joined_room.room.live_kit_room,
2456                    &session.user_id.to_string(),
2457                )
2458                .trace_err()?;
2459
2460            Some(LiveKitConnectionInfo {
2461                server_url: live_kit.url().into(),
2462                token,
2463            })
2464        });
2465
2466        response.send(proto::JoinRoomResponse {
2467            room: Some(joined_room.room.clone()),
2468            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2469            live_kit_connection_info,
2470        })?;
2471
2472        room_updated(&joined_room.room, &session.peer);
2473
2474        joined_room.into_inner()
2475    };
2476
2477    channel_updated(
2478        channel_id,
2479        &joined_room.room,
2480        &joined_room.channel_members,
2481        &session.peer,
2482        &*session.connection_pool().await,
2483    );
2484
2485    update_user_contacts(session.user_id, &session).await?;
2486
2487    Ok(())
2488}
2489
2490async fn join_channel_buffer(
2491    request: proto::JoinChannelBuffer,
2492    response: Response<proto::JoinChannelBuffer>,
2493    session: Session,
2494) -> Result<()> {
2495    let db = session.db().await;
2496    let channel_id = ChannelId::from_proto(request.channel_id);
2497
2498    let open_response = db
2499        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2500        .await?;
2501
2502    let replica_id = open_response.replica_id;
2503    let collaborators = open_response.collaborators.clone();
2504
2505    response.send(open_response)?;
2506
2507    let update = AddChannelBufferCollaborator {
2508        channel_id: channel_id.to_proto(),
2509        collaborator: Some(proto::Collaborator {
2510            user_id: session.user_id.to_proto(),
2511            peer_id: Some(session.connection_id.into()),
2512            replica_id,
2513        }),
2514    };
2515    channel_buffer_updated(
2516        session.connection_id,
2517        collaborators
2518            .iter()
2519            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2520        &update,
2521        &session.peer,
2522    );
2523
2524    Ok(())
2525}
2526
2527async fn update_channel_buffer(
2528    request: proto::UpdateChannelBuffer,
2529    session: Session,
2530) -> Result<()> {
2531    let db = session.db().await;
2532    let channel_id = ChannelId::from_proto(request.channel_id);
2533
2534    let collaborators = db
2535        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2536        .await?;
2537
2538    channel_buffer_updated(
2539        session.connection_id,
2540        collaborators,
2541        &proto::UpdateChannelBuffer {
2542            channel_id: channel_id.to_proto(),
2543            operations: request.operations,
2544        },
2545        &session.peer,
2546    );
2547    Ok(())
2548}
2549
2550async fn leave_channel_buffer(
2551    request: proto::LeaveChannelBuffer,
2552    response: Response<proto::LeaveChannelBuffer>,
2553    session: Session,
2554) -> Result<()> {
2555    let db = session.db().await;
2556    let channel_id = ChannelId::from_proto(request.channel_id);
2557
2558    let collaborators_to_notify = db
2559        .leave_channel_buffer(channel_id, session.connection_id)
2560        .await?;
2561
2562    response.send(Ack {})?;
2563
2564    channel_buffer_updated(
2565        session.connection_id,
2566        collaborators_to_notify,
2567        &proto::RemoveChannelBufferCollaborator {
2568            channel_id: channel_id.to_proto(),
2569            peer_id: Some(session.connection_id.into()),
2570        },
2571        &session.peer,
2572    );
2573
2574    Ok(())
2575}
2576
2577fn channel_buffer_updated<T: EnvelopedMessage>(
2578    sender_id: ConnectionId,
2579    collaborators: impl IntoIterator<Item = ConnectionId>,
2580    message: &T,
2581    peer: &Peer,
2582) {
2583    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2584        peer.send(peer_id.into(), message.clone())
2585    });
2586}
2587
2588async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2589    let project_id = ProjectId::from_proto(request.project_id);
2590    let project_connection_ids = session
2591        .db()
2592        .await
2593        .project_connection_ids(project_id, session.connection_id)
2594        .await?;
2595    broadcast(
2596        Some(session.connection_id),
2597        project_connection_ids.iter().copied(),
2598        |connection_id| {
2599            session
2600                .peer
2601                .forward_send(session.connection_id, connection_id, request.clone())
2602        },
2603    );
2604    Ok(())
2605}
2606
2607async fn get_private_user_info(
2608    _request: proto::GetPrivateUserInfo,
2609    response: Response<proto::GetPrivateUserInfo>,
2610    session: Session,
2611) -> Result<()> {
2612    let db = session.db().await;
2613
2614    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2615    let user = db
2616        .get_user_by_id(session.user_id)
2617        .await?
2618        .ok_or_else(|| anyhow!("user not found"))?;
2619    let flags = db.get_user_flags(session.user_id).await?;
2620
2621    response.send(proto::GetPrivateUserInfoResponse {
2622        metrics_id,
2623        staff: user.admin,
2624        flags,
2625    })?;
2626    Ok(())
2627}
2628
2629fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2630    match message {
2631        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2632        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2633        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2634        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2635        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2636            code: frame.code.into(),
2637            reason: frame.reason,
2638        })),
2639    }
2640}
2641
2642fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2643    match message {
2644        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2645        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2646        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2647        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2648        AxumMessage::Close(frame) => {
2649            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2650                code: frame.code.into(),
2651                reason: frame.reason,
2652            }))
2653        }
2654    }
2655}
2656
2657fn build_initial_channels_update(
2658    channels: ChannelsForUser,
2659    channel_invites: Vec<db::Channel>,
2660) -> proto::UpdateChannels {
2661    let mut update = proto::UpdateChannels::default();
2662
2663    for channel in channels.channels {
2664        update.channels.push(proto::Channel {
2665            id: channel.id.to_proto(),
2666            name: channel.name,
2667            parent_id: channel.parent_id.map(|id| id.to_proto()),
2668        });
2669    }
2670
2671    for (channel_id, participants) in channels.channel_participants {
2672        update
2673            .channel_participants
2674            .push(proto::ChannelParticipants {
2675                channel_id: channel_id.to_proto(),
2676                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2677            });
2678    }
2679
2680    update
2681        .channel_permissions
2682        .extend(
2683            channels
2684                .channels_with_admin_privileges
2685                .into_iter()
2686                .map(|id| proto::ChannelPermission {
2687                    channel_id: id.to_proto(),
2688                    is_admin: true,
2689                }),
2690        );
2691
2692    for channel in channel_invites {
2693        update.channel_invitations.push(proto::Channel {
2694            id: channel.id.to_proto(),
2695            name: channel.name,
2696            parent_id: None,
2697        });
2698    }
2699
2700    update
2701}
2702
2703fn build_initial_contacts_update(
2704    contacts: Vec<db::Contact>,
2705    pool: &ConnectionPool,
2706) -> proto::UpdateContacts {
2707    let mut update = proto::UpdateContacts::default();
2708
2709    for contact in contacts {
2710        match contact {
2711            db::Contact::Accepted {
2712                user_id,
2713                should_notify,
2714                busy,
2715            } => {
2716                update
2717                    .contacts
2718                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2719            }
2720            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2721            db::Contact::Incoming {
2722                user_id,
2723                should_notify,
2724            } => update
2725                .incoming_requests
2726                .push(proto::IncomingContactRequest {
2727                    requester_id: user_id.to_proto(),
2728                    should_notify,
2729                }),
2730        }
2731    }
2732
2733    update
2734}
2735
2736fn contact_for_user(
2737    user_id: UserId,
2738    should_notify: bool,
2739    busy: bool,
2740    pool: &ConnectionPool,
2741) -> proto::Contact {
2742    proto::Contact {
2743        user_id: user_id.to_proto(),
2744        online: pool.is_user_online(user_id),
2745        busy,
2746        should_notify,
2747    }
2748}
2749
2750fn room_updated(room: &proto::Room, peer: &Peer) {
2751    broadcast(
2752        None,
2753        room.participants
2754            .iter()
2755            .filter_map(|participant| Some(participant.peer_id?.into())),
2756        |peer_id| {
2757            peer.send(
2758                peer_id.into(),
2759                proto::RoomUpdated {
2760                    room: Some(room.clone()),
2761                },
2762            )
2763        },
2764    );
2765}
2766
2767fn channel_updated(
2768    channel_id: ChannelId,
2769    room: &proto::Room,
2770    channel_members: &[UserId],
2771    peer: &Peer,
2772    pool: &ConnectionPool,
2773) {
2774    let participants = room
2775        .participants
2776        .iter()
2777        .map(|p| p.user_id)
2778        .collect::<Vec<_>>();
2779
2780    broadcast(
2781        None,
2782        channel_members
2783            .iter()
2784            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2785        |peer_id| {
2786            peer.send(
2787                peer_id.into(),
2788                proto::UpdateChannels {
2789                    channel_participants: vec![proto::ChannelParticipants {
2790                        channel_id: channel_id.to_proto(),
2791                        participant_user_ids: participants.clone(),
2792                    }],
2793                    ..Default::default()
2794                },
2795            )
2796        },
2797    );
2798}
2799
2800async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2801    let db = session.db().await;
2802
2803    let contacts = db.get_contacts(user_id).await?;
2804    let busy = db.is_user_busy(user_id).await?;
2805
2806    let pool = session.connection_pool().await;
2807    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2808    for contact in contacts {
2809        if let db::Contact::Accepted {
2810            user_id: contact_user_id,
2811            ..
2812        } = contact
2813        {
2814            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2815                session
2816                    .peer
2817                    .send(
2818                        contact_conn_id,
2819                        proto::UpdateContacts {
2820                            contacts: vec![updated_contact.clone()],
2821                            remove_contacts: Default::default(),
2822                            incoming_requests: Default::default(),
2823                            remove_incoming_requests: Default::default(),
2824                            outgoing_requests: Default::default(),
2825                            remove_outgoing_requests: Default::default(),
2826                        },
2827                    )
2828                    .trace_err();
2829            }
2830        }
2831    }
2832    Ok(())
2833}
2834
2835async fn leave_room_for_session(session: &Session) -> Result<()> {
2836    let mut contacts_to_update = HashSet::default();
2837
2838    let room_id;
2839    let canceled_calls_to_user_ids;
2840    let live_kit_room;
2841    let delete_live_kit_room;
2842    let room;
2843    let channel_members;
2844    let channel_id;
2845
2846    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2847        contacts_to_update.insert(session.user_id);
2848
2849        for project in left_room.left_projects.values() {
2850            project_left(project, session);
2851        }
2852
2853        room_id = RoomId::from_proto(left_room.room.id);
2854        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2855        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2856        delete_live_kit_room = left_room.deleted;
2857        room = mem::take(&mut left_room.room);
2858        channel_members = mem::take(&mut left_room.channel_members);
2859        channel_id = left_room.channel_id;
2860
2861        room_updated(&room, &session.peer);
2862    } else {
2863        return Ok(());
2864    }
2865
2866    if let Some(channel_id) = channel_id {
2867        channel_updated(
2868            channel_id,
2869            &room,
2870            &channel_members,
2871            &session.peer,
2872            &*session.connection_pool().await,
2873        );
2874    }
2875
2876    {
2877        let pool = session.connection_pool().await;
2878        for canceled_user_id in canceled_calls_to_user_ids {
2879            for connection_id in pool.user_connection_ids(canceled_user_id) {
2880                session
2881                    .peer
2882                    .send(
2883                        connection_id,
2884                        proto::CallCanceled {
2885                            room_id: room_id.to_proto(),
2886                        },
2887                    )
2888                    .trace_err();
2889            }
2890            contacts_to_update.insert(canceled_user_id);
2891        }
2892    }
2893
2894    for contact_user_id in contacts_to_update {
2895        update_user_contacts(contact_user_id, &session).await?;
2896    }
2897
2898    if let Some(live_kit) = session.live_kit_client.as_ref() {
2899        live_kit
2900            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2901            .await
2902            .trace_err();
2903
2904        if delete_live_kit_room {
2905            live_kit.delete_room(live_kit_room).await.trace_err();
2906        }
2907    }
2908
2909    Ok(())
2910}
2911
2912async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
2913    let left_channel_buffers = session
2914        .db()
2915        .await
2916        .leave_channel_buffers(session.connection_id)
2917        .await?;
2918
2919    for (channel_id, connections) in left_channel_buffers {
2920        channel_buffer_updated(
2921            session.connection_id,
2922            connections,
2923            &proto::RemoveChannelBufferCollaborator {
2924                channel_id: channel_id.to_proto(),
2925                peer_id: Some(session.connection_id.into()),
2926            },
2927            &session.peer,
2928        );
2929    }
2930
2931    Ok(())
2932}
2933
2934fn project_left(project: &db::LeftProject, session: &Session) {
2935    for connection_id in &project.connection_ids {
2936        if project.host_user_id == session.user_id {
2937            session
2938                .peer
2939                .send(
2940                    *connection_id,
2941                    proto::UnshareProject {
2942                        project_id: project.id.to_proto(),
2943                    },
2944                )
2945                .trace_err();
2946        } else {
2947            session
2948                .peer
2949                .send(
2950                    *connection_id,
2951                    proto::RemoveProjectCollaborator {
2952                        project_id: project.id.to_proto(),
2953                        peer_id: Some(session.connection_id.into()),
2954                    },
2955                )
2956                .trace_err();
2957        }
2958    }
2959}
2960
2961pub trait ResultExt {
2962    type Ok;
2963
2964    fn trace_err(self) -> Option<Self::Ok>;
2965}
2966
2967impl<T, E> ResultExt for Result<T, E>
2968where
2969    E: std::fmt::Debug,
2970{
2971    type Ok = T;
2972
2973    fn trace_err(self) -> Option<T> {
2974        match self {
2975            Ok(value) => Some(value),
2976            Err(error) => {
2977                tracing::error!("{:?}", error);
2978                None
2979            }
2980        }
2981    }
2982}