rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{
  38        self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, GetChannelBufferResponse,
  39        LiveKitConnectionInfo, RequestMessage,
  40    },
  41    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  42};
  43use serde::{Serialize, Serializer};
  44use std::{
  45    any::TypeId,
  46    fmt,
  47    future::Future,
  48    marker::PhantomData,
  49    mem,
  50    net::SocketAddr,
  51    ops::{Deref, DerefMut},
  52    rc::Rc,
  53    sync::{
  54        atomic::{AtomicBool, Ordering::SeqCst},
  55        Arc,
  56    },
  57    time::{Duration, Instant},
  58};
  59use tokio::sync::{watch, Semaphore};
  60use tower::ServiceBuilder;
  61use tracing::{info_span, instrument, Instrument};
  62
  63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  65
  66lazy_static! {
  67    static ref METRIC_CONNECTIONS: IntGauge =
  68        register_int_gauge!("connections", "number of connections").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<()>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(()).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_message_handler(refresh_inlay_hints)
 208            .add_request_handler(forward_project_request::<proto::GetHover>)
 209            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 210            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 211            .add_request_handler(forward_project_request::<proto::GetReferences>)
 212            .add_request_handler(forward_project_request::<proto::SearchProject>)
 213            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 214            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 215            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 216            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 217            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 218            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 219            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 220            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 221            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 222            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 223            .add_request_handler(forward_project_request::<proto::PerformRename>)
 224            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 225            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 226            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 227            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 230            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 231            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 232            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 233            .add_request_handler(forward_project_request::<proto::InlayHints>)
 234            .add_message_handler(create_buffer_for_peer)
 235            .add_request_handler(update_buffer)
 236            .add_message_handler(update_buffer_file)
 237            .add_message_handler(buffer_reloaded)
 238            .add_message_handler(buffer_saved)
 239            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 240            .add_request_handler(get_users)
 241            .add_request_handler(fuzzy_search_users)
 242            .add_request_handler(request_contact)
 243            .add_request_handler(remove_contact)
 244            .add_request_handler(respond_to_contact_request)
 245            .add_request_handler(create_channel)
 246            .add_request_handler(remove_channel)
 247            .add_request_handler(invite_channel_member)
 248            .add_request_handler(remove_channel_member)
 249            .add_request_handler(set_channel_member_admin)
 250            .add_request_handler(rename_channel)
 251            .add_request_handler(get_channel_buffer)
 252            .add_request_handler(get_channel_members)
 253            .add_request_handler(respond_to_channel_invite)
 254            .add_request_handler(join_channel)
 255            .add_request_handler(follow)
 256            .add_message_handler(unfollow)
 257            .add_message_handler(update_followers)
 258            .add_message_handler(update_diff_base)
 259            .add_request_handler(get_private_user_info);
 260
 261        Arc::new(server)
 262    }
 263
 264    pub async fn start(&self) -> Result<()> {
 265        let server_id = *self.id.lock();
 266        let app_state = self.app_state.clone();
 267        let peer = self.peer.clone();
 268        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 269        let pool = self.connection_pool.clone();
 270        let live_kit_client = self.app_state.live_kit_client.clone();
 271
 272        let span = info_span!("start server");
 273        self.executor.spawn_detached(
 274            async move {
 275                tracing::info!("waiting for cleanup timeout");
 276                timeout.await;
 277                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 278                if let Some(room_ids) = app_state
 279                    .db
 280                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 281                    .await
 282                    .trace_err()
 283                {
 284                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 285                    for room_id in room_ids {
 286                        let mut contacts_to_update = HashSet::default();
 287                        let mut canceled_calls_to_user_ids = Vec::new();
 288                        let mut live_kit_room = String::new();
 289                        let mut delete_live_kit_room = false;
 290
 291                        if let Some(mut refreshed_room) = app_state
 292                            .db
 293                            .refresh_room(room_id, server_id)
 294                            .await
 295                            .trace_err()
 296                        {
 297                            tracing::info!(
 298                                room_id = room_id.0,
 299                                new_participant_count = refreshed_room.room.participants.len(),
 300                                "refreshed room"
 301                            );
 302                            room_updated(&refreshed_room.room, &peer);
 303                            if let Some(channel_id) = refreshed_room.channel_id {
 304                                channel_updated(
 305                                    channel_id,
 306                                    &refreshed_room.room,
 307                                    &refreshed_room.channel_members,
 308                                    &peer,
 309                                    &*pool.lock(),
 310                                );
 311                            }
 312                            contacts_to_update
 313                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 314                            contacts_to_update
 315                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 316                            canceled_calls_to_user_ids =
 317                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 318                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 319                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 320                        }
 321
 322                        {
 323                            let pool = pool.lock();
 324                            for canceled_user_id in canceled_calls_to_user_ids {
 325                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 326                                    peer.send(
 327                                        connection_id,
 328                                        proto::CallCanceled {
 329                                            room_id: room_id.to_proto(),
 330                                        },
 331                                    )
 332                                    .trace_err();
 333                                }
 334                            }
 335                        }
 336
 337                        for user_id in contacts_to_update {
 338                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 339                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 340                            if let Some((busy, contacts)) = busy.zip(contacts) {
 341                                let pool = pool.lock();
 342                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 343                                for contact in contacts {
 344                                    if let db::Contact::Accepted {
 345                                        user_id: contact_user_id,
 346                                        ..
 347                                    } = contact
 348                                    {
 349                                        for contact_conn_id in
 350                                            pool.user_connection_ids(contact_user_id)
 351                                        {
 352                                            peer.send(
 353                                                contact_conn_id,
 354                                                proto::UpdateContacts {
 355                                                    contacts: vec![updated_contact.clone()],
 356                                                    remove_contacts: Default::default(),
 357                                                    incoming_requests: Default::default(),
 358                                                    remove_incoming_requests: Default::default(),
 359                                                    outgoing_requests: Default::default(),
 360                                                    remove_outgoing_requests: Default::default(),
 361                                                },
 362                                            )
 363                                            .trace_err();
 364                                        }
 365                                    }
 366                                }
 367                            }
 368                        }
 369
 370                        if let Some(live_kit) = live_kit_client.as_ref() {
 371                            if delete_live_kit_room {
 372                                live_kit.delete_room(live_kit_room).await.trace_err();
 373                            }
 374                        }
 375                    }
 376                }
 377
 378                app_state
 379                    .db
 380                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 381                    .await
 382                    .trace_err();
 383            }
 384            .instrument(span),
 385        );
 386        Ok(())
 387    }
 388
 389    pub fn teardown(&self) {
 390        self.peer.teardown();
 391        self.connection_pool.lock().reset();
 392        let _ = self.teardown.send(());
 393    }
 394
 395    #[cfg(test)]
 396    pub fn reset(&self, id: ServerId) {
 397        self.teardown();
 398        *self.id.lock() = id;
 399        self.peer.reset(id.0 as u32);
 400    }
 401
 402    #[cfg(test)]
 403    pub fn id(&self) -> ServerId {
 404        *self.id.lock()
 405    }
 406
 407    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 408    where
 409        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 410        Fut: 'static + Send + Future<Output = Result<()>>,
 411        M: EnvelopedMessage,
 412    {
 413        let prev_handler = self.handlers.insert(
 414            TypeId::of::<M>(),
 415            Box::new(move |envelope, session| {
 416                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 417                let span = info_span!(
 418                    "handle message",
 419                    payload_type = envelope.payload_type_name()
 420                );
 421                span.in_scope(|| {
 422                    tracing::info!(
 423                        payload_type = envelope.payload_type_name(),
 424                        "message received"
 425                    );
 426                });
 427                let start_time = Instant::now();
 428                let future = (handler)(*envelope, session);
 429                async move {
 430                    let result = future.await;
 431                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 432                    match result {
 433                        Err(error) => {
 434                            tracing::error!(%error, ?duration_ms, "error handling message")
 435                        }
 436                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 437                    }
 438                }
 439                .instrument(span)
 440                .boxed()
 441            }),
 442        );
 443        if prev_handler.is_some() {
 444            panic!("registered a handler for the same message twice");
 445        }
 446        self
 447    }
 448
 449    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 450    where
 451        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 452        Fut: 'static + Send + Future<Output = Result<()>>,
 453        M: EnvelopedMessage,
 454    {
 455        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 456        self
 457    }
 458
 459    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 460    where
 461        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 462        Fut: Send + Future<Output = Result<()>>,
 463        M: RequestMessage,
 464    {
 465        let handler = Arc::new(handler);
 466        self.add_handler(move |envelope, session| {
 467            let receipt = envelope.receipt();
 468            let handler = handler.clone();
 469            async move {
 470                let peer = session.peer.clone();
 471                let responded = Arc::new(AtomicBool::default());
 472                let response = Response {
 473                    peer: peer.clone(),
 474                    responded: responded.clone(),
 475                    receipt,
 476                };
 477                match (handler)(envelope.payload, response, session).await {
 478                    Ok(()) => {
 479                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 480                            Ok(())
 481                        } else {
 482                            Err(anyhow!("handler did not send a response"))?
 483                        }
 484                    }
 485                    Err(error) => {
 486                        peer.respond_with_error(
 487                            receipt,
 488                            proto::Error {
 489                                message: error.to_string(),
 490                            },
 491                        )?;
 492                        Err(error)
 493                    }
 494                }
 495            }
 496        })
 497    }
 498
 499    pub fn handle_connection(
 500        self: &Arc<Self>,
 501        connection: Connection,
 502        address: String,
 503        user: User,
 504        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 505        executor: Executor,
 506    ) -> impl Future<Output = Result<()>> {
 507        let this = self.clone();
 508        let user_id = user.id;
 509        let login = user.github_login;
 510        let span = info_span!("handle connection", %user_id, %login, %address);
 511        let mut teardown = self.teardown.subscribe();
 512        async move {
 513            let (connection_id, handle_io, mut incoming_rx) = this
 514                .peer
 515                .add_connection(connection, {
 516                    let executor = executor.clone();
 517                    move |duration| executor.sleep(duration)
 518                });
 519
 520            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 521            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 522            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 523
 524            if let Some(send_connection_id) = send_connection_id.take() {
 525                let _ = send_connection_id.send(connection_id);
 526            }
 527
 528            if !user.connected_once {
 529                this.peer.send(connection_id, proto::ShowContacts {})?;
 530                this.app_state.db.set_user_connected_once(user_id, true).await?;
 531            }
 532
 533            let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
 534                this.app_state.db.get_contacts(user_id),
 535                this.app_state.db.get_invite_code_for_user(user_id),
 536                this.app_state.db.get_channels_for_user(user_id),
 537                this.app_state.db.get_channel_invites_for_user(user_id)
 538            ).await?;
 539
 540            {
 541                let mut pool = this.connection_pool.lock();
 542                pool.add_connection(connection_id, user_id, user.admin);
 543                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 544                this.peer.send(connection_id, build_initial_channels_update(
 545                    channels_for_user,
 546                    channel_invites
 547                ))?;
 548
 549                if let Some((code, count)) = invite_code {
 550                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 551                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 552                        count: count as u32,
 553                    })?;
 554                }
 555            }
 556
 557            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 558                this.peer.send(connection_id, incoming_call)?;
 559            }
 560
 561            let session = Session {
 562                user_id,
 563                connection_id,
 564                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 565                peer: this.peer.clone(),
 566                connection_pool: this.connection_pool.clone(),
 567                live_kit_client: this.app_state.live_kit_client.clone(),
 568                executor: executor.clone(),
 569            };
 570            update_user_contacts(user_id, &session).await?;
 571
 572            let handle_io = handle_io.fuse();
 573            futures::pin_mut!(handle_io);
 574
 575            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 576            // This prevents deadlocks when e.g., client A performs a request to client B and
 577            // client B performs a request to client A. If both clients stop processing further
 578            // messages until their respective request completes, they won't have a chance to
 579            // respond to the other client's request and cause a deadlock.
 580            //
 581            // This arrangement ensures we will attempt to process earlier messages first, but fall
 582            // back to processing messages arrived later in the spirit of making progress.
 583            let mut foreground_message_handlers = FuturesUnordered::new();
 584            let concurrent_handlers = Arc::new(Semaphore::new(256));
 585            loop {
 586                let next_message = async {
 587                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 588                    let message = incoming_rx.next().await;
 589                    (permit, message)
 590                }.fuse();
 591                futures::pin_mut!(next_message);
 592                futures::select_biased! {
 593                    _ = teardown.changed().fuse() => return Ok(()),
 594                    result = handle_io => {
 595                        if let Err(error) = result {
 596                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 597                        }
 598                        break;
 599                    }
 600                    _ = foreground_message_handlers.next() => {}
 601                    next_message = next_message => {
 602                        let (permit, message) = next_message;
 603                        if let Some(message) = message {
 604                            let type_name = message.payload_type_name();
 605                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 606                            let span_enter = span.enter();
 607                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 608                                let is_background = message.is_background();
 609                                let handle_message = (handler)(message, session.clone());
 610                                drop(span_enter);
 611
 612                                let handle_message = async move {
 613                                    handle_message.await;
 614                                    drop(permit);
 615                                }.instrument(span);
 616                                if is_background {
 617                                    executor.spawn_detached(handle_message);
 618                                } else {
 619                                    foreground_message_handlers.push(handle_message);
 620                                }
 621                            } else {
 622                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 623                            }
 624                        } else {
 625                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 626                            break;
 627                        }
 628                    }
 629                }
 630            }
 631
 632            drop(foreground_message_handlers);
 633            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 634            if let Err(error) = connection_lost(session, teardown, executor).await {
 635                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 636            }
 637
 638            Ok(())
 639        }.instrument(span)
 640    }
 641
 642    pub async fn invite_code_redeemed(
 643        self: &Arc<Self>,
 644        inviter_id: UserId,
 645        invitee_id: UserId,
 646    ) -> Result<()> {
 647        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 648            if let Some(code) = &user.invite_code {
 649                let pool = self.connection_pool.lock();
 650                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 651                for connection_id in pool.user_connection_ids(inviter_id) {
 652                    self.peer.send(
 653                        connection_id,
 654                        proto::UpdateContacts {
 655                            contacts: vec![invitee_contact.clone()],
 656                            ..Default::default()
 657                        },
 658                    )?;
 659                    self.peer.send(
 660                        connection_id,
 661                        proto::UpdateInviteInfo {
 662                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 663                            count: user.invite_count as u32,
 664                        },
 665                    )?;
 666                }
 667            }
 668        }
 669        Ok(())
 670    }
 671
 672    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 673        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 674            if let Some(invite_code) = &user.invite_code {
 675                let pool = self.connection_pool.lock();
 676                for connection_id in pool.user_connection_ids(user_id) {
 677                    self.peer.send(
 678                        connection_id,
 679                        proto::UpdateInviteInfo {
 680                            url: format!(
 681                                "{}{}",
 682                                self.app_state.config.invite_link_prefix, invite_code
 683                            ),
 684                            count: user.invite_count as u32,
 685                        },
 686                    )?;
 687                }
 688            }
 689        }
 690        Ok(())
 691    }
 692
 693    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 694        ServerSnapshot {
 695            connection_pool: ConnectionPoolGuard {
 696                guard: self.connection_pool.lock(),
 697                _not_send: PhantomData,
 698            },
 699            peer: &self.peer,
 700        }
 701    }
 702}
 703
 704impl<'a> Deref for ConnectionPoolGuard<'a> {
 705    type Target = ConnectionPool;
 706
 707    fn deref(&self) -> &Self::Target {
 708        &*self.guard
 709    }
 710}
 711
 712impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 713    fn deref_mut(&mut self) -> &mut Self::Target {
 714        &mut *self.guard
 715    }
 716}
 717
 718impl<'a> Drop for ConnectionPoolGuard<'a> {
 719    fn drop(&mut self) {
 720        #[cfg(test)]
 721        self.check_invariants();
 722    }
 723}
 724
 725fn broadcast<F>(
 726    sender_id: Option<ConnectionId>,
 727    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 728    mut f: F,
 729) where
 730    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 731{
 732    for receiver_id in receiver_ids {
 733        if Some(receiver_id) != sender_id {
 734            if let Err(error) = f(receiver_id) {
 735                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 736            }
 737        }
 738    }
 739}
 740
 741lazy_static! {
 742    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 743}
 744
 745pub struct ProtocolVersion(u32);
 746
 747impl Header for ProtocolVersion {
 748    fn name() -> &'static HeaderName {
 749        &ZED_PROTOCOL_VERSION
 750    }
 751
 752    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 753    where
 754        Self: Sized,
 755        I: Iterator<Item = &'i axum::http::HeaderValue>,
 756    {
 757        let version = values
 758            .next()
 759            .ok_or_else(axum::headers::Error::invalid)?
 760            .to_str()
 761            .map_err(|_| axum::headers::Error::invalid())?
 762            .parse()
 763            .map_err(|_| axum::headers::Error::invalid())?;
 764        Ok(Self(version))
 765    }
 766
 767    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 768        values.extend([self.0.to_string().parse().unwrap()]);
 769    }
 770}
 771
 772pub fn routes(server: Arc<Server>) -> Router<Body> {
 773    Router::new()
 774        .route("/rpc", get(handle_websocket_request))
 775        .layer(
 776            ServiceBuilder::new()
 777                .layer(Extension(server.app_state.clone()))
 778                .layer(middleware::from_fn(auth::validate_header)),
 779        )
 780        .route("/metrics", get(handle_metrics))
 781        .layer(Extension(server))
 782}
 783
 784pub async fn handle_websocket_request(
 785    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 786    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 787    Extension(server): Extension<Arc<Server>>,
 788    Extension(user): Extension<User>,
 789    ws: WebSocketUpgrade,
 790) -> axum::response::Response {
 791    if protocol_version != rpc::PROTOCOL_VERSION {
 792        return (
 793            StatusCode::UPGRADE_REQUIRED,
 794            "client must be upgraded".to_string(),
 795        )
 796            .into_response();
 797    }
 798    let socket_address = socket_address.to_string();
 799    ws.on_upgrade(move |socket| {
 800        use util::ResultExt;
 801        let socket = socket
 802            .map_ok(to_tungstenite_message)
 803            .err_into()
 804            .with(|message| async move { Ok(to_axum_message(message)) });
 805        let connection = Connection::new(Box::pin(socket));
 806        async move {
 807            server
 808                .handle_connection(connection, socket_address, user, None, Executor::Production)
 809                .await
 810                .log_err();
 811        }
 812    })
 813}
 814
 815pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 816    let connections = server
 817        .connection_pool
 818        .lock()
 819        .connections()
 820        .filter(|connection| !connection.admin)
 821        .count();
 822
 823    METRIC_CONNECTIONS.set(connections as _);
 824
 825    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 826    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 827
 828    let encoder = prometheus::TextEncoder::new();
 829    let metric_families = prometheus::gather();
 830    let encoded_metrics = encoder
 831        .encode_to_string(&metric_families)
 832        .map_err(|err| anyhow!("{}", err))?;
 833    Ok(encoded_metrics)
 834}
 835
 836#[instrument(err, skip(executor))]
 837async fn connection_lost(
 838    session: Session,
 839    mut teardown: watch::Receiver<()>,
 840    executor: Executor,
 841) -> Result<()> {
 842    session.peer.disconnect(session.connection_id);
 843    session
 844        .connection_pool()
 845        .await
 846        .remove_connection(session.connection_id)?;
 847
 848    session
 849        .db()
 850        .await
 851        .connection_lost(session.connection_id)
 852        .await
 853        .trace_err();
 854
 855    futures::select_biased! {
 856        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 857            leave_room_for_session(&session).await.trace_err();
 858
 859            if !session
 860                .connection_pool()
 861                .await
 862                .is_user_online(session.user_id)
 863            {
 864                let db = session.db().await;
 865                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 866                    room_updated(&room, &session.peer);
 867                }
 868            }
 869            update_user_contacts(session.user_id, &session).await?;
 870        }
 871        _ = teardown.changed().fuse() => {}
 872    }
 873
 874    Ok(())
 875}
 876
 877async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 878    response.send(proto::Ack {})?;
 879    Ok(())
 880}
 881
 882async fn create_room(
 883    _request: proto::CreateRoom,
 884    response: Response<proto::CreateRoom>,
 885    session: Session,
 886) -> Result<()> {
 887    let live_kit_room = nanoid::nanoid!(30);
 888
 889    let live_kit_connection_info = {
 890        let live_kit_room = live_kit_room.clone();
 891        let live_kit = session.live_kit_client.as_ref();
 892
 893        util::async_iife!({
 894            let live_kit = live_kit?;
 895
 896            live_kit
 897                .create_room(live_kit_room.clone())
 898                .await
 899                .trace_err()?;
 900
 901            let token = live_kit
 902                .room_token(&live_kit_room, &session.user_id.to_string())
 903                .trace_err()?;
 904
 905            Some(proto::LiveKitConnectionInfo {
 906                server_url: live_kit.url().into(),
 907                token,
 908            })
 909        })
 910    }
 911    .await;
 912
 913    let room = session
 914        .db()
 915        .await
 916        .create_room(session.user_id, session.connection_id, &live_kit_room)
 917        .await?;
 918
 919    response.send(proto::CreateRoomResponse {
 920        room: Some(room.clone()),
 921        live_kit_connection_info,
 922    })?;
 923
 924    update_user_contacts(session.user_id, &session).await?;
 925    Ok(())
 926}
 927
 928async fn join_room(
 929    request: proto::JoinRoom,
 930    response: Response<proto::JoinRoom>,
 931    session: Session,
 932) -> Result<()> {
 933    let room_id = RoomId::from_proto(request.id);
 934    let joined_room = {
 935        let room = session
 936            .db()
 937            .await
 938            .join_room(room_id, session.user_id, session.connection_id)
 939            .await?;
 940        room_updated(&room.room, &session.peer);
 941        room.into_inner()
 942    };
 943
 944    if let Some(channel_id) = joined_room.channel_id {
 945        channel_updated(
 946            channel_id,
 947            &joined_room.room,
 948            &joined_room.channel_members,
 949            &session.peer,
 950            &*session.connection_pool().await,
 951        )
 952    }
 953
 954    for connection_id in session
 955        .connection_pool()
 956        .await
 957        .user_connection_ids(session.user_id)
 958    {
 959        session
 960            .peer
 961            .send(
 962                connection_id,
 963                proto::CallCanceled {
 964                    room_id: room_id.to_proto(),
 965                },
 966            )
 967            .trace_err();
 968    }
 969
 970    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 971        if let Some(token) = live_kit
 972            .room_token(
 973                &joined_room.room.live_kit_room,
 974                &session.user_id.to_string(),
 975            )
 976            .trace_err()
 977        {
 978            Some(proto::LiveKitConnectionInfo {
 979                server_url: live_kit.url().into(),
 980                token,
 981            })
 982        } else {
 983            None
 984        }
 985    } else {
 986        None
 987    };
 988
 989    response.send(proto::JoinRoomResponse {
 990        room: Some(joined_room.room),
 991        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
 992        live_kit_connection_info,
 993    })?;
 994
 995    update_user_contacts(session.user_id, &session).await?;
 996    Ok(())
 997}
 998
 999async fn rejoin_room(
1000    request: proto::RejoinRoom,
1001    response: Response<proto::RejoinRoom>,
1002    session: Session,
1003) -> Result<()> {
1004    let room;
1005    let channel_id;
1006    let channel_members;
1007    {
1008        let mut rejoined_room = session
1009            .db()
1010            .await
1011            .rejoin_room(request, session.user_id, session.connection_id)
1012            .await?;
1013
1014        response.send(proto::RejoinRoomResponse {
1015            room: Some(rejoined_room.room.clone()),
1016            reshared_projects: rejoined_room
1017                .reshared_projects
1018                .iter()
1019                .map(|project| proto::ResharedProject {
1020                    id: project.id.to_proto(),
1021                    collaborators: project
1022                        .collaborators
1023                        .iter()
1024                        .map(|collaborator| collaborator.to_proto())
1025                        .collect(),
1026                })
1027                .collect(),
1028            rejoined_projects: rejoined_room
1029                .rejoined_projects
1030                .iter()
1031                .map(|rejoined_project| proto::RejoinedProject {
1032                    id: rejoined_project.id.to_proto(),
1033                    worktrees: rejoined_project
1034                        .worktrees
1035                        .iter()
1036                        .map(|worktree| proto::WorktreeMetadata {
1037                            id: worktree.id,
1038                            root_name: worktree.root_name.clone(),
1039                            visible: worktree.visible,
1040                            abs_path: worktree.abs_path.clone(),
1041                        })
1042                        .collect(),
1043                    collaborators: rejoined_project
1044                        .collaborators
1045                        .iter()
1046                        .map(|collaborator| collaborator.to_proto())
1047                        .collect(),
1048                    language_servers: rejoined_project.language_servers.clone(),
1049                })
1050                .collect(),
1051        })?;
1052        room_updated(&rejoined_room.room, &session.peer);
1053
1054        for project in &rejoined_room.reshared_projects {
1055            for collaborator in &project.collaborators {
1056                session
1057                    .peer
1058                    .send(
1059                        collaborator.connection_id,
1060                        proto::UpdateProjectCollaborator {
1061                            project_id: project.id.to_proto(),
1062                            old_peer_id: Some(project.old_connection_id.into()),
1063                            new_peer_id: Some(session.connection_id.into()),
1064                        },
1065                    )
1066                    .trace_err();
1067            }
1068
1069            broadcast(
1070                Some(session.connection_id),
1071                project
1072                    .collaborators
1073                    .iter()
1074                    .map(|collaborator| collaborator.connection_id),
1075                |connection_id| {
1076                    session.peer.forward_send(
1077                        session.connection_id,
1078                        connection_id,
1079                        proto::UpdateProject {
1080                            project_id: project.id.to_proto(),
1081                            worktrees: project.worktrees.clone(),
1082                        },
1083                    )
1084                },
1085            );
1086        }
1087
1088        for project in &rejoined_room.rejoined_projects {
1089            for collaborator in &project.collaborators {
1090                session
1091                    .peer
1092                    .send(
1093                        collaborator.connection_id,
1094                        proto::UpdateProjectCollaborator {
1095                            project_id: project.id.to_proto(),
1096                            old_peer_id: Some(project.old_connection_id.into()),
1097                            new_peer_id: Some(session.connection_id.into()),
1098                        },
1099                    )
1100                    .trace_err();
1101            }
1102        }
1103
1104        for project in &mut rejoined_room.rejoined_projects {
1105            for worktree in mem::take(&mut project.worktrees) {
1106                #[cfg(any(test, feature = "test-support"))]
1107                const MAX_CHUNK_SIZE: usize = 2;
1108                #[cfg(not(any(test, feature = "test-support")))]
1109                const MAX_CHUNK_SIZE: usize = 256;
1110
1111                // Stream this worktree's entries.
1112                let message = proto::UpdateWorktree {
1113                    project_id: project.id.to_proto(),
1114                    worktree_id: worktree.id,
1115                    abs_path: worktree.abs_path.clone(),
1116                    root_name: worktree.root_name,
1117                    updated_entries: worktree.updated_entries,
1118                    removed_entries: worktree.removed_entries,
1119                    scan_id: worktree.scan_id,
1120                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1121                    updated_repositories: worktree.updated_repositories,
1122                    removed_repositories: worktree.removed_repositories,
1123                };
1124                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1125                    session.peer.send(session.connection_id, update.clone())?;
1126                }
1127
1128                // Stream this worktree's diagnostics.
1129                for summary in worktree.diagnostic_summaries {
1130                    session.peer.send(
1131                        session.connection_id,
1132                        proto::UpdateDiagnosticSummary {
1133                            project_id: project.id.to_proto(),
1134                            worktree_id: worktree.id,
1135                            summary: Some(summary),
1136                        },
1137                    )?;
1138                }
1139
1140                for settings_file in worktree.settings_files {
1141                    session.peer.send(
1142                        session.connection_id,
1143                        proto::UpdateWorktreeSettings {
1144                            project_id: project.id.to_proto(),
1145                            worktree_id: worktree.id,
1146                            path: settings_file.path,
1147                            content: Some(settings_file.content),
1148                        },
1149                    )?;
1150                }
1151            }
1152
1153            for language_server in &project.language_servers {
1154                session.peer.send(
1155                    session.connection_id,
1156                    proto::UpdateLanguageServer {
1157                        project_id: project.id.to_proto(),
1158                        language_server_id: language_server.id,
1159                        variant: Some(
1160                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1161                                proto::LspDiskBasedDiagnosticsUpdated {},
1162                            ),
1163                        ),
1164                    },
1165                )?;
1166            }
1167        }
1168
1169        let rejoined_room = rejoined_room.into_inner();
1170
1171        room = rejoined_room.room;
1172        channel_id = rejoined_room.channel_id;
1173        channel_members = rejoined_room.channel_members;
1174    }
1175
1176    if let Some(channel_id) = channel_id {
1177        channel_updated(
1178            channel_id,
1179            &room,
1180            &channel_members,
1181            &session.peer,
1182            &*session.connection_pool().await,
1183        );
1184    }
1185
1186    update_user_contacts(session.user_id, &session).await?;
1187    Ok(())
1188}
1189
1190async fn leave_room(
1191    _: proto::LeaveRoom,
1192    response: Response<proto::LeaveRoom>,
1193    session: Session,
1194) -> Result<()> {
1195    leave_room_for_session(&session).await?;
1196    response.send(proto::Ack {})?;
1197    Ok(())
1198}
1199
1200async fn call(
1201    request: proto::Call,
1202    response: Response<proto::Call>,
1203    session: Session,
1204) -> Result<()> {
1205    let room_id = RoomId::from_proto(request.room_id);
1206    let calling_user_id = session.user_id;
1207    let calling_connection_id = session.connection_id;
1208    let called_user_id = UserId::from_proto(request.called_user_id);
1209    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1210    if !session
1211        .db()
1212        .await
1213        .has_contact(calling_user_id, called_user_id)
1214        .await?
1215    {
1216        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1217    }
1218
1219    let incoming_call = {
1220        let (room, incoming_call) = &mut *session
1221            .db()
1222            .await
1223            .call(
1224                room_id,
1225                calling_user_id,
1226                calling_connection_id,
1227                called_user_id,
1228                initial_project_id,
1229            )
1230            .await?;
1231        room_updated(&room, &session.peer);
1232        mem::take(incoming_call)
1233    };
1234    update_user_contacts(called_user_id, &session).await?;
1235
1236    let mut calls = session
1237        .connection_pool()
1238        .await
1239        .user_connection_ids(called_user_id)
1240        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1241        .collect::<FuturesUnordered<_>>();
1242
1243    while let Some(call_response) = calls.next().await {
1244        match call_response.as_ref() {
1245            Ok(_) => {
1246                response.send(proto::Ack {})?;
1247                return Ok(());
1248            }
1249            Err(_) => {
1250                call_response.trace_err();
1251            }
1252        }
1253    }
1254
1255    {
1256        let room = session
1257            .db()
1258            .await
1259            .call_failed(room_id, called_user_id)
1260            .await?;
1261        room_updated(&room, &session.peer);
1262    }
1263    update_user_contacts(called_user_id, &session).await?;
1264
1265    Err(anyhow!("failed to ring user"))?
1266}
1267
1268async fn cancel_call(
1269    request: proto::CancelCall,
1270    response: Response<proto::CancelCall>,
1271    session: Session,
1272) -> Result<()> {
1273    let called_user_id = UserId::from_proto(request.called_user_id);
1274    let room_id = RoomId::from_proto(request.room_id);
1275    {
1276        let room = session
1277            .db()
1278            .await
1279            .cancel_call(room_id, session.connection_id, called_user_id)
1280            .await?;
1281        room_updated(&room, &session.peer);
1282    }
1283
1284    for connection_id in session
1285        .connection_pool()
1286        .await
1287        .user_connection_ids(called_user_id)
1288    {
1289        session
1290            .peer
1291            .send(
1292                connection_id,
1293                proto::CallCanceled {
1294                    room_id: room_id.to_proto(),
1295                },
1296            )
1297            .trace_err();
1298    }
1299    response.send(proto::Ack {})?;
1300
1301    update_user_contacts(called_user_id, &session).await?;
1302    Ok(())
1303}
1304
1305async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1306    let room_id = RoomId::from_proto(message.room_id);
1307    {
1308        let room = session
1309            .db()
1310            .await
1311            .decline_call(Some(room_id), session.user_id)
1312            .await?
1313            .ok_or_else(|| anyhow!("failed to decline call"))?;
1314        room_updated(&room, &session.peer);
1315    }
1316
1317    for connection_id in session
1318        .connection_pool()
1319        .await
1320        .user_connection_ids(session.user_id)
1321    {
1322        session
1323            .peer
1324            .send(
1325                connection_id,
1326                proto::CallCanceled {
1327                    room_id: room_id.to_proto(),
1328                },
1329            )
1330            .trace_err();
1331    }
1332    update_user_contacts(session.user_id, &session).await?;
1333    Ok(())
1334}
1335
1336async fn update_participant_location(
1337    request: proto::UpdateParticipantLocation,
1338    response: Response<proto::UpdateParticipantLocation>,
1339    session: Session,
1340) -> Result<()> {
1341    let room_id = RoomId::from_proto(request.room_id);
1342    let location = request
1343        .location
1344        .ok_or_else(|| anyhow!("invalid location"))?;
1345
1346    let db = session.db().await;
1347    let room = db
1348        .update_room_participant_location(room_id, session.connection_id, location)
1349        .await?;
1350
1351    room_updated(&room, &session.peer);
1352    response.send(proto::Ack {})?;
1353    Ok(())
1354}
1355
1356async fn share_project(
1357    request: proto::ShareProject,
1358    response: Response<proto::ShareProject>,
1359    session: Session,
1360) -> Result<()> {
1361    let (project_id, room) = &*session
1362        .db()
1363        .await
1364        .share_project(
1365            RoomId::from_proto(request.room_id),
1366            session.connection_id,
1367            &request.worktrees,
1368        )
1369        .await?;
1370    response.send(proto::ShareProjectResponse {
1371        project_id: project_id.to_proto(),
1372    })?;
1373    room_updated(&room, &session.peer);
1374
1375    Ok(())
1376}
1377
1378async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1379    let project_id = ProjectId::from_proto(message.project_id);
1380
1381    let (room, guest_connection_ids) = &*session
1382        .db()
1383        .await
1384        .unshare_project(project_id, session.connection_id)
1385        .await?;
1386
1387    broadcast(
1388        Some(session.connection_id),
1389        guest_connection_ids.iter().copied(),
1390        |conn_id| session.peer.send(conn_id, message.clone()),
1391    );
1392    room_updated(&room, &session.peer);
1393
1394    Ok(())
1395}
1396
1397async fn join_project(
1398    request: proto::JoinProject,
1399    response: Response<proto::JoinProject>,
1400    session: Session,
1401) -> Result<()> {
1402    let project_id = ProjectId::from_proto(request.project_id);
1403    let guest_user_id = session.user_id;
1404
1405    tracing::info!(%project_id, "join project");
1406
1407    let (project, replica_id) = &mut *session
1408        .db()
1409        .await
1410        .join_project(project_id, session.connection_id)
1411        .await?;
1412
1413    let collaborators = project
1414        .collaborators
1415        .iter()
1416        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1417        .map(|collaborator| collaborator.to_proto())
1418        .collect::<Vec<_>>();
1419
1420    let worktrees = project
1421        .worktrees
1422        .iter()
1423        .map(|(id, worktree)| proto::WorktreeMetadata {
1424            id: *id,
1425            root_name: worktree.root_name.clone(),
1426            visible: worktree.visible,
1427            abs_path: worktree.abs_path.clone(),
1428        })
1429        .collect::<Vec<_>>();
1430
1431    for collaborator in &collaborators {
1432        session
1433            .peer
1434            .send(
1435                collaborator.peer_id.unwrap().into(),
1436                proto::AddProjectCollaborator {
1437                    project_id: project_id.to_proto(),
1438                    collaborator: Some(proto::Collaborator {
1439                        peer_id: Some(session.connection_id.into()),
1440                        replica_id: replica_id.0 as u32,
1441                        user_id: guest_user_id.to_proto(),
1442                    }),
1443                },
1444            )
1445            .trace_err();
1446    }
1447
1448    // First, we send the metadata associated with each worktree.
1449    response.send(proto::JoinProjectResponse {
1450        worktrees: worktrees.clone(),
1451        replica_id: replica_id.0 as u32,
1452        collaborators: collaborators.clone(),
1453        language_servers: project.language_servers.clone(),
1454    })?;
1455
1456    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1457        #[cfg(any(test, feature = "test-support"))]
1458        const MAX_CHUNK_SIZE: usize = 2;
1459        #[cfg(not(any(test, feature = "test-support")))]
1460        const MAX_CHUNK_SIZE: usize = 256;
1461
1462        // Stream this worktree's entries.
1463        let message = proto::UpdateWorktree {
1464            project_id: project_id.to_proto(),
1465            worktree_id,
1466            abs_path: worktree.abs_path.clone(),
1467            root_name: worktree.root_name,
1468            updated_entries: worktree.entries,
1469            removed_entries: Default::default(),
1470            scan_id: worktree.scan_id,
1471            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1472            updated_repositories: worktree.repository_entries.into_values().collect(),
1473            removed_repositories: Default::default(),
1474        };
1475        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1476            session.peer.send(session.connection_id, update.clone())?;
1477        }
1478
1479        // Stream this worktree's diagnostics.
1480        for summary in worktree.diagnostic_summaries {
1481            session.peer.send(
1482                session.connection_id,
1483                proto::UpdateDiagnosticSummary {
1484                    project_id: project_id.to_proto(),
1485                    worktree_id: worktree.id,
1486                    summary: Some(summary),
1487                },
1488            )?;
1489        }
1490
1491        for settings_file in worktree.settings_files {
1492            session.peer.send(
1493                session.connection_id,
1494                proto::UpdateWorktreeSettings {
1495                    project_id: project_id.to_proto(),
1496                    worktree_id: worktree.id,
1497                    path: settings_file.path,
1498                    content: Some(settings_file.content),
1499                },
1500            )?;
1501        }
1502    }
1503
1504    for language_server in &project.language_servers {
1505        session.peer.send(
1506            session.connection_id,
1507            proto::UpdateLanguageServer {
1508                project_id: project_id.to_proto(),
1509                language_server_id: language_server.id,
1510                variant: Some(
1511                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1512                        proto::LspDiskBasedDiagnosticsUpdated {},
1513                    ),
1514                ),
1515            },
1516        )?;
1517    }
1518
1519    Ok(())
1520}
1521
1522async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1523    let sender_id = session.connection_id;
1524    let project_id = ProjectId::from_proto(request.project_id);
1525
1526    let (room, project) = &*session
1527        .db()
1528        .await
1529        .leave_project(project_id, sender_id)
1530        .await?;
1531    tracing::info!(
1532        %project_id,
1533        host_user_id = %project.host_user_id,
1534        host_connection_id = %project.host_connection_id,
1535        "leave project"
1536    );
1537
1538    project_left(&project, &session);
1539    room_updated(&room, &session.peer);
1540
1541    Ok(())
1542}
1543
1544async fn update_project(
1545    request: proto::UpdateProject,
1546    response: Response<proto::UpdateProject>,
1547    session: Session,
1548) -> Result<()> {
1549    let project_id = ProjectId::from_proto(request.project_id);
1550    let (room, guest_connection_ids) = &*session
1551        .db()
1552        .await
1553        .update_project(project_id, session.connection_id, &request.worktrees)
1554        .await?;
1555    broadcast(
1556        Some(session.connection_id),
1557        guest_connection_ids.iter().copied(),
1558        |connection_id| {
1559            session
1560                .peer
1561                .forward_send(session.connection_id, connection_id, request.clone())
1562        },
1563    );
1564    room_updated(&room, &session.peer);
1565    response.send(proto::Ack {})?;
1566
1567    Ok(())
1568}
1569
1570async fn update_worktree(
1571    request: proto::UpdateWorktree,
1572    response: Response<proto::UpdateWorktree>,
1573    session: Session,
1574) -> Result<()> {
1575    let guest_connection_ids = session
1576        .db()
1577        .await
1578        .update_worktree(&request, session.connection_id)
1579        .await?;
1580
1581    broadcast(
1582        Some(session.connection_id),
1583        guest_connection_ids.iter().copied(),
1584        |connection_id| {
1585            session
1586                .peer
1587                .forward_send(session.connection_id, connection_id, request.clone())
1588        },
1589    );
1590    response.send(proto::Ack {})?;
1591    Ok(())
1592}
1593
1594async fn update_diagnostic_summary(
1595    message: proto::UpdateDiagnosticSummary,
1596    session: Session,
1597) -> Result<()> {
1598    let guest_connection_ids = session
1599        .db()
1600        .await
1601        .update_diagnostic_summary(&message, session.connection_id)
1602        .await?;
1603
1604    broadcast(
1605        Some(session.connection_id),
1606        guest_connection_ids.iter().copied(),
1607        |connection_id| {
1608            session
1609                .peer
1610                .forward_send(session.connection_id, connection_id, message.clone())
1611        },
1612    );
1613
1614    Ok(())
1615}
1616
1617async fn update_worktree_settings(
1618    message: proto::UpdateWorktreeSettings,
1619    session: Session,
1620) -> Result<()> {
1621    let guest_connection_ids = session
1622        .db()
1623        .await
1624        .update_worktree_settings(&message, session.connection_id)
1625        .await?;
1626
1627    broadcast(
1628        Some(session.connection_id),
1629        guest_connection_ids.iter().copied(),
1630        |connection_id| {
1631            session
1632                .peer
1633                .forward_send(session.connection_id, connection_id, message.clone())
1634        },
1635    );
1636
1637    Ok(())
1638}
1639
1640async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1641    broadcast_project_message(request.project_id, request, session).await
1642}
1643
1644async fn start_language_server(
1645    request: proto::StartLanguageServer,
1646    session: Session,
1647) -> Result<()> {
1648    let guest_connection_ids = session
1649        .db()
1650        .await
1651        .start_language_server(&request, session.connection_id)
1652        .await?;
1653
1654    broadcast(
1655        Some(session.connection_id),
1656        guest_connection_ids.iter().copied(),
1657        |connection_id| {
1658            session
1659                .peer
1660                .forward_send(session.connection_id, connection_id, request.clone())
1661        },
1662    );
1663    Ok(())
1664}
1665
1666async fn update_language_server(
1667    request: proto::UpdateLanguageServer,
1668    session: Session,
1669) -> Result<()> {
1670    session.executor.record_backtrace();
1671    let project_id = ProjectId::from_proto(request.project_id);
1672    let project_connection_ids = session
1673        .db()
1674        .await
1675        .project_connection_ids(project_id, session.connection_id)
1676        .await?;
1677    broadcast(
1678        Some(session.connection_id),
1679        project_connection_ids.iter().copied(),
1680        |connection_id| {
1681            session
1682                .peer
1683                .forward_send(session.connection_id, connection_id, request.clone())
1684        },
1685    );
1686    Ok(())
1687}
1688
1689async fn forward_project_request<T>(
1690    request: T,
1691    response: Response<T>,
1692    session: Session,
1693) -> Result<()>
1694where
1695    T: EntityMessage + RequestMessage,
1696{
1697    session.executor.record_backtrace();
1698    let project_id = ProjectId::from_proto(request.remote_entity_id());
1699    let host_connection_id = {
1700        let collaborators = session
1701            .db()
1702            .await
1703            .project_collaborators(project_id, session.connection_id)
1704            .await?;
1705        collaborators
1706            .iter()
1707            .find(|collaborator| collaborator.is_host)
1708            .ok_or_else(|| anyhow!("host not found"))?
1709            .connection_id
1710    };
1711
1712    let payload = session
1713        .peer
1714        .forward_request(session.connection_id, host_connection_id, request)
1715        .await?;
1716
1717    response.send(payload)?;
1718    Ok(())
1719}
1720
1721async fn create_buffer_for_peer(
1722    request: proto::CreateBufferForPeer,
1723    session: Session,
1724) -> Result<()> {
1725    session.executor.record_backtrace();
1726    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1727    session
1728        .peer
1729        .forward_send(session.connection_id, peer_id.into(), request)?;
1730    Ok(())
1731}
1732
1733async fn update_buffer(
1734    request: proto::UpdateBuffer,
1735    response: Response<proto::UpdateBuffer>,
1736    session: Session,
1737) -> Result<()> {
1738    session.executor.record_backtrace();
1739    let project_id = ProjectId::from_proto(request.project_id);
1740    let mut guest_connection_ids;
1741    let mut host_connection_id = None;
1742    {
1743        let collaborators = session
1744            .db()
1745            .await
1746            .project_collaborators(project_id, session.connection_id)
1747            .await?;
1748        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1749        for collaborator in collaborators.iter() {
1750            if collaborator.is_host {
1751                host_connection_id = Some(collaborator.connection_id);
1752            } else {
1753                guest_connection_ids.push(collaborator.connection_id);
1754            }
1755        }
1756    }
1757    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1758
1759    session.executor.record_backtrace();
1760    broadcast(
1761        Some(session.connection_id),
1762        guest_connection_ids,
1763        |connection_id| {
1764            session
1765                .peer
1766                .forward_send(session.connection_id, connection_id, request.clone())
1767        },
1768    );
1769    if host_connection_id != session.connection_id {
1770        session
1771            .peer
1772            .forward_request(session.connection_id, host_connection_id, request.clone())
1773            .await?;
1774    }
1775
1776    response.send(proto::Ack {})?;
1777    Ok(())
1778}
1779
1780async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1781    let project_id = ProjectId::from_proto(request.project_id);
1782    let project_connection_ids = session
1783        .db()
1784        .await
1785        .project_connection_ids(project_id, session.connection_id)
1786        .await?;
1787
1788    broadcast(
1789        Some(session.connection_id),
1790        project_connection_ids.iter().copied(),
1791        |connection_id| {
1792            session
1793                .peer
1794                .forward_send(session.connection_id, connection_id, request.clone())
1795        },
1796    );
1797    Ok(())
1798}
1799
1800async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1801    let project_id = ProjectId::from_proto(request.project_id);
1802    let project_connection_ids = session
1803        .db()
1804        .await
1805        .project_connection_ids(project_id, session.connection_id)
1806        .await?;
1807    broadcast(
1808        Some(session.connection_id),
1809        project_connection_ids.iter().copied(),
1810        |connection_id| {
1811            session
1812                .peer
1813                .forward_send(session.connection_id, connection_id, request.clone())
1814        },
1815    );
1816    Ok(())
1817}
1818
1819async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1820    broadcast_project_message(request.project_id, request, session).await
1821}
1822
1823async fn broadcast_project_message<T: EnvelopedMessage>(
1824    project_id: u64,
1825    request: T,
1826    session: Session,
1827) -> Result<()> {
1828    let project_id = ProjectId::from_proto(project_id);
1829    let project_connection_ids = session
1830        .db()
1831        .await
1832        .project_connection_ids(project_id, session.connection_id)
1833        .await?;
1834    broadcast(
1835        Some(session.connection_id),
1836        project_connection_ids.iter().copied(),
1837        |connection_id| {
1838            session
1839                .peer
1840                .forward_send(session.connection_id, connection_id, request.clone())
1841        },
1842    );
1843    Ok(())
1844}
1845
1846async fn follow(
1847    request: proto::Follow,
1848    response: Response<proto::Follow>,
1849    session: Session,
1850) -> Result<()> {
1851    let project_id = ProjectId::from_proto(request.project_id);
1852    let leader_id = request
1853        .leader_id
1854        .ok_or_else(|| anyhow!("invalid leader id"))?
1855        .into();
1856    let follower_id = session.connection_id;
1857
1858    {
1859        let project_connection_ids = session
1860            .db()
1861            .await
1862            .project_connection_ids(project_id, session.connection_id)
1863            .await?;
1864
1865        if !project_connection_ids.contains(&leader_id) {
1866            Err(anyhow!("no such peer"))?;
1867        }
1868    }
1869
1870    let mut response_payload = session
1871        .peer
1872        .forward_request(session.connection_id, leader_id, request)
1873        .await?;
1874    response_payload
1875        .views
1876        .retain(|view| view.leader_id != Some(follower_id.into()));
1877    response.send(response_payload)?;
1878
1879    let room = session
1880        .db()
1881        .await
1882        .follow(project_id, leader_id, follower_id)
1883        .await?;
1884    room_updated(&room, &session.peer);
1885
1886    Ok(())
1887}
1888
1889async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1890    let project_id = ProjectId::from_proto(request.project_id);
1891    let leader_id = request
1892        .leader_id
1893        .ok_or_else(|| anyhow!("invalid leader id"))?
1894        .into();
1895    let follower_id = session.connection_id;
1896
1897    if !session
1898        .db()
1899        .await
1900        .project_connection_ids(project_id, session.connection_id)
1901        .await?
1902        .contains(&leader_id)
1903    {
1904        Err(anyhow!("no such peer"))?;
1905    }
1906
1907    session
1908        .peer
1909        .forward_send(session.connection_id, leader_id, request)?;
1910
1911    let room = session
1912        .db()
1913        .await
1914        .unfollow(project_id, leader_id, follower_id)
1915        .await?;
1916    room_updated(&room, &session.peer);
1917
1918    Ok(())
1919}
1920
1921async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1922    let project_id = ProjectId::from_proto(request.project_id);
1923    let project_connection_ids = session
1924        .db
1925        .lock()
1926        .await
1927        .project_connection_ids(project_id, session.connection_id)
1928        .await?;
1929
1930    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1931        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1932        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1933        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1934    });
1935    for follower_peer_id in request.follower_ids.iter().copied() {
1936        let follower_connection_id = follower_peer_id.into();
1937        if project_connection_ids.contains(&follower_connection_id)
1938            && Some(follower_peer_id) != leader_id
1939        {
1940            session.peer.forward_send(
1941                session.connection_id,
1942                follower_connection_id,
1943                request.clone(),
1944            )?;
1945        }
1946    }
1947    Ok(())
1948}
1949
1950async fn get_users(
1951    request: proto::GetUsers,
1952    response: Response<proto::GetUsers>,
1953    session: Session,
1954) -> Result<()> {
1955    let user_ids = request
1956        .user_ids
1957        .into_iter()
1958        .map(UserId::from_proto)
1959        .collect();
1960    let users = session
1961        .db()
1962        .await
1963        .get_users_by_ids(user_ids)
1964        .await?
1965        .into_iter()
1966        .map(|user| proto::User {
1967            id: user.id.to_proto(),
1968            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1969            github_login: user.github_login,
1970        })
1971        .collect();
1972    response.send(proto::UsersResponse { users })?;
1973    Ok(())
1974}
1975
1976async fn fuzzy_search_users(
1977    request: proto::FuzzySearchUsers,
1978    response: Response<proto::FuzzySearchUsers>,
1979    session: Session,
1980) -> Result<()> {
1981    let query = request.query;
1982    let users = match query.len() {
1983        0 => vec![],
1984        1 | 2 => session
1985            .db()
1986            .await
1987            .get_user_by_github_login(&query)
1988            .await?
1989            .into_iter()
1990            .collect(),
1991        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1992    };
1993    let users = users
1994        .into_iter()
1995        .filter(|user| user.id != session.user_id)
1996        .map(|user| proto::User {
1997            id: user.id.to_proto(),
1998            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1999            github_login: user.github_login,
2000        })
2001        .collect();
2002    response.send(proto::UsersResponse { users })?;
2003    Ok(())
2004}
2005
2006async fn request_contact(
2007    request: proto::RequestContact,
2008    response: Response<proto::RequestContact>,
2009    session: Session,
2010) -> Result<()> {
2011    let requester_id = session.user_id;
2012    let responder_id = UserId::from_proto(request.responder_id);
2013    if requester_id == responder_id {
2014        return Err(anyhow!("cannot add yourself as a contact"))?;
2015    }
2016
2017    session
2018        .db()
2019        .await
2020        .send_contact_request(requester_id, responder_id)
2021        .await?;
2022
2023    // Update outgoing contact requests of requester
2024    let mut update = proto::UpdateContacts::default();
2025    update.outgoing_requests.push(responder_id.to_proto());
2026    for connection_id in session
2027        .connection_pool()
2028        .await
2029        .user_connection_ids(requester_id)
2030    {
2031        session.peer.send(connection_id, update.clone())?;
2032    }
2033
2034    // Update incoming contact requests of responder
2035    let mut update = proto::UpdateContacts::default();
2036    update
2037        .incoming_requests
2038        .push(proto::IncomingContactRequest {
2039            requester_id: requester_id.to_proto(),
2040            should_notify: true,
2041        });
2042    for connection_id in session
2043        .connection_pool()
2044        .await
2045        .user_connection_ids(responder_id)
2046    {
2047        session.peer.send(connection_id, update.clone())?;
2048    }
2049
2050    response.send(proto::Ack {})?;
2051    Ok(())
2052}
2053
2054async fn respond_to_contact_request(
2055    request: proto::RespondToContactRequest,
2056    response: Response<proto::RespondToContactRequest>,
2057    session: Session,
2058) -> Result<()> {
2059    let responder_id = session.user_id;
2060    let requester_id = UserId::from_proto(request.requester_id);
2061    let db = session.db().await;
2062    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2063        db.dismiss_contact_notification(responder_id, requester_id)
2064            .await?;
2065    } else {
2066        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2067
2068        db.respond_to_contact_request(responder_id, requester_id, accept)
2069            .await?;
2070        let requester_busy = db.is_user_busy(requester_id).await?;
2071        let responder_busy = db.is_user_busy(responder_id).await?;
2072
2073        let pool = session.connection_pool().await;
2074        // Update responder with new contact
2075        let mut update = proto::UpdateContacts::default();
2076        if accept {
2077            update
2078                .contacts
2079                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2080        }
2081        update
2082            .remove_incoming_requests
2083            .push(requester_id.to_proto());
2084        for connection_id in pool.user_connection_ids(responder_id) {
2085            session.peer.send(connection_id, update.clone())?;
2086        }
2087
2088        // Update requester with new contact
2089        let mut update = proto::UpdateContacts::default();
2090        if accept {
2091            update
2092                .contacts
2093                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2094        }
2095        update
2096            .remove_outgoing_requests
2097            .push(responder_id.to_proto());
2098        for connection_id in pool.user_connection_ids(requester_id) {
2099            session.peer.send(connection_id, update.clone())?;
2100        }
2101    }
2102
2103    response.send(proto::Ack {})?;
2104    Ok(())
2105}
2106
2107async fn remove_contact(
2108    request: proto::RemoveContact,
2109    response: Response<proto::RemoveContact>,
2110    session: Session,
2111) -> Result<()> {
2112    let requester_id = session.user_id;
2113    let responder_id = UserId::from_proto(request.user_id);
2114    let db = session.db().await;
2115    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2116
2117    let pool = session.connection_pool().await;
2118    // Update outgoing contact requests of requester
2119    let mut update = proto::UpdateContacts::default();
2120    if contact_accepted {
2121        update.remove_contacts.push(responder_id.to_proto());
2122    } else {
2123        update
2124            .remove_outgoing_requests
2125            .push(responder_id.to_proto());
2126    }
2127    for connection_id in pool.user_connection_ids(requester_id) {
2128        session.peer.send(connection_id, update.clone())?;
2129    }
2130
2131    // Update incoming contact requests of responder
2132    let mut update = proto::UpdateContacts::default();
2133    if contact_accepted {
2134        update.remove_contacts.push(requester_id.to_proto());
2135    } else {
2136        update
2137            .remove_incoming_requests
2138            .push(requester_id.to_proto());
2139    }
2140    for connection_id in pool.user_connection_ids(responder_id) {
2141        session.peer.send(connection_id, update.clone())?;
2142    }
2143
2144    response.send(proto::Ack {})?;
2145    Ok(())
2146}
2147
2148async fn create_channel(
2149    request: proto::CreateChannel,
2150    response: Response<proto::CreateChannel>,
2151    session: Session,
2152) -> Result<()> {
2153    let db = session.db().await;
2154    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2155
2156    if let Some(live_kit) = session.live_kit_client.as_ref() {
2157        live_kit.create_room(live_kit_room.clone()).await?;
2158    }
2159
2160    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2161    let id = db
2162        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2163        .await?;
2164
2165    let channel = proto::Channel {
2166        id: id.to_proto(),
2167        name: request.name,
2168        parent_id: request.parent_id,
2169    };
2170
2171    response.send(proto::ChannelResponse {
2172        channel: Some(channel.clone()),
2173    })?;
2174
2175    let mut update = proto::UpdateChannels::default();
2176    update.channels.push(channel);
2177
2178    let user_ids_to_notify = if let Some(parent_id) = parent_id {
2179        db.get_channel_members(parent_id).await?
2180    } else {
2181        vec![session.user_id]
2182    };
2183
2184    let connection_pool = session.connection_pool().await;
2185    for user_id in user_ids_to_notify {
2186        for connection_id in connection_pool.user_connection_ids(user_id) {
2187            let mut update = update.clone();
2188            if user_id == session.user_id {
2189                update.channel_permissions.push(proto::ChannelPermission {
2190                    channel_id: id.to_proto(),
2191                    is_admin: true,
2192                });
2193            }
2194            session.peer.send(connection_id, update)?;
2195        }
2196    }
2197
2198    Ok(())
2199}
2200
2201async fn remove_channel(
2202    request: proto::RemoveChannel,
2203    response: Response<proto::RemoveChannel>,
2204    session: Session,
2205) -> Result<()> {
2206    let db = session.db().await;
2207
2208    let channel_id = request.channel_id;
2209    let (removed_channels, member_ids) = db
2210        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2211        .await?;
2212    response.send(proto::Ack {})?;
2213
2214    // Notify members of removed channels
2215    let mut update = proto::UpdateChannels::default();
2216    update
2217        .remove_channels
2218        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2219
2220    let connection_pool = session.connection_pool().await;
2221    for member_id in member_ids {
2222        for connection_id in connection_pool.user_connection_ids(member_id) {
2223            session.peer.send(connection_id, update.clone())?;
2224        }
2225    }
2226
2227    Ok(())
2228}
2229
2230async fn invite_channel_member(
2231    request: proto::InviteChannelMember,
2232    response: Response<proto::InviteChannelMember>,
2233    session: Session,
2234) -> Result<()> {
2235    let db = session.db().await;
2236    let channel_id = ChannelId::from_proto(request.channel_id);
2237    let invitee_id = UserId::from_proto(request.user_id);
2238    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2239        .await?;
2240
2241    let (channel, _) = db
2242        .get_channel(channel_id, session.user_id)
2243        .await?
2244        .ok_or_else(|| anyhow!("channel not found"))?;
2245
2246    let mut update = proto::UpdateChannels::default();
2247    update.channel_invitations.push(proto::Channel {
2248        id: channel.id.to_proto(),
2249        name: channel.name,
2250        parent_id: None,
2251    });
2252    for connection_id in session
2253        .connection_pool()
2254        .await
2255        .user_connection_ids(invitee_id)
2256    {
2257        session.peer.send(connection_id, update.clone())?;
2258    }
2259
2260    response.send(proto::Ack {})?;
2261    Ok(())
2262}
2263
2264async fn remove_channel_member(
2265    request: proto::RemoveChannelMember,
2266    response: Response<proto::RemoveChannelMember>,
2267    session: Session,
2268) -> Result<()> {
2269    let db = session.db().await;
2270    let channel_id = ChannelId::from_proto(request.channel_id);
2271    let member_id = UserId::from_proto(request.user_id);
2272
2273    db.remove_channel_member(channel_id, member_id, session.user_id)
2274        .await?;
2275
2276    let mut update = proto::UpdateChannels::default();
2277    update.remove_channels.push(channel_id.to_proto());
2278
2279    for connection_id in session
2280        .connection_pool()
2281        .await
2282        .user_connection_ids(member_id)
2283    {
2284        session.peer.send(connection_id, update.clone())?;
2285    }
2286
2287    response.send(proto::Ack {})?;
2288    Ok(())
2289}
2290
2291async fn set_channel_member_admin(
2292    request: proto::SetChannelMemberAdmin,
2293    response: Response<proto::SetChannelMemberAdmin>,
2294    session: Session,
2295) -> Result<()> {
2296    let db = session.db().await;
2297    let channel_id = ChannelId::from_proto(request.channel_id);
2298    let member_id = UserId::from_proto(request.user_id);
2299    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2300        .await?;
2301
2302    let (channel, has_accepted) = db
2303        .get_channel(channel_id, member_id)
2304        .await?
2305        .ok_or_else(|| anyhow!("channel not found"))?;
2306
2307    let mut update = proto::UpdateChannels::default();
2308    if has_accepted {
2309        update.channel_permissions.push(proto::ChannelPermission {
2310            channel_id: channel.id.to_proto(),
2311            is_admin: request.admin,
2312        });
2313    }
2314
2315    for connection_id in session
2316        .connection_pool()
2317        .await
2318        .user_connection_ids(member_id)
2319    {
2320        session.peer.send(connection_id, update.clone())?;
2321    }
2322
2323    response.send(proto::Ack {})?;
2324    Ok(())
2325}
2326
2327async fn rename_channel(
2328    request: proto::RenameChannel,
2329    response: Response<proto::RenameChannel>,
2330    session: Session,
2331) -> Result<()> {
2332    let db = session.db().await;
2333    let channel_id = ChannelId::from_proto(request.channel_id);
2334    let new_name = db
2335        .rename_channel(channel_id, session.user_id, &request.name)
2336        .await?;
2337
2338    let channel = proto::Channel {
2339        id: request.channel_id,
2340        name: new_name,
2341        parent_id: None,
2342    };
2343    response.send(proto::ChannelResponse {
2344        channel: Some(channel.clone()),
2345    })?;
2346    let mut update = proto::UpdateChannels::default();
2347    update.channels.push(channel);
2348
2349    let member_ids = db.get_channel_members(channel_id).await?;
2350
2351    let connection_pool = session.connection_pool().await;
2352    for member_id in member_ids {
2353        for connection_id in connection_pool.user_connection_ids(member_id) {
2354            session.peer.send(connection_id, update.clone())?;
2355        }
2356    }
2357
2358    Ok(())
2359}
2360
2361async fn get_channel_members(
2362    request: proto::GetChannelMembers,
2363    response: Response<proto::GetChannelMembers>,
2364    session: Session,
2365) -> Result<()> {
2366    let db = session.db().await;
2367    let channel_id = ChannelId::from_proto(request.channel_id);
2368    let members = db
2369        .get_channel_member_details(channel_id, session.user_id)
2370        .await?;
2371    response.send(proto::GetChannelMembersResponse { members })?;
2372    Ok(())
2373}
2374
2375async fn respond_to_channel_invite(
2376    request: proto::RespondToChannelInvite,
2377    response: Response<proto::RespondToChannelInvite>,
2378    session: Session,
2379) -> Result<()> {
2380    let db = session.db().await;
2381    let channel_id = ChannelId::from_proto(request.channel_id);
2382    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2383        .await?;
2384
2385    let mut update = proto::UpdateChannels::default();
2386    update
2387        .remove_channel_invitations
2388        .push(channel_id.to_proto());
2389    if request.accept {
2390        let result = db.get_channels_for_user(session.user_id).await?;
2391        update
2392            .channels
2393            .extend(result.channels.into_iter().map(|channel| proto::Channel {
2394                id: channel.id.to_proto(),
2395                name: channel.name,
2396                parent_id: channel.parent_id.map(ChannelId::to_proto),
2397            }));
2398        update
2399            .channel_participants
2400            .extend(
2401                result
2402                    .channel_participants
2403                    .into_iter()
2404                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2405                        channel_id: channel_id.to_proto(),
2406                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2407                    }),
2408            );
2409        update
2410            .channel_permissions
2411            .extend(
2412                result
2413                    .channels_with_admin_privileges
2414                    .into_iter()
2415                    .map(|channel_id| proto::ChannelPermission {
2416                        channel_id: channel_id.to_proto(),
2417                        is_admin: true,
2418                    }),
2419            );
2420    }
2421    session.peer.send(session.connection_id, update)?;
2422    response.send(proto::Ack {})?;
2423
2424    Ok(())
2425}
2426
2427async fn join_channel(
2428    request: proto::JoinChannel,
2429    response: Response<proto::JoinChannel>,
2430    session: Session,
2431) -> Result<()> {
2432    let channel_id = ChannelId::from_proto(request.channel_id);
2433
2434    let joined_room = {
2435        leave_room_for_session(&session).await?;
2436        let db = session.db().await;
2437
2438        let room_id = db.room_id_for_channel(channel_id).await?;
2439
2440        let joined_room = db
2441            .join_room(room_id, session.user_id, session.connection_id)
2442            .await?;
2443
2444        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2445            let token = live_kit
2446                .room_token(
2447                    &joined_room.room.live_kit_room,
2448                    &session.user_id.to_string(),
2449                )
2450                .trace_err()?;
2451
2452            Some(LiveKitConnectionInfo {
2453                server_url: live_kit.url().into(),
2454                token,
2455            })
2456        });
2457
2458        response.send(proto::JoinRoomResponse {
2459            room: Some(joined_room.room.clone()),
2460            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2461            live_kit_connection_info,
2462        })?;
2463
2464        room_updated(&joined_room.room, &session.peer);
2465
2466        joined_room.into_inner()
2467    };
2468
2469    channel_updated(
2470        channel_id,
2471        &joined_room.room,
2472        &joined_room.channel_members,
2473        &session.peer,
2474        &*session.connection_pool().await,
2475    );
2476
2477    update_user_contacts(session.user_id, &session).await?;
2478
2479    Ok(())
2480}
2481
2482async fn get_channel_buffer(
2483    request: proto::GetChannelBuffer,
2484    response: Response<proto::GetChannelBuffer>,
2485    session: Session,
2486) -> Result<()> {
2487    let db = session.db().await;
2488    let channel_id = ChannelId::from_proto(request.channel_id);
2489
2490    let buffer_id = db.get_or_create_buffer_for_channel(channel_id).await?;
2491
2492    let buffer = db.get_buffer(buffer_id).await?;
2493
2494    response.send(GetChannelBufferResponse {
2495        base_text: buffer.base_text,
2496        operations: buffer.operations,
2497    })?;
2498
2499    Ok(())
2500}
2501
2502async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2503    let project_id = ProjectId::from_proto(request.project_id);
2504    let project_connection_ids = session
2505        .db()
2506        .await
2507        .project_connection_ids(project_id, session.connection_id)
2508        .await?;
2509    broadcast(
2510        Some(session.connection_id),
2511        project_connection_ids.iter().copied(),
2512        |connection_id| {
2513            session
2514                .peer
2515                .forward_send(session.connection_id, connection_id, request.clone())
2516        },
2517    );
2518    Ok(())
2519}
2520
2521async fn get_private_user_info(
2522    _request: proto::GetPrivateUserInfo,
2523    response: Response<proto::GetPrivateUserInfo>,
2524    session: Session,
2525) -> Result<()> {
2526    let metrics_id = session
2527        .db()
2528        .await
2529        .get_user_metrics_id(session.user_id)
2530        .await?;
2531    let user = session
2532        .db()
2533        .await
2534        .get_user_by_id(session.user_id)
2535        .await?
2536        .ok_or_else(|| anyhow!("user not found"))?;
2537    response.send(proto::GetPrivateUserInfoResponse {
2538        metrics_id,
2539        staff: user.admin,
2540    })?;
2541    Ok(())
2542}
2543
2544fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2545    match message {
2546        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2547        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2548        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2549        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2550        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2551            code: frame.code.into(),
2552            reason: frame.reason,
2553        })),
2554    }
2555}
2556
2557fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2558    match message {
2559        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2560        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2561        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2562        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2563        AxumMessage::Close(frame) => {
2564            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2565                code: frame.code.into(),
2566                reason: frame.reason,
2567            }))
2568        }
2569    }
2570}
2571
2572fn build_initial_channels_update(
2573    channels: ChannelsForUser,
2574    channel_invites: Vec<db::Channel>,
2575) -> proto::UpdateChannels {
2576    let mut update = proto::UpdateChannels::default();
2577
2578    for channel in channels.channels {
2579        update.channels.push(proto::Channel {
2580            id: channel.id.to_proto(),
2581            name: channel.name,
2582            parent_id: channel.parent_id.map(|id| id.to_proto()),
2583        });
2584    }
2585
2586    for (channel_id, participants) in channels.channel_participants {
2587        update
2588            .channel_participants
2589            .push(proto::ChannelParticipants {
2590                channel_id: channel_id.to_proto(),
2591                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2592            });
2593    }
2594
2595    update
2596        .channel_permissions
2597        .extend(
2598            channels
2599                .channels_with_admin_privileges
2600                .into_iter()
2601                .map(|id| proto::ChannelPermission {
2602                    channel_id: id.to_proto(),
2603                    is_admin: true,
2604                }),
2605        );
2606
2607    for channel in channel_invites {
2608        update.channel_invitations.push(proto::Channel {
2609            id: channel.id.to_proto(),
2610            name: channel.name,
2611            parent_id: None,
2612        });
2613    }
2614
2615    update
2616}
2617
2618fn build_initial_contacts_update(
2619    contacts: Vec<db::Contact>,
2620    pool: &ConnectionPool,
2621) -> proto::UpdateContacts {
2622    let mut update = proto::UpdateContacts::default();
2623
2624    for contact in contacts {
2625        match contact {
2626            db::Contact::Accepted {
2627                user_id,
2628                should_notify,
2629                busy,
2630            } => {
2631                update
2632                    .contacts
2633                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2634            }
2635            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2636            db::Contact::Incoming {
2637                user_id,
2638                should_notify,
2639            } => update
2640                .incoming_requests
2641                .push(proto::IncomingContactRequest {
2642                    requester_id: user_id.to_proto(),
2643                    should_notify,
2644                }),
2645        }
2646    }
2647
2648    update
2649}
2650
2651fn contact_for_user(
2652    user_id: UserId,
2653    should_notify: bool,
2654    busy: bool,
2655    pool: &ConnectionPool,
2656) -> proto::Contact {
2657    proto::Contact {
2658        user_id: user_id.to_proto(),
2659        online: pool.is_user_online(user_id),
2660        busy,
2661        should_notify,
2662    }
2663}
2664
2665fn room_updated(room: &proto::Room, peer: &Peer) {
2666    broadcast(
2667        None,
2668        room.participants
2669            .iter()
2670            .filter_map(|participant| Some(participant.peer_id?.into())),
2671        |peer_id| {
2672            peer.send(
2673                peer_id.into(),
2674                proto::RoomUpdated {
2675                    room: Some(room.clone()),
2676                },
2677            )
2678        },
2679    );
2680}
2681
2682fn channel_updated(
2683    channel_id: ChannelId,
2684    room: &proto::Room,
2685    channel_members: &[UserId],
2686    peer: &Peer,
2687    pool: &ConnectionPool,
2688) {
2689    let participants = room
2690        .participants
2691        .iter()
2692        .map(|p| p.user_id)
2693        .collect::<Vec<_>>();
2694
2695    broadcast(
2696        None,
2697        channel_members
2698            .iter()
2699            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2700        |peer_id| {
2701            peer.send(
2702                peer_id.into(),
2703                proto::UpdateChannels {
2704                    channel_participants: vec![proto::ChannelParticipants {
2705                        channel_id: channel_id.to_proto(),
2706                        participant_user_ids: participants.clone(),
2707                    }],
2708                    ..Default::default()
2709                },
2710            )
2711        },
2712    );
2713}
2714
2715async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2716    let db = session.db().await;
2717
2718    let contacts = db.get_contacts(user_id).await?;
2719    let busy = db.is_user_busy(user_id).await?;
2720
2721    let pool = session.connection_pool().await;
2722    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2723    for contact in contacts {
2724        if let db::Contact::Accepted {
2725            user_id: contact_user_id,
2726            ..
2727        } = contact
2728        {
2729            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2730                session
2731                    .peer
2732                    .send(
2733                        contact_conn_id,
2734                        proto::UpdateContacts {
2735                            contacts: vec![updated_contact.clone()],
2736                            remove_contacts: Default::default(),
2737                            incoming_requests: Default::default(),
2738                            remove_incoming_requests: Default::default(),
2739                            outgoing_requests: Default::default(),
2740                            remove_outgoing_requests: Default::default(),
2741                        },
2742                    )
2743                    .trace_err();
2744            }
2745        }
2746    }
2747    Ok(())
2748}
2749
2750async fn leave_room_for_session(session: &Session) -> Result<()> {
2751    let mut contacts_to_update = HashSet::default();
2752
2753    let room_id;
2754    let canceled_calls_to_user_ids;
2755    let live_kit_room;
2756    let delete_live_kit_room;
2757    let room;
2758    let channel_members;
2759    let channel_id;
2760
2761    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2762        contacts_to_update.insert(session.user_id);
2763
2764        for project in left_room.left_projects.values() {
2765            project_left(project, session);
2766        }
2767
2768        room_id = RoomId::from_proto(left_room.room.id);
2769        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2770        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2771        delete_live_kit_room = left_room.deleted;
2772        room = mem::take(&mut left_room.room);
2773        channel_members = mem::take(&mut left_room.channel_members);
2774        channel_id = left_room.channel_id;
2775
2776        room_updated(&room, &session.peer);
2777    } else {
2778        return Ok(());
2779    }
2780
2781    if let Some(channel_id) = channel_id {
2782        channel_updated(
2783            channel_id,
2784            &room,
2785            &channel_members,
2786            &session.peer,
2787            &*session.connection_pool().await,
2788        );
2789    }
2790
2791    {
2792        let pool = session.connection_pool().await;
2793        for canceled_user_id in canceled_calls_to_user_ids {
2794            for connection_id in pool.user_connection_ids(canceled_user_id) {
2795                session
2796                    .peer
2797                    .send(
2798                        connection_id,
2799                        proto::CallCanceled {
2800                            room_id: room_id.to_proto(),
2801                        },
2802                    )
2803                    .trace_err();
2804            }
2805            contacts_to_update.insert(canceled_user_id);
2806        }
2807    }
2808
2809    for contact_user_id in contacts_to_update {
2810        update_user_contacts(contact_user_id, &session).await?;
2811    }
2812
2813    if let Some(live_kit) = session.live_kit_client.as_ref() {
2814        live_kit
2815            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2816            .await
2817            .trace_err();
2818
2819        if delete_live_kit_room {
2820            live_kit.delete_room(live_kit_room).await.trace_err();
2821        }
2822    }
2823
2824    Ok(())
2825}
2826
2827fn project_left(project: &db::LeftProject, session: &Session) {
2828    for connection_id in &project.connection_ids {
2829        if project.host_user_id == session.user_id {
2830            session
2831                .peer
2832                .send(
2833                    *connection_id,
2834                    proto::UnshareProject {
2835                        project_id: project.id.to_proto(),
2836                    },
2837                )
2838                .trace_err();
2839        } else {
2840            session
2841                .peer
2842                .send(
2843                    *connection_id,
2844                    proto::RemoveProjectCollaborator {
2845                        project_id: project.id.to_proto(),
2846                        peer_id: Some(session.connection_id.into()),
2847                    },
2848                )
2849                .trace_err();
2850        }
2851    }
2852}
2853
2854pub trait ResultExt {
2855    type Ok;
2856
2857    fn trace_err(self) -> Option<Self::Ok>;
2858}
2859
2860impl<T, E> ResultExt for Result<T, E>
2861where
2862    E: std::fmt::Debug,
2863{
2864    type Ok = T;
2865
2866    fn trace_err(self) -> Option<T> {
2867        match self {
2868            Ok(value) => Some(value),
2869            Err(error) => {
2870                tracing::error!("{:?}", error);
2871                None
2872            }
2873        }
2874    }
2875}