rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
   7        ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73
  74lazy_static! {
  75    static ref METRIC_CONNECTIONS: IntGauge =
  76        register_int_gauge!("connections", "number of connections").unwrap();
  77    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  78        "shared_projects",
  79        "number of open projects with one or more guests"
  80    )
  81    .unwrap();
  82}
  83
  84type MessageHandler =
  85    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  86
  87struct Response<R> {
  88    peer: Arc<Peer>,
  89    receipt: Receipt<R>,
  90    responded: Arc<AtomicBool>,
  91}
  92
  93impl<R: RequestMessage> Response<R> {
  94    fn send(self, payload: R::Response) -> Result<()> {
  95        self.responded.store(true, SeqCst);
  96        self.peer.respond(self.receipt, payload)?;
  97        Ok(())
  98    }
  99}
 100
 101#[derive(Clone)]
 102struct Session {
 103    user_id: UserId,
 104    connection_id: ConnectionId,
 105    db: Arc<tokio::sync::Mutex<DbHandle>>,
 106    peer: Arc<Peer>,
 107    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 108    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 109    executor: Executor,
 110}
 111
 112impl Session {
 113    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.db.lock().await;
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        guard
 120    }
 121
 122    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.connection_pool.lock();
 126        ConnectionPoolGuard {
 127            guard,
 128            _not_send: PhantomData,
 129        }
 130    }
 131}
 132
 133impl fmt::Debug for Session {
 134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 135        f.debug_struct("Session")
 136            .field("user_id", &self.user_id)
 137            .field("connection_id", &self.connection_id)
 138            .finish()
 139    }
 140}
 141
 142struct DbHandle(Arc<Database>);
 143
 144impl Deref for DbHandle {
 145    type Target = Database;
 146
 147    fn deref(&self) -> &Self::Target {
 148        self.0.as_ref()
 149    }
 150}
 151
 152pub struct Server {
 153    id: parking_lot::Mutex<ServerId>,
 154    peer: Arc<Peer>,
 155    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 156    app_state: Arc<AppState>,
 157    executor: Executor,
 158    handlers: HashMap<TypeId, MessageHandler>,
 159    teardown: watch::Sender<()>,
 160}
 161
 162pub(crate) struct ConnectionPoolGuard<'a> {
 163    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 164    _not_send: PhantomData<Rc<()>>,
 165}
 166
 167#[derive(Serialize)]
 168pub struct ServerSnapshot<'a> {
 169    peer: &'a Peer,
 170    #[serde(serialize_with = "serialize_deref")]
 171    connection_pool: ConnectionPoolGuard<'a>,
 172}
 173
 174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 175where
 176    S: Serializer,
 177    T: Deref<Target = U>,
 178    U: Serialize,
 179{
 180    Serialize::serialize(value.deref(), serializer)
 181}
 182
 183impl Server {
 184    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 185        let mut server = Self {
 186            id: parking_lot::Mutex::new(id),
 187            peer: Peer::new(id.0 as u32),
 188            app_state,
 189            executor,
 190            connection_pool: Default::default(),
 191            handlers: Default::default(),
 192            teardown: watch::channel(()).0,
 193        };
 194
 195        server
 196            .add_request_handler(ping)
 197            .add_request_handler(create_room)
 198            .add_request_handler(join_room)
 199            .add_request_handler(rejoin_room)
 200            .add_request_handler(leave_room)
 201            .add_request_handler(call)
 202            .add_request_handler(cancel_call)
 203            .add_message_handler(decline_call)
 204            .add_request_handler(update_participant_location)
 205            .add_request_handler(share_project)
 206            .add_message_handler(unshare_project)
 207            .add_request_handler(join_project)
 208            .add_message_handler(leave_project)
 209            .add_request_handler(update_project)
 210            .add_request_handler(update_worktree)
 211            .add_message_handler(start_language_server)
 212            .add_message_handler(update_language_server)
 213            .add_message_handler(update_diagnostic_summary)
 214            .add_message_handler(update_worktree_settings)
 215            .add_message_handler(refresh_inlay_hints)
 216            .add_request_handler(forward_project_request::<proto::GetHover>)
 217            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetReferences>)
 220            .add_request_handler(forward_project_request::<proto::SearchProject>)
 221            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 222            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 226            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 227            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 228            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 229            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 230            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 231            .add_request_handler(forward_project_request::<proto::PerformRename>)
 232            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 233            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 234            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 235            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 241            .add_request_handler(forward_project_request::<proto::InlayHints>)
 242            .add_message_handler(create_buffer_for_peer)
 243            .add_request_handler(update_buffer)
 244            .add_message_handler(update_buffer_file)
 245            .add_message_handler(buffer_reloaded)
 246            .add_message_handler(buffer_saved)
 247            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 248            .add_request_handler(get_users)
 249            .add_request_handler(fuzzy_search_users)
 250            .add_request_handler(request_contact)
 251            .add_request_handler(remove_contact)
 252            .add_request_handler(respond_to_contact_request)
 253            .add_request_handler(create_channel)
 254            .add_request_handler(delete_channel)
 255            .add_request_handler(invite_channel_member)
 256            .add_request_handler(remove_channel_member)
 257            .add_request_handler(set_channel_member_admin)
 258            .add_request_handler(rename_channel)
 259            .add_request_handler(join_channel_buffer)
 260            .add_request_handler(leave_channel_buffer)
 261            .add_message_handler(update_channel_buffer)
 262            .add_request_handler(rejoin_channel_buffers)
 263            .add_request_handler(get_channel_members)
 264            .add_request_handler(respond_to_channel_invite)
 265            .add_request_handler(join_channel)
 266            .add_request_handler(join_channel_chat)
 267            .add_message_handler(leave_channel_chat)
 268            .add_request_handler(send_channel_message)
 269            .add_request_handler(remove_channel_message)
 270            .add_request_handler(get_channel_messages)
 271            .add_request_handler(link_channel)
 272            .add_request_handler(unlink_channel)
 273            .add_request_handler(move_channel)
 274            .add_request_handler(follow)
 275            .add_message_handler(unfollow)
 276            .add_message_handler(update_followers)
 277            .add_message_handler(update_diff_base)
 278            .add_request_handler(get_private_user_info)
 279            .add_message_handler(acknowledge_channel_message)
 280            .add_message_handler(acknowledge_buffer_version);
 281
 282        Arc::new(server)
 283    }
 284
 285    pub async fn start(&self) -> Result<()> {
 286        let server_id = *self.id.lock();
 287        let app_state = self.app_state.clone();
 288        let peer = self.peer.clone();
 289        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 290        let pool = self.connection_pool.clone();
 291        let live_kit_client = self.app_state.live_kit_client.clone();
 292
 293        let span = info_span!("start server");
 294        self.executor.spawn_detached(
 295            async move {
 296                tracing::info!("waiting for cleanup timeout");
 297                timeout.await;
 298                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 299                if let Some((room_ids, channel_ids)) = app_state
 300                    .db
 301                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 302                    .await
 303                    .trace_err()
 304                {
 305                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 306                    tracing::info!(
 307                        stale_channel_buffer_count = channel_ids.len(),
 308                        "retrieved stale channel buffers"
 309                    );
 310
 311                    for channel_id in channel_ids {
 312                        if let Some(refreshed_channel_buffer) = app_state
 313                            .db
 314                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 315                            .await
 316                            .trace_err()
 317                        {
 318                            for connection_id in refreshed_channel_buffer.connection_ids {
 319                                peer.send(
 320                                    connection_id,
 321                                    proto::UpdateChannelBufferCollaborators {
 322                                        channel_id: channel_id.to_proto(),
 323                                        collaborators: refreshed_channel_buffer
 324                                            .collaborators
 325                                            .clone(),
 326                                    },
 327                                )
 328                                .trace_err();
 329                            }
 330                        }
 331                    }
 332
 333                    for room_id in room_ids {
 334                        let mut contacts_to_update = HashSet::default();
 335                        let mut canceled_calls_to_user_ids = Vec::new();
 336                        let mut live_kit_room = String::new();
 337                        let mut delete_live_kit_room = false;
 338
 339                        if let Some(mut refreshed_room) = app_state
 340                            .db
 341                            .clear_stale_room_participants(room_id, server_id)
 342                            .await
 343                            .trace_err()
 344                        {
 345                            tracing::info!(
 346                                room_id = room_id.0,
 347                                new_participant_count = refreshed_room.room.participants.len(),
 348                                "refreshed room"
 349                            );
 350                            room_updated(&refreshed_room.room, &peer);
 351                            if let Some(channel_id) = refreshed_room.channel_id {
 352                                channel_updated(
 353                                    channel_id,
 354                                    &refreshed_room.room,
 355                                    &refreshed_room.channel_members,
 356                                    &peer,
 357                                    &*pool.lock(),
 358                                );
 359                            }
 360                            contacts_to_update
 361                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 362                            contacts_to_update
 363                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 364                            canceled_calls_to_user_ids =
 365                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 366                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 367                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 368                        }
 369
 370                        {
 371                            let pool = pool.lock();
 372                            for canceled_user_id in canceled_calls_to_user_ids {
 373                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 374                                    peer.send(
 375                                        connection_id,
 376                                        proto::CallCanceled {
 377                                            room_id: room_id.to_proto(),
 378                                        },
 379                                    )
 380                                    .trace_err();
 381                                }
 382                            }
 383                        }
 384
 385                        for user_id in contacts_to_update {
 386                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 387                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 388                            if let Some((busy, contacts)) = busy.zip(contacts) {
 389                                let pool = pool.lock();
 390                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 391                                for contact in contacts {
 392                                    if let db::Contact::Accepted {
 393                                        user_id: contact_user_id,
 394                                        ..
 395                                    } = contact
 396                                    {
 397                                        for contact_conn_id in
 398                                            pool.user_connection_ids(contact_user_id)
 399                                        {
 400                                            peer.send(
 401                                                contact_conn_id,
 402                                                proto::UpdateContacts {
 403                                                    contacts: vec![updated_contact.clone()],
 404                                                    remove_contacts: Default::default(),
 405                                                    incoming_requests: Default::default(),
 406                                                    remove_incoming_requests: Default::default(),
 407                                                    outgoing_requests: Default::default(),
 408                                                    remove_outgoing_requests: Default::default(),
 409                                                },
 410                                            )
 411                                            .trace_err();
 412                                        }
 413                                    }
 414                                }
 415                            }
 416                        }
 417
 418                        if let Some(live_kit) = live_kit_client.as_ref() {
 419                            if delete_live_kit_room {
 420                                live_kit.delete_room(live_kit_room).await.trace_err();
 421                            }
 422                        }
 423                    }
 424                }
 425
 426                app_state
 427                    .db
 428                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 429                    .await
 430                    .trace_err();
 431            }
 432            .instrument(span),
 433        );
 434        Ok(())
 435    }
 436
 437    pub fn teardown(&self) {
 438        self.peer.teardown();
 439        self.connection_pool.lock().reset();
 440        let _ = self.teardown.send(());
 441    }
 442
 443    #[cfg(test)]
 444    pub fn reset(&self, id: ServerId) {
 445        self.teardown();
 446        *self.id.lock() = id;
 447        self.peer.reset(id.0 as u32);
 448    }
 449
 450    #[cfg(test)]
 451    pub fn id(&self) -> ServerId {
 452        *self.id.lock()
 453    }
 454
 455    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 456    where
 457        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 458        Fut: 'static + Send + Future<Output = Result<()>>,
 459        M: EnvelopedMessage,
 460    {
 461        let prev_handler = self.handlers.insert(
 462            TypeId::of::<M>(),
 463            Box::new(move |envelope, session| {
 464                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 465                let span = info_span!(
 466                    "handle message",
 467                    payload_type = envelope.payload_type_name()
 468                );
 469                span.in_scope(|| {
 470                    tracing::info!(
 471                        payload_type = envelope.payload_type_name(),
 472                        "message received"
 473                    );
 474                });
 475                let start_time = Instant::now();
 476                let future = (handler)(*envelope, session);
 477                async move {
 478                    let result = future.await;
 479                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 480                    match result {
 481                        Err(error) => {
 482                            tracing::error!(%error, ?duration_ms, "error handling message")
 483                        }
 484                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 485                    }
 486                }
 487                .instrument(span)
 488                .boxed()
 489            }),
 490        );
 491        if prev_handler.is_some() {
 492            panic!("registered a handler for the same message twice");
 493        }
 494        self
 495    }
 496
 497    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 498    where
 499        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 500        Fut: 'static + Send + Future<Output = Result<()>>,
 501        M: EnvelopedMessage,
 502    {
 503        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 504        self
 505    }
 506
 507    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 510        Fut: Send + Future<Output = Result<()>>,
 511        M: RequestMessage,
 512    {
 513        let handler = Arc::new(handler);
 514        self.add_handler(move |envelope, session| {
 515            let receipt = envelope.receipt();
 516            let handler = handler.clone();
 517            async move {
 518                let peer = session.peer.clone();
 519                let responded = Arc::new(AtomicBool::default());
 520                let response = Response {
 521                    peer: peer.clone(),
 522                    responded: responded.clone(),
 523                    receipt,
 524                };
 525                match (handler)(envelope.payload, response, session).await {
 526                    Ok(()) => {
 527                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 528                            Ok(())
 529                        } else {
 530                            Err(anyhow!("handler did not send a response"))?
 531                        }
 532                    }
 533                    Err(error) => {
 534                        peer.respond_with_error(
 535                            receipt,
 536                            proto::Error {
 537                                message: error.to_string(),
 538                            },
 539                        )?;
 540                        Err(error)
 541                    }
 542                }
 543            }
 544        })
 545    }
 546
 547    pub fn handle_connection(
 548        self: &Arc<Self>,
 549        connection: Connection,
 550        address: String,
 551        user: User,
 552        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 553        executor: Executor,
 554    ) -> impl Future<Output = Result<()>> {
 555        let this = self.clone();
 556        let user_id = user.id;
 557        let login = user.github_login;
 558        let span = info_span!("handle connection", %user_id, %login, %address);
 559        let mut teardown = self.teardown.subscribe();
 560        async move {
 561            let (connection_id, handle_io, mut incoming_rx) = this
 562                .peer
 563                .add_connection(connection, {
 564                    let executor = executor.clone();
 565                    move |duration| executor.sleep(duration)
 566                });
 567
 568            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 569            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 570            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 571
 572            if let Some(send_connection_id) = send_connection_id.take() {
 573                let _ = send_connection_id.send(connection_id);
 574            }
 575
 576            if !user.connected_once {
 577                this.peer.send(connection_id, proto::ShowContacts {})?;
 578                this.app_state.db.set_user_connected_once(user_id, true).await?;
 579            }
 580
 581            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 582                this.app_state.db.get_contacts(user_id),
 583                this.app_state.db.get_channels_for_user(user_id),
 584                this.app_state.db.get_channel_invites_for_user(user_id)
 585            ).await?;
 586
 587            {
 588                let mut pool = this.connection_pool.lock();
 589                pool.add_connection(connection_id, user_id, user.admin);
 590                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 591                this.peer.send(connection_id, build_initial_channels_update(
 592                    channels_for_user,
 593                    channel_invites
 594                ))?;
 595            }
 596
 597            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 598                this.peer.send(connection_id, incoming_call)?;
 599            }
 600
 601            let session = Session {
 602                user_id,
 603                connection_id,
 604                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 605                peer: this.peer.clone(),
 606                connection_pool: this.connection_pool.clone(),
 607                live_kit_client: this.app_state.live_kit_client.clone(),
 608                executor: executor.clone(),
 609            };
 610            update_user_contacts(user_id, &session).await?;
 611
 612            let handle_io = handle_io.fuse();
 613            futures::pin_mut!(handle_io);
 614
 615            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 616            // This prevents deadlocks when e.g., client A performs a request to client B and
 617            // client B performs a request to client A. If both clients stop processing further
 618            // messages until their respective request completes, they won't have a chance to
 619            // respond to the other client's request and cause a deadlock.
 620            //
 621            // This arrangement ensures we will attempt to process earlier messages first, but fall
 622            // back to processing messages arrived later in the spirit of making progress.
 623            let mut foreground_message_handlers = FuturesUnordered::new();
 624            let concurrent_handlers = Arc::new(Semaphore::new(256));
 625            loop {
 626                let next_message = async {
 627                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 628                    let message = incoming_rx.next().await;
 629                    (permit, message)
 630                }.fuse();
 631                futures::pin_mut!(next_message);
 632                futures::select_biased! {
 633                    _ = teardown.changed().fuse() => return Ok(()),
 634                    result = handle_io => {
 635                        if let Err(error) = result {
 636                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 637                        }
 638                        break;
 639                    }
 640                    _ = foreground_message_handlers.next() => {}
 641                    next_message = next_message => {
 642                        let (permit, message) = next_message;
 643                        if let Some(message) = message {
 644                            let type_name = message.payload_type_name();
 645                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 646                            let span_enter = span.enter();
 647                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 648                                let is_background = message.is_background();
 649                                let handle_message = (handler)(message, session.clone());
 650                                drop(span_enter);
 651
 652                                let handle_message = async move {
 653                                    handle_message.await;
 654                                    drop(permit);
 655                                }.instrument(span);
 656                                if is_background {
 657                                    executor.spawn_detached(handle_message);
 658                                } else {
 659                                    foreground_message_handlers.push(handle_message);
 660                                }
 661                            } else {
 662                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 663                            }
 664                        } else {
 665                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 666                            break;
 667                        }
 668                    }
 669                }
 670            }
 671
 672            drop(foreground_message_handlers);
 673            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 674            if let Err(error) = connection_lost(session, teardown, executor).await {
 675                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 676            }
 677
 678            Ok(())
 679        }.instrument(span)
 680    }
 681
 682    pub async fn invite_code_redeemed(
 683        self: &Arc<Self>,
 684        inviter_id: UserId,
 685        invitee_id: UserId,
 686    ) -> Result<()> {
 687        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 688            if let Some(code) = &user.invite_code {
 689                let pool = self.connection_pool.lock();
 690                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 691                for connection_id in pool.user_connection_ids(inviter_id) {
 692                    self.peer.send(
 693                        connection_id,
 694                        proto::UpdateContacts {
 695                            contacts: vec![invitee_contact.clone()],
 696                            ..Default::default()
 697                        },
 698                    )?;
 699                    self.peer.send(
 700                        connection_id,
 701                        proto::UpdateInviteInfo {
 702                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 703                            count: user.invite_count as u32,
 704                        },
 705                    )?;
 706                }
 707            }
 708        }
 709        Ok(())
 710    }
 711
 712    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 713        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 714            if let Some(invite_code) = &user.invite_code {
 715                let pool = self.connection_pool.lock();
 716                for connection_id in pool.user_connection_ids(user_id) {
 717                    self.peer.send(
 718                        connection_id,
 719                        proto::UpdateInviteInfo {
 720                            url: format!(
 721                                "{}{}",
 722                                self.app_state.config.invite_link_prefix, invite_code
 723                            ),
 724                            count: user.invite_count as u32,
 725                        },
 726                    )?;
 727                }
 728            }
 729        }
 730        Ok(())
 731    }
 732
 733    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 734        ServerSnapshot {
 735            connection_pool: ConnectionPoolGuard {
 736                guard: self.connection_pool.lock(),
 737                _not_send: PhantomData,
 738            },
 739            peer: &self.peer,
 740        }
 741    }
 742}
 743
 744impl<'a> Deref for ConnectionPoolGuard<'a> {
 745    type Target = ConnectionPool;
 746
 747    fn deref(&self) -> &Self::Target {
 748        &*self.guard
 749    }
 750}
 751
 752impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 753    fn deref_mut(&mut self) -> &mut Self::Target {
 754        &mut *self.guard
 755    }
 756}
 757
 758impl<'a> Drop for ConnectionPoolGuard<'a> {
 759    fn drop(&mut self) {
 760        #[cfg(test)]
 761        self.check_invariants();
 762    }
 763}
 764
 765fn broadcast<F>(
 766    sender_id: Option<ConnectionId>,
 767    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 768    mut f: F,
 769) where
 770    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 771{
 772    for receiver_id in receiver_ids {
 773        if Some(receiver_id) != sender_id {
 774            if let Err(error) = f(receiver_id) {
 775                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 776            }
 777        }
 778    }
 779}
 780
 781lazy_static! {
 782    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 783}
 784
 785pub struct ProtocolVersion(u32);
 786
 787impl Header for ProtocolVersion {
 788    fn name() -> &'static HeaderName {
 789        &ZED_PROTOCOL_VERSION
 790    }
 791
 792    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 793    where
 794        Self: Sized,
 795        I: Iterator<Item = &'i axum::http::HeaderValue>,
 796    {
 797        let version = values
 798            .next()
 799            .ok_or_else(axum::headers::Error::invalid)?
 800            .to_str()
 801            .map_err(|_| axum::headers::Error::invalid())?
 802            .parse()
 803            .map_err(|_| axum::headers::Error::invalid())?;
 804        Ok(Self(version))
 805    }
 806
 807    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 808        values.extend([self.0.to_string().parse().unwrap()]);
 809    }
 810}
 811
 812pub fn routes(server: Arc<Server>) -> Router<Body> {
 813    Router::new()
 814        .route("/rpc", get(handle_websocket_request))
 815        .layer(
 816            ServiceBuilder::new()
 817                .layer(Extension(server.app_state.clone()))
 818                .layer(middleware::from_fn(auth::validate_header)),
 819        )
 820        .route("/metrics", get(handle_metrics))
 821        .layer(Extension(server))
 822}
 823
 824pub async fn handle_websocket_request(
 825    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 826    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 827    Extension(server): Extension<Arc<Server>>,
 828    Extension(user): Extension<User>,
 829    ws: WebSocketUpgrade,
 830) -> axum::response::Response {
 831    if protocol_version != rpc::PROTOCOL_VERSION {
 832        return (
 833            StatusCode::UPGRADE_REQUIRED,
 834            "client must be upgraded".to_string(),
 835        )
 836            .into_response();
 837    }
 838    let socket_address = socket_address.to_string();
 839    ws.on_upgrade(move |socket| {
 840        use util::ResultExt;
 841        let socket = socket
 842            .map_ok(to_tungstenite_message)
 843            .err_into()
 844            .with(|message| async move { Ok(to_axum_message(message)) });
 845        let connection = Connection::new(Box::pin(socket));
 846        async move {
 847            server
 848                .handle_connection(connection, socket_address, user, None, Executor::Production)
 849                .await
 850                .log_err();
 851        }
 852    })
 853}
 854
 855pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 856    let connections = server
 857        .connection_pool
 858        .lock()
 859        .connections()
 860        .filter(|connection| !connection.admin)
 861        .count();
 862
 863    METRIC_CONNECTIONS.set(connections as _);
 864
 865    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 866    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 867
 868    let encoder = prometheus::TextEncoder::new();
 869    let metric_families = prometheus::gather();
 870    let encoded_metrics = encoder
 871        .encode_to_string(&metric_families)
 872        .map_err(|err| anyhow!("{}", err))?;
 873    Ok(encoded_metrics)
 874}
 875
 876#[instrument(err, skip(executor))]
 877async fn connection_lost(
 878    session: Session,
 879    mut teardown: watch::Receiver<()>,
 880    executor: Executor,
 881) -> Result<()> {
 882    session.peer.disconnect(session.connection_id);
 883    session
 884        .connection_pool()
 885        .await
 886        .remove_connection(session.connection_id)?;
 887
 888    session
 889        .db()
 890        .await
 891        .connection_lost(session.connection_id)
 892        .await
 893        .trace_err();
 894
 895    futures::select_biased! {
 896        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 897            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 898            leave_room_for_session(&session).await.trace_err();
 899            leave_channel_buffers_for_session(&session)
 900                .await
 901                .trace_err();
 902
 903            if !session
 904                .connection_pool()
 905                .await
 906                .is_user_online(session.user_id)
 907            {
 908                let db = session.db().await;
 909                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 910                    room_updated(&room, &session.peer);
 911                }
 912            }
 913
 914            update_user_contacts(session.user_id, &session).await?;
 915        }
 916        _ = teardown.changed().fuse() => {}
 917    }
 918
 919    Ok(())
 920}
 921
 922async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 923    response.send(proto::Ack {})?;
 924    Ok(())
 925}
 926
 927async fn create_room(
 928    _request: proto::CreateRoom,
 929    response: Response<proto::CreateRoom>,
 930    session: Session,
 931) -> Result<()> {
 932    let live_kit_room = nanoid::nanoid!(30);
 933
 934    let live_kit_connection_info = {
 935        let live_kit_room = live_kit_room.clone();
 936        let live_kit = session.live_kit_client.as_ref();
 937
 938        util::async_iife!({
 939            let live_kit = live_kit?;
 940
 941            let token = live_kit
 942                .room_token(&live_kit_room, &session.user_id.to_string())
 943                .trace_err()?;
 944
 945            Some(proto::LiveKitConnectionInfo {
 946                server_url: live_kit.url().into(),
 947                token,
 948            })
 949        })
 950    }
 951    .await;
 952
 953    let room = session
 954        .db()
 955        .await
 956        .create_room(
 957            session.user_id,
 958            session.connection_id,
 959            &live_kit_room,
 960            RELEASE_CHANNEL_NAME.as_str(),
 961        )
 962        .await?;
 963
 964    response.send(proto::CreateRoomResponse {
 965        room: Some(room.clone()),
 966        live_kit_connection_info,
 967    })?;
 968
 969    update_user_contacts(session.user_id, &session).await?;
 970    Ok(())
 971}
 972
 973async fn join_room(
 974    request: proto::JoinRoom,
 975    response: Response<proto::JoinRoom>,
 976    session: Session,
 977) -> Result<()> {
 978    let room_id = RoomId::from_proto(request.id);
 979    let joined_room = {
 980        let room = session
 981            .db()
 982            .await
 983            .join_room(
 984                room_id,
 985                session.user_id,
 986                session.connection_id,
 987                RELEASE_CHANNEL_NAME.as_str(),
 988            )
 989            .await?;
 990        room_updated(&room.room, &session.peer);
 991        room.into_inner()
 992    };
 993
 994    if let Some(channel_id) = joined_room.channel_id {
 995        channel_updated(
 996            channel_id,
 997            &joined_room.room,
 998            &joined_room.channel_members,
 999            &session.peer,
1000            &*session.connection_pool().await,
1001        )
1002    }
1003
1004    for connection_id in session
1005        .connection_pool()
1006        .await
1007        .user_connection_ids(session.user_id)
1008    {
1009        session
1010            .peer
1011            .send(
1012                connection_id,
1013                proto::CallCanceled {
1014                    room_id: room_id.to_proto(),
1015                },
1016            )
1017            .trace_err();
1018    }
1019
1020    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1021        if let Some(token) = live_kit
1022            .room_token(
1023                &joined_room.room.live_kit_room,
1024                &session.user_id.to_string(),
1025            )
1026            .trace_err()
1027        {
1028            Some(proto::LiveKitConnectionInfo {
1029                server_url: live_kit.url().into(),
1030                token,
1031            })
1032        } else {
1033            None
1034        }
1035    } else {
1036        None
1037    };
1038
1039    response.send(proto::JoinRoomResponse {
1040        room: Some(joined_room.room),
1041        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1042        live_kit_connection_info,
1043    })?;
1044
1045    update_user_contacts(session.user_id, &session).await?;
1046    Ok(())
1047}
1048
1049async fn rejoin_room(
1050    request: proto::RejoinRoom,
1051    response: Response<proto::RejoinRoom>,
1052    session: Session,
1053) -> Result<()> {
1054    let room;
1055    let channel_id;
1056    let channel_members;
1057    {
1058        let mut rejoined_room = session
1059            .db()
1060            .await
1061            .rejoin_room(request, session.user_id, session.connection_id)
1062            .await?;
1063
1064        response.send(proto::RejoinRoomResponse {
1065            room: Some(rejoined_room.room.clone()),
1066            reshared_projects: rejoined_room
1067                .reshared_projects
1068                .iter()
1069                .map(|project| proto::ResharedProject {
1070                    id: project.id.to_proto(),
1071                    collaborators: project
1072                        .collaborators
1073                        .iter()
1074                        .map(|collaborator| collaborator.to_proto())
1075                        .collect(),
1076                })
1077                .collect(),
1078            rejoined_projects: rejoined_room
1079                .rejoined_projects
1080                .iter()
1081                .map(|rejoined_project| proto::RejoinedProject {
1082                    id: rejoined_project.id.to_proto(),
1083                    worktrees: rejoined_project
1084                        .worktrees
1085                        .iter()
1086                        .map(|worktree| proto::WorktreeMetadata {
1087                            id: worktree.id,
1088                            root_name: worktree.root_name.clone(),
1089                            visible: worktree.visible,
1090                            abs_path: worktree.abs_path.clone(),
1091                        })
1092                        .collect(),
1093                    collaborators: rejoined_project
1094                        .collaborators
1095                        .iter()
1096                        .map(|collaborator| collaborator.to_proto())
1097                        .collect(),
1098                    language_servers: rejoined_project.language_servers.clone(),
1099                })
1100                .collect(),
1101        })?;
1102        room_updated(&rejoined_room.room, &session.peer);
1103
1104        for project in &rejoined_room.reshared_projects {
1105            for collaborator in &project.collaborators {
1106                session
1107                    .peer
1108                    .send(
1109                        collaborator.connection_id,
1110                        proto::UpdateProjectCollaborator {
1111                            project_id: project.id.to_proto(),
1112                            old_peer_id: Some(project.old_connection_id.into()),
1113                            new_peer_id: Some(session.connection_id.into()),
1114                        },
1115                    )
1116                    .trace_err();
1117            }
1118
1119            broadcast(
1120                Some(session.connection_id),
1121                project
1122                    .collaborators
1123                    .iter()
1124                    .map(|collaborator| collaborator.connection_id),
1125                |connection_id| {
1126                    session.peer.forward_send(
1127                        session.connection_id,
1128                        connection_id,
1129                        proto::UpdateProject {
1130                            project_id: project.id.to_proto(),
1131                            worktrees: project.worktrees.clone(),
1132                        },
1133                    )
1134                },
1135            );
1136        }
1137
1138        for project in &rejoined_room.rejoined_projects {
1139            for collaborator in &project.collaborators {
1140                session
1141                    .peer
1142                    .send(
1143                        collaborator.connection_id,
1144                        proto::UpdateProjectCollaborator {
1145                            project_id: project.id.to_proto(),
1146                            old_peer_id: Some(project.old_connection_id.into()),
1147                            new_peer_id: Some(session.connection_id.into()),
1148                        },
1149                    )
1150                    .trace_err();
1151            }
1152        }
1153
1154        for project in &mut rejoined_room.rejoined_projects {
1155            for worktree in mem::take(&mut project.worktrees) {
1156                #[cfg(any(test, feature = "test-support"))]
1157                const MAX_CHUNK_SIZE: usize = 2;
1158                #[cfg(not(any(test, feature = "test-support")))]
1159                const MAX_CHUNK_SIZE: usize = 256;
1160
1161                // Stream this worktree's entries.
1162                let message = proto::UpdateWorktree {
1163                    project_id: project.id.to_proto(),
1164                    worktree_id: worktree.id,
1165                    abs_path: worktree.abs_path.clone(),
1166                    root_name: worktree.root_name,
1167                    updated_entries: worktree.updated_entries,
1168                    removed_entries: worktree.removed_entries,
1169                    scan_id: worktree.scan_id,
1170                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1171                    updated_repositories: worktree.updated_repositories,
1172                    removed_repositories: worktree.removed_repositories,
1173                };
1174                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1175                    session.peer.send(session.connection_id, update.clone())?;
1176                }
1177
1178                // Stream this worktree's diagnostics.
1179                for summary in worktree.diagnostic_summaries {
1180                    session.peer.send(
1181                        session.connection_id,
1182                        proto::UpdateDiagnosticSummary {
1183                            project_id: project.id.to_proto(),
1184                            worktree_id: worktree.id,
1185                            summary: Some(summary),
1186                        },
1187                    )?;
1188                }
1189
1190                for settings_file in worktree.settings_files {
1191                    session.peer.send(
1192                        session.connection_id,
1193                        proto::UpdateWorktreeSettings {
1194                            project_id: project.id.to_proto(),
1195                            worktree_id: worktree.id,
1196                            path: settings_file.path,
1197                            content: Some(settings_file.content),
1198                        },
1199                    )?;
1200                }
1201            }
1202
1203            for language_server in &project.language_servers {
1204                session.peer.send(
1205                    session.connection_id,
1206                    proto::UpdateLanguageServer {
1207                        project_id: project.id.to_proto(),
1208                        language_server_id: language_server.id,
1209                        variant: Some(
1210                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1211                                proto::LspDiskBasedDiagnosticsUpdated {},
1212                            ),
1213                        ),
1214                    },
1215                )?;
1216            }
1217        }
1218
1219        let rejoined_room = rejoined_room.into_inner();
1220
1221        room = rejoined_room.room;
1222        channel_id = rejoined_room.channel_id;
1223        channel_members = rejoined_room.channel_members;
1224    }
1225
1226    if let Some(channel_id) = channel_id {
1227        channel_updated(
1228            channel_id,
1229            &room,
1230            &channel_members,
1231            &session.peer,
1232            &*session.connection_pool().await,
1233        );
1234    }
1235
1236    update_user_contacts(session.user_id, &session).await?;
1237    Ok(())
1238}
1239
1240async fn leave_room(
1241    _: proto::LeaveRoom,
1242    response: Response<proto::LeaveRoom>,
1243    session: Session,
1244) -> Result<()> {
1245    leave_room_for_session(&session).await?;
1246    response.send(proto::Ack {})?;
1247    Ok(())
1248}
1249
1250async fn call(
1251    request: proto::Call,
1252    response: Response<proto::Call>,
1253    session: Session,
1254) -> Result<()> {
1255    let room_id = RoomId::from_proto(request.room_id);
1256    let calling_user_id = session.user_id;
1257    let calling_connection_id = session.connection_id;
1258    let called_user_id = UserId::from_proto(request.called_user_id);
1259    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1260    if !session
1261        .db()
1262        .await
1263        .has_contact(calling_user_id, called_user_id)
1264        .await?
1265    {
1266        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1267    }
1268
1269    let incoming_call = {
1270        let (room, incoming_call) = &mut *session
1271            .db()
1272            .await
1273            .call(
1274                room_id,
1275                calling_user_id,
1276                calling_connection_id,
1277                called_user_id,
1278                initial_project_id,
1279            )
1280            .await?;
1281        room_updated(&room, &session.peer);
1282        mem::take(incoming_call)
1283    };
1284    update_user_contacts(called_user_id, &session).await?;
1285
1286    let mut calls = session
1287        .connection_pool()
1288        .await
1289        .user_connection_ids(called_user_id)
1290        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1291        .collect::<FuturesUnordered<_>>();
1292
1293    while let Some(call_response) = calls.next().await {
1294        match call_response.as_ref() {
1295            Ok(_) => {
1296                response.send(proto::Ack {})?;
1297                return Ok(());
1298            }
1299            Err(_) => {
1300                call_response.trace_err();
1301            }
1302        }
1303    }
1304
1305    {
1306        let room = session
1307            .db()
1308            .await
1309            .call_failed(room_id, called_user_id)
1310            .await?;
1311        room_updated(&room, &session.peer);
1312    }
1313    update_user_contacts(called_user_id, &session).await?;
1314
1315    Err(anyhow!("failed to ring user"))?
1316}
1317
1318async fn cancel_call(
1319    request: proto::CancelCall,
1320    response: Response<proto::CancelCall>,
1321    session: Session,
1322) -> Result<()> {
1323    let called_user_id = UserId::from_proto(request.called_user_id);
1324    let room_id = RoomId::from_proto(request.room_id);
1325    {
1326        let room = session
1327            .db()
1328            .await
1329            .cancel_call(room_id, session.connection_id, called_user_id)
1330            .await?;
1331        room_updated(&room, &session.peer);
1332    }
1333
1334    for connection_id in session
1335        .connection_pool()
1336        .await
1337        .user_connection_ids(called_user_id)
1338    {
1339        session
1340            .peer
1341            .send(
1342                connection_id,
1343                proto::CallCanceled {
1344                    room_id: room_id.to_proto(),
1345                },
1346            )
1347            .trace_err();
1348    }
1349    response.send(proto::Ack {})?;
1350
1351    update_user_contacts(called_user_id, &session).await?;
1352    Ok(())
1353}
1354
1355async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1356    let room_id = RoomId::from_proto(message.room_id);
1357    {
1358        let room = session
1359            .db()
1360            .await
1361            .decline_call(Some(room_id), session.user_id)
1362            .await?
1363            .ok_or_else(|| anyhow!("failed to decline call"))?;
1364        room_updated(&room, &session.peer);
1365    }
1366
1367    for connection_id in session
1368        .connection_pool()
1369        .await
1370        .user_connection_ids(session.user_id)
1371    {
1372        session
1373            .peer
1374            .send(
1375                connection_id,
1376                proto::CallCanceled {
1377                    room_id: room_id.to_proto(),
1378                },
1379            )
1380            .trace_err();
1381    }
1382    update_user_contacts(session.user_id, &session).await?;
1383    Ok(())
1384}
1385
1386async fn update_participant_location(
1387    request: proto::UpdateParticipantLocation,
1388    response: Response<proto::UpdateParticipantLocation>,
1389    session: Session,
1390) -> Result<()> {
1391    let room_id = RoomId::from_proto(request.room_id);
1392    let location = request
1393        .location
1394        .ok_or_else(|| anyhow!("invalid location"))?;
1395
1396    let db = session.db().await;
1397    let room = db
1398        .update_room_participant_location(room_id, session.connection_id, location)
1399        .await?;
1400
1401    room_updated(&room, &session.peer);
1402    response.send(proto::Ack {})?;
1403    Ok(())
1404}
1405
1406async fn share_project(
1407    request: proto::ShareProject,
1408    response: Response<proto::ShareProject>,
1409    session: Session,
1410) -> Result<()> {
1411    let (project_id, room) = &*session
1412        .db()
1413        .await
1414        .share_project(
1415            RoomId::from_proto(request.room_id),
1416            session.connection_id,
1417            &request.worktrees,
1418        )
1419        .await?;
1420    response.send(proto::ShareProjectResponse {
1421        project_id: project_id.to_proto(),
1422    })?;
1423    room_updated(&room, &session.peer);
1424
1425    Ok(())
1426}
1427
1428async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1429    let project_id = ProjectId::from_proto(message.project_id);
1430
1431    let (room, guest_connection_ids) = &*session
1432        .db()
1433        .await
1434        .unshare_project(project_id, session.connection_id)
1435        .await?;
1436
1437    broadcast(
1438        Some(session.connection_id),
1439        guest_connection_ids.iter().copied(),
1440        |conn_id| session.peer.send(conn_id, message.clone()),
1441    );
1442    room_updated(&room, &session.peer);
1443
1444    Ok(())
1445}
1446
1447async fn join_project(
1448    request: proto::JoinProject,
1449    response: Response<proto::JoinProject>,
1450    session: Session,
1451) -> Result<()> {
1452    let project_id = ProjectId::from_proto(request.project_id);
1453    let guest_user_id = session.user_id;
1454
1455    tracing::info!(%project_id, "join project");
1456
1457    let (project, replica_id) = &mut *session
1458        .db()
1459        .await
1460        .join_project(project_id, session.connection_id)
1461        .await?;
1462
1463    let collaborators = project
1464        .collaborators
1465        .iter()
1466        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1467        .map(|collaborator| collaborator.to_proto())
1468        .collect::<Vec<_>>();
1469
1470    let worktrees = project
1471        .worktrees
1472        .iter()
1473        .map(|(id, worktree)| proto::WorktreeMetadata {
1474            id: *id,
1475            root_name: worktree.root_name.clone(),
1476            visible: worktree.visible,
1477            abs_path: worktree.abs_path.clone(),
1478        })
1479        .collect::<Vec<_>>();
1480
1481    for collaborator in &collaborators {
1482        session
1483            .peer
1484            .send(
1485                collaborator.peer_id.unwrap().into(),
1486                proto::AddProjectCollaborator {
1487                    project_id: project_id.to_proto(),
1488                    collaborator: Some(proto::Collaborator {
1489                        peer_id: Some(session.connection_id.into()),
1490                        replica_id: replica_id.0 as u32,
1491                        user_id: guest_user_id.to_proto(),
1492                    }),
1493                },
1494            )
1495            .trace_err();
1496    }
1497
1498    // First, we send the metadata associated with each worktree.
1499    response.send(proto::JoinProjectResponse {
1500        worktrees: worktrees.clone(),
1501        replica_id: replica_id.0 as u32,
1502        collaborators: collaborators.clone(),
1503        language_servers: project.language_servers.clone(),
1504    })?;
1505
1506    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1507        #[cfg(any(test, feature = "test-support"))]
1508        const MAX_CHUNK_SIZE: usize = 2;
1509        #[cfg(not(any(test, feature = "test-support")))]
1510        const MAX_CHUNK_SIZE: usize = 256;
1511
1512        // Stream this worktree's entries.
1513        let message = proto::UpdateWorktree {
1514            project_id: project_id.to_proto(),
1515            worktree_id,
1516            abs_path: worktree.abs_path.clone(),
1517            root_name: worktree.root_name,
1518            updated_entries: worktree.entries,
1519            removed_entries: Default::default(),
1520            scan_id: worktree.scan_id,
1521            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1522            updated_repositories: worktree.repository_entries.into_values().collect(),
1523            removed_repositories: Default::default(),
1524        };
1525        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1526            session.peer.send(session.connection_id, update.clone())?;
1527        }
1528
1529        // Stream this worktree's diagnostics.
1530        for summary in worktree.diagnostic_summaries {
1531            session.peer.send(
1532                session.connection_id,
1533                proto::UpdateDiagnosticSummary {
1534                    project_id: project_id.to_proto(),
1535                    worktree_id: worktree.id,
1536                    summary: Some(summary),
1537                },
1538            )?;
1539        }
1540
1541        for settings_file in worktree.settings_files {
1542            session.peer.send(
1543                session.connection_id,
1544                proto::UpdateWorktreeSettings {
1545                    project_id: project_id.to_proto(),
1546                    worktree_id: worktree.id,
1547                    path: settings_file.path,
1548                    content: Some(settings_file.content),
1549                },
1550            )?;
1551        }
1552    }
1553
1554    for language_server in &project.language_servers {
1555        session.peer.send(
1556            session.connection_id,
1557            proto::UpdateLanguageServer {
1558                project_id: project_id.to_proto(),
1559                language_server_id: language_server.id,
1560                variant: Some(
1561                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1562                        proto::LspDiskBasedDiagnosticsUpdated {},
1563                    ),
1564                ),
1565            },
1566        )?;
1567    }
1568
1569    Ok(())
1570}
1571
1572async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1573    let sender_id = session.connection_id;
1574    let project_id = ProjectId::from_proto(request.project_id);
1575
1576    let (room, project) = &*session
1577        .db()
1578        .await
1579        .leave_project(project_id, sender_id)
1580        .await?;
1581    tracing::info!(
1582        %project_id,
1583        host_user_id = %project.host_user_id,
1584        host_connection_id = %project.host_connection_id,
1585        "leave project"
1586    );
1587
1588    project_left(&project, &session);
1589    room_updated(&room, &session.peer);
1590
1591    Ok(())
1592}
1593
1594async fn update_project(
1595    request: proto::UpdateProject,
1596    response: Response<proto::UpdateProject>,
1597    session: Session,
1598) -> Result<()> {
1599    let project_id = ProjectId::from_proto(request.project_id);
1600    let (room, guest_connection_ids) = &*session
1601        .db()
1602        .await
1603        .update_project(project_id, session.connection_id, &request.worktrees)
1604        .await?;
1605    broadcast(
1606        Some(session.connection_id),
1607        guest_connection_ids.iter().copied(),
1608        |connection_id| {
1609            session
1610                .peer
1611                .forward_send(session.connection_id, connection_id, request.clone())
1612        },
1613    );
1614    room_updated(&room, &session.peer);
1615    response.send(proto::Ack {})?;
1616
1617    Ok(())
1618}
1619
1620async fn update_worktree(
1621    request: proto::UpdateWorktree,
1622    response: Response<proto::UpdateWorktree>,
1623    session: Session,
1624) -> Result<()> {
1625    let guest_connection_ids = session
1626        .db()
1627        .await
1628        .update_worktree(&request, session.connection_id)
1629        .await?;
1630
1631    broadcast(
1632        Some(session.connection_id),
1633        guest_connection_ids.iter().copied(),
1634        |connection_id| {
1635            session
1636                .peer
1637                .forward_send(session.connection_id, connection_id, request.clone())
1638        },
1639    );
1640    response.send(proto::Ack {})?;
1641    Ok(())
1642}
1643
1644async fn update_diagnostic_summary(
1645    message: proto::UpdateDiagnosticSummary,
1646    session: Session,
1647) -> Result<()> {
1648    let guest_connection_ids = session
1649        .db()
1650        .await
1651        .update_diagnostic_summary(&message, session.connection_id)
1652        .await?;
1653
1654    broadcast(
1655        Some(session.connection_id),
1656        guest_connection_ids.iter().copied(),
1657        |connection_id| {
1658            session
1659                .peer
1660                .forward_send(session.connection_id, connection_id, message.clone())
1661        },
1662    );
1663
1664    Ok(())
1665}
1666
1667async fn update_worktree_settings(
1668    message: proto::UpdateWorktreeSettings,
1669    session: Session,
1670) -> Result<()> {
1671    let guest_connection_ids = session
1672        .db()
1673        .await
1674        .update_worktree_settings(&message, session.connection_id)
1675        .await?;
1676
1677    broadcast(
1678        Some(session.connection_id),
1679        guest_connection_ids.iter().copied(),
1680        |connection_id| {
1681            session
1682                .peer
1683                .forward_send(session.connection_id, connection_id, message.clone())
1684        },
1685    );
1686
1687    Ok(())
1688}
1689
1690async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1691    broadcast_project_message(request.project_id, request, session).await
1692}
1693
1694async fn start_language_server(
1695    request: proto::StartLanguageServer,
1696    session: Session,
1697) -> Result<()> {
1698    let guest_connection_ids = session
1699        .db()
1700        .await
1701        .start_language_server(&request, session.connection_id)
1702        .await?;
1703
1704    broadcast(
1705        Some(session.connection_id),
1706        guest_connection_ids.iter().copied(),
1707        |connection_id| {
1708            session
1709                .peer
1710                .forward_send(session.connection_id, connection_id, request.clone())
1711        },
1712    );
1713    Ok(())
1714}
1715
1716async fn update_language_server(
1717    request: proto::UpdateLanguageServer,
1718    session: Session,
1719) -> Result<()> {
1720    session.executor.record_backtrace();
1721    let project_id = ProjectId::from_proto(request.project_id);
1722    let project_connection_ids = session
1723        .db()
1724        .await
1725        .project_connection_ids(project_id, session.connection_id)
1726        .await?;
1727    broadcast(
1728        Some(session.connection_id),
1729        project_connection_ids.iter().copied(),
1730        |connection_id| {
1731            session
1732                .peer
1733                .forward_send(session.connection_id, connection_id, request.clone())
1734        },
1735    );
1736    Ok(())
1737}
1738
1739async fn forward_project_request<T>(
1740    request: T,
1741    response: Response<T>,
1742    session: Session,
1743) -> Result<()>
1744where
1745    T: EntityMessage + RequestMessage,
1746{
1747    session.executor.record_backtrace();
1748    let project_id = ProjectId::from_proto(request.remote_entity_id());
1749    let host_connection_id = {
1750        let collaborators = session
1751            .db()
1752            .await
1753            .project_collaborators(project_id, session.connection_id)
1754            .await?;
1755        collaborators
1756            .iter()
1757            .find(|collaborator| collaborator.is_host)
1758            .ok_or_else(|| anyhow!("host not found"))?
1759            .connection_id
1760    };
1761
1762    let payload = session
1763        .peer
1764        .forward_request(session.connection_id, host_connection_id, request)
1765        .await?;
1766
1767    response.send(payload)?;
1768    Ok(())
1769}
1770
1771async fn create_buffer_for_peer(
1772    request: proto::CreateBufferForPeer,
1773    session: Session,
1774) -> Result<()> {
1775    session.executor.record_backtrace();
1776    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1777    session
1778        .peer
1779        .forward_send(session.connection_id, peer_id.into(), request)?;
1780    Ok(())
1781}
1782
1783async fn update_buffer(
1784    request: proto::UpdateBuffer,
1785    response: Response<proto::UpdateBuffer>,
1786    session: Session,
1787) -> Result<()> {
1788    session.executor.record_backtrace();
1789    let project_id = ProjectId::from_proto(request.project_id);
1790    let mut guest_connection_ids;
1791    let mut host_connection_id = None;
1792    {
1793        let collaborators = session
1794            .db()
1795            .await
1796            .project_collaborators(project_id, session.connection_id)
1797            .await?;
1798        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1799        for collaborator in collaborators.iter() {
1800            if collaborator.is_host {
1801                host_connection_id = Some(collaborator.connection_id);
1802            } else {
1803                guest_connection_ids.push(collaborator.connection_id);
1804            }
1805        }
1806    }
1807    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1808
1809    session.executor.record_backtrace();
1810    broadcast(
1811        Some(session.connection_id),
1812        guest_connection_ids,
1813        |connection_id| {
1814            session
1815                .peer
1816                .forward_send(session.connection_id, connection_id, request.clone())
1817        },
1818    );
1819    if host_connection_id != session.connection_id {
1820        session
1821            .peer
1822            .forward_request(session.connection_id, host_connection_id, request.clone())
1823            .await?;
1824    }
1825
1826    response.send(proto::Ack {})?;
1827    Ok(())
1828}
1829
1830async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1831    let project_id = ProjectId::from_proto(request.project_id);
1832    let project_connection_ids = session
1833        .db()
1834        .await
1835        .project_connection_ids(project_id, session.connection_id)
1836        .await?;
1837
1838    broadcast(
1839        Some(session.connection_id),
1840        project_connection_ids.iter().copied(),
1841        |connection_id| {
1842            session
1843                .peer
1844                .forward_send(session.connection_id, connection_id, request.clone())
1845        },
1846    );
1847    Ok(())
1848}
1849
1850async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1851    let project_id = ProjectId::from_proto(request.project_id);
1852    let project_connection_ids = session
1853        .db()
1854        .await
1855        .project_connection_ids(project_id, session.connection_id)
1856        .await?;
1857    broadcast(
1858        Some(session.connection_id),
1859        project_connection_ids.iter().copied(),
1860        |connection_id| {
1861            session
1862                .peer
1863                .forward_send(session.connection_id, connection_id, request.clone())
1864        },
1865    );
1866    Ok(())
1867}
1868
1869async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1870    broadcast_project_message(request.project_id, request, session).await
1871}
1872
1873async fn broadcast_project_message<T: EnvelopedMessage>(
1874    project_id: u64,
1875    request: T,
1876    session: Session,
1877) -> Result<()> {
1878    let project_id = ProjectId::from_proto(project_id);
1879    let project_connection_ids = session
1880        .db()
1881        .await
1882        .project_connection_ids(project_id, session.connection_id)
1883        .await?;
1884    broadcast(
1885        Some(session.connection_id),
1886        project_connection_ids.iter().copied(),
1887        |connection_id| {
1888            session
1889                .peer
1890                .forward_send(session.connection_id, connection_id, request.clone())
1891        },
1892    );
1893    Ok(())
1894}
1895
1896async fn follow(
1897    request: proto::Follow,
1898    response: Response<proto::Follow>,
1899    session: Session,
1900) -> Result<()> {
1901    let room_id = RoomId::from_proto(request.room_id);
1902    let project_id = request.project_id.map(ProjectId::from_proto);
1903    let leader_id = request
1904        .leader_id
1905        .ok_or_else(|| anyhow!("invalid leader id"))?
1906        .into();
1907    let follower_id = session.connection_id;
1908
1909    session
1910        .db()
1911        .await
1912        .check_room_participants(room_id, leader_id, session.connection_id)
1913        .await?;
1914
1915    let response_payload = session
1916        .peer
1917        .forward_request(session.connection_id, leader_id, request)
1918        .await?;
1919    response.send(response_payload)?;
1920
1921    if let Some(project_id) = project_id {
1922        let room = session
1923            .db()
1924            .await
1925            .follow(room_id, project_id, leader_id, follower_id)
1926            .await?;
1927        room_updated(&room, &session.peer);
1928    }
1929
1930    Ok(())
1931}
1932
1933async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1934    let room_id = RoomId::from_proto(request.room_id);
1935    let project_id = request.project_id.map(ProjectId::from_proto);
1936    let leader_id = request
1937        .leader_id
1938        .ok_or_else(|| anyhow!("invalid leader id"))?
1939        .into();
1940    let follower_id = session.connection_id;
1941
1942    session
1943        .db()
1944        .await
1945        .check_room_participants(room_id, leader_id, session.connection_id)
1946        .await?;
1947
1948    session
1949        .peer
1950        .forward_send(session.connection_id, leader_id, request)?;
1951
1952    if let Some(project_id) = project_id {
1953        let room = session
1954            .db()
1955            .await
1956            .unfollow(room_id, project_id, leader_id, follower_id)
1957            .await?;
1958        room_updated(&room, &session.peer);
1959    }
1960
1961    Ok(())
1962}
1963
1964async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1965    let room_id = RoomId::from_proto(request.room_id);
1966    let database = session.db.lock().await;
1967
1968    let connection_ids = if let Some(project_id) = request.project_id {
1969        let project_id = ProjectId::from_proto(project_id);
1970        database
1971            .project_connection_ids(project_id, session.connection_id)
1972            .await?
1973    } else {
1974        database
1975            .room_connection_ids(room_id, session.connection_id)
1976            .await?
1977    };
1978
1979    // For now, don't send view update messages back to that view's current leader.
1980    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1981        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1982        _ => None,
1983    });
1984
1985    for follower_peer_id in request.follower_ids.iter().copied() {
1986        let follower_connection_id = follower_peer_id.into();
1987        if Some(follower_peer_id) != connection_id_to_omit
1988            && connection_ids.contains(&follower_connection_id)
1989        {
1990            session.peer.forward_send(
1991                session.connection_id,
1992                follower_connection_id,
1993                request.clone(),
1994            )?;
1995        }
1996    }
1997    Ok(())
1998}
1999
2000async fn get_users(
2001    request: proto::GetUsers,
2002    response: Response<proto::GetUsers>,
2003    session: Session,
2004) -> Result<()> {
2005    let user_ids = request
2006        .user_ids
2007        .into_iter()
2008        .map(UserId::from_proto)
2009        .collect();
2010    let users = session
2011        .db()
2012        .await
2013        .get_users_by_ids(user_ids)
2014        .await?
2015        .into_iter()
2016        .map(|user| proto::User {
2017            id: user.id.to_proto(),
2018            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2019            github_login: user.github_login,
2020        })
2021        .collect();
2022    response.send(proto::UsersResponse { users })?;
2023    Ok(())
2024}
2025
2026async fn fuzzy_search_users(
2027    request: proto::FuzzySearchUsers,
2028    response: Response<proto::FuzzySearchUsers>,
2029    session: Session,
2030) -> Result<()> {
2031    let query = request.query;
2032    let users = match query.len() {
2033        0 => vec![],
2034        1 | 2 => session
2035            .db()
2036            .await
2037            .get_user_by_github_login(&query)
2038            .await?
2039            .into_iter()
2040            .collect(),
2041        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2042    };
2043    let users = users
2044        .into_iter()
2045        .filter(|user| user.id != session.user_id)
2046        .map(|user| proto::User {
2047            id: user.id.to_proto(),
2048            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2049            github_login: user.github_login,
2050        })
2051        .collect();
2052    response.send(proto::UsersResponse { users })?;
2053    Ok(())
2054}
2055
2056async fn request_contact(
2057    request: proto::RequestContact,
2058    response: Response<proto::RequestContact>,
2059    session: Session,
2060) -> Result<()> {
2061    let requester_id = session.user_id;
2062    let responder_id = UserId::from_proto(request.responder_id);
2063    if requester_id == responder_id {
2064        return Err(anyhow!("cannot add yourself as a contact"))?;
2065    }
2066
2067    session
2068        .db()
2069        .await
2070        .send_contact_request(requester_id, responder_id)
2071        .await?;
2072
2073    // Update outgoing contact requests of requester
2074    let mut update = proto::UpdateContacts::default();
2075    update.outgoing_requests.push(responder_id.to_proto());
2076    for connection_id in session
2077        .connection_pool()
2078        .await
2079        .user_connection_ids(requester_id)
2080    {
2081        session.peer.send(connection_id, update.clone())?;
2082    }
2083
2084    // Update incoming contact requests of responder
2085    let mut update = proto::UpdateContacts::default();
2086    update
2087        .incoming_requests
2088        .push(proto::IncomingContactRequest {
2089            requester_id: requester_id.to_proto(),
2090            should_notify: true,
2091        });
2092    for connection_id in session
2093        .connection_pool()
2094        .await
2095        .user_connection_ids(responder_id)
2096    {
2097        session.peer.send(connection_id, update.clone())?;
2098    }
2099
2100    response.send(proto::Ack {})?;
2101    Ok(())
2102}
2103
2104async fn respond_to_contact_request(
2105    request: proto::RespondToContactRequest,
2106    response: Response<proto::RespondToContactRequest>,
2107    session: Session,
2108) -> Result<()> {
2109    let responder_id = session.user_id;
2110    let requester_id = UserId::from_proto(request.requester_id);
2111    let db = session.db().await;
2112    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2113        db.dismiss_contact_notification(responder_id, requester_id)
2114            .await?;
2115    } else {
2116        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2117
2118        db.respond_to_contact_request(responder_id, requester_id, accept)
2119            .await?;
2120        let requester_busy = db.is_user_busy(requester_id).await?;
2121        let responder_busy = db.is_user_busy(responder_id).await?;
2122
2123        let pool = session.connection_pool().await;
2124        // Update responder with new contact
2125        let mut update = proto::UpdateContacts::default();
2126        if accept {
2127            update
2128                .contacts
2129                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2130        }
2131        update
2132            .remove_incoming_requests
2133            .push(requester_id.to_proto());
2134        for connection_id in pool.user_connection_ids(responder_id) {
2135            session.peer.send(connection_id, update.clone())?;
2136        }
2137
2138        // Update requester with new contact
2139        let mut update = proto::UpdateContacts::default();
2140        if accept {
2141            update
2142                .contacts
2143                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2144        }
2145        update
2146            .remove_outgoing_requests
2147            .push(responder_id.to_proto());
2148        for connection_id in pool.user_connection_ids(requester_id) {
2149            session.peer.send(connection_id, update.clone())?;
2150        }
2151    }
2152
2153    response.send(proto::Ack {})?;
2154    Ok(())
2155}
2156
2157async fn remove_contact(
2158    request: proto::RemoveContact,
2159    response: Response<proto::RemoveContact>,
2160    session: Session,
2161) -> Result<()> {
2162    let requester_id = session.user_id;
2163    let responder_id = UserId::from_proto(request.user_id);
2164    let db = session.db().await;
2165    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2166
2167    let pool = session.connection_pool().await;
2168    // Update outgoing contact requests of requester
2169    let mut update = proto::UpdateContacts::default();
2170    if contact_accepted {
2171        update.remove_contacts.push(responder_id.to_proto());
2172    } else {
2173        update
2174            .remove_outgoing_requests
2175            .push(responder_id.to_proto());
2176    }
2177    for connection_id in pool.user_connection_ids(requester_id) {
2178        session.peer.send(connection_id, update.clone())?;
2179    }
2180
2181    // Update incoming contact requests of responder
2182    let mut update = proto::UpdateContacts::default();
2183    if contact_accepted {
2184        update.remove_contacts.push(requester_id.to_proto());
2185    } else {
2186        update
2187            .remove_incoming_requests
2188            .push(requester_id.to_proto());
2189    }
2190    for connection_id in pool.user_connection_ids(responder_id) {
2191        session.peer.send(connection_id, update.clone())?;
2192    }
2193
2194    response.send(proto::Ack {})?;
2195    Ok(())
2196}
2197
2198async fn create_channel(
2199    request: proto::CreateChannel,
2200    response: Response<proto::CreateChannel>,
2201    session: Session,
2202) -> Result<()> {
2203    let db = session.db().await;
2204
2205    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2206    let id = db
2207        .create_channel(&request.name, parent_id, session.user_id)
2208        .await?;
2209
2210    let channel = proto::Channel {
2211        id: id.to_proto(),
2212        name: request.name,
2213    };
2214
2215    response.send(proto::CreateChannelResponse {
2216        channel: Some(channel.clone()),
2217        parent_id: request.parent_id,
2218    })?;
2219
2220    let Some(parent_id) = parent_id else {
2221        return Ok(());
2222    };
2223
2224    let update = proto::UpdateChannels {
2225        channels: vec![channel],
2226        insert_edge: vec![ChannelEdge {
2227            parent_id: parent_id.to_proto(),
2228            channel_id: id.to_proto(),
2229        }],
2230        ..Default::default()
2231    };
2232
2233    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2234
2235    let connection_pool = session.connection_pool().await;
2236    for user_id in user_ids_to_notify {
2237        for connection_id in connection_pool.user_connection_ids(user_id) {
2238            if user_id == session.user_id {
2239                continue;
2240            }
2241            session.peer.send(connection_id, update.clone())?;
2242        }
2243    }
2244
2245    Ok(())
2246}
2247
2248async fn delete_channel(
2249    request: proto::DeleteChannel,
2250    response: Response<proto::DeleteChannel>,
2251    session: Session,
2252) -> Result<()> {
2253    let db = session.db().await;
2254
2255    let channel_id = request.channel_id;
2256    let (removed_channels, member_ids) = db
2257        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2258        .await?;
2259    response.send(proto::Ack {})?;
2260
2261    // Notify members of removed channels
2262    let mut update = proto::UpdateChannels::default();
2263    update
2264        .delete_channels
2265        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2266
2267    let connection_pool = session.connection_pool().await;
2268    for member_id in member_ids {
2269        for connection_id in connection_pool.user_connection_ids(member_id) {
2270            session.peer.send(connection_id, update.clone())?;
2271        }
2272    }
2273
2274    Ok(())
2275}
2276
2277async fn invite_channel_member(
2278    request: proto::InviteChannelMember,
2279    response: Response<proto::InviteChannelMember>,
2280    session: Session,
2281) -> Result<()> {
2282    let db = session.db().await;
2283    let channel_id = ChannelId::from_proto(request.channel_id);
2284    let invitee_id = UserId::from_proto(request.user_id);
2285    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2286        .await?;
2287
2288    let (channel, _) = db
2289        .get_channel(channel_id, session.user_id)
2290        .await?
2291        .ok_or_else(|| anyhow!("channel not found"))?;
2292
2293    let mut update = proto::UpdateChannels::default();
2294    update.channel_invitations.push(proto::Channel {
2295        id: channel.id.to_proto(),
2296        name: channel.name,
2297    });
2298    for connection_id in session
2299        .connection_pool()
2300        .await
2301        .user_connection_ids(invitee_id)
2302    {
2303        session.peer.send(connection_id, update.clone())?;
2304    }
2305
2306    response.send(proto::Ack {})?;
2307    Ok(())
2308}
2309
2310async fn remove_channel_member(
2311    request: proto::RemoveChannelMember,
2312    response: Response<proto::RemoveChannelMember>,
2313    session: Session,
2314) -> Result<()> {
2315    let db = session.db().await;
2316    let channel_id = ChannelId::from_proto(request.channel_id);
2317    let member_id = UserId::from_proto(request.user_id);
2318
2319    db.remove_channel_member(channel_id, member_id, session.user_id)
2320        .await?;
2321
2322    let mut update = proto::UpdateChannels::default();
2323    update.delete_channels.push(channel_id.to_proto());
2324
2325    for connection_id in session
2326        .connection_pool()
2327        .await
2328        .user_connection_ids(member_id)
2329    {
2330        session.peer.send(connection_id, update.clone())?;
2331    }
2332
2333    response.send(proto::Ack {})?;
2334    Ok(())
2335}
2336
2337async fn set_channel_member_admin(
2338    request: proto::SetChannelMemberAdmin,
2339    response: Response<proto::SetChannelMemberAdmin>,
2340    session: Session,
2341) -> Result<()> {
2342    let db = session.db().await;
2343    let channel_id = ChannelId::from_proto(request.channel_id);
2344    let member_id = UserId::from_proto(request.user_id);
2345    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2346        .await?;
2347
2348    let (channel, has_accepted) = db
2349        .get_channel(channel_id, member_id)
2350        .await?
2351        .ok_or_else(|| anyhow!("channel not found"))?;
2352
2353    let mut update = proto::UpdateChannels::default();
2354    if has_accepted {
2355        update.channel_permissions.push(proto::ChannelPermission {
2356            channel_id: channel.id.to_proto(),
2357            is_admin: request.admin,
2358        });
2359    }
2360
2361    for connection_id in session
2362        .connection_pool()
2363        .await
2364        .user_connection_ids(member_id)
2365    {
2366        session.peer.send(connection_id, update.clone())?;
2367    }
2368
2369    response.send(proto::Ack {})?;
2370    Ok(())
2371}
2372
2373async fn rename_channel(
2374    request: proto::RenameChannel,
2375    response: Response<proto::RenameChannel>,
2376    session: Session,
2377) -> Result<()> {
2378    let db = session.db().await;
2379    let channel_id = ChannelId::from_proto(request.channel_id);
2380    let new_name = db
2381        .rename_channel(channel_id, session.user_id, &request.name)
2382        .await?;
2383
2384    let channel = proto::Channel {
2385        id: request.channel_id,
2386        name: new_name,
2387    };
2388    response.send(proto::RenameChannelResponse {
2389        channel: Some(channel.clone()),
2390    })?;
2391    let mut update = proto::UpdateChannels::default();
2392    update.channels.push(channel);
2393
2394    let member_ids = db.get_channel_members(channel_id).await?;
2395
2396    let connection_pool = session.connection_pool().await;
2397    for member_id in member_ids {
2398        for connection_id in connection_pool.user_connection_ids(member_id) {
2399            session.peer.send(connection_id, update.clone())?;
2400        }
2401    }
2402
2403    Ok(())
2404}
2405
2406async fn link_channel(
2407    request: proto::LinkChannel,
2408    response: Response<proto::LinkChannel>,
2409    session: Session,
2410) -> Result<()> {
2411    let db = session.db().await;
2412    let channel_id = ChannelId::from_proto(request.channel_id);
2413    let to = ChannelId::from_proto(request.to);
2414    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2415
2416    let members = db.get_channel_members(to).await?;
2417    let connection_pool = session.connection_pool().await;
2418    let update = proto::UpdateChannels {
2419        channels: channels_to_send
2420            .channels
2421            .into_iter()
2422            .map(|channel| proto::Channel {
2423                id: channel.id.to_proto(),
2424                name: channel.name,
2425            })
2426            .collect(),
2427        insert_edge: channels_to_send.edges,
2428        ..Default::default()
2429    };
2430    for member_id in members {
2431        for connection_id in connection_pool.user_connection_ids(member_id) {
2432            session.peer.send(connection_id, update.clone())?;
2433        }
2434    }
2435
2436    response.send(Ack {})?;
2437
2438    Ok(())
2439}
2440
2441async fn unlink_channel(
2442    request: proto::UnlinkChannel,
2443    response: Response<proto::UnlinkChannel>,
2444    session: Session,
2445) -> Result<()> {
2446    let db = session.db().await;
2447    let channel_id = ChannelId::from_proto(request.channel_id);
2448    let from = ChannelId::from_proto(request.from);
2449
2450    db.unlink_channel(session.user_id, channel_id, from).await?;
2451
2452    let members = db.get_channel_members(from).await?;
2453
2454    let update = proto::UpdateChannels {
2455        delete_edge: vec![proto::ChannelEdge {
2456            channel_id: channel_id.to_proto(),
2457            parent_id: from.to_proto(),
2458        }],
2459        ..Default::default()
2460    };
2461    let connection_pool = session.connection_pool().await;
2462    for member_id in members {
2463        for connection_id in connection_pool.user_connection_ids(member_id) {
2464            session.peer.send(connection_id, update.clone())?;
2465        }
2466    }
2467
2468    response.send(Ack {})?;
2469
2470    Ok(())
2471}
2472
2473async fn move_channel(
2474    request: proto::MoveChannel,
2475    response: Response<proto::MoveChannel>,
2476    session: Session,
2477) -> Result<()> {
2478    let db = session.db().await;
2479    let channel_id = ChannelId::from_proto(request.channel_id);
2480    let from_parent = ChannelId::from_proto(request.from);
2481    let to = ChannelId::from_proto(request.to);
2482
2483    let channels_to_send = db
2484        .move_channel(session.user_id, channel_id, from_parent, to)
2485        .await?;
2486
2487    if channels_to_send.is_empty() {
2488        response.send(Ack {})?;
2489        return Ok(());
2490    }
2491
2492    let members_from = db.get_channel_members(from_parent).await?;
2493    let members_to = db.get_channel_members(to).await?;
2494
2495    let update = proto::UpdateChannels {
2496        delete_edge: vec![proto::ChannelEdge {
2497            channel_id: channel_id.to_proto(),
2498            parent_id: from_parent.to_proto(),
2499        }],
2500        ..Default::default()
2501    };
2502    let connection_pool = session.connection_pool().await;
2503    for member_id in members_from {
2504        for connection_id in connection_pool.user_connection_ids(member_id) {
2505            session.peer.send(connection_id, update.clone())?;
2506        }
2507    }
2508
2509    let update = proto::UpdateChannels {
2510        channels: channels_to_send
2511            .channels
2512            .into_iter()
2513            .map(|channel| proto::Channel {
2514                id: channel.id.to_proto(),
2515                name: channel.name,
2516            })
2517            .collect(),
2518        insert_edge: channels_to_send.edges,
2519        ..Default::default()
2520    };
2521    for member_id in members_to {
2522        for connection_id in connection_pool.user_connection_ids(member_id) {
2523            session.peer.send(connection_id, update.clone())?;
2524        }
2525    }
2526
2527    response.send(Ack {})?;
2528
2529    Ok(())
2530}
2531
2532async fn get_channel_members(
2533    request: proto::GetChannelMembers,
2534    response: Response<proto::GetChannelMembers>,
2535    session: Session,
2536) -> Result<()> {
2537    let db = session.db().await;
2538    let channel_id = ChannelId::from_proto(request.channel_id);
2539    let members = db
2540        .get_channel_member_details(channel_id, session.user_id)
2541        .await?;
2542    response.send(proto::GetChannelMembersResponse { members })?;
2543    Ok(())
2544}
2545
2546async fn respond_to_channel_invite(
2547    request: proto::RespondToChannelInvite,
2548    response: Response<proto::RespondToChannelInvite>,
2549    session: Session,
2550) -> Result<()> {
2551    let db = session.db().await;
2552    let channel_id = ChannelId::from_proto(request.channel_id);
2553    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2554        .await?;
2555
2556    let mut update = proto::UpdateChannels::default();
2557    update
2558        .remove_channel_invitations
2559        .push(channel_id.to_proto());
2560    if request.accept {
2561        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2562        update
2563            .channels
2564            .extend(
2565                result
2566                    .channels
2567                    .channels
2568                    .into_iter()
2569                    .map(|channel| proto::Channel {
2570                        id: channel.id.to_proto(),
2571                        name: channel.name,
2572                    }),
2573            );
2574        update.unseen_channel_messages = result.channel_messages;
2575        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2576        update.insert_edge = result.channels.edges;
2577        update
2578            .channel_participants
2579            .extend(
2580                result
2581                    .channel_participants
2582                    .into_iter()
2583                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2584                        channel_id: channel_id.to_proto(),
2585                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2586                    }),
2587            );
2588        update
2589            .channel_permissions
2590            .extend(
2591                result
2592                    .channels_with_admin_privileges
2593                    .into_iter()
2594                    .map(|channel_id| proto::ChannelPermission {
2595                        channel_id: channel_id.to_proto(),
2596                        is_admin: true,
2597                    }),
2598            );
2599    }
2600    session.peer.send(session.connection_id, update)?;
2601    response.send(proto::Ack {})?;
2602
2603    Ok(())
2604}
2605
2606async fn join_channel(
2607    request: proto::JoinChannel,
2608    response: Response<proto::JoinChannel>,
2609    session: Session,
2610) -> Result<()> {
2611    let channel_id = ChannelId::from_proto(request.channel_id);
2612    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2613
2614    let joined_room = {
2615        leave_room_for_session(&session).await?;
2616        let db = session.db().await;
2617
2618        let room_id = db
2619            .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME)
2620            .await?;
2621
2622        let joined_room = db
2623            .join_room(
2624                room_id,
2625                session.user_id,
2626                session.connection_id,
2627                RELEASE_CHANNEL_NAME.as_str(),
2628            )
2629            .await?;
2630
2631        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2632            let token = live_kit
2633                .room_token(
2634                    &joined_room.room.live_kit_room,
2635                    &session.user_id.to_string(),
2636                )
2637                .trace_err()?;
2638
2639            Some(LiveKitConnectionInfo {
2640                server_url: live_kit.url().into(),
2641                token,
2642            })
2643        });
2644
2645        response.send(proto::JoinRoomResponse {
2646            room: Some(joined_room.room.clone()),
2647            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2648            live_kit_connection_info,
2649        })?;
2650
2651        room_updated(&joined_room.room, &session.peer);
2652
2653        joined_room.into_inner()
2654    };
2655
2656    channel_updated(
2657        channel_id,
2658        &joined_room.room,
2659        &joined_room.channel_members,
2660        &session.peer,
2661        &*session.connection_pool().await,
2662    );
2663
2664    update_user_contacts(session.user_id, &session).await?;
2665
2666    Ok(())
2667}
2668
2669async fn join_channel_buffer(
2670    request: proto::JoinChannelBuffer,
2671    response: Response<proto::JoinChannelBuffer>,
2672    session: Session,
2673) -> Result<()> {
2674    let db = session.db().await;
2675    let channel_id = ChannelId::from_proto(request.channel_id);
2676
2677    let open_response = db
2678        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2679        .await?;
2680
2681    let collaborators = open_response.collaborators.clone();
2682    response.send(open_response)?;
2683
2684    let update = UpdateChannelBufferCollaborators {
2685        channel_id: channel_id.to_proto(),
2686        collaborators: collaborators.clone(),
2687    };
2688    channel_buffer_updated(
2689        session.connection_id,
2690        collaborators
2691            .iter()
2692            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2693        &update,
2694        &session.peer,
2695    );
2696
2697    Ok(())
2698}
2699
2700async fn update_channel_buffer(
2701    request: proto::UpdateChannelBuffer,
2702    session: Session,
2703) -> Result<()> {
2704    let db = session.db().await;
2705    let channel_id = ChannelId::from_proto(request.channel_id);
2706
2707    let (collaborators, non_collaborators, epoch, version) = db
2708        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2709        .await?;
2710
2711    channel_buffer_updated(
2712        session.connection_id,
2713        collaborators,
2714        &proto::UpdateChannelBuffer {
2715            channel_id: channel_id.to_proto(),
2716            operations: request.operations,
2717        },
2718        &session.peer,
2719    );
2720
2721    let pool = &*session.connection_pool().await;
2722
2723    broadcast(
2724        None,
2725        non_collaborators
2726            .iter()
2727            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2728        |peer_id| {
2729            session.peer.send(
2730                peer_id.into(),
2731                proto::UpdateChannels {
2732                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2733                        channel_id: channel_id.to_proto(),
2734                        epoch: epoch as u64,
2735                        version: version.clone(),
2736                    }],
2737                    ..Default::default()
2738                },
2739            )
2740        },
2741    );
2742
2743    Ok(())
2744}
2745
2746async fn rejoin_channel_buffers(
2747    request: proto::RejoinChannelBuffers,
2748    response: Response<proto::RejoinChannelBuffers>,
2749    session: Session,
2750) -> Result<()> {
2751    let db = session.db().await;
2752    let buffers = db
2753        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2754        .await?;
2755
2756    for rejoined_buffer in &buffers {
2757        let collaborators_to_notify = rejoined_buffer
2758            .buffer
2759            .collaborators
2760            .iter()
2761            .filter_map(|c| Some(c.peer_id?.into()));
2762        channel_buffer_updated(
2763            session.connection_id,
2764            collaborators_to_notify,
2765            &proto::UpdateChannelBufferCollaborators {
2766                channel_id: rejoined_buffer.buffer.channel_id,
2767                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2768            },
2769            &session.peer,
2770        );
2771    }
2772
2773    response.send(proto::RejoinChannelBuffersResponse {
2774        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2775    })?;
2776
2777    Ok(())
2778}
2779
2780async fn leave_channel_buffer(
2781    request: proto::LeaveChannelBuffer,
2782    response: Response<proto::LeaveChannelBuffer>,
2783    session: Session,
2784) -> Result<()> {
2785    let db = session.db().await;
2786    let channel_id = ChannelId::from_proto(request.channel_id);
2787
2788    let left_buffer = db
2789        .leave_channel_buffer(channel_id, session.connection_id)
2790        .await?;
2791
2792    response.send(Ack {})?;
2793
2794    channel_buffer_updated(
2795        session.connection_id,
2796        left_buffer.connections,
2797        &proto::UpdateChannelBufferCollaborators {
2798            channel_id: channel_id.to_proto(),
2799            collaborators: left_buffer.collaborators,
2800        },
2801        &session.peer,
2802    );
2803
2804    Ok(())
2805}
2806
2807fn channel_buffer_updated<T: EnvelopedMessage>(
2808    sender_id: ConnectionId,
2809    collaborators: impl IntoIterator<Item = ConnectionId>,
2810    message: &T,
2811    peer: &Peer,
2812) {
2813    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2814        peer.send(peer_id.into(), message.clone())
2815    });
2816}
2817
2818async fn send_channel_message(
2819    request: proto::SendChannelMessage,
2820    response: Response<proto::SendChannelMessage>,
2821    session: Session,
2822) -> Result<()> {
2823    // Validate the message body.
2824    let body = request.body.trim().to_string();
2825    if body.len() > MAX_MESSAGE_LEN {
2826        return Err(anyhow!("message is too long"))?;
2827    }
2828    if body.is_empty() {
2829        return Err(anyhow!("message can't be blank"))?;
2830    }
2831
2832    let timestamp = OffsetDateTime::now_utc();
2833    let nonce = request
2834        .nonce
2835        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2836
2837    let channel_id = ChannelId::from_proto(request.channel_id);
2838    let (message_id, connection_ids, non_participants) = session
2839        .db()
2840        .await
2841        .create_channel_message(
2842            channel_id,
2843            session.user_id,
2844            &body,
2845            timestamp,
2846            nonce.clone().into(),
2847        )
2848        .await?;
2849    let message = proto::ChannelMessage {
2850        sender_id: session.user_id.to_proto(),
2851        id: message_id.to_proto(),
2852        body,
2853        timestamp: timestamp.unix_timestamp() as u64,
2854        nonce: Some(nonce),
2855    };
2856    broadcast(Some(session.connection_id), connection_ids, |connection| {
2857        session.peer.send(
2858            connection,
2859            proto::ChannelMessageSent {
2860                channel_id: channel_id.to_proto(),
2861                message: Some(message.clone()),
2862            },
2863        )
2864    });
2865    response.send(proto::SendChannelMessageResponse {
2866        message: Some(message),
2867    })?;
2868
2869    let pool = &*session.connection_pool().await;
2870    broadcast(
2871        None,
2872        non_participants
2873            .iter()
2874            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2875        |peer_id| {
2876            session.peer.send(
2877                peer_id.into(),
2878                proto::UpdateChannels {
2879                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2880                        channel_id: channel_id.to_proto(),
2881                        message_id: message_id.to_proto(),
2882                    }],
2883                    ..Default::default()
2884                },
2885            )
2886        },
2887    );
2888
2889    Ok(())
2890}
2891
2892async fn remove_channel_message(
2893    request: proto::RemoveChannelMessage,
2894    response: Response<proto::RemoveChannelMessage>,
2895    session: Session,
2896) -> Result<()> {
2897    let channel_id = ChannelId::from_proto(request.channel_id);
2898    let message_id = MessageId::from_proto(request.message_id);
2899    let connection_ids = session
2900        .db()
2901        .await
2902        .remove_channel_message(channel_id, message_id, session.user_id)
2903        .await?;
2904    broadcast(Some(session.connection_id), connection_ids, |connection| {
2905        session.peer.send(connection, request.clone())
2906    });
2907    response.send(proto::Ack {})?;
2908    Ok(())
2909}
2910
2911async fn acknowledge_channel_message(
2912    request: proto::AckChannelMessage,
2913    session: Session,
2914) -> Result<()> {
2915    let channel_id = ChannelId::from_proto(request.channel_id);
2916    let message_id = MessageId::from_proto(request.message_id);
2917    session
2918        .db()
2919        .await
2920        .observe_channel_message(channel_id, session.user_id, message_id)
2921        .await?;
2922    Ok(())
2923}
2924
2925async fn acknowledge_buffer_version(
2926    request: proto::AckBufferOperation,
2927    session: Session,
2928) -> Result<()> {
2929    let buffer_id = BufferId::from_proto(request.buffer_id);
2930    session
2931        .db()
2932        .await
2933        .observe_buffer_version(
2934            buffer_id,
2935            session.user_id,
2936            request.epoch as i32,
2937            &request.version,
2938        )
2939        .await?;
2940    Ok(())
2941}
2942
2943async fn join_channel_chat(
2944    request: proto::JoinChannelChat,
2945    response: Response<proto::JoinChannelChat>,
2946    session: Session,
2947) -> Result<()> {
2948    let channel_id = ChannelId::from_proto(request.channel_id);
2949
2950    let db = session.db().await;
2951    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2952        .await?;
2953    let messages = db
2954        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2955        .await?;
2956    response.send(proto::JoinChannelChatResponse {
2957        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2958        messages,
2959    })?;
2960    Ok(())
2961}
2962
2963async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2964    let channel_id = ChannelId::from_proto(request.channel_id);
2965    session
2966        .db()
2967        .await
2968        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2969        .await?;
2970    Ok(())
2971}
2972
2973async fn get_channel_messages(
2974    request: proto::GetChannelMessages,
2975    response: Response<proto::GetChannelMessages>,
2976    session: Session,
2977) -> Result<()> {
2978    let channel_id = ChannelId::from_proto(request.channel_id);
2979    let messages = session
2980        .db()
2981        .await
2982        .get_channel_messages(
2983            channel_id,
2984            session.user_id,
2985            MESSAGE_COUNT_PER_PAGE,
2986            Some(MessageId::from_proto(request.before_message_id)),
2987        )
2988        .await?;
2989    response.send(proto::GetChannelMessagesResponse {
2990        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2991        messages,
2992    })?;
2993    Ok(())
2994}
2995
2996async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2997    let project_id = ProjectId::from_proto(request.project_id);
2998    let project_connection_ids = session
2999        .db()
3000        .await
3001        .project_connection_ids(project_id, session.connection_id)
3002        .await?;
3003    broadcast(
3004        Some(session.connection_id),
3005        project_connection_ids.iter().copied(),
3006        |connection_id| {
3007            session
3008                .peer
3009                .forward_send(session.connection_id, connection_id, request.clone())
3010        },
3011    );
3012    Ok(())
3013}
3014
3015async fn get_private_user_info(
3016    _request: proto::GetPrivateUserInfo,
3017    response: Response<proto::GetPrivateUserInfo>,
3018    session: Session,
3019) -> Result<()> {
3020    let db = session.db().await;
3021
3022    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3023    let user = db
3024        .get_user_by_id(session.user_id)
3025        .await?
3026        .ok_or_else(|| anyhow!("user not found"))?;
3027    let flags = db.get_user_flags(session.user_id).await?;
3028
3029    response.send(proto::GetPrivateUserInfoResponse {
3030        metrics_id,
3031        staff: user.admin,
3032        flags,
3033    })?;
3034    Ok(())
3035}
3036
3037fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3038    match message {
3039        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3040        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3041        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3042        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3043        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3044            code: frame.code.into(),
3045            reason: frame.reason,
3046        })),
3047    }
3048}
3049
3050fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3051    match message {
3052        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3053        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3054        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3055        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3056        AxumMessage::Close(frame) => {
3057            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3058                code: frame.code.into(),
3059                reason: frame.reason,
3060            }))
3061        }
3062    }
3063}
3064
3065fn build_initial_channels_update(
3066    channels: ChannelsForUser,
3067    channel_invites: Vec<db::Channel>,
3068) -> proto::UpdateChannels {
3069    let mut update = proto::UpdateChannels::default();
3070
3071    for channel in channels.channels.channels {
3072        update.channels.push(proto::Channel {
3073            id: channel.id.to_proto(),
3074            name: channel.name,
3075        });
3076    }
3077
3078    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3079    update.unseen_channel_messages = channels.channel_messages;
3080    update.insert_edge = channels.channels.edges;
3081
3082    for (channel_id, participants) in channels.channel_participants {
3083        update
3084            .channel_participants
3085            .push(proto::ChannelParticipants {
3086                channel_id: channel_id.to_proto(),
3087                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3088            });
3089    }
3090
3091    update
3092        .channel_permissions
3093        .extend(
3094            channels
3095                .channels_with_admin_privileges
3096                .into_iter()
3097                .map(|id| proto::ChannelPermission {
3098                    channel_id: id.to_proto(),
3099                    is_admin: true,
3100                }),
3101        );
3102
3103    for channel in channel_invites {
3104        update.channel_invitations.push(proto::Channel {
3105            id: channel.id.to_proto(),
3106            name: channel.name,
3107        });
3108    }
3109
3110    update
3111}
3112
3113fn build_initial_contacts_update(
3114    contacts: Vec<db::Contact>,
3115    pool: &ConnectionPool,
3116) -> proto::UpdateContacts {
3117    let mut update = proto::UpdateContacts::default();
3118
3119    for contact in contacts {
3120        match contact {
3121            db::Contact::Accepted {
3122                user_id,
3123                should_notify,
3124                busy,
3125            } => {
3126                update
3127                    .contacts
3128                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3129            }
3130            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3131            db::Contact::Incoming {
3132                user_id,
3133                should_notify,
3134            } => update
3135                .incoming_requests
3136                .push(proto::IncomingContactRequest {
3137                    requester_id: user_id.to_proto(),
3138                    should_notify,
3139                }),
3140        }
3141    }
3142
3143    update
3144}
3145
3146fn contact_for_user(
3147    user_id: UserId,
3148    should_notify: bool,
3149    busy: bool,
3150    pool: &ConnectionPool,
3151) -> proto::Contact {
3152    proto::Contact {
3153        user_id: user_id.to_proto(),
3154        online: pool.is_user_online(user_id),
3155        busy,
3156        should_notify,
3157    }
3158}
3159
3160fn room_updated(room: &proto::Room, peer: &Peer) {
3161    broadcast(
3162        None,
3163        room.participants
3164            .iter()
3165            .filter_map(|participant| Some(participant.peer_id?.into())),
3166        |peer_id| {
3167            peer.send(
3168                peer_id.into(),
3169                proto::RoomUpdated {
3170                    room: Some(room.clone()),
3171                },
3172            )
3173        },
3174    );
3175}
3176
3177fn channel_updated(
3178    channel_id: ChannelId,
3179    room: &proto::Room,
3180    channel_members: &[UserId],
3181    peer: &Peer,
3182    pool: &ConnectionPool,
3183) {
3184    let participants = room
3185        .participants
3186        .iter()
3187        .map(|p| p.user_id)
3188        .collect::<Vec<_>>();
3189
3190    broadcast(
3191        None,
3192        channel_members
3193            .iter()
3194            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3195        |peer_id| {
3196            peer.send(
3197                peer_id.into(),
3198                proto::UpdateChannels {
3199                    channel_participants: vec![proto::ChannelParticipants {
3200                        channel_id: channel_id.to_proto(),
3201                        participant_user_ids: participants.clone(),
3202                    }],
3203                    ..Default::default()
3204                },
3205            )
3206        },
3207    );
3208}
3209
3210async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3211    let db = session.db().await;
3212
3213    let contacts = db.get_contacts(user_id).await?;
3214    let busy = db.is_user_busy(user_id).await?;
3215
3216    let pool = session.connection_pool().await;
3217    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3218    for contact in contacts {
3219        if let db::Contact::Accepted {
3220            user_id: contact_user_id,
3221            ..
3222        } = contact
3223        {
3224            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3225                session
3226                    .peer
3227                    .send(
3228                        contact_conn_id,
3229                        proto::UpdateContacts {
3230                            contacts: vec![updated_contact.clone()],
3231                            remove_contacts: Default::default(),
3232                            incoming_requests: Default::default(),
3233                            remove_incoming_requests: Default::default(),
3234                            outgoing_requests: Default::default(),
3235                            remove_outgoing_requests: Default::default(),
3236                        },
3237                    )
3238                    .trace_err();
3239            }
3240        }
3241    }
3242    Ok(())
3243}
3244
3245async fn leave_room_for_session(session: &Session) -> Result<()> {
3246    let mut contacts_to_update = HashSet::default();
3247
3248    let room_id;
3249    let canceled_calls_to_user_ids;
3250    let live_kit_room;
3251    let delete_live_kit_room;
3252    let room;
3253    let channel_members;
3254    let channel_id;
3255
3256    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3257        contacts_to_update.insert(session.user_id);
3258
3259        for project in left_room.left_projects.values() {
3260            project_left(project, session);
3261        }
3262
3263        room_id = RoomId::from_proto(left_room.room.id);
3264        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3265        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3266        delete_live_kit_room = left_room.deleted;
3267        room = mem::take(&mut left_room.room);
3268        channel_members = mem::take(&mut left_room.channel_members);
3269        channel_id = left_room.channel_id;
3270
3271        room_updated(&room, &session.peer);
3272    } else {
3273        return Ok(());
3274    }
3275
3276    if let Some(channel_id) = channel_id {
3277        channel_updated(
3278            channel_id,
3279            &room,
3280            &channel_members,
3281            &session.peer,
3282            &*session.connection_pool().await,
3283        );
3284    }
3285
3286    {
3287        let pool = session.connection_pool().await;
3288        for canceled_user_id in canceled_calls_to_user_ids {
3289            for connection_id in pool.user_connection_ids(canceled_user_id) {
3290                session
3291                    .peer
3292                    .send(
3293                        connection_id,
3294                        proto::CallCanceled {
3295                            room_id: room_id.to_proto(),
3296                        },
3297                    )
3298                    .trace_err();
3299            }
3300            contacts_to_update.insert(canceled_user_id);
3301        }
3302    }
3303
3304    for contact_user_id in contacts_to_update {
3305        update_user_contacts(contact_user_id, &session).await?;
3306    }
3307
3308    if let Some(live_kit) = session.live_kit_client.as_ref() {
3309        live_kit
3310            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3311            .await
3312            .trace_err();
3313
3314        if delete_live_kit_room {
3315            live_kit.delete_room(live_kit_room).await.trace_err();
3316        }
3317    }
3318
3319    Ok(())
3320}
3321
3322async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3323    let left_channel_buffers = session
3324        .db()
3325        .await
3326        .leave_channel_buffers(session.connection_id)
3327        .await?;
3328
3329    for left_buffer in left_channel_buffers {
3330        channel_buffer_updated(
3331            session.connection_id,
3332            left_buffer.connections,
3333            &proto::UpdateChannelBufferCollaborators {
3334                channel_id: left_buffer.channel_id.to_proto(),
3335                collaborators: left_buffer.collaborators,
3336            },
3337            &session.peer,
3338        );
3339    }
3340
3341    Ok(())
3342}
3343
3344fn project_left(project: &db::LeftProject, session: &Session) {
3345    for connection_id in &project.connection_ids {
3346        if project.host_user_id == session.user_id {
3347            session
3348                .peer
3349                .send(
3350                    *connection_id,
3351                    proto::UnshareProject {
3352                        project_id: project.id.to_proto(),
3353                    },
3354                )
3355                .trace_err();
3356        } else {
3357            session
3358                .peer
3359                .send(
3360                    *connection_id,
3361                    proto::RemoveProjectCollaborator {
3362                        project_id: project.id.to_proto(),
3363                        peer_id: Some(session.connection_id.into()),
3364                    },
3365                )
3366                .trace_err();
3367        }
3368    }
3369}
3370
3371pub trait ResultExt {
3372    type Ok;
3373
3374    fn trace_err(self) -> Option<Self::Ok>;
3375}
3376
3377impl<T, E> ResultExt for Result<T, E>
3378where
3379    E: std::fmt::Debug,
3380{
3381    type Ok = T;
3382
3383    fn trace_err(self) -> Option<T> {
3384        match self {
3385            Ok(value) => Some(value),
3386            Err(error) => {
3387                tracing::error!("{:?}", error);
3388                None
3389            }
3390        }
3391    }
3392}