rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
   7        ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73
  74lazy_static! {
  75    static ref METRIC_CONNECTIONS: IntGauge =
  76        register_int_gauge!("connections", "number of connections").unwrap();
  77    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  78        "shared_projects",
  79        "number of open projects with one or more guests"
  80    )
  81    .unwrap();
  82}
  83
  84type MessageHandler =
  85    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  86
  87struct Response<R> {
  88    peer: Arc<Peer>,
  89    receipt: Receipt<R>,
  90    responded: Arc<AtomicBool>,
  91}
  92
  93impl<R: RequestMessage> Response<R> {
  94    fn send(self, payload: R::Response) -> Result<()> {
  95        self.responded.store(true, SeqCst);
  96        self.peer.respond(self.receipt, payload)?;
  97        Ok(())
  98    }
  99}
 100
 101#[derive(Clone)]
 102struct Session {
 103    user_id: UserId,
 104    connection_id: ConnectionId,
 105    db: Arc<tokio::sync::Mutex<DbHandle>>,
 106    peer: Arc<Peer>,
 107    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 108    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 109    executor: Executor,
 110}
 111
 112impl Session {
 113    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.db.lock().await;
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        guard
 120    }
 121
 122    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.connection_pool.lock();
 126        ConnectionPoolGuard {
 127            guard,
 128            _not_send: PhantomData,
 129        }
 130    }
 131}
 132
 133impl fmt::Debug for Session {
 134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 135        f.debug_struct("Session")
 136            .field("user_id", &self.user_id)
 137            .field("connection_id", &self.connection_id)
 138            .finish()
 139    }
 140}
 141
 142struct DbHandle(Arc<Database>);
 143
 144impl Deref for DbHandle {
 145    type Target = Database;
 146
 147    fn deref(&self) -> &Self::Target {
 148        self.0.as_ref()
 149    }
 150}
 151
 152pub struct Server {
 153    id: parking_lot::Mutex<ServerId>,
 154    peer: Arc<Peer>,
 155    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 156    app_state: Arc<AppState>,
 157    executor: Executor,
 158    handlers: HashMap<TypeId, MessageHandler>,
 159    teardown: watch::Sender<()>,
 160}
 161
 162pub(crate) struct ConnectionPoolGuard<'a> {
 163    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 164    _not_send: PhantomData<Rc<()>>,
 165}
 166
 167#[derive(Serialize)]
 168pub struct ServerSnapshot<'a> {
 169    peer: &'a Peer,
 170    #[serde(serialize_with = "serialize_deref")]
 171    connection_pool: ConnectionPoolGuard<'a>,
 172}
 173
 174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 175where
 176    S: Serializer,
 177    T: Deref<Target = U>,
 178    U: Serialize,
 179{
 180    Serialize::serialize(value.deref(), serializer)
 181}
 182
 183impl Server {
 184    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 185        let mut server = Self {
 186            id: parking_lot::Mutex::new(id),
 187            peer: Peer::new(id.0 as u32),
 188            app_state,
 189            executor,
 190            connection_pool: Default::default(),
 191            handlers: Default::default(),
 192            teardown: watch::channel(()).0,
 193        };
 194
 195        server
 196            .add_request_handler(ping)
 197            .add_request_handler(create_room)
 198            .add_request_handler(join_room)
 199            .add_request_handler(rejoin_room)
 200            .add_request_handler(leave_room)
 201            .add_request_handler(call)
 202            .add_request_handler(cancel_call)
 203            .add_message_handler(decline_call)
 204            .add_request_handler(update_participant_location)
 205            .add_request_handler(share_project)
 206            .add_message_handler(unshare_project)
 207            .add_request_handler(join_project)
 208            .add_message_handler(leave_project)
 209            .add_request_handler(update_project)
 210            .add_request_handler(update_worktree)
 211            .add_message_handler(start_language_server)
 212            .add_message_handler(update_language_server)
 213            .add_message_handler(update_diagnostic_summary)
 214            .add_message_handler(update_worktree_settings)
 215            .add_message_handler(refresh_inlay_hints)
 216            .add_request_handler(forward_project_request::<proto::GetHover>)
 217            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetReferences>)
 220            .add_request_handler(forward_project_request::<proto::SearchProject>)
 221            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 222            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 226            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 227            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 228            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 229            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 230            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 231            .add_request_handler(forward_project_request::<proto::PerformRename>)
 232            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 233            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 234            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 235            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 241            .add_request_handler(forward_project_request::<proto::InlayHints>)
 242            .add_message_handler(create_buffer_for_peer)
 243            .add_request_handler(update_buffer)
 244            .add_message_handler(update_buffer_file)
 245            .add_message_handler(buffer_reloaded)
 246            .add_message_handler(buffer_saved)
 247            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 248            .add_request_handler(get_users)
 249            .add_request_handler(fuzzy_search_users)
 250            .add_request_handler(request_contact)
 251            .add_request_handler(remove_contact)
 252            .add_request_handler(respond_to_contact_request)
 253            .add_request_handler(create_channel)
 254            .add_request_handler(delete_channel)
 255            .add_request_handler(invite_channel_member)
 256            .add_request_handler(remove_channel_member)
 257            .add_request_handler(set_channel_member_admin)
 258            .add_request_handler(rename_channel)
 259            .add_request_handler(join_channel_buffer)
 260            .add_request_handler(leave_channel_buffer)
 261            .add_message_handler(update_channel_buffer)
 262            .add_request_handler(rejoin_channel_buffers)
 263            .add_request_handler(get_channel_members)
 264            .add_request_handler(respond_to_channel_invite)
 265            .add_request_handler(join_channel)
 266            .add_request_handler(join_channel_chat)
 267            .add_message_handler(leave_channel_chat)
 268            .add_request_handler(send_channel_message)
 269            .add_request_handler(remove_channel_message)
 270            .add_request_handler(get_channel_messages)
 271            .add_request_handler(link_channel)
 272            .add_request_handler(unlink_channel)
 273            .add_request_handler(move_channel)
 274            .add_request_handler(follow)
 275            .add_message_handler(unfollow)
 276            .add_message_handler(update_followers)
 277            .add_message_handler(update_diff_base)
 278            .add_request_handler(get_private_user_info)
 279            .add_message_handler(acknowledge_channel_message)
 280            .add_message_handler(acknowledge_buffer_version);
 281
 282        Arc::new(server)
 283    }
 284
 285    pub async fn start(&self) -> Result<()> {
 286        let server_id = *self.id.lock();
 287        let app_state = self.app_state.clone();
 288        let peer = self.peer.clone();
 289        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 290        let pool = self.connection_pool.clone();
 291        let live_kit_client = self.app_state.live_kit_client.clone();
 292
 293        let span = info_span!("start server");
 294        self.executor.spawn_detached(
 295            async move {
 296                tracing::info!("waiting for cleanup timeout");
 297                timeout.await;
 298                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 299                if let Some((room_ids, channel_ids)) = app_state
 300                    .db
 301                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 302                    .await
 303                    .trace_err()
 304                {
 305                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 306                    tracing::info!(
 307                        stale_channel_buffer_count = channel_ids.len(),
 308                        "retrieved stale channel buffers"
 309                    );
 310
 311                    for channel_id in channel_ids {
 312                        if let Some(refreshed_channel_buffer) = app_state
 313                            .db
 314                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 315                            .await
 316                            .trace_err()
 317                        {
 318                            for connection_id in refreshed_channel_buffer.connection_ids {
 319                                peer.send(
 320                                    connection_id,
 321                                    proto::UpdateChannelBufferCollaborators {
 322                                        channel_id: channel_id.to_proto(),
 323                                        collaborators: refreshed_channel_buffer
 324                                            .collaborators
 325                                            .clone(),
 326                                    },
 327                                )
 328                                .trace_err();
 329                            }
 330                        }
 331                    }
 332
 333                    for room_id in room_ids {
 334                        let mut contacts_to_update = HashSet::default();
 335                        let mut canceled_calls_to_user_ids = Vec::new();
 336                        let mut live_kit_room = String::new();
 337                        let mut delete_live_kit_room = false;
 338
 339                        if let Some(mut refreshed_room) = app_state
 340                            .db
 341                            .clear_stale_room_participants(room_id, server_id)
 342                            .await
 343                            .trace_err()
 344                        {
 345                            tracing::info!(
 346                                room_id = room_id.0,
 347                                new_participant_count = refreshed_room.room.participants.len(),
 348                                "refreshed room"
 349                            );
 350                            room_updated(&refreshed_room.room, &peer);
 351                            if let Some(channel_id) = refreshed_room.channel_id {
 352                                channel_updated(
 353                                    channel_id,
 354                                    &refreshed_room.room,
 355                                    &refreshed_room.channel_members,
 356                                    &peer,
 357                                    &*pool.lock(),
 358                                );
 359                            }
 360                            contacts_to_update
 361                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 362                            contacts_to_update
 363                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 364                            canceled_calls_to_user_ids =
 365                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 366                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 367                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 368                        }
 369
 370                        {
 371                            let pool = pool.lock();
 372                            for canceled_user_id in canceled_calls_to_user_ids {
 373                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 374                                    peer.send(
 375                                        connection_id,
 376                                        proto::CallCanceled {
 377                                            room_id: room_id.to_proto(),
 378                                        },
 379                                    )
 380                                    .trace_err();
 381                                }
 382                            }
 383                        }
 384
 385                        for user_id in contacts_to_update {
 386                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 387                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 388                            if let Some((busy, contacts)) = busy.zip(contacts) {
 389                                let pool = pool.lock();
 390                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 391                                for contact in contacts {
 392                                    if let db::Contact::Accepted {
 393                                        user_id: contact_user_id,
 394                                        ..
 395                                    } = contact
 396                                    {
 397                                        for contact_conn_id in
 398                                            pool.user_connection_ids(contact_user_id)
 399                                        {
 400                                            peer.send(
 401                                                contact_conn_id,
 402                                                proto::UpdateContacts {
 403                                                    contacts: vec![updated_contact.clone()],
 404                                                    remove_contacts: Default::default(),
 405                                                    incoming_requests: Default::default(),
 406                                                    remove_incoming_requests: Default::default(),
 407                                                    outgoing_requests: Default::default(),
 408                                                    remove_outgoing_requests: Default::default(),
 409                                                },
 410                                            )
 411                                            .trace_err();
 412                                        }
 413                                    }
 414                                }
 415                            }
 416                        }
 417
 418                        if let Some(live_kit) = live_kit_client.as_ref() {
 419                            if delete_live_kit_room {
 420                                live_kit.delete_room(live_kit_room).await.trace_err();
 421                            }
 422                        }
 423                    }
 424                }
 425
 426                app_state
 427                    .db
 428                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 429                    .await
 430                    .trace_err();
 431            }
 432            .instrument(span),
 433        );
 434        Ok(())
 435    }
 436
 437    pub fn teardown(&self) {
 438        self.peer.teardown();
 439        self.connection_pool.lock().reset();
 440        let _ = self.teardown.send(());
 441    }
 442
 443    #[cfg(test)]
 444    pub fn reset(&self, id: ServerId) {
 445        self.teardown();
 446        *self.id.lock() = id;
 447        self.peer.reset(id.0 as u32);
 448    }
 449
 450    #[cfg(test)]
 451    pub fn id(&self) -> ServerId {
 452        *self.id.lock()
 453    }
 454
 455    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 456    where
 457        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 458        Fut: 'static + Send + Future<Output = Result<()>>,
 459        M: EnvelopedMessage,
 460    {
 461        let prev_handler = self.handlers.insert(
 462            TypeId::of::<M>(),
 463            Box::new(move |envelope, session| {
 464                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 465                let span = info_span!(
 466                    "handle message",
 467                    payload_type = envelope.payload_type_name()
 468                );
 469                span.in_scope(|| {
 470                    tracing::info!(
 471                        payload_type = envelope.payload_type_name(),
 472                        "message received"
 473                    );
 474                });
 475                let start_time = Instant::now();
 476                let future = (handler)(*envelope, session);
 477                async move {
 478                    let result = future.await;
 479                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 480                    match result {
 481                        Err(error) => {
 482                            tracing::error!(%error, ?duration_ms, "error handling message")
 483                        }
 484                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 485                    }
 486                }
 487                .instrument(span)
 488                .boxed()
 489            }),
 490        );
 491        if prev_handler.is_some() {
 492            panic!("registered a handler for the same message twice");
 493        }
 494        self
 495    }
 496
 497    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 498    where
 499        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 500        Fut: 'static + Send + Future<Output = Result<()>>,
 501        M: EnvelopedMessage,
 502    {
 503        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 504        self
 505    }
 506
 507    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 510        Fut: Send + Future<Output = Result<()>>,
 511        M: RequestMessage,
 512    {
 513        let handler = Arc::new(handler);
 514        self.add_handler(move |envelope, session| {
 515            let receipt = envelope.receipt();
 516            let handler = handler.clone();
 517            async move {
 518                let peer = session.peer.clone();
 519                let responded = Arc::new(AtomicBool::default());
 520                let response = Response {
 521                    peer: peer.clone(),
 522                    responded: responded.clone(),
 523                    receipt,
 524                };
 525                match (handler)(envelope.payload, response, session).await {
 526                    Ok(()) => {
 527                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 528                            Ok(())
 529                        } else {
 530                            Err(anyhow!("handler did not send a response"))?
 531                        }
 532                    }
 533                    Err(error) => {
 534                        peer.respond_with_error(
 535                            receipt,
 536                            proto::Error {
 537                                message: error.to_string(),
 538                            },
 539                        )?;
 540                        Err(error)
 541                    }
 542                }
 543            }
 544        })
 545    }
 546
 547    pub fn handle_connection(
 548        self: &Arc<Self>,
 549        connection: Connection,
 550        address: String,
 551        user: User,
 552        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 553        executor: Executor,
 554    ) -> impl Future<Output = Result<()>> {
 555        let this = self.clone();
 556        let user_id = user.id;
 557        let login = user.github_login;
 558        let span = info_span!("handle connection", %user_id, %login, %address);
 559        let mut teardown = self.teardown.subscribe();
 560        async move {
 561            let (connection_id, handle_io, mut incoming_rx) = this
 562                .peer
 563                .add_connection(connection, {
 564                    let executor = executor.clone();
 565                    move |duration| executor.sleep(duration)
 566                });
 567
 568            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 569            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 570            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 571
 572            if let Some(send_connection_id) = send_connection_id.take() {
 573                let _ = send_connection_id.send(connection_id);
 574            }
 575
 576            if !user.connected_once {
 577                this.peer.send(connection_id, proto::ShowContacts {})?;
 578                this.app_state.db.set_user_connected_once(user_id, true).await?;
 579            }
 580
 581            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 582                this.app_state.db.get_contacts(user_id),
 583                this.app_state.db.get_channels_for_user(user_id),
 584                this.app_state.db.get_channel_invites_for_user(user_id)
 585            ).await?;
 586
 587            {
 588                let mut pool = this.connection_pool.lock();
 589                pool.add_connection(connection_id, user_id, user.admin);
 590                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 591                this.peer.send(connection_id, build_initial_channels_update(
 592                    channels_for_user,
 593                    channel_invites
 594                ))?;
 595            }
 596
 597            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 598                this.peer.send(connection_id, incoming_call)?;
 599            }
 600
 601            let session = Session {
 602                user_id,
 603                connection_id,
 604                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 605                peer: this.peer.clone(),
 606                connection_pool: this.connection_pool.clone(),
 607                live_kit_client: this.app_state.live_kit_client.clone(),
 608                executor: executor.clone(),
 609            };
 610            update_user_contacts(user_id, &session).await?;
 611
 612            let handle_io = handle_io.fuse();
 613            futures::pin_mut!(handle_io);
 614
 615            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 616            // This prevents deadlocks when e.g., client A performs a request to client B and
 617            // client B performs a request to client A. If both clients stop processing further
 618            // messages until their respective request completes, they won't have a chance to
 619            // respond to the other client's request and cause a deadlock.
 620            //
 621            // This arrangement ensures we will attempt to process earlier messages first, but fall
 622            // back to processing messages arrived later in the spirit of making progress.
 623            let mut foreground_message_handlers = FuturesUnordered::new();
 624            let concurrent_handlers = Arc::new(Semaphore::new(256));
 625            loop {
 626                let next_message = async {
 627                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 628                    let message = incoming_rx.next().await;
 629                    (permit, message)
 630                }.fuse();
 631                futures::pin_mut!(next_message);
 632                futures::select_biased! {
 633                    _ = teardown.changed().fuse() => return Ok(()),
 634                    result = handle_io => {
 635                        if let Err(error) = result {
 636                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 637                        }
 638                        break;
 639                    }
 640                    _ = foreground_message_handlers.next() => {}
 641                    next_message = next_message => {
 642                        let (permit, message) = next_message;
 643                        if let Some(message) = message {
 644                            let type_name = message.payload_type_name();
 645                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 646                            let span_enter = span.enter();
 647                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 648                                let is_background = message.is_background();
 649                                let handle_message = (handler)(message, session.clone());
 650                                drop(span_enter);
 651
 652                                let handle_message = async move {
 653                                    handle_message.await;
 654                                    drop(permit);
 655                                }.instrument(span);
 656                                if is_background {
 657                                    executor.spawn_detached(handle_message);
 658                                } else {
 659                                    foreground_message_handlers.push(handle_message);
 660                                }
 661                            } else {
 662                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 663                            }
 664                        } else {
 665                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 666                            break;
 667                        }
 668                    }
 669                }
 670            }
 671
 672            drop(foreground_message_handlers);
 673            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 674            if let Err(error) = connection_lost(session, teardown, executor).await {
 675                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 676            }
 677
 678            Ok(())
 679        }.instrument(span)
 680    }
 681
 682    pub async fn invite_code_redeemed(
 683        self: &Arc<Self>,
 684        inviter_id: UserId,
 685        invitee_id: UserId,
 686    ) -> Result<()> {
 687        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 688            if let Some(code) = &user.invite_code {
 689                let pool = self.connection_pool.lock();
 690                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 691                for connection_id in pool.user_connection_ids(inviter_id) {
 692                    self.peer.send(
 693                        connection_id,
 694                        proto::UpdateContacts {
 695                            contacts: vec![invitee_contact.clone()],
 696                            ..Default::default()
 697                        },
 698                    )?;
 699                    self.peer.send(
 700                        connection_id,
 701                        proto::UpdateInviteInfo {
 702                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 703                            count: user.invite_count as u32,
 704                        },
 705                    )?;
 706                }
 707            }
 708        }
 709        Ok(())
 710    }
 711
 712    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 713        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 714            if let Some(invite_code) = &user.invite_code {
 715                let pool = self.connection_pool.lock();
 716                for connection_id in pool.user_connection_ids(user_id) {
 717                    self.peer.send(
 718                        connection_id,
 719                        proto::UpdateInviteInfo {
 720                            url: format!(
 721                                "{}{}",
 722                                self.app_state.config.invite_link_prefix, invite_code
 723                            ),
 724                            count: user.invite_count as u32,
 725                        },
 726                    )?;
 727                }
 728            }
 729        }
 730        Ok(())
 731    }
 732
 733    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 734        ServerSnapshot {
 735            connection_pool: ConnectionPoolGuard {
 736                guard: self.connection_pool.lock(),
 737                _not_send: PhantomData,
 738            },
 739            peer: &self.peer,
 740        }
 741    }
 742}
 743
 744impl<'a> Deref for ConnectionPoolGuard<'a> {
 745    type Target = ConnectionPool;
 746
 747    fn deref(&self) -> &Self::Target {
 748        &*self.guard
 749    }
 750}
 751
 752impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 753    fn deref_mut(&mut self) -> &mut Self::Target {
 754        &mut *self.guard
 755    }
 756}
 757
 758impl<'a> Drop for ConnectionPoolGuard<'a> {
 759    fn drop(&mut self) {
 760        #[cfg(test)]
 761        self.check_invariants();
 762    }
 763}
 764
 765fn broadcast<F>(
 766    sender_id: Option<ConnectionId>,
 767    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 768    mut f: F,
 769) where
 770    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 771{
 772    for receiver_id in receiver_ids {
 773        if Some(receiver_id) != sender_id {
 774            if let Err(error) = f(receiver_id) {
 775                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 776            }
 777        }
 778    }
 779}
 780
 781lazy_static! {
 782    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 783}
 784
 785pub struct ProtocolVersion(u32);
 786
 787impl Header for ProtocolVersion {
 788    fn name() -> &'static HeaderName {
 789        &ZED_PROTOCOL_VERSION
 790    }
 791
 792    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 793    where
 794        Self: Sized,
 795        I: Iterator<Item = &'i axum::http::HeaderValue>,
 796    {
 797        let version = values
 798            .next()
 799            .ok_or_else(axum::headers::Error::invalid)?
 800            .to_str()
 801            .map_err(|_| axum::headers::Error::invalid())?
 802            .parse()
 803            .map_err(|_| axum::headers::Error::invalid())?;
 804        Ok(Self(version))
 805    }
 806
 807    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 808        values.extend([self.0.to_string().parse().unwrap()]);
 809    }
 810}
 811
 812pub fn routes(server: Arc<Server>) -> Router<Body> {
 813    Router::new()
 814        .route("/rpc", get(handle_websocket_request))
 815        .layer(
 816            ServiceBuilder::new()
 817                .layer(Extension(server.app_state.clone()))
 818                .layer(middleware::from_fn(auth::validate_header)),
 819        )
 820        .route("/metrics", get(handle_metrics))
 821        .layer(Extension(server))
 822}
 823
 824pub async fn handle_websocket_request(
 825    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 826    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 827    Extension(server): Extension<Arc<Server>>,
 828    Extension(user): Extension<User>,
 829    ws: WebSocketUpgrade,
 830) -> axum::response::Response {
 831    if protocol_version != rpc::PROTOCOL_VERSION {
 832        return (
 833            StatusCode::UPGRADE_REQUIRED,
 834            "client must be upgraded".to_string(),
 835        )
 836            .into_response();
 837    }
 838    let socket_address = socket_address.to_string();
 839    ws.on_upgrade(move |socket| {
 840        use util::ResultExt;
 841        let socket = socket
 842            .map_ok(to_tungstenite_message)
 843            .err_into()
 844            .with(|message| async move { Ok(to_axum_message(message)) });
 845        let connection = Connection::new(Box::pin(socket));
 846        async move {
 847            server
 848                .handle_connection(connection, socket_address, user, None, Executor::Production)
 849                .await
 850                .log_err();
 851        }
 852    })
 853}
 854
 855pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 856    let connections = server
 857        .connection_pool
 858        .lock()
 859        .connections()
 860        .filter(|connection| !connection.admin)
 861        .count();
 862
 863    METRIC_CONNECTIONS.set(connections as _);
 864
 865    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 866    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 867
 868    let encoder = prometheus::TextEncoder::new();
 869    let metric_families = prometheus::gather();
 870    let encoded_metrics = encoder
 871        .encode_to_string(&metric_families)
 872        .map_err(|err| anyhow!("{}", err))?;
 873    Ok(encoded_metrics)
 874}
 875
 876#[instrument(err, skip(executor))]
 877async fn connection_lost(
 878    session: Session,
 879    mut teardown: watch::Receiver<()>,
 880    executor: Executor,
 881) -> Result<()> {
 882    session.peer.disconnect(session.connection_id);
 883    session
 884        .connection_pool()
 885        .await
 886        .remove_connection(session.connection_id)?;
 887
 888    session
 889        .db()
 890        .await
 891        .connection_lost(session.connection_id)
 892        .await
 893        .trace_err();
 894
 895    futures::select_biased! {
 896        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 897            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 898            leave_room_for_session(&session).await.trace_err();
 899            leave_channel_buffers_for_session(&session)
 900                .await
 901                .trace_err();
 902
 903            if !session
 904                .connection_pool()
 905                .await
 906                .is_user_online(session.user_id)
 907            {
 908                let db = session.db().await;
 909                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 910                    room_updated(&room, &session.peer);
 911                }
 912            }
 913
 914            update_user_contacts(session.user_id, &session).await?;
 915        }
 916        _ = teardown.changed().fuse() => {}
 917    }
 918
 919    Ok(())
 920}
 921
 922async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 923    response.send(proto::Ack {})?;
 924    Ok(())
 925}
 926
 927async fn create_room(
 928    _request: proto::CreateRoom,
 929    response: Response<proto::CreateRoom>,
 930    session: Session,
 931) -> Result<()> {
 932    let live_kit_room = nanoid::nanoid!(30);
 933
 934    let live_kit_connection_info = {
 935        let live_kit_room = live_kit_room.clone();
 936        let live_kit = session.live_kit_client.as_ref();
 937
 938        util::async_iife!({
 939            let live_kit = live_kit?;
 940
 941            live_kit
 942                .create_room(live_kit_room.clone())
 943                .await
 944                .trace_err()?;
 945
 946            let token = live_kit
 947                .room_token(&live_kit_room, &session.user_id.to_string())
 948                .trace_err()?;
 949
 950            Some(proto::LiveKitConnectionInfo {
 951                server_url: live_kit.url().into(),
 952                token,
 953            })
 954        })
 955    }
 956    .await;
 957
 958    let room = session
 959        .db()
 960        .await
 961        .create_room(
 962            session.user_id,
 963            session.connection_id,
 964            &live_kit_room,
 965            RELEASE_CHANNEL_NAME.as_str(),
 966        )
 967        .await?;
 968
 969    response.send(proto::CreateRoomResponse {
 970        room: Some(room.clone()),
 971        live_kit_connection_info,
 972    })?;
 973
 974    update_user_contacts(session.user_id, &session).await?;
 975    Ok(())
 976}
 977
 978async fn join_room(
 979    request: proto::JoinRoom,
 980    response: Response<proto::JoinRoom>,
 981    session: Session,
 982) -> Result<()> {
 983    let room_id = RoomId::from_proto(request.id);
 984    let joined_room = {
 985        let room = session
 986            .db()
 987            .await
 988            .join_room(
 989                room_id,
 990                session.user_id,
 991                session.connection_id,
 992                RELEASE_CHANNEL_NAME.as_str(),
 993            )
 994            .await?;
 995        room_updated(&room.room, &session.peer);
 996        room.into_inner()
 997    };
 998
 999    if let Some(channel_id) = joined_room.channel_id {
1000        channel_updated(
1001            channel_id,
1002            &joined_room.room,
1003            &joined_room.channel_members,
1004            &session.peer,
1005            &*session.connection_pool().await,
1006        )
1007    }
1008
1009    for connection_id in session
1010        .connection_pool()
1011        .await
1012        .user_connection_ids(session.user_id)
1013    {
1014        session
1015            .peer
1016            .send(
1017                connection_id,
1018                proto::CallCanceled {
1019                    room_id: room_id.to_proto(),
1020                },
1021            )
1022            .trace_err();
1023    }
1024
1025    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1026        if let Some(token) = live_kit
1027            .room_token(
1028                &joined_room.room.live_kit_room,
1029                &session.user_id.to_string(),
1030            )
1031            .trace_err()
1032        {
1033            Some(proto::LiveKitConnectionInfo {
1034                server_url: live_kit.url().into(),
1035                token,
1036            })
1037        } else {
1038            None
1039        }
1040    } else {
1041        None
1042    };
1043
1044    response.send(proto::JoinRoomResponse {
1045        room: Some(joined_room.room),
1046        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1047        live_kit_connection_info,
1048    })?;
1049
1050    update_user_contacts(session.user_id, &session).await?;
1051    Ok(())
1052}
1053
1054async fn rejoin_room(
1055    request: proto::RejoinRoom,
1056    response: Response<proto::RejoinRoom>,
1057    session: Session,
1058) -> Result<()> {
1059    let room;
1060    let channel_id;
1061    let channel_members;
1062    {
1063        let mut rejoined_room = session
1064            .db()
1065            .await
1066            .rejoin_room(request, session.user_id, session.connection_id)
1067            .await?;
1068
1069        response.send(proto::RejoinRoomResponse {
1070            room: Some(rejoined_room.room.clone()),
1071            reshared_projects: rejoined_room
1072                .reshared_projects
1073                .iter()
1074                .map(|project| proto::ResharedProject {
1075                    id: project.id.to_proto(),
1076                    collaborators: project
1077                        .collaborators
1078                        .iter()
1079                        .map(|collaborator| collaborator.to_proto())
1080                        .collect(),
1081                })
1082                .collect(),
1083            rejoined_projects: rejoined_room
1084                .rejoined_projects
1085                .iter()
1086                .map(|rejoined_project| proto::RejoinedProject {
1087                    id: rejoined_project.id.to_proto(),
1088                    worktrees: rejoined_project
1089                        .worktrees
1090                        .iter()
1091                        .map(|worktree| proto::WorktreeMetadata {
1092                            id: worktree.id,
1093                            root_name: worktree.root_name.clone(),
1094                            visible: worktree.visible,
1095                            abs_path: worktree.abs_path.clone(),
1096                        })
1097                        .collect(),
1098                    collaborators: rejoined_project
1099                        .collaborators
1100                        .iter()
1101                        .map(|collaborator| collaborator.to_proto())
1102                        .collect(),
1103                    language_servers: rejoined_project.language_servers.clone(),
1104                })
1105                .collect(),
1106        })?;
1107        room_updated(&rejoined_room.room, &session.peer);
1108
1109        for project in &rejoined_room.reshared_projects {
1110            for collaborator in &project.collaborators {
1111                session
1112                    .peer
1113                    .send(
1114                        collaborator.connection_id,
1115                        proto::UpdateProjectCollaborator {
1116                            project_id: project.id.to_proto(),
1117                            old_peer_id: Some(project.old_connection_id.into()),
1118                            new_peer_id: Some(session.connection_id.into()),
1119                        },
1120                    )
1121                    .trace_err();
1122            }
1123
1124            broadcast(
1125                Some(session.connection_id),
1126                project
1127                    .collaborators
1128                    .iter()
1129                    .map(|collaborator| collaborator.connection_id),
1130                |connection_id| {
1131                    session.peer.forward_send(
1132                        session.connection_id,
1133                        connection_id,
1134                        proto::UpdateProject {
1135                            project_id: project.id.to_proto(),
1136                            worktrees: project.worktrees.clone(),
1137                        },
1138                    )
1139                },
1140            );
1141        }
1142
1143        for project in &rejoined_room.rejoined_projects {
1144            for collaborator in &project.collaborators {
1145                session
1146                    .peer
1147                    .send(
1148                        collaborator.connection_id,
1149                        proto::UpdateProjectCollaborator {
1150                            project_id: project.id.to_proto(),
1151                            old_peer_id: Some(project.old_connection_id.into()),
1152                            new_peer_id: Some(session.connection_id.into()),
1153                        },
1154                    )
1155                    .trace_err();
1156            }
1157        }
1158
1159        for project in &mut rejoined_room.rejoined_projects {
1160            for worktree in mem::take(&mut project.worktrees) {
1161                #[cfg(any(test, feature = "test-support"))]
1162                const MAX_CHUNK_SIZE: usize = 2;
1163                #[cfg(not(any(test, feature = "test-support")))]
1164                const MAX_CHUNK_SIZE: usize = 256;
1165
1166                // Stream this worktree's entries.
1167                let message = proto::UpdateWorktree {
1168                    project_id: project.id.to_proto(),
1169                    worktree_id: worktree.id,
1170                    abs_path: worktree.abs_path.clone(),
1171                    root_name: worktree.root_name,
1172                    updated_entries: worktree.updated_entries,
1173                    removed_entries: worktree.removed_entries,
1174                    scan_id: worktree.scan_id,
1175                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1176                    updated_repositories: worktree.updated_repositories,
1177                    removed_repositories: worktree.removed_repositories,
1178                };
1179                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1180                    session.peer.send(session.connection_id, update.clone())?;
1181                }
1182
1183                // Stream this worktree's diagnostics.
1184                for summary in worktree.diagnostic_summaries {
1185                    session.peer.send(
1186                        session.connection_id,
1187                        proto::UpdateDiagnosticSummary {
1188                            project_id: project.id.to_proto(),
1189                            worktree_id: worktree.id,
1190                            summary: Some(summary),
1191                        },
1192                    )?;
1193                }
1194
1195                for settings_file in worktree.settings_files {
1196                    session.peer.send(
1197                        session.connection_id,
1198                        proto::UpdateWorktreeSettings {
1199                            project_id: project.id.to_proto(),
1200                            worktree_id: worktree.id,
1201                            path: settings_file.path,
1202                            content: Some(settings_file.content),
1203                        },
1204                    )?;
1205                }
1206            }
1207
1208            for language_server in &project.language_servers {
1209                session.peer.send(
1210                    session.connection_id,
1211                    proto::UpdateLanguageServer {
1212                        project_id: project.id.to_proto(),
1213                        language_server_id: language_server.id,
1214                        variant: Some(
1215                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1216                                proto::LspDiskBasedDiagnosticsUpdated {},
1217                            ),
1218                        ),
1219                    },
1220                )?;
1221            }
1222        }
1223
1224        let rejoined_room = rejoined_room.into_inner();
1225
1226        room = rejoined_room.room;
1227        channel_id = rejoined_room.channel_id;
1228        channel_members = rejoined_room.channel_members;
1229    }
1230
1231    if let Some(channel_id) = channel_id {
1232        channel_updated(
1233            channel_id,
1234            &room,
1235            &channel_members,
1236            &session.peer,
1237            &*session.connection_pool().await,
1238        );
1239    }
1240
1241    update_user_contacts(session.user_id, &session).await?;
1242    Ok(())
1243}
1244
1245async fn leave_room(
1246    _: proto::LeaveRoom,
1247    response: Response<proto::LeaveRoom>,
1248    session: Session,
1249) -> Result<()> {
1250    leave_room_for_session(&session).await?;
1251    response.send(proto::Ack {})?;
1252    Ok(())
1253}
1254
1255async fn call(
1256    request: proto::Call,
1257    response: Response<proto::Call>,
1258    session: Session,
1259) -> Result<()> {
1260    let room_id = RoomId::from_proto(request.room_id);
1261    let calling_user_id = session.user_id;
1262    let calling_connection_id = session.connection_id;
1263    let called_user_id = UserId::from_proto(request.called_user_id);
1264    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1265    if !session
1266        .db()
1267        .await
1268        .has_contact(calling_user_id, called_user_id)
1269        .await?
1270    {
1271        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1272    }
1273
1274    let incoming_call = {
1275        let (room, incoming_call) = &mut *session
1276            .db()
1277            .await
1278            .call(
1279                room_id,
1280                calling_user_id,
1281                calling_connection_id,
1282                called_user_id,
1283                initial_project_id,
1284            )
1285            .await?;
1286        room_updated(&room, &session.peer);
1287        mem::take(incoming_call)
1288    };
1289    update_user_contacts(called_user_id, &session).await?;
1290
1291    let mut calls = session
1292        .connection_pool()
1293        .await
1294        .user_connection_ids(called_user_id)
1295        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1296        .collect::<FuturesUnordered<_>>();
1297
1298    while let Some(call_response) = calls.next().await {
1299        match call_response.as_ref() {
1300            Ok(_) => {
1301                response.send(proto::Ack {})?;
1302                return Ok(());
1303            }
1304            Err(_) => {
1305                call_response.trace_err();
1306            }
1307        }
1308    }
1309
1310    {
1311        let room = session
1312            .db()
1313            .await
1314            .call_failed(room_id, called_user_id)
1315            .await?;
1316        room_updated(&room, &session.peer);
1317    }
1318    update_user_contacts(called_user_id, &session).await?;
1319
1320    Err(anyhow!("failed to ring user"))?
1321}
1322
1323async fn cancel_call(
1324    request: proto::CancelCall,
1325    response: Response<proto::CancelCall>,
1326    session: Session,
1327) -> Result<()> {
1328    let called_user_id = UserId::from_proto(request.called_user_id);
1329    let room_id = RoomId::from_proto(request.room_id);
1330    {
1331        let room = session
1332            .db()
1333            .await
1334            .cancel_call(room_id, session.connection_id, called_user_id)
1335            .await?;
1336        room_updated(&room, &session.peer);
1337    }
1338
1339    for connection_id in session
1340        .connection_pool()
1341        .await
1342        .user_connection_ids(called_user_id)
1343    {
1344        session
1345            .peer
1346            .send(
1347                connection_id,
1348                proto::CallCanceled {
1349                    room_id: room_id.to_proto(),
1350                },
1351            )
1352            .trace_err();
1353    }
1354    response.send(proto::Ack {})?;
1355
1356    update_user_contacts(called_user_id, &session).await?;
1357    Ok(())
1358}
1359
1360async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1361    let room_id = RoomId::from_proto(message.room_id);
1362    {
1363        let room = session
1364            .db()
1365            .await
1366            .decline_call(Some(room_id), session.user_id)
1367            .await?
1368            .ok_or_else(|| anyhow!("failed to decline call"))?;
1369        room_updated(&room, &session.peer);
1370    }
1371
1372    for connection_id in session
1373        .connection_pool()
1374        .await
1375        .user_connection_ids(session.user_id)
1376    {
1377        session
1378            .peer
1379            .send(
1380                connection_id,
1381                proto::CallCanceled {
1382                    room_id: room_id.to_proto(),
1383                },
1384            )
1385            .trace_err();
1386    }
1387    update_user_contacts(session.user_id, &session).await?;
1388    Ok(())
1389}
1390
1391async fn update_participant_location(
1392    request: proto::UpdateParticipantLocation,
1393    response: Response<proto::UpdateParticipantLocation>,
1394    session: Session,
1395) -> Result<()> {
1396    let room_id = RoomId::from_proto(request.room_id);
1397    let location = request
1398        .location
1399        .ok_or_else(|| anyhow!("invalid location"))?;
1400
1401    let db = session.db().await;
1402    let room = db
1403        .update_room_participant_location(room_id, session.connection_id, location)
1404        .await?;
1405
1406    room_updated(&room, &session.peer);
1407    response.send(proto::Ack {})?;
1408    Ok(())
1409}
1410
1411async fn share_project(
1412    request: proto::ShareProject,
1413    response: Response<proto::ShareProject>,
1414    session: Session,
1415) -> Result<()> {
1416    let (project_id, room) = &*session
1417        .db()
1418        .await
1419        .share_project(
1420            RoomId::from_proto(request.room_id),
1421            session.connection_id,
1422            &request.worktrees,
1423        )
1424        .await?;
1425    response.send(proto::ShareProjectResponse {
1426        project_id: project_id.to_proto(),
1427    })?;
1428    room_updated(&room, &session.peer);
1429
1430    Ok(())
1431}
1432
1433async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1434    let project_id = ProjectId::from_proto(message.project_id);
1435
1436    let (room, guest_connection_ids) = &*session
1437        .db()
1438        .await
1439        .unshare_project(project_id, session.connection_id)
1440        .await?;
1441
1442    broadcast(
1443        Some(session.connection_id),
1444        guest_connection_ids.iter().copied(),
1445        |conn_id| session.peer.send(conn_id, message.clone()),
1446    );
1447    room_updated(&room, &session.peer);
1448
1449    Ok(())
1450}
1451
1452async fn join_project(
1453    request: proto::JoinProject,
1454    response: Response<proto::JoinProject>,
1455    session: Session,
1456) -> Result<()> {
1457    let project_id = ProjectId::from_proto(request.project_id);
1458    let guest_user_id = session.user_id;
1459
1460    tracing::info!(%project_id, "join project");
1461
1462    let (project, replica_id) = &mut *session
1463        .db()
1464        .await
1465        .join_project(project_id, session.connection_id)
1466        .await?;
1467
1468    let collaborators = project
1469        .collaborators
1470        .iter()
1471        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1472        .map(|collaborator| collaborator.to_proto())
1473        .collect::<Vec<_>>();
1474
1475    let worktrees = project
1476        .worktrees
1477        .iter()
1478        .map(|(id, worktree)| proto::WorktreeMetadata {
1479            id: *id,
1480            root_name: worktree.root_name.clone(),
1481            visible: worktree.visible,
1482            abs_path: worktree.abs_path.clone(),
1483        })
1484        .collect::<Vec<_>>();
1485
1486    for collaborator in &collaborators {
1487        session
1488            .peer
1489            .send(
1490                collaborator.peer_id.unwrap().into(),
1491                proto::AddProjectCollaborator {
1492                    project_id: project_id.to_proto(),
1493                    collaborator: Some(proto::Collaborator {
1494                        peer_id: Some(session.connection_id.into()),
1495                        replica_id: replica_id.0 as u32,
1496                        user_id: guest_user_id.to_proto(),
1497                    }),
1498                },
1499            )
1500            .trace_err();
1501    }
1502
1503    // First, we send the metadata associated with each worktree.
1504    response.send(proto::JoinProjectResponse {
1505        worktrees: worktrees.clone(),
1506        replica_id: replica_id.0 as u32,
1507        collaborators: collaborators.clone(),
1508        language_servers: project.language_servers.clone(),
1509    })?;
1510
1511    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1512        #[cfg(any(test, feature = "test-support"))]
1513        const MAX_CHUNK_SIZE: usize = 2;
1514        #[cfg(not(any(test, feature = "test-support")))]
1515        const MAX_CHUNK_SIZE: usize = 256;
1516
1517        // Stream this worktree's entries.
1518        let message = proto::UpdateWorktree {
1519            project_id: project_id.to_proto(),
1520            worktree_id,
1521            abs_path: worktree.abs_path.clone(),
1522            root_name: worktree.root_name,
1523            updated_entries: worktree.entries,
1524            removed_entries: Default::default(),
1525            scan_id: worktree.scan_id,
1526            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1527            updated_repositories: worktree.repository_entries.into_values().collect(),
1528            removed_repositories: Default::default(),
1529        };
1530        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1531            session.peer.send(session.connection_id, update.clone())?;
1532        }
1533
1534        // Stream this worktree's diagnostics.
1535        for summary in worktree.diagnostic_summaries {
1536            session.peer.send(
1537                session.connection_id,
1538                proto::UpdateDiagnosticSummary {
1539                    project_id: project_id.to_proto(),
1540                    worktree_id: worktree.id,
1541                    summary: Some(summary),
1542                },
1543            )?;
1544        }
1545
1546        for settings_file in worktree.settings_files {
1547            session.peer.send(
1548                session.connection_id,
1549                proto::UpdateWorktreeSettings {
1550                    project_id: project_id.to_proto(),
1551                    worktree_id: worktree.id,
1552                    path: settings_file.path,
1553                    content: Some(settings_file.content),
1554                },
1555            )?;
1556        }
1557    }
1558
1559    for language_server in &project.language_servers {
1560        session.peer.send(
1561            session.connection_id,
1562            proto::UpdateLanguageServer {
1563                project_id: project_id.to_proto(),
1564                language_server_id: language_server.id,
1565                variant: Some(
1566                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1567                        proto::LspDiskBasedDiagnosticsUpdated {},
1568                    ),
1569                ),
1570            },
1571        )?;
1572    }
1573
1574    Ok(())
1575}
1576
1577async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1578    let sender_id = session.connection_id;
1579    let project_id = ProjectId::from_proto(request.project_id);
1580
1581    let (room, project) = &*session
1582        .db()
1583        .await
1584        .leave_project(project_id, sender_id)
1585        .await?;
1586    tracing::info!(
1587        %project_id,
1588        host_user_id = %project.host_user_id,
1589        host_connection_id = %project.host_connection_id,
1590        "leave project"
1591    );
1592
1593    project_left(&project, &session);
1594    room_updated(&room, &session.peer);
1595
1596    Ok(())
1597}
1598
1599async fn update_project(
1600    request: proto::UpdateProject,
1601    response: Response<proto::UpdateProject>,
1602    session: Session,
1603) -> Result<()> {
1604    let project_id = ProjectId::from_proto(request.project_id);
1605    let (room, guest_connection_ids) = &*session
1606        .db()
1607        .await
1608        .update_project(project_id, session.connection_id, &request.worktrees)
1609        .await?;
1610    broadcast(
1611        Some(session.connection_id),
1612        guest_connection_ids.iter().copied(),
1613        |connection_id| {
1614            session
1615                .peer
1616                .forward_send(session.connection_id, connection_id, request.clone())
1617        },
1618    );
1619    room_updated(&room, &session.peer);
1620    response.send(proto::Ack {})?;
1621
1622    Ok(())
1623}
1624
1625async fn update_worktree(
1626    request: proto::UpdateWorktree,
1627    response: Response<proto::UpdateWorktree>,
1628    session: Session,
1629) -> Result<()> {
1630    let guest_connection_ids = session
1631        .db()
1632        .await
1633        .update_worktree(&request, session.connection_id)
1634        .await?;
1635
1636    broadcast(
1637        Some(session.connection_id),
1638        guest_connection_ids.iter().copied(),
1639        |connection_id| {
1640            session
1641                .peer
1642                .forward_send(session.connection_id, connection_id, request.clone())
1643        },
1644    );
1645    response.send(proto::Ack {})?;
1646    Ok(())
1647}
1648
1649async fn update_diagnostic_summary(
1650    message: proto::UpdateDiagnosticSummary,
1651    session: Session,
1652) -> Result<()> {
1653    let guest_connection_ids = session
1654        .db()
1655        .await
1656        .update_diagnostic_summary(&message, session.connection_id)
1657        .await?;
1658
1659    broadcast(
1660        Some(session.connection_id),
1661        guest_connection_ids.iter().copied(),
1662        |connection_id| {
1663            session
1664                .peer
1665                .forward_send(session.connection_id, connection_id, message.clone())
1666        },
1667    );
1668
1669    Ok(())
1670}
1671
1672async fn update_worktree_settings(
1673    message: proto::UpdateWorktreeSettings,
1674    session: Session,
1675) -> Result<()> {
1676    let guest_connection_ids = session
1677        .db()
1678        .await
1679        .update_worktree_settings(&message, session.connection_id)
1680        .await?;
1681
1682    broadcast(
1683        Some(session.connection_id),
1684        guest_connection_ids.iter().copied(),
1685        |connection_id| {
1686            session
1687                .peer
1688                .forward_send(session.connection_id, connection_id, message.clone())
1689        },
1690    );
1691
1692    Ok(())
1693}
1694
1695async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1696    broadcast_project_message(request.project_id, request, session).await
1697}
1698
1699async fn start_language_server(
1700    request: proto::StartLanguageServer,
1701    session: Session,
1702) -> Result<()> {
1703    let guest_connection_ids = session
1704        .db()
1705        .await
1706        .start_language_server(&request, session.connection_id)
1707        .await?;
1708
1709    broadcast(
1710        Some(session.connection_id),
1711        guest_connection_ids.iter().copied(),
1712        |connection_id| {
1713            session
1714                .peer
1715                .forward_send(session.connection_id, connection_id, request.clone())
1716        },
1717    );
1718    Ok(())
1719}
1720
1721async fn update_language_server(
1722    request: proto::UpdateLanguageServer,
1723    session: Session,
1724) -> Result<()> {
1725    session.executor.record_backtrace();
1726    let project_id = ProjectId::from_proto(request.project_id);
1727    let project_connection_ids = session
1728        .db()
1729        .await
1730        .project_connection_ids(project_id, session.connection_id)
1731        .await?;
1732    broadcast(
1733        Some(session.connection_id),
1734        project_connection_ids.iter().copied(),
1735        |connection_id| {
1736            session
1737                .peer
1738                .forward_send(session.connection_id, connection_id, request.clone())
1739        },
1740    );
1741    Ok(())
1742}
1743
1744async fn forward_project_request<T>(
1745    request: T,
1746    response: Response<T>,
1747    session: Session,
1748) -> Result<()>
1749where
1750    T: EntityMessage + RequestMessage,
1751{
1752    session.executor.record_backtrace();
1753    let project_id = ProjectId::from_proto(request.remote_entity_id());
1754    let host_connection_id = {
1755        let collaborators = session
1756            .db()
1757            .await
1758            .project_collaborators(project_id, session.connection_id)
1759            .await?;
1760        collaborators
1761            .iter()
1762            .find(|collaborator| collaborator.is_host)
1763            .ok_or_else(|| anyhow!("host not found"))?
1764            .connection_id
1765    };
1766
1767    let payload = session
1768        .peer
1769        .forward_request(session.connection_id, host_connection_id, request)
1770        .await?;
1771
1772    response.send(payload)?;
1773    Ok(())
1774}
1775
1776async fn create_buffer_for_peer(
1777    request: proto::CreateBufferForPeer,
1778    session: Session,
1779) -> Result<()> {
1780    session.executor.record_backtrace();
1781    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1782    session
1783        .peer
1784        .forward_send(session.connection_id, peer_id.into(), request)?;
1785    Ok(())
1786}
1787
1788async fn update_buffer(
1789    request: proto::UpdateBuffer,
1790    response: Response<proto::UpdateBuffer>,
1791    session: Session,
1792) -> Result<()> {
1793    session.executor.record_backtrace();
1794    let project_id = ProjectId::from_proto(request.project_id);
1795    let mut guest_connection_ids;
1796    let mut host_connection_id = None;
1797    {
1798        let collaborators = session
1799            .db()
1800            .await
1801            .project_collaborators(project_id, session.connection_id)
1802            .await?;
1803        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1804        for collaborator in collaborators.iter() {
1805            if collaborator.is_host {
1806                host_connection_id = Some(collaborator.connection_id);
1807            } else {
1808                guest_connection_ids.push(collaborator.connection_id);
1809            }
1810        }
1811    }
1812    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1813
1814    session.executor.record_backtrace();
1815    broadcast(
1816        Some(session.connection_id),
1817        guest_connection_ids,
1818        |connection_id| {
1819            session
1820                .peer
1821                .forward_send(session.connection_id, connection_id, request.clone())
1822        },
1823    );
1824    if host_connection_id != session.connection_id {
1825        session
1826            .peer
1827            .forward_request(session.connection_id, host_connection_id, request.clone())
1828            .await?;
1829    }
1830
1831    response.send(proto::Ack {})?;
1832    Ok(())
1833}
1834
1835async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1836    let project_id = ProjectId::from_proto(request.project_id);
1837    let project_connection_ids = session
1838        .db()
1839        .await
1840        .project_connection_ids(project_id, session.connection_id)
1841        .await?;
1842
1843    broadcast(
1844        Some(session.connection_id),
1845        project_connection_ids.iter().copied(),
1846        |connection_id| {
1847            session
1848                .peer
1849                .forward_send(session.connection_id, connection_id, request.clone())
1850        },
1851    );
1852    Ok(())
1853}
1854
1855async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1856    let project_id = ProjectId::from_proto(request.project_id);
1857    let project_connection_ids = session
1858        .db()
1859        .await
1860        .project_connection_ids(project_id, session.connection_id)
1861        .await?;
1862    broadcast(
1863        Some(session.connection_id),
1864        project_connection_ids.iter().copied(),
1865        |connection_id| {
1866            session
1867                .peer
1868                .forward_send(session.connection_id, connection_id, request.clone())
1869        },
1870    );
1871    Ok(())
1872}
1873
1874async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1875    broadcast_project_message(request.project_id, request, session).await
1876}
1877
1878async fn broadcast_project_message<T: EnvelopedMessage>(
1879    project_id: u64,
1880    request: T,
1881    session: Session,
1882) -> Result<()> {
1883    let project_id = ProjectId::from_proto(project_id);
1884    let project_connection_ids = session
1885        .db()
1886        .await
1887        .project_connection_ids(project_id, session.connection_id)
1888        .await?;
1889    broadcast(
1890        Some(session.connection_id),
1891        project_connection_ids.iter().copied(),
1892        |connection_id| {
1893            session
1894                .peer
1895                .forward_send(session.connection_id, connection_id, request.clone())
1896        },
1897    );
1898    Ok(())
1899}
1900
1901async fn follow(
1902    request: proto::Follow,
1903    response: Response<proto::Follow>,
1904    session: Session,
1905) -> Result<()> {
1906    let room_id = RoomId::from_proto(request.room_id);
1907    let project_id = request.project_id.map(ProjectId::from_proto);
1908    let leader_id = request
1909        .leader_id
1910        .ok_or_else(|| anyhow!("invalid leader id"))?
1911        .into();
1912    let follower_id = session.connection_id;
1913
1914    session
1915        .db()
1916        .await
1917        .check_room_participants(room_id, leader_id, session.connection_id)
1918        .await?;
1919
1920    let response_payload = session
1921        .peer
1922        .forward_request(session.connection_id, leader_id, request)
1923        .await?;
1924    response.send(response_payload)?;
1925
1926    if let Some(project_id) = project_id {
1927        let room = session
1928            .db()
1929            .await
1930            .follow(room_id, project_id, leader_id, follower_id)
1931            .await?;
1932        room_updated(&room, &session.peer);
1933    }
1934
1935    Ok(())
1936}
1937
1938async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1939    let room_id = RoomId::from_proto(request.room_id);
1940    let project_id = request.project_id.map(ProjectId::from_proto);
1941    let leader_id = request
1942        .leader_id
1943        .ok_or_else(|| anyhow!("invalid leader id"))?
1944        .into();
1945    let follower_id = session.connection_id;
1946
1947    session
1948        .db()
1949        .await
1950        .check_room_participants(room_id, leader_id, session.connection_id)
1951        .await?;
1952
1953    session
1954        .peer
1955        .forward_send(session.connection_id, leader_id, request)?;
1956
1957    if let Some(project_id) = project_id {
1958        let room = session
1959            .db()
1960            .await
1961            .unfollow(room_id, project_id, leader_id, follower_id)
1962            .await?;
1963        room_updated(&room, &session.peer);
1964    }
1965
1966    Ok(())
1967}
1968
1969async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1970    let room_id = RoomId::from_proto(request.room_id);
1971    let database = session.db.lock().await;
1972
1973    let connection_ids = if let Some(project_id) = request.project_id {
1974        let project_id = ProjectId::from_proto(project_id);
1975        database
1976            .project_connection_ids(project_id, session.connection_id)
1977            .await?
1978    } else {
1979        database
1980            .room_connection_ids(room_id, session.connection_id)
1981            .await?
1982    };
1983
1984    // For now, don't send view update messages back to that view's current leader.
1985    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1986        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1987        _ => None,
1988    });
1989
1990    for follower_peer_id in request.follower_ids.iter().copied() {
1991        let follower_connection_id = follower_peer_id.into();
1992        if Some(follower_peer_id) != connection_id_to_omit
1993            && connection_ids.contains(&follower_connection_id)
1994        {
1995            session.peer.forward_send(
1996                session.connection_id,
1997                follower_connection_id,
1998                request.clone(),
1999            )?;
2000        }
2001    }
2002    Ok(())
2003}
2004
2005async fn get_users(
2006    request: proto::GetUsers,
2007    response: Response<proto::GetUsers>,
2008    session: Session,
2009) -> Result<()> {
2010    let user_ids = request
2011        .user_ids
2012        .into_iter()
2013        .map(UserId::from_proto)
2014        .collect();
2015    let users = session
2016        .db()
2017        .await
2018        .get_users_by_ids(user_ids)
2019        .await?
2020        .into_iter()
2021        .map(|user| proto::User {
2022            id: user.id.to_proto(),
2023            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2024            github_login: user.github_login,
2025        })
2026        .collect();
2027    response.send(proto::UsersResponse { users })?;
2028    Ok(())
2029}
2030
2031async fn fuzzy_search_users(
2032    request: proto::FuzzySearchUsers,
2033    response: Response<proto::FuzzySearchUsers>,
2034    session: Session,
2035) -> Result<()> {
2036    let query = request.query;
2037    let users = match query.len() {
2038        0 => vec![],
2039        1 | 2 => session
2040            .db()
2041            .await
2042            .get_user_by_github_login(&query)
2043            .await?
2044            .into_iter()
2045            .collect(),
2046        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2047    };
2048    let users = users
2049        .into_iter()
2050        .filter(|user| user.id != session.user_id)
2051        .map(|user| proto::User {
2052            id: user.id.to_proto(),
2053            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2054            github_login: user.github_login,
2055        })
2056        .collect();
2057    response.send(proto::UsersResponse { users })?;
2058    Ok(())
2059}
2060
2061async fn request_contact(
2062    request: proto::RequestContact,
2063    response: Response<proto::RequestContact>,
2064    session: Session,
2065) -> Result<()> {
2066    let requester_id = session.user_id;
2067    let responder_id = UserId::from_proto(request.responder_id);
2068    if requester_id == responder_id {
2069        return Err(anyhow!("cannot add yourself as a contact"))?;
2070    }
2071
2072    session
2073        .db()
2074        .await
2075        .send_contact_request(requester_id, responder_id)
2076        .await?;
2077
2078    // Update outgoing contact requests of requester
2079    let mut update = proto::UpdateContacts::default();
2080    update.outgoing_requests.push(responder_id.to_proto());
2081    for connection_id in session
2082        .connection_pool()
2083        .await
2084        .user_connection_ids(requester_id)
2085    {
2086        session.peer.send(connection_id, update.clone())?;
2087    }
2088
2089    // Update incoming contact requests of responder
2090    let mut update = proto::UpdateContacts::default();
2091    update
2092        .incoming_requests
2093        .push(proto::IncomingContactRequest {
2094            requester_id: requester_id.to_proto(),
2095            should_notify: true,
2096        });
2097    for connection_id in session
2098        .connection_pool()
2099        .await
2100        .user_connection_ids(responder_id)
2101    {
2102        session.peer.send(connection_id, update.clone())?;
2103    }
2104
2105    response.send(proto::Ack {})?;
2106    Ok(())
2107}
2108
2109async fn respond_to_contact_request(
2110    request: proto::RespondToContactRequest,
2111    response: Response<proto::RespondToContactRequest>,
2112    session: Session,
2113) -> Result<()> {
2114    let responder_id = session.user_id;
2115    let requester_id = UserId::from_proto(request.requester_id);
2116    let db = session.db().await;
2117    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2118        db.dismiss_contact_notification(responder_id, requester_id)
2119            .await?;
2120    } else {
2121        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2122
2123        db.respond_to_contact_request(responder_id, requester_id, accept)
2124            .await?;
2125        let requester_busy = db.is_user_busy(requester_id).await?;
2126        let responder_busy = db.is_user_busy(responder_id).await?;
2127
2128        let pool = session.connection_pool().await;
2129        // Update responder with new contact
2130        let mut update = proto::UpdateContacts::default();
2131        if accept {
2132            update
2133                .contacts
2134                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2135        }
2136        update
2137            .remove_incoming_requests
2138            .push(requester_id.to_proto());
2139        for connection_id in pool.user_connection_ids(responder_id) {
2140            session.peer.send(connection_id, update.clone())?;
2141        }
2142
2143        // Update requester with new contact
2144        let mut update = proto::UpdateContacts::default();
2145        if accept {
2146            update
2147                .contacts
2148                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2149        }
2150        update
2151            .remove_outgoing_requests
2152            .push(responder_id.to_proto());
2153        for connection_id in pool.user_connection_ids(requester_id) {
2154            session.peer.send(connection_id, update.clone())?;
2155        }
2156    }
2157
2158    response.send(proto::Ack {})?;
2159    Ok(())
2160}
2161
2162async fn remove_contact(
2163    request: proto::RemoveContact,
2164    response: Response<proto::RemoveContact>,
2165    session: Session,
2166) -> Result<()> {
2167    let requester_id = session.user_id;
2168    let responder_id = UserId::from_proto(request.user_id);
2169    let db = session.db().await;
2170    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2171
2172    let pool = session.connection_pool().await;
2173    // Update outgoing contact requests of requester
2174    let mut update = proto::UpdateContacts::default();
2175    if contact_accepted {
2176        update.remove_contacts.push(responder_id.to_proto());
2177    } else {
2178        update
2179            .remove_outgoing_requests
2180            .push(responder_id.to_proto());
2181    }
2182    for connection_id in pool.user_connection_ids(requester_id) {
2183        session.peer.send(connection_id, update.clone())?;
2184    }
2185
2186    // Update incoming contact requests of responder
2187    let mut update = proto::UpdateContacts::default();
2188    if contact_accepted {
2189        update.remove_contacts.push(requester_id.to_proto());
2190    } else {
2191        update
2192            .remove_incoming_requests
2193            .push(requester_id.to_proto());
2194    }
2195    for connection_id in pool.user_connection_ids(responder_id) {
2196        session.peer.send(connection_id, update.clone())?;
2197    }
2198
2199    response.send(proto::Ack {})?;
2200    Ok(())
2201}
2202
2203async fn create_channel(
2204    request: proto::CreateChannel,
2205    response: Response<proto::CreateChannel>,
2206    session: Session,
2207) -> Result<()> {
2208    let db = session.db().await;
2209    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2210
2211    if let Some(live_kit) = session.live_kit_client.as_ref() {
2212        live_kit.create_room(live_kit_room.clone()).await?;
2213    }
2214
2215    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2216    let id = db
2217        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2218        .await?;
2219
2220    let channel = proto::Channel {
2221        id: id.to_proto(),
2222        name: request.name,
2223    };
2224
2225    response.send(proto::CreateChannelResponse {
2226        channel: Some(channel.clone()),
2227        parent_id: request.parent_id,
2228    })?;
2229
2230    let Some(parent_id) = parent_id else {
2231        return Ok(());
2232    };
2233
2234    let update = proto::UpdateChannels {
2235        channels: vec![channel],
2236        insert_edge: vec![ChannelEdge {
2237            parent_id: parent_id.to_proto(),
2238            channel_id: id.to_proto(),
2239        }],
2240        ..Default::default()
2241    };
2242
2243    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2244
2245    let connection_pool = session.connection_pool().await;
2246    for user_id in user_ids_to_notify {
2247        for connection_id in connection_pool.user_connection_ids(user_id) {
2248            if user_id == session.user_id {
2249                continue;
2250            }
2251            session.peer.send(connection_id, update.clone())?;
2252        }
2253    }
2254
2255    Ok(())
2256}
2257
2258async fn delete_channel(
2259    request: proto::DeleteChannel,
2260    response: Response<proto::DeleteChannel>,
2261    session: Session,
2262) -> Result<()> {
2263    let db = session.db().await;
2264
2265    let channel_id = request.channel_id;
2266    let (removed_channels, member_ids) = db
2267        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2268        .await?;
2269    response.send(proto::Ack {})?;
2270
2271    // Notify members of removed channels
2272    let mut update = proto::UpdateChannels::default();
2273    update
2274        .delete_channels
2275        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2276
2277    let connection_pool = session.connection_pool().await;
2278    for member_id in member_ids {
2279        for connection_id in connection_pool.user_connection_ids(member_id) {
2280            session.peer.send(connection_id, update.clone())?;
2281        }
2282    }
2283
2284    Ok(())
2285}
2286
2287async fn invite_channel_member(
2288    request: proto::InviteChannelMember,
2289    response: Response<proto::InviteChannelMember>,
2290    session: Session,
2291) -> Result<()> {
2292    let db = session.db().await;
2293    let channel_id = ChannelId::from_proto(request.channel_id);
2294    let invitee_id = UserId::from_proto(request.user_id);
2295    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2296        .await?;
2297
2298    let (channel, _) = db
2299        .get_channel(channel_id, session.user_id)
2300        .await?
2301        .ok_or_else(|| anyhow!("channel not found"))?;
2302
2303    let mut update = proto::UpdateChannels::default();
2304    update.channel_invitations.push(proto::Channel {
2305        id: channel.id.to_proto(),
2306        name: channel.name,
2307    });
2308    for connection_id in session
2309        .connection_pool()
2310        .await
2311        .user_connection_ids(invitee_id)
2312    {
2313        session.peer.send(connection_id, update.clone())?;
2314    }
2315
2316    response.send(proto::Ack {})?;
2317    Ok(())
2318}
2319
2320async fn remove_channel_member(
2321    request: proto::RemoveChannelMember,
2322    response: Response<proto::RemoveChannelMember>,
2323    session: Session,
2324) -> Result<()> {
2325    let db = session.db().await;
2326    let channel_id = ChannelId::from_proto(request.channel_id);
2327    let member_id = UserId::from_proto(request.user_id);
2328
2329    db.remove_channel_member(channel_id, member_id, session.user_id)
2330        .await?;
2331
2332    let mut update = proto::UpdateChannels::default();
2333    update.delete_channels.push(channel_id.to_proto());
2334
2335    for connection_id in session
2336        .connection_pool()
2337        .await
2338        .user_connection_ids(member_id)
2339    {
2340        session.peer.send(connection_id, update.clone())?;
2341    }
2342
2343    response.send(proto::Ack {})?;
2344    Ok(())
2345}
2346
2347async fn set_channel_member_admin(
2348    request: proto::SetChannelMemberAdmin,
2349    response: Response<proto::SetChannelMemberAdmin>,
2350    session: Session,
2351) -> Result<()> {
2352    let db = session.db().await;
2353    let channel_id = ChannelId::from_proto(request.channel_id);
2354    let member_id = UserId::from_proto(request.user_id);
2355    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2356        .await?;
2357
2358    let (channel, has_accepted) = db
2359        .get_channel(channel_id, member_id)
2360        .await?
2361        .ok_or_else(|| anyhow!("channel not found"))?;
2362
2363    let mut update = proto::UpdateChannels::default();
2364    if has_accepted {
2365        update.channel_permissions.push(proto::ChannelPermission {
2366            channel_id: channel.id.to_proto(),
2367            is_admin: request.admin,
2368        });
2369    }
2370
2371    for connection_id in session
2372        .connection_pool()
2373        .await
2374        .user_connection_ids(member_id)
2375    {
2376        session.peer.send(connection_id, update.clone())?;
2377    }
2378
2379    response.send(proto::Ack {})?;
2380    Ok(())
2381}
2382
2383async fn rename_channel(
2384    request: proto::RenameChannel,
2385    response: Response<proto::RenameChannel>,
2386    session: Session,
2387) -> Result<()> {
2388    let db = session.db().await;
2389    let channel_id = ChannelId::from_proto(request.channel_id);
2390    let new_name = db
2391        .rename_channel(channel_id, session.user_id, &request.name)
2392        .await?;
2393
2394    let channel = proto::Channel {
2395        id: request.channel_id,
2396        name: new_name,
2397    };
2398    response.send(proto::RenameChannelResponse {
2399        channel: Some(channel.clone()),
2400    })?;
2401    let mut update = proto::UpdateChannels::default();
2402    update.channels.push(channel);
2403
2404    let member_ids = db.get_channel_members(channel_id).await?;
2405
2406    let connection_pool = session.connection_pool().await;
2407    for member_id in member_ids {
2408        for connection_id in connection_pool.user_connection_ids(member_id) {
2409            session.peer.send(connection_id, update.clone())?;
2410        }
2411    }
2412
2413    Ok(())
2414}
2415
2416async fn link_channel(
2417    request: proto::LinkChannel,
2418    response: Response<proto::LinkChannel>,
2419    session: Session,
2420) -> Result<()> {
2421    let db = session.db().await;
2422    let channel_id = ChannelId::from_proto(request.channel_id);
2423    let to = ChannelId::from_proto(request.to);
2424    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2425
2426    let members = db.get_channel_members(to).await?;
2427    let connection_pool = session.connection_pool().await;
2428    let update = proto::UpdateChannels {
2429        channels: channels_to_send
2430            .channels
2431            .into_iter()
2432            .map(|channel| proto::Channel {
2433                id: channel.id.to_proto(),
2434                name: channel.name,
2435            })
2436            .collect(),
2437        insert_edge: channels_to_send.edges,
2438        ..Default::default()
2439    };
2440    for member_id in members {
2441        for connection_id in connection_pool.user_connection_ids(member_id) {
2442            session.peer.send(connection_id, update.clone())?;
2443        }
2444    }
2445
2446    response.send(Ack {})?;
2447
2448    Ok(())
2449}
2450
2451async fn unlink_channel(
2452    request: proto::UnlinkChannel,
2453    response: Response<proto::UnlinkChannel>,
2454    session: Session,
2455) -> Result<()> {
2456    let db = session.db().await;
2457    let channel_id = ChannelId::from_proto(request.channel_id);
2458    let from = ChannelId::from_proto(request.from);
2459
2460    db.unlink_channel(session.user_id, channel_id, from).await?;
2461
2462    let members = db.get_channel_members(from).await?;
2463
2464    let update = proto::UpdateChannels {
2465        delete_edge: vec![proto::ChannelEdge {
2466            channel_id: channel_id.to_proto(),
2467            parent_id: from.to_proto(),
2468        }],
2469        ..Default::default()
2470    };
2471    let connection_pool = session.connection_pool().await;
2472    for member_id in members {
2473        for connection_id in connection_pool.user_connection_ids(member_id) {
2474            session.peer.send(connection_id, update.clone())?;
2475        }
2476    }
2477
2478    response.send(Ack {})?;
2479
2480    Ok(())
2481}
2482
2483async fn move_channel(
2484    request: proto::MoveChannel,
2485    response: Response<proto::MoveChannel>,
2486    session: Session,
2487) -> Result<()> {
2488    let db = session.db().await;
2489    let channel_id = ChannelId::from_proto(request.channel_id);
2490    let from_parent = ChannelId::from_proto(request.from);
2491    let to = ChannelId::from_proto(request.to);
2492
2493    let channels_to_send = db
2494        .move_channel(session.user_id, channel_id, from_parent, to)
2495        .await?;
2496
2497    if channels_to_send.is_empty() {
2498        response.send(Ack {})?;
2499        return Ok(());
2500    }
2501
2502    let members_from = db.get_channel_members(from_parent).await?;
2503    let members_to = db.get_channel_members(to).await?;
2504
2505    let update = proto::UpdateChannels {
2506        delete_edge: vec![proto::ChannelEdge {
2507            channel_id: channel_id.to_proto(),
2508            parent_id: from_parent.to_proto(),
2509        }],
2510        ..Default::default()
2511    };
2512    let connection_pool = session.connection_pool().await;
2513    for member_id in members_from {
2514        for connection_id in connection_pool.user_connection_ids(member_id) {
2515            session.peer.send(connection_id, update.clone())?;
2516        }
2517    }
2518
2519    let update = proto::UpdateChannels {
2520        channels: channels_to_send
2521            .channels
2522            .into_iter()
2523            .map(|channel| proto::Channel {
2524                id: channel.id.to_proto(),
2525                name: channel.name,
2526            })
2527            .collect(),
2528        insert_edge: channels_to_send.edges,
2529        ..Default::default()
2530    };
2531    for member_id in members_to {
2532        for connection_id in connection_pool.user_connection_ids(member_id) {
2533            session.peer.send(connection_id, update.clone())?;
2534        }
2535    }
2536
2537    response.send(Ack {})?;
2538
2539    Ok(())
2540}
2541
2542async fn get_channel_members(
2543    request: proto::GetChannelMembers,
2544    response: Response<proto::GetChannelMembers>,
2545    session: Session,
2546) -> Result<()> {
2547    let db = session.db().await;
2548    let channel_id = ChannelId::from_proto(request.channel_id);
2549    let members = db
2550        .get_channel_member_details(channel_id, session.user_id)
2551        .await?;
2552    response.send(proto::GetChannelMembersResponse { members })?;
2553    Ok(())
2554}
2555
2556async fn respond_to_channel_invite(
2557    request: proto::RespondToChannelInvite,
2558    response: Response<proto::RespondToChannelInvite>,
2559    session: Session,
2560) -> Result<()> {
2561    let db = session.db().await;
2562    let channel_id = ChannelId::from_proto(request.channel_id);
2563    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2564        .await?;
2565
2566    let mut update = proto::UpdateChannels::default();
2567    update
2568        .remove_channel_invitations
2569        .push(channel_id.to_proto());
2570    if request.accept {
2571        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2572        update
2573            .channels
2574            .extend(
2575                result
2576                    .channels
2577                    .channels
2578                    .into_iter()
2579                    .map(|channel| proto::Channel {
2580                        id: channel.id.to_proto(),
2581                        name: channel.name,
2582                    }),
2583            );
2584        update.unseen_channel_messages = result.channel_messages;
2585        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2586        update.insert_edge = result.channels.edges;
2587        update
2588            .channel_participants
2589            .extend(
2590                result
2591                    .channel_participants
2592                    .into_iter()
2593                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2594                        channel_id: channel_id.to_proto(),
2595                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2596                    }),
2597            );
2598        update
2599            .channel_permissions
2600            .extend(
2601                result
2602                    .channels_with_admin_privileges
2603                    .into_iter()
2604                    .map(|channel_id| proto::ChannelPermission {
2605                        channel_id: channel_id.to_proto(),
2606                        is_admin: true,
2607                    }),
2608            );
2609    }
2610    session.peer.send(session.connection_id, update)?;
2611    response.send(proto::Ack {})?;
2612
2613    Ok(())
2614}
2615
2616async fn join_channel(
2617    request: proto::JoinChannel,
2618    response: Response<proto::JoinChannel>,
2619    session: Session,
2620) -> Result<()> {
2621    let channel_id = ChannelId::from_proto(request.channel_id);
2622
2623    let joined_room = {
2624        leave_room_for_session(&session).await?;
2625        let db = session.db().await;
2626
2627        let room_id = db.room_id_for_channel(channel_id).await?;
2628
2629        let joined_room = db
2630            .join_room(
2631                room_id,
2632                session.user_id,
2633                session.connection_id,
2634                RELEASE_CHANNEL_NAME.as_str(),
2635            )
2636            .await?;
2637
2638        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2639            let token = live_kit
2640                .room_token(
2641                    &joined_room.room.live_kit_room,
2642                    &session.user_id.to_string(),
2643                )
2644                .trace_err()?;
2645
2646            Some(LiveKitConnectionInfo {
2647                server_url: live_kit.url().into(),
2648                token,
2649            })
2650        });
2651
2652        response.send(proto::JoinRoomResponse {
2653            room: Some(joined_room.room.clone()),
2654            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2655            live_kit_connection_info,
2656        })?;
2657
2658        room_updated(&joined_room.room, &session.peer);
2659
2660        joined_room.into_inner()
2661    };
2662
2663    channel_updated(
2664        channel_id,
2665        &joined_room.room,
2666        &joined_room.channel_members,
2667        &session.peer,
2668        &*session.connection_pool().await,
2669    );
2670
2671    update_user_contacts(session.user_id, &session).await?;
2672
2673    Ok(())
2674}
2675
2676async fn join_channel_buffer(
2677    request: proto::JoinChannelBuffer,
2678    response: Response<proto::JoinChannelBuffer>,
2679    session: Session,
2680) -> Result<()> {
2681    let db = session.db().await;
2682    let channel_id = ChannelId::from_proto(request.channel_id);
2683
2684    let open_response = db
2685        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2686        .await?;
2687
2688    let collaborators = open_response.collaborators.clone();
2689    response.send(open_response)?;
2690
2691    let update = UpdateChannelBufferCollaborators {
2692        channel_id: channel_id.to_proto(),
2693        collaborators: collaborators.clone(),
2694    };
2695    channel_buffer_updated(
2696        session.connection_id,
2697        collaborators
2698            .iter()
2699            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2700        &update,
2701        &session.peer,
2702    );
2703
2704    Ok(())
2705}
2706
2707async fn update_channel_buffer(
2708    request: proto::UpdateChannelBuffer,
2709    session: Session,
2710) -> Result<()> {
2711    let db = session.db().await;
2712    let channel_id = ChannelId::from_proto(request.channel_id);
2713
2714    let (collaborators, non_collaborators, epoch, version) = db
2715        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2716        .await?;
2717
2718    channel_buffer_updated(
2719        session.connection_id,
2720        collaborators,
2721        &proto::UpdateChannelBuffer {
2722            channel_id: channel_id.to_proto(),
2723            operations: request.operations,
2724        },
2725        &session.peer,
2726    );
2727
2728    let pool = &*session.connection_pool().await;
2729
2730    broadcast(
2731        None,
2732        non_collaborators
2733            .iter()
2734            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2735        |peer_id| {
2736            session.peer.send(
2737                peer_id.into(),
2738                proto::UpdateChannels {
2739                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2740                        channel_id: channel_id.to_proto(),
2741                        epoch: epoch as u64,
2742                        version: version.clone(),
2743                    }],
2744                    ..Default::default()
2745                },
2746            )
2747        },
2748    );
2749
2750    Ok(())
2751}
2752
2753async fn rejoin_channel_buffers(
2754    request: proto::RejoinChannelBuffers,
2755    response: Response<proto::RejoinChannelBuffers>,
2756    session: Session,
2757) -> Result<()> {
2758    let db = session.db().await;
2759    let buffers = db
2760        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2761        .await?;
2762
2763    for rejoined_buffer in &buffers {
2764        let collaborators_to_notify = rejoined_buffer
2765            .buffer
2766            .collaborators
2767            .iter()
2768            .filter_map(|c| Some(c.peer_id?.into()));
2769        channel_buffer_updated(
2770            session.connection_id,
2771            collaborators_to_notify,
2772            &proto::UpdateChannelBufferCollaborators {
2773                channel_id: rejoined_buffer.buffer.channel_id,
2774                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2775            },
2776            &session.peer,
2777        );
2778    }
2779
2780    response.send(proto::RejoinChannelBuffersResponse {
2781        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2782    })?;
2783
2784    Ok(())
2785}
2786
2787async fn leave_channel_buffer(
2788    request: proto::LeaveChannelBuffer,
2789    response: Response<proto::LeaveChannelBuffer>,
2790    session: Session,
2791) -> Result<()> {
2792    let db = session.db().await;
2793    let channel_id = ChannelId::from_proto(request.channel_id);
2794
2795    let left_buffer = db
2796        .leave_channel_buffer(channel_id, session.connection_id)
2797        .await?;
2798
2799    response.send(Ack {})?;
2800
2801    channel_buffer_updated(
2802        session.connection_id,
2803        left_buffer.connections,
2804        &proto::UpdateChannelBufferCollaborators {
2805            channel_id: channel_id.to_proto(),
2806            collaborators: left_buffer.collaborators,
2807        },
2808        &session.peer,
2809    );
2810
2811    Ok(())
2812}
2813
2814fn channel_buffer_updated<T: EnvelopedMessage>(
2815    sender_id: ConnectionId,
2816    collaborators: impl IntoIterator<Item = ConnectionId>,
2817    message: &T,
2818    peer: &Peer,
2819) {
2820    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2821        peer.send(peer_id.into(), message.clone())
2822    });
2823}
2824
2825async fn send_channel_message(
2826    request: proto::SendChannelMessage,
2827    response: Response<proto::SendChannelMessage>,
2828    session: Session,
2829) -> Result<()> {
2830    // Validate the message body.
2831    let body = request.body.trim().to_string();
2832    if body.len() > MAX_MESSAGE_LEN {
2833        return Err(anyhow!("message is too long"))?;
2834    }
2835    if body.is_empty() {
2836        return Err(anyhow!("message can't be blank"))?;
2837    }
2838
2839    let timestamp = OffsetDateTime::now_utc();
2840    let nonce = request
2841        .nonce
2842        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2843
2844    let channel_id = ChannelId::from_proto(request.channel_id);
2845    let (message_id, connection_ids, non_participants) = session
2846        .db()
2847        .await
2848        .create_channel_message(
2849            channel_id,
2850            session.user_id,
2851            &body,
2852            timestamp,
2853            nonce.clone().into(),
2854        )
2855        .await?;
2856    let message = proto::ChannelMessage {
2857        sender_id: session.user_id.to_proto(),
2858        id: message_id.to_proto(),
2859        body,
2860        timestamp: timestamp.unix_timestamp() as u64,
2861        nonce: Some(nonce),
2862    };
2863    broadcast(Some(session.connection_id), connection_ids, |connection| {
2864        session.peer.send(
2865            connection,
2866            proto::ChannelMessageSent {
2867                channel_id: channel_id.to_proto(),
2868                message: Some(message.clone()),
2869            },
2870        )
2871    });
2872    response.send(proto::SendChannelMessageResponse {
2873        message: Some(message),
2874    })?;
2875
2876    let pool = &*session.connection_pool().await;
2877    broadcast(
2878        None,
2879        non_participants
2880            .iter()
2881            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2882        |peer_id| {
2883            session.peer.send(
2884                peer_id.into(),
2885                proto::UpdateChannels {
2886                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2887                        channel_id: channel_id.to_proto(),
2888                        message_id: message_id.to_proto(),
2889                    }],
2890                    ..Default::default()
2891                },
2892            )
2893        },
2894    );
2895
2896    Ok(())
2897}
2898
2899async fn remove_channel_message(
2900    request: proto::RemoveChannelMessage,
2901    response: Response<proto::RemoveChannelMessage>,
2902    session: Session,
2903) -> Result<()> {
2904    let channel_id = ChannelId::from_proto(request.channel_id);
2905    let message_id = MessageId::from_proto(request.message_id);
2906    let connection_ids = session
2907        .db()
2908        .await
2909        .remove_channel_message(channel_id, message_id, session.user_id)
2910        .await?;
2911    broadcast(Some(session.connection_id), connection_ids, |connection| {
2912        session.peer.send(connection, request.clone())
2913    });
2914    response.send(proto::Ack {})?;
2915    Ok(())
2916}
2917
2918async fn acknowledge_channel_message(
2919    request: proto::AckChannelMessage,
2920    session: Session,
2921) -> Result<()> {
2922    let channel_id = ChannelId::from_proto(request.channel_id);
2923    let message_id = MessageId::from_proto(request.message_id);
2924    session
2925        .db()
2926        .await
2927        .observe_channel_message(channel_id, session.user_id, message_id)
2928        .await?;
2929    Ok(())
2930}
2931
2932async fn acknowledge_buffer_version(
2933    request: proto::AckBufferOperation,
2934    session: Session,
2935) -> Result<()> {
2936    let buffer_id = BufferId::from_proto(request.buffer_id);
2937    session
2938        .db()
2939        .await
2940        .observe_buffer_version(
2941            buffer_id,
2942            session.user_id,
2943            request.epoch as i32,
2944            &request.version,
2945        )
2946        .await?;
2947    Ok(())
2948}
2949
2950async fn join_channel_chat(
2951    request: proto::JoinChannelChat,
2952    response: Response<proto::JoinChannelChat>,
2953    session: Session,
2954) -> Result<()> {
2955    let channel_id = ChannelId::from_proto(request.channel_id);
2956
2957    let db = session.db().await;
2958    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2959        .await?;
2960    let messages = db
2961        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2962        .await?;
2963    response.send(proto::JoinChannelChatResponse {
2964        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2965        messages,
2966    })?;
2967    Ok(())
2968}
2969
2970async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2971    let channel_id = ChannelId::from_proto(request.channel_id);
2972    session
2973        .db()
2974        .await
2975        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2976        .await?;
2977    Ok(())
2978}
2979
2980async fn get_channel_messages(
2981    request: proto::GetChannelMessages,
2982    response: Response<proto::GetChannelMessages>,
2983    session: Session,
2984) -> Result<()> {
2985    let channel_id = ChannelId::from_proto(request.channel_id);
2986    let messages = session
2987        .db()
2988        .await
2989        .get_channel_messages(
2990            channel_id,
2991            session.user_id,
2992            MESSAGE_COUNT_PER_PAGE,
2993            Some(MessageId::from_proto(request.before_message_id)),
2994        )
2995        .await?;
2996    response.send(proto::GetChannelMessagesResponse {
2997        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2998        messages,
2999    })?;
3000    Ok(())
3001}
3002
3003async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3004    let project_id = ProjectId::from_proto(request.project_id);
3005    let project_connection_ids = session
3006        .db()
3007        .await
3008        .project_connection_ids(project_id, session.connection_id)
3009        .await?;
3010    broadcast(
3011        Some(session.connection_id),
3012        project_connection_ids.iter().copied(),
3013        |connection_id| {
3014            session
3015                .peer
3016                .forward_send(session.connection_id, connection_id, request.clone())
3017        },
3018    );
3019    Ok(())
3020}
3021
3022async fn get_private_user_info(
3023    _request: proto::GetPrivateUserInfo,
3024    response: Response<proto::GetPrivateUserInfo>,
3025    session: Session,
3026) -> Result<()> {
3027    let db = session.db().await;
3028
3029    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3030    let user = db
3031        .get_user_by_id(session.user_id)
3032        .await?
3033        .ok_or_else(|| anyhow!("user not found"))?;
3034    let flags = db.get_user_flags(session.user_id).await?;
3035
3036    response.send(proto::GetPrivateUserInfoResponse {
3037        metrics_id,
3038        staff: user.admin,
3039        flags,
3040    })?;
3041    Ok(())
3042}
3043
3044fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3045    match message {
3046        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3047        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3048        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3049        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3050        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3051            code: frame.code.into(),
3052            reason: frame.reason,
3053        })),
3054    }
3055}
3056
3057fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3058    match message {
3059        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3060        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3061        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3062        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3063        AxumMessage::Close(frame) => {
3064            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3065                code: frame.code.into(),
3066                reason: frame.reason,
3067            }))
3068        }
3069    }
3070}
3071
3072fn build_initial_channels_update(
3073    channels: ChannelsForUser,
3074    channel_invites: Vec<db::Channel>,
3075) -> proto::UpdateChannels {
3076    let mut update = proto::UpdateChannels::default();
3077
3078    for channel in channels.channels.channels {
3079        update.channels.push(proto::Channel {
3080            id: channel.id.to_proto(),
3081            name: channel.name,
3082        });
3083    }
3084
3085    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3086    update.unseen_channel_messages = channels.channel_messages;
3087    update.insert_edge = channels.channels.edges;
3088
3089    for (channel_id, participants) in channels.channel_participants {
3090        update
3091            .channel_participants
3092            .push(proto::ChannelParticipants {
3093                channel_id: channel_id.to_proto(),
3094                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3095            });
3096    }
3097
3098    update
3099        .channel_permissions
3100        .extend(
3101            channels
3102                .channels_with_admin_privileges
3103                .into_iter()
3104                .map(|id| proto::ChannelPermission {
3105                    channel_id: id.to_proto(),
3106                    is_admin: true,
3107                }),
3108        );
3109
3110    for channel in channel_invites {
3111        update.channel_invitations.push(proto::Channel {
3112            id: channel.id.to_proto(),
3113            name: channel.name,
3114        });
3115    }
3116
3117    update
3118}
3119
3120fn build_initial_contacts_update(
3121    contacts: Vec<db::Contact>,
3122    pool: &ConnectionPool,
3123) -> proto::UpdateContacts {
3124    let mut update = proto::UpdateContacts::default();
3125
3126    for contact in contacts {
3127        match contact {
3128            db::Contact::Accepted {
3129                user_id,
3130                should_notify,
3131                busy,
3132            } => {
3133                update
3134                    .contacts
3135                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3136            }
3137            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3138            db::Contact::Incoming {
3139                user_id,
3140                should_notify,
3141            } => update
3142                .incoming_requests
3143                .push(proto::IncomingContactRequest {
3144                    requester_id: user_id.to_proto(),
3145                    should_notify,
3146                }),
3147        }
3148    }
3149
3150    update
3151}
3152
3153fn contact_for_user(
3154    user_id: UserId,
3155    should_notify: bool,
3156    busy: bool,
3157    pool: &ConnectionPool,
3158) -> proto::Contact {
3159    proto::Contact {
3160        user_id: user_id.to_proto(),
3161        online: pool.is_user_online(user_id),
3162        busy,
3163        should_notify,
3164    }
3165}
3166
3167fn room_updated(room: &proto::Room, peer: &Peer) {
3168    broadcast(
3169        None,
3170        room.participants
3171            .iter()
3172            .filter_map(|participant| Some(participant.peer_id?.into())),
3173        |peer_id| {
3174            peer.send(
3175                peer_id.into(),
3176                proto::RoomUpdated {
3177                    room: Some(room.clone()),
3178                },
3179            )
3180        },
3181    );
3182}
3183
3184fn channel_updated(
3185    channel_id: ChannelId,
3186    room: &proto::Room,
3187    channel_members: &[UserId],
3188    peer: &Peer,
3189    pool: &ConnectionPool,
3190) {
3191    let participants = room
3192        .participants
3193        .iter()
3194        .map(|p| p.user_id)
3195        .collect::<Vec<_>>();
3196
3197    broadcast(
3198        None,
3199        channel_members
3200            .iter()
3201            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3202        |peer_id| {
3203            peer.send(
3204                peer_id.into(),
3205                proto::UpdateChannels {
3206                    channel_participants: vec![proto::ChannelParticipants {
3207                        channel_id: channel_id.to_proto(),
3208                        participant_user_ids: participants.clone(),
3209                    }],
3210                    ..Default::default()
3211                },
3212            )
3213        },
3214    );
3215}
3216
3217async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3218    let db = session.db().await;
3219
3220    let contacts = db.get_contacts(user_id).await?;
3221    let busy = db.is_user_busy(user_id).await?;
3222
3223    let pool = session.connection_pool().await;
3224    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3225    for contact in contacts {
3226        if let db::Contact::Accepted {
3227            user_id: contact_user_id,
3228            ..
3229        } = contact
3230        {
3231            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3232                session
3233                    .peer
3234                    .send(
3235                        contact_conn_id,
3236                        proto::UpdateContacts {
3237                            contacts: vec![updated_contact.clone()],
3238                            remove_contacts: Default::default(),
3239                            incoming_requests: Default::default(),
3240                            remove_incoming_requests: Default::default(),
3241                            outgoing_requests: Default::default(),
3242                            remove_outgoing_requests: Default::default(),
3243                        },
3244                    )
3245                    .trace_err();
3246            }
3247        }
3248    }
3249    Ok(())
3250}
3251
3252async fn leave_room_for_session(session: &Session) -> Result<()> {
3253    let mut contacts_to_update = HashSet::default();
3254
3255    let room_id;
3256    let canceled_calls_to_user_ids;
3257    let live_kit_room;
3258    let delete_live_kit_room;
3259    let room;
3260    let channel_members;
3261    let channel_id;
3262
3263    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3264        contacts_to_update.insert(session.user_id);
3265
3266        for project in left_room.left_projects.values() {
3267            project_left(project, session);
3268        }
3269
3270        room_id = RoomId::from_proto(left_room.room.id);
3271        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3272        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3273        delete_live_kit_room = left_room.deleted;
3274        room = mem::take(&mut left_room.room);
3275        channel_members = mem::take(&mut left_room.channel_members);
3276        channel_id = left_room.channel_id;
3277
3278        room_updated(&room, &session.peer);
3279    } else {
3280        return Ok(());
3281    }
3282
3283    if let Some(channel_id) = channel_id {
3284        channel_updated(
3285            channel_id,
3286            &room,
3287            &channel_members,
3288            &session.peer,
3289            &*session.connection_pool().await,
3290        );
3291    }
3292
3293    {
3294        let pool = session.connection_pool().await;
3295        for canceled_user_id in canceled_calls_to_user_ids {
3296            for connection_id in pool.user_connection_ids(canceled_user_id) {
3297                session
3298                    .peer
3299                    .send(
3300                        connection_id,
3301                        proto::CallCanceled {
3302                            room_id: room_id.to_proto(),
3303                        },
3304                    )
3305                    .trace_err();
3306            }
3307            contacts_to_update.insert(canceled_user_id);
3308        }
3309    }
3310
3311    for contact_user_id in contacts_to_update {
3312        update_user_contacts(contact_user_id, &session).await?;
3313    }
3314
3315    if let Some(live_kit) = session.live_kit_client.as_ref() {
3316        live_kit
3317            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3318            .await
3319            .trace_err();
3320
3321        if delete_live_kit_room {
3322            live_kit.delete_room(live_kit_room).await.trace_err();
3323        }
3324    }
3325
3326    Ok(())
3327}
3328
3329async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3330    let left_channel_buffers = session
3331        .db()
3332        .await
3333        .leave_channel_buffers(session.connection_id)
3334        .await?;
3335
3336    for left_buffer in left_channel_buffers {
3337        channel_buffer_updated(
3338            session.connection_id,
3339            left_buffer.connections,
3340            &proto::UpdateChannelBufferCollaborators {
3341                channel_id: left_buffer.channel_id.to_proto(),
3342                collaborators: left_buffer.collaborators,
3343            },
3344            &session.peer,
3345        );
3346    }
3347
3348    Ok(())
3349}
3350
3351fn project_left(project: &db::LeftProject, session: &Session) {
3352    for connection_id in &project.connection_ids {
3353        if project.host_user_id == session.user_id {
3354            session
3355                .peer
3356                .send(
3357                    *connection_id,
3358                    proto::UnshareProject {
3359                        project_id: project.id.to_proto(),
3360                    },
3361                )
3362                .trace_err();
3363        } else {
3364            session
3365                .peer
3366                .send(
3367                    *connection_id,
3368                    proto::RemoveProjectCollaborator {
3369                        project_id: project.id.to_proto(),
3370                        peer_id: Some(session.connection_id.into()),
3371                    },
3372                )
3373                .trace_err();
3374        }
3375    }
3376}
3377
3378pub trait ResultExt {
3379    type Ok;
3380
3381    fn trace_err(self) -> Option<Self::Ok>;
3382}
3383
3384impl<T, E> ResultExt for Result<T, E>
3385where
3386    E: std::fmt::Debug,
3387{
3388    type Ok = T;
3389
3390    fn trace_err(self) -> Option<T> {
3391        match self {
3392            Ok(value) => Some(value),
3393            Err(error) => {
3394                tracing::error!("{:?}", error);
3395                None
3396            }
3397        }
3398    }
3399}