rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
   7        ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66
  67pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  68pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  69
  70const MESSAGE_COUNT_PER_PAGE: usize = 100;
  71const MAX_MESSAGE_LEN: usize = 1024;
  72
  73lazy_static! {
  74    static ref METRIC_CONNECTIONS: IntGauge =
  75        register_int_gauge!("connections", "number of connections").unwrap();
  76    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  77        "shared_projects",
  78        "number of open projects with one or more guests"
  79    )
  80    .unwrap();
  81}
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100#[derive(Clone)]
 101struct Session {
 102    user_id: UserId,
 103    connection_id: ConnectionId,
 104    db: Arc<tokio::sync::Mutex<DbHandle>>,
 105    peer: Arc<Peer>,
 106    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 107    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 108    executor: Executor,
 109}
 110
 111impl Session {
 112    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        let guard = self.db.lock().await;
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        guard
 119    }
 120
 121    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 122        #[cfg(test)]
 123        tokio::task::yield_now().await;
 124        let guard = self.connection_pool.lock();
 125        ConnectionPoolGuard {
 126            guard,
 127            _not_send: PhantomData,
 128        }
 129    }
 130}
 131
 132impl fmt::Debug for Session {
 133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 134        f.debug_struct("Session")
 135            .field("user_id", &self.user_id)
 136            .field("connection_id", &self.connection_id)
 137            .finish()
 138    }
 139}
 140
 141struct DbHandle(Arc<Database>);
 142
 143impl Deref for DbHandle {
 144    type Target = Database;
 145
 146    fn deref(&self) -> &Self::Target {
 147        self.0.as_ref()
 148    }
 149}
 150
 151pub struct Server {
 152    id: parking_lot::Mutex<ServerId>,
 153    peer: Arc<Peer>,
 154    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 155    app_state: Arc<AppState>,
 156    executor: Executor,
 157    handlers: HashMap<TypeId, MessageHandler>,
 158    teardown: watch::Sender<()>,
 159}
 160
 161pub(crate) struct ConnectionPoolGuard<'a> {
 162    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 163    _not_send: PhantomData<Rc<()>>,
 164}
 165
 166#[derive(Serialize)]
 167pub struct ServerSnapshot<'a> {
 168    peer: &'a Peer,
 169    #[serde(serialize_with = "serialize_deref")]
 170    connection_pool: ConnectionPoolGuard<'a>,
 171}
 172
 173pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 174where
 175    S: Serializer,
 176    T: Deref<Target = U>,
 177    U: Serialize,
 178{
 179    Serialize::serialize(value.deref(), serializer)
 180}
 181
 182impl Server {
 183    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 184        let mut server = Self {
 185            id: parking_lot::Mutex::new(id),
 186            peer: Peer::new(id.0 as u32),
 187            app_state,
 188            executor,
 189            connection_pool: Default::default(),
 190            handlers: Default::default(),
 191            teardown: watch::channel(()).0,
 192        };
 193
 194        server
 195            .add_request_handler(ping)
 196            .add_request_handler(create_room)
 197            .add_request_handler(join_room)
 198            .add_request_handler(rejoin_room)
 199            .add_request_handler(leave_room)
 200            .add_request_handler(call)
 201            .add_request_handler(cancel_call)
 202            .add_message_handler(decline_call)
 203            .add_request_handler(update_participant_location)
 204            .add_request_handler(share_project)
 205            .add_message_handler(unshare_project)
 206            .add_request_handler(join_project)
 207            .add_message_handler(leave_project)
 208            .add_request_handler(update_project)
 209            .add_request_handler(update_worktree)
 210            .add_message_handler(start_language_server)
 211            .add_message_handler(update_language_server)
 212            .add_message_handler(update_diagnostic_summary)
 213            .add_message_handler(update_worktree_settings)
 214            .add_message_handler(refresh_inlay_hints)
 215            .add_request_handler(forward_project_request::<proto::GetHover>)
 216            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 217            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetReferences>)
 219            .add_request_handler(forward_project_request::<proto::SearchProject>)
 220            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 221            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 222            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 225            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 226            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 227            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 229            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 230            .add_request_handler(forward_project_request::<proto::PerformRename>)
 231            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 232            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 233            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 234            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 235            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 240            .add_request_handler(forward_project_request::<proto::InlayHints>)
 241            .add_message_handler(create_buffer_for_peer)
 242            .add_request_handler(update_buffer)
 243            .add_message_handler(update_buffer_file)
 244            .add_message_handler(buffer_reloaded)
 245            .add_message_handler(buffer_saved)
 246            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 247            .add_request_handler(get_users)
 248            .add_request_handler(fuzzy_search_users)
 249            .add_request_handler(request_contact)
 250            .add_request_handler(remove_contact)
 251            .add_request_handler(respond_to_contact_request)
 252            .add_request_handler(create_channel)
 253            .add_request_handler(delete_channel)
 254            .add_request_handler(invite_channel_member)
 255            .add_request_handler(remove_channel_member)
 256            .add_request_handler(set_channel_member_admin)
 257            .add_request_handler(rename_channel)
 258            .add_request_handler(join_channel_buffer)
 259            .add_request_handler(leave_channel_buffer)
 260            .add_message_handler(update_channel_buffer)
 261            .add_request_handler(rejoin_channel_buffers)
 262            .add_request_handler(get_channel_members)
 263            .add_request_handler(respond_to_channel_invite)
 264            .add_request_handler(join_channel)
 265            .add_request_handler(join_channel_chat)
 266            .add_message_handler(leave_channel_chat)
 267            .add_request_handler(send_channel_message)
 268            .add_request_handler(remove_channel_message)
 269            .add_request_handler(get_channel_messages)
 270            .add_request_handler(link_channel)
 271            .add_request_handler(unlink_channel)
 272            .add_request_handler(move_channel)
 273            .add_request_handler(follow)
 274            .add_message_handler(unfollow)
 275            .add_message_handler(update_followers)
 276            .add_message_handler(update_diff_base)
 277            .add_request_handler(get_private_user_info)
 278            .add_message_handler(acknowledge_channel_message)
 279            .add_message_handler(acknowledge_buffer_version);
 280
 281        Arc::new(server)
 282    }
 283
 284    pub async fn start(&self) -> Result<()> {
 285        let server_id = *self.id.lock();
 286        let app_state = self.app_state.clone();
 287        let peer = self.peer.clone();
 288        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 289        let pool = self.connection_pool.clone();
 290        let live_kit_client = self.app_state.live_kit_client.clone();
 291
 292        let span = info_span!("start server");
 293        self.executor.spawn_detached(
 294            async move {
 295                tracing::info!("waiting for cleanup timeout");
 296                timeout.await;
 297                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 298                if let Some((room_ids, channel_ids)) = app_state
 299                    .db
 300                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 301                    .await
 302                    .trace_err()
 303                {
 304                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 305                    tracing::info!(
 306                        stale_channel_buffer_count = channel_ids.len(),
 307                        "retrieved stale channel buffers"
 308                    );
 309
 310                    for channel_id in channel_ids {
 311                        if let Some(refreshed_channel_buffer) = app_state
 312                            .db
 313                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 314                            .await
 315                            .trace_err()
 316                        {
 317                            for connection_id in refreshed_channel_buffer.connection_ids {
 318                                peer.send(
 319                                    connection_id,
 320                                    proto::UpdateChannelBufferCollaborators {
 321                                        channel_id: channel_id.to_proto(),
 322                                        collaborators: refreshed_channel_buffer
 323                                            .collaborators
 324                                            .clone(),
 325                                    },
 326                                )
 327                                .trace_err();
 328                            }
 329                        }
 330                    }
 331
 332                    for room_id in room_ids {
 333                        let mut contacts_to_update = HashSet::default();
 334                        let mut canceled_calls_to_user_ids = Vec::new();
 335                        let mut live_kit_room = String::new();
 336                        let mut delete_live_kit_room = false;
 337
 338                        if let Some(mut refreshed_room) = app_state
 339                            .db
 340                            .clear_stale_room_participants(room_id, server_id)
 341                            .await
 342                            .trace_err()
 343                        {
 344                            tracing::info!(
 345                                room_id = room_id.0,
 346                                new_participant_count = refreshed_room.room.participants.len(),
 347                                "refreshed room"
 348                            );
 349                            room_updated(&refreshed_room.room, &peer);
 350                            if let Some(channel_id) = refreshed_room.channel_id {
 351                                channel_updated(
 352                                    channel_id,
 353                                    &refreshed_room.room,
 354                                    &refreshed_room.channel_members,
 355                                    &peer,
 356                                    &*pool.lock(),
 357                                );
 358                            }
 359                            contacts_to_update
 360                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 361                            contacts_to_update
 362                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 363                            canceled_calls_to_user_ids =
 364                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 365                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 366                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 367                        }
 368
 369                        {
 370                            let pool = pool.lock();
 371                            for canceled_user_id in canceled_calls_to_user_ids {
 372                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 373                                    peer.send(
 374                                        connection_id,
 375                                        proto::CallCanceled {
 376                                            room_id: room_id.to_proto(),
 377                                        },
 378                                    )
 379                                    .trace_err();
 380                                }
 381                            }
 382                        }
 383
 384                        for user_id in contacts_to_update {
 385                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 386                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 387                            if let Some((busy, contacts)) = busy.zip(contacts) {
 388                                let pool = pool.lock();
 389                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 390                                for contact in contacts {
 391                                    if let db::Contact::Accepted {
 392                                        user_id: contact_user_id,
 393                                        ..
 394                                    } = contact
 395                                    {
 396                                        for contact_conn_id in
 397                                            pool.user_connection_ids(contact_user_id)
 398                                        {
 399                                            peer.send(
 400                                                contact_conn_id,
 401                                                proto::UpdateContacts {
 402                                                    contacts: vec![updated_contact.clone()],
 403                                                    remove_contacts: Default::default(),
 404                                                    incoming_requests: Default::default(),
 405                                                    remove_incoming_requests: Default::default(),
 406                                                    outgoing_requests: Default::default(),
 407                                                    remove_outgoing_requests: Default::default(),
 408                                                },
 409                                            )
 410                                            .trace_err();
 411                                        }
 412                                    }
 413                                }
 414                            }
 415                        }
 416
 417                        if let Some(live_kit) = live_kit_client.as_ref() {
 418                            if delete_live_kit_room {
 419                                live_kit.delete_room(live_kit_room).await.trace_err();
 420                            }
 421                        }
 422                    }
 423                }
 424
 425                app_state
 426                    .db
 427                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 428                    .await
 429                    .trace_err();
 430            }
 431            .instrument(span),
 432        );
 433        Ok(())
 434    }
 435
 436    pub fn teardown(&self) {
 437        self.peer.teardown();
 438        self.connection_pool.lock().reset();
 439        let _ = self.teardown.send(());
 440    }
 441
 442    #[cfg(test)]
 443    pub fn reset(&self, id: ServerId) {
 444        self.teardown();
 445        *self.id.lock() = id;
 446        self.peer.reset(id.0 as u32);
 447    }
 448
 449    #[cfg(test)]
 450    pub fn id(&self) -> ServerId {
 451        *self.id.lock()
 452    }
 453
 454    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 455    where
 456        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 457        Fut: 'static + Send + Future<Output = Result<()>>,
 458        M: EnvelopedMessage,
 459    {
 460        let prev_handler = self.handlers.insert(
 461            TypeId::of::<M>(),
 462            Box::new(move |envelope, session| {
 463                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 464                let span = info_span!(
 465                    "handle message",
 466                    payload_type = envelope.payload_type_name()
 467                );
 468                span.in_scope(|| {
 469                    tracing::info!(
 470                        payload_type = envelope.payload_type_name(),
 471                        "message received"
 472                    );
 473                });
 474                let start_time = Instant::now();
 475                let future = (handler)(*envelope, session);
 476                async move {
 477                    let result = future.await;
 478                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 479                    match result {
 480                        Err(error) => {
 481                            tracing::error!(%error, ?duration_ms, "error handling message")
 482                        }
 483                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 484                    }
 485                }
 486                .instrument(span)
 487                .boxed()
 488            }),
 489        );
 490        if prev_handler.is_some() {
 491            panic!("registered a handler for the same message twice");
 492        }
 493        self
 494    }
 495
 496    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 497    where
 498        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 499        Fut: 'static + Send + Future<Output = Result<()>>,
 500        M: EnvelopedMessage,
 501    {
 502        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 503        self
 504    }
 505
 506    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 507    where
 508        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 509        Fut: Send + Future<Output = Result<()>>,
 510        M: RequestMessage,
 511    {
 512        let handler = Arc::new(handler);
 513        self.add_handler(move |envelope, session| {
 514            let receipt = envelope.receipt();
 515            let handler = handler.clone();
 516            async move {
 517                let peer = session.peer.clone();
 518                let responded = Arc::new(AtomicBool::default());
 519                let response = Response {
 520                    peer: peer.clone(),
 521                    responded: responded.clone(),
 522                    receipt,
 523                };
 524                match (handler)(envelope.payload, response, session).await {
 525                    Ok(()) => {
 526                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 527                            Ok(())
 528                        } else {
 529                            Err(anyhow!("handler did not send a response"))?
 530                        }
 531                    }
 532                    Err(error) => {
 533                        peer.respond_with_error(
 534                            receipt,
 535                            proto::Error {
 536                                message: error.to_string(),
 537                            },
 538                        )?;
 539                        Err(error)
 540                    }
 541                }
 542            }
 543        })
 544    }
 545
 546    pub fn handle_connection(
 547        self: &Arc<Self>,
 548        connection: Connection,
 549        address: String,
 550        user: User,
 551        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 552        executor: Executor,
 553    ) -> impl Future<Output = Result<()>> {
 554        let this = self.clone();
 555        let user_id = user.id;
 556        let login = user.github_login;
 557        let span = info_span!("handle connection", %user_id, %login, %address);
 558        let mut teardown = self.teardown.subscribe();
 559        async move {
 560            let (connection_id, handle_io, mut incoming_rx) = this
 561                .peer
 562                .add_connection(connection, {
 563                    let executor = executor.clone();
 564                    move |duration| executor.sleep(duration)
 565                });
 566
 567            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 568            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 569            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 570
 571            if let Some(send_connection_id) = send_connection_id.take() {
 572                let _ = send_connection_id.send(connection_id);
 573            }
 574
 575            if !user.connected_once {
 576                this.peer.send(connection_id, proto::ShowContacts {})?;
 577                this.app_state.db.set_user_connected_once(user_id, true).await?;
 578            }
 579
 580            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 581                this.app_state.db.get_contacts(user_id),
 582                this.app_state.db.get_channels_for_user(user_id),
 583                this.app_state.db.get_channel_invites_for_user(user_id)
 584            ).await?;
 585
 586            {
 587                let mut pool = this.connection_pool.lock();
 588                pool.add_connection(connection_id, user_id, user.admin);
 589                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 590                this.peer.send(connection_id, build_initial_channels_update(
 591                    channels_for_user,
 592                    channel_invites
 593                ))?;
 594            }
 595
 596            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 597                this.peer.send(connection_id, incoming_call)?;
 598            }
 599
 600            let session = Session {
 601                user_id,
 602                connection_id,
 603                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 604                peer: this.peer.clone(),
 605                connection_pool: this.connection_pool.clone(),
 606                live_kit_client: this.app_state.live_kit_client.clone(),
 607                executor: executor.clone(),
 608            };
 609            update_user_contacts(user_id, &session).await?;
 610
 611            let handle_io = handle_io.fuse();
 612            futures::pin_mut!(handle_io);
 613
 614            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 615            // This prevents deadlocks when e.g., client A performs a request to client B and
 616            // client B performs a request to client A. If both clients stop processing further
 617            // messages until their respective request completes, they won't have a chance to
 618            // respond to the other client's request and cause a deadlock.
 619            //
 620            // This arrangement ensures we will attempt to process earlier messages first, but fall
 621            // back to processing messages arrived later in the spirit of making progress.
 622            let mut foreground_message_handlers = FuturesUnordered::new();
 623            let concurrent_handlers = Arc::new(Semaphore::new(256));
 624            loop {
 625                let next_message = async {
 626                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 627                    let message = incoming_rx.next().await;
 628                    (permit, message)
 629                }.fuse();
 630                futures::pin_mut!(next_message);
 631                futures::select_biased! {
 632                    _ = teardown.changed().fuse() => return Ok(()),
 633                    result = handle_io => {
 634                        if let Err(error) = result {
 635                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 636                        }
 637                        break;
 638                    }
 639                    _ = foreground_message_handlers.next() => {}
 640                    next_message = next_message => {
 641                        let (permit, message) = next_message;
 642                        if let Some(message) = message {
 643                            let type_name = message.payload_type_name();
 644                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 645                            let span_enter = span.enter();
 646                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 647                                let is_background = message.is_background();
 648                                let handle_message = (handler)(message, session.clone());
 649                                drop(span_enter);
 650
 651                                let handle_message = async move {
 652                                    handle_message.await;
 653                                    drop(permit);
 654                                }.instrument(span);
 655                                if is_background {
 656                                    executor.spawn_detached(handle_message);
 657                                } else {
 658                                    foreground_message_handlers.push(handle_message);
 659                                }
 660                            } else {
 661                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 662                            }
 663                        } else {
 664                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 665                            break;
 666                        }
 667                    }
 668                }
 669            }
 670
 671            drop(foreground_message_handlers);
 672            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 673            if let Err(error) = connection_lost(session, teardown, executor).await {
 674                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 675            }
 676
 677            Ok(())
 678        }.instrument(span)
 679    }
 680
 681    pub async fn invite_code_redeemed(
 682        self: &Arc<Self>,
 683        inviter_id: UserId,
 684        invitee_id: UserId,
 685    ) -> Result<()> {
 686        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 687            if let Some(code) = &user.invite_code {
 688                let pool = self.connection_pool.lock();
 689                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 690                for connection_id in pool.user_connection_ids(inviter_id) {
 691                    self.peer.send(
 692                        connection_id,
 693                        proto::UpdateContacts {
 694                            contacts: vec![invitee_contact.clone()],
 695                            ..Default::default()
 696                        },
 697                    )?;
 698                    self.peer.send(
 699                        connection_id,
 700                        proto::UpdateInviteInfo {
 701                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 702                            count: user.invite_count as u32,
 703                        },
 704                    )?;
 705                }
 706            }
 707        }
 708        Ok(())
 709    }
 710
 711    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 712        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 713            if let Some(invite_code) = &user.invite_code {
 714                let pool = self.connection_pool.lock();
 715                for connection_id in pool.user_connection_ids(user_id) {
 716                    self.peer.send(
 717                        connection_id,
 718                        proto::UpdateInviteInfo {
 719                            url: format!(
 720                                "{}{}",
 721                                self.app_state.config.invite_link_prefix, invite_code
 722                            ),
 723                            count: user.invite_count as u32,
 724                        },
 725                    )?;
 726                }
 727            }
 728        }
 729        Ok(())
 730    }
 731
 732    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 733        ServerSnapshot {
 734            connection_pool: ConnectionPoolGuard {
 735                guard: self.connection_pool.lock(),
 736                _not_send: PhantomData,
 737            },
 738            peer: &self.peer,
 739        }
 740    }
 741}
 742
 743impl<'a> Deref for ConnectionPoolGuard<'a> {
 744    type Target = ConnectionPool;
 745
 746    fn deref(&self) -> &Self::Target {
 747        &*self.guard
 748    }
 749}
 750
 751impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 752    fn deref_mut(&mut self) -> &mut Self::Target {
 753        &mut *self.guard
 754    }
 755}
 756
 757impl<'a> Drop for ConnectionPoolGuard<'a> {
 758    fn drop(&mut self) {
 759        #[cfg(test)]
 760        self.check_invariants();
 761    }
 762}
 763
 764fn broadcast<F>(
 765    sender_id: Option<ConnectionId>,
 766    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 767    mut f: F,
 768) where
 769    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 770{
 771    for receiver_id in receiver_ids {
 772        if Some(receiver_id) != sender_id {
 773            if let Err(error) = f(receiver_id) {
 774                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 775            }
 776        }
 777    }
 778}
 779
 780lazy_static! {
 781    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 782}
 783
 784pub struct ProtocolVersion(u32);
 785
 786impl Header for ProtocolVersion {
 787    fn name() -> &'static HeaderName {
 788        &ZED_PROTOCOL_VERSION
 789    }
 790
 791    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 792    where
 793        Self: Sized,
 794        I: Iterator<Item = &'i axum::http::HeaderValue>,
 795    {
 796        let version = values
 797            .next()
 798            .ok_or_else(axum::headers::Error::invalid)?
 799            .to_str()
 800            .map_err(|_| axum::headers::Error::invalid())?
 801            .parse()
 802            .map_err(|_| axum::headers::Error::invalid())?;
 803        Ok(Self(version))
 804    }
 805
 806    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 807        values.extend([self.0.to_string().parse().unwrap()]);
 808    }
 809}
 810
 811pub fn routes(server: Arc<Server>) -> Router<Body> {
 812    Router::new()
 813        .route("/rpc", get(handle_websocket_request))
 814        .layer(
 815            ServiceBuilder::new()
 816                .layer(Extension(server.app_state.clone()))
 817                .layer(middleware::from_fn(auth::validate_header)),
 818        )
 819        .route("/metrics", get(handle_metrics))
 820        .layer(Extension(server))
 821}
 822
 823pub async fn handle_websocket_request(
 824    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 825    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 826    Extension(server): Extension<Arc<Server>>,
 827    Extension(user): Extension<User>,
 828    ws: WebSocketUpgrade,
 829) -> axum::response::Response {
 830    if protocol_version != rpc::PROTOCOL_VERSION {
 831        return (
 832            StatusCode::UPGRADE_REQUIRED,
 833            "client must be upgraded".to_string(),
 834        )
 835            .into_response();
 836    }
 837    let socket_address = socket_address.to_string();
 838    ws.on_upgrade(move |socket| {
 839        use util::ResultExt;
 840        let socket = socket
 841            .map_ok(to_tungstenite_message)
 842            .err_into()
 843            .with(|message| async move { Ok(to_axum_message(message)) });
 844        let connection = Connection::new(Box::pin(socket));
 845        async move {
 846            server
 847                .handle_connection(connection, socket_address, user, None, Executor::Production)
 848                .await
 849                .log_err();
 850        }
 851    })
 852}
 853
 854pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 855    let connections = server
 856        .connection_pool
 857        .lock()
 858        .connections()
 859        .filter(|connection| !connection.admin)
 860        .count();
 861
 862    METRIC_CONNECTIONS.set(connections as _);
 863
 864    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 865    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 866
 867    let encoder = prometheus::TextEncoder::new();
 868    let metric_families = prometheus::gather();
 869    let encoded_metrics = encoder
 870        .encode_to_string(&metric_families)
 871        .map_err(|err| anyhow!("{}", err))?;
 872    Ok(encoded_metrics)
 873}
 874
 875#[instrument(err, skip(executor))]
 876async fn connection_lost(
 877    session: Session,
 878    mut teardown: watch::Receiver<()>,
 879    executor: Executor,
 880) -> Result<()> {
 881    session.peer.disconnect(session.connection_id);
 882    session
 883        .connection_pool()
 884        .await
 885        .remove_connection(session.connection_id)?;
 886
 887    session
 888        .db()
 889        .await
 890        .connection_lost(session.connection_id)
 891        .await
 892        .trace_err();
 893
 894    futures::select_biased! {
 895        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 896            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 897            leave_room_for_session(&session).await.trace_err();
 898            leave_channel_buffers_for_session(&session)
 899                .await
 900                .trace_err();
 901
 902            if !session
 903                .connection_pool()
 904                .await
 905                .is_user_online(session.user_id)
 906            {
 907                let db = session.db().await;
 908                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 909                    room_updated(&room, &session.peer);
 910                }
 911            }
 912
 913            update_user_contacts(session.user_id, &session).await?;
 914        }
 915        _ = teardown.changed().fuse() => {}
 916    }
 917
 918    Ok(())
 919}
 920
 921async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 922    response.send(proto::Ack {})?;
 923    Ok(())
 924}
 925
 926async fn create_room(
 927    _request: proto::CreateRoom,
 928    response: Response<proto::CreateRoom>,
 929    session: Session,
 930) -> Result<()> {
 931    let live_kit_room = nanoid::nanoid!(30);
 932
 933    let live_kit_connection_info = {
 934        let live_kit_room = live_kit_room.clone();
 935        let live_kit = session.live_kit_client.as_ref();
 936
 937        util::async_iife!({
 938            let live_kit = live_kit?;
 939
 940            live_kit
 941                .create_room(live_kit_room.clone())
 942                .await
 943                .trace_err()?;
 944
 945            let token = live_kit
 946                .room_token(&live_kit_room, &session.user_id.to_string())
 947                .trace_err()?;
 948
 949            Some(proto::LiveKitConnectionInfo {
 950                server_url: live_kit.url().into(),
 951                token,
 952            })
 953        })
 954    }
 955    .await;
 956
 957    let room = session
 958        .db()
 959        .await
 960        .create_room(session.user_id, session.connection_id, &live_kit_room)
 961        .await?;
 962
 963    response.send(proto::CreateRoomResponse {
 964        room: Some(room.clone()),
 965        live_kit_connection_info,
 966    })?;
 967
 968    update_user_contacts(session.user_id, &session).await?;
 969    Ok(())
 970}
 971
 972async fn join_room(
 973    request: proto::JoinRoom,
 974    response: Response<proto::JoinRoom>,
 975    session: Session,
 976) -> Result<()> {
 977    let room_id = RoomId::from_proto(request.id);
 978    let joined_room = {
 979        let room = session
 980            .db()
 981            .await
 982            .join_room(room_id, session.user_id, session.connection_id)
 983            .await?;
 984        room_updated(&room.room, &session.peer);
 985        room.into_inner()
 986    };
 987
 988    if let Some(channel_id) = joined_room.channel_id {
 989        channel_updated(
 990            channel_id,
 991            &joined_room.room,
 992            &joined_room.channel_members,
 993            &session.peer,
 994            &*session.connection_pool().await,
 995        )
 996    }
 997
 998    for connection_id in session
 999        .connection_pool()
1000        .await
1001        .user_connection_ids(session.user_id)
1002    {
1003        session
1004            .peer
1005            .send(
1006                connection_id,
1007                proto::CallCanceled {
1008                    room_id: room_id.to_proto(),
1009                },
1010            )
1011            .trace_err();
1012    }
1013
1014    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1015        if let Some(token) = live_kit
1016            .room_token(
1017                &joined_room.room.live_kit_room,
1018                &session.user_id.to_string(),
1019            )
1020            .trace_err()
1021        {
1022            Some(proto::LiveKitConnectionInfo {
1023                server_url: live_kit.url().into(),
1024                token,
1025            })
1026        } else {
1027            None
1028        }
1029    } else {
1030        None
1031    };
1032
1033    response.send(proto::JoinRoomResponse {
1034        room: Some(joined_room.room),
1035        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1036        live_kit_connection_info,
1037    })?;
1038
1039    update_user_contacts(session.user_id, &session).await?;
1040    Ok(())
1041}
1042
1043async fn rejoin_room(
1044    request: proto::RejoinRoom,
1045    response: Response<proto::RejoinRoom>,
1046    session: Session,
1047) -> Result<()> {
1048    let room;
1049    let channel_id;
1050    let channel_members;
1051    {
1052        let mut rejoined_room = session
1053            .db()
1054            .await
1055            .rejoin_room(request, session.user_id, session.connection_id)
1056            .await?;
1057
1058        response.send(proto::RejoinRoomResponse {
1059            room: Some(rejoined_room.room.clone()),
1060            reshared_projects: rejoined_room
1061                .reshared_projects
1062                .iter()
1063                .map(|project| proto::ResharedProject {
1064                    id: project.id.to_proto(),
1065                    collaborators: project
1066                        .collaborators
1067                        .iter()
1068                        .map(|collaborator| collaborator.to_proto())
1069                        .collect(),
1070                })
1071                .collect(),
1072            rejoined_projects: rejoined_room
1073                .rejoined_projects
1074                .iter()
1075                .map(|rejoined_project| proto::RejoinedProject {
1076                    id: rejoined_project.id.to_proto(),
1077                    worktrees: rejoined_project
1078                        .worktrees
1079                        .iter()
1080                        .map(|worktree| proto::WorktreeMetadata {
1081                            id: worktree.id,
1082                            root_name: worktree.root_name.clone(),
1083                            visible: worktree.visible,
1084                            abs_path: worktree.abs_path.clone(),
1085                        })
1086                        .collect(),
1087                    collaborators: rejoined_project
1088                        .collaborators
1089                        .iter()
1090                        .map(|collaborator| collaborator.to_proto())
1091                        .collect(),
1092                    language_servers: rejoined_project.language_servers.clone(),
1093                })
1094                .collect(),
1095        })?;
1096        room_updated(&rejoined_room.room, &session.peer);
1097
1098        for project in &rejoined_room.reshared_projects {
1099            for collaborator in &project.collaborators {
1100                session
1101                    .peer
1102                    .send(
1103                        collaborator.connection_id,
1104                        proto::UpdateProjectCollaborator {
1105                            project_id: project.id.to_proto(),
1106                            old_peer_id: Some(project.old_connection_id.into()),
1107                            new_peer_id: Some(session.connection_id.into()),
1108                        },
1109                    )
1110                    .trace_err();
1111            }
1112
1113            broadcast(
1114                Some(session.connection_id),
1115                project
1116                    .collaborators
1117                    .iter()
1118                    .map(|collaborator| collaborator.connection_id),
1119                |connection_id| {
1120                    session.peer.forward_send(
1121                        session.connection_id,
1122                        connection_id,
1123                        proto::UpdateProject {
1124                            project_id: project.id.to_proto(),
1125                            worktrees: project.worktrees.clone(),
1126                        },
1127                    )
1128                },
1129            );
1130        }
1131
1132        for project in &rejoined_room.rejoined_projects {
1133            for collaborator in &project.collaborators {
1134                session
1135                    .peer
1136                    .send(
1137                        collaborator.connection_id,
1138                        proto::UpdateProjectCollaborator {
1139                            project_id: project.id.to_proto(),
1140                            old_peer_id: Some(project.old_connection_id.into()),
1141                            new_peer_id: Some(session.connection_id.into()),
1142                        },
1143                    )
1144                    .trace_err();
1145            }
1146        }
1147
1148        for project in &mut rejoined_room.rejoined_projects {
1149            for worktree in mem::take(&mut project.worktrees) {
1150                #[cfg(any(test, feature = "test-support"))]
1151                const MAX_CHUNK_SIZE: usize = 2;
1152                #[cfg(not(any(test, feature = "test-support")))]
1153                const MAX_CHUNK_SIZE: usize = 256;
1154
1155                // Stream this worktree's entries.
1156                let message = proto::UpdateWorktree {
1157                    project_id: project.id.to_proto(),
1158                    worktree_id: worktree.id,
1159                    abs_path: worktree.abs_path.clone(),
1160                    root_name: worktree.root_name,
1161                    updated_entries: worktree.updated_entries,
1162                    removed_entries: worktree.removed_entries,
1163                    scan_id: worktree.scan_id,
1164                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1165                    updated_repositories: worktree.updated_repositories,
1166                    removed_repositories: worktree.removed_repositories,
1167                };
1168                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1169                    session.peer.send(session.connection_id, update.clone())?;
1170                }
1171
1172                // Stream this worktree's diagnostics.
1173                for summary in worktree.diagnostic_summaries {
1174                    session.peer.send(
1175                        session.connection_id,
1176                        proto::UpdateDiagnosticSummary {
1177                            project_id: project.id.to_proto(),
1178                            worktree_id: worktree.id,
1179                            summary: Some(summary),
1180                        },
1181                    )?;
1182                }
1183
1184                for settings_file in worktree.settings_files {
1185                    session.peer.send(
1186                        session.connection_id,
1187                        proto::UpdateWorktreeSettings {
1188                            project_id: project.id.to_proto(),
1189                            worktree_id: worktree.id,
1190                            path: settings_file.path,
1191                            content: Some(settings_file.content),
1192                        },
1193                    )?;
1194                }
1195            }
1196
1197            for language_server in &project.language_servers {
1198                session.peer.send(
1199                    session.connection_id,
1200                    proto::UpdateLanguageServer {
1201                        project_id: project.id.to_proto(),
1202                        language_server_id: language_server.id,
1203                        variant: Some(
1204                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1205                                proto::LspDiskBasedDiagnosticsUpdated {},
1206                            ),
1207                        ),
1208                    },
1209                )?;
1210            }
1211        }
1212
1213        let rejoined_room = rejoined_room.into_inner();
1214
1215        room = rejoined_room.room;
1216        channel_id = rejoined_room.channel_id;
1217        channel_members = rejoined_room.channel_members;
1218    }
1219
1220    if let Some(channel_id) = channel_id {
1221        channel_updated(
1222            channel_id,
1223            &room,
1224            &channel_members,
1225            &session.peer,
1226            &*session.connection_pool().await,
1227        );
1228    }
1229
1230    update_user_contacts(session.user_id, &session).await?;
1231    Ok(())
1232}
1233
1234async fn leave_room(
1235    _: proto::LeaveRoom,
1236    response: Response<proto::LeaveRoom>,
1237    session: Session,
1238) -> Result<()> {
1239    leave_room_for_session(&session).await?;
1240    response.send(proto::Ack {})?;
1241    Ok(())
1242}
1243
1244async fn call(
1245    request: proto::Call,
1246    response: Response<proto::Call>,
1247    session: Session,
1248) -> Result<()> {
1249    let room_id = RoomId::from_proto(request.room_id);
1250    let calling_user_id = session.user_id;
1251    let calling_connection_id = session.connection_id;
1252    let called_user_id = UserId::from_proto(request.called_user_id);
1253    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1254    if !session
1255        .db()
1256        .await
1257        .has_contact(calling_user_id, called_user_id)
1258        .await?
1259    {
1260        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1261    }
1262
1263    let incoming_call = {
1264        let (room, incoming_call) = &mut *session
1265            .db()
1266            .await
1267            .call(
1268                room_id,
1269                calling_user_id,
1270                calling_connection_id,
1271                called_user_id,
1272                initial_project_id,
1273            )
1274            .await?;
1275        room_updated(&room, &session.peer);
1276        mem::take(incoming_call)
1277    };
1278    update_user_contacts(called_user_id, &session).await?;
1279
1280    let mut calls = session
1281        .connection_pool()
1282        .await
1283        .user_connection_ids(called_user_id)
1284        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1285        .collect::<FuturesUnordered<_>>();
1286
1287    while let Some(call_response) = calls.next().await {
1288        match call_response.as_ref() {
1289            Ok(_) => {
1290                response.send(proto::Ack {})?;
1291                return Ok(());
1292            }
1293            Err(_) => {
1294                call_response.trace_err();
1295            }
1296        }
1297    }
1298
1299    {
1300        let room = session
1301            .db()
1302            .await
1303            .call_failed(room_id, called_user_id)
1304            .await?;
1305        room_updated(&room, &session.peer);
1306    }
1307    update_user_contacts(called_user_id, &session).await?;
1308
1309    Err(anyhow!("failed to ring user"))?
1310}
1311
1312async fn cancel_call(
1313    request: proto::CancelCall,
1314    response: Response<proto::CancelCall>,
1315    session: Session,
1316) -> Result<()> {
1317    let called_user_id = UserId::from_proto(request.called_user_id);
1318    let room_id = RoomId::from_proto(request.room_id);
1319    {
1320        let room = session
1321            .db()
1322            .await
1323            .cancel_call(room_id, session.connection_id, called_user_id)
1324            .await?;
1325        room_updated(&room, &session.peer);
1326    }
1327
1328    for connection_id in session
1329        .connection_pool()
1330        .await
1331        .user_connection_ids(called_user_id)
1332    {
1333        session
1334            .peer
1335            .send(
1336                connection_id,
1337                proto::CallCanceled {
1338                    room_id: room_id.to_proto(),
1339                },
1340            )
1341            .trace_err();
1342    }
1343    response.send(proto::Ack {})?;
1344
1345    update_user_contacts(called_user_id, &session).await?;
1346    Ok(())
1347}
1348
1349async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1350    let room_id = RoomId::from_proto(message.room_id);
1351    {
1352        let room = session
1353            .db()
1354            .await
1355            .decline_call(Some(room_id), session.user_id)
1356            .await?
1357            .ok_or_else(|| anyhow!("failed to decline call"))?;
1358        room_updated(&room, &session.peer);
1359    }
1360
1361    for connection_id in session
1362        .connection_pool()
1363        .await
1364        .user_connection_ids(session.user_id)
1365    {
1366        session
1367            .peer
1368            .send(
1369                connection_id,
1370                proto::CallCanceled {
1371                    room_id: room_id.to_proto(),
1372                },
1373            )
1374            .trace_err();
1375    }
1376    update_user_contacts(session.user_id, &session).await?;
1377    Ok(())
1378}
1379
1380async fn update_participant_location(
1381    request: proto::UpdateParticipantLocation,
1382    response: Response<proto::UpdateParticipantLocation>,
1383    session: Session,
1384) -> Result<()> {
1385    let room_id = RoomId::from_proto(request.room_id);
1386    let location = request
1387        .location
1388        .ok_or_else(|| anyhow!("invalid location"))?;
1389
1390    let db = session.db().await;
1391    let room = db
1392        .update_room_participant_location(room_id, session.connection_id, location)
1393        .await?;
1394
1395    room_updated(&room, &session.peer);
1396    response.send(proto::Ack {})?;
1397    Ok(())
1398}
1399
1400async fn share_project(
1401    request: proto::ShareProject,
1402    response: Response<proto::ShareProject>,
1403    session: Session,
1404) -> Result<()> {
1405    let (project_id, room) = &*session
1406        .db()
1407        .await
1408        .share_project(
1409            RoomId::from_proto(request.room_id),
1410            session.connection_id,
1411            &request.worktrees,
1412        )
1413        .await?;
1414    response.send(proto::ShareProjectResponse {
1415        project_id: project_id.to_proto(),
1416    })?;
1417    room_updated(&room, &session.peer);
1418
1419    Ok(())
1420}
1421
1422async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1423    let project_id = ProjectId::from_proto(message.project_id);
1424
1425    let (room, guest_connection_ids) = &*session
1426        .db()
1427        .await
1428        .unshare_project(project_id, session.connection_id)
1429        .await?;
1430
1431    broadcast(
1432        Some(session.connection_id),
1433        guest_connection_ids.iter().copied(),
1434        |conn_id| session.peer.send(conn_id, message.clone()),
1435    );
1436    room_updated(&room, &session.peer);
1437
1438    Ok(())
1439}
1440
1441async fn join_project(
1442    request: proto::JoinProject,
1443    response: Response<proto::JoinProject>,
1444    session: Session,
1445) -> Result<()> {
1446    let project_id = ProjectId::from_proto(request.project_id);
1447    let guest_user_id = session.user_id;
1448
1449    tracing::info!(%project_id, "join project");
1450
1451    let (project, replica_id) = &mut *session
1452        .db()
1453        .await
1454        .join_project(project_id, session.connection_id)
1455        .await?;
1456
1457    let collaborators = project
1458        .collaborators
1459        .iter()
1460        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1461        .map(|collaborator| collaborator.to_proto())
1462        .collect::<Vec<_>>();
1463
1464    let worktrees = project
1465        .worktrees
1466        .iter()
1467        .map(|(id, worktree)| proto::WorktreeMetadata {
1468            id: *id,
1469            root_name: worktree.root_name.clone(),
1470            visible: worktree.visible,
1471            abs_path: worktree.abs_path.clone(),
1472        })
1473        .collect::<Vec<_>>();
1474
1475    for collaborator in &collaborators {
1476        session
1477            .peer
1478            .send(
1479                collaborator.peer_id.unwrap().into(),
1480                proto::AddProjectCollaborator {
1481                    project_id: project_id.to_proto(),
1482                    collaborator: Some(proto::Collaborator {
1483                        peer_id: Some(session.connection_id.into()),
1484                        replica_id: replica_id.0 as u32,
1485                        user_id: guest_user_id.to_proto(),
1486                    }),
1487                },
1488            )
1489            .trace_err();
1490    }
1491
1492    // First, we send the metadata associated with each worktree.
1493    response.send(proto::JoinProjectResponse {
1494        worktrees: worktrees.clone(),
1495        replica_id: replica_id.0 as u32,
1496        collaborators: collaborators.clone(),
1497        language_servers: project.language_servers.clone(),
1498    })?;
1499
1500    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1501        #[cfg(any(test, feature = "test-support"))]
1502        const MAX_CHUNK_SIZE: usize = 2;
1503        #[cfg(not(any(test, feature = "test-support")))]
1504        const MAX_CHUNK_SIZE: usize = 256;
1505
1506        // Stream this worktree's entries.
1507        let message = proto::UpdateWorktree {
1508            project_id: project_id.to_proto(),
1509            worktree_id,
1510            abs_path: worktree.abs_path.clone(),
1511            root_name: worktree.root_name,
1512            updated_entries: worktree.entries,
1513            removed_entries: Default::default(),
1514            scan_id: worktree.scan_id,
1515            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1516            updated_repositories: worktree.repository_entries.into_values().collect(),
1517            removed_repositories: Default::default(),
1518        };
1519        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1520            session.peer.send(session.connection_id, update.clone())?;
1521        }
1522
1523        // Stream this worktree's diagnostics.
1524        for summary in worktree.diagnostic_summaries {
1525            session.peer.send(
1526                session.connection_id,
1527                proto::UpdateDiagnosticSummary {
1528                    project_id: project_id.to_proto(),
1529                    worktree_id: worktree.id,
1530                    summary: Some(summary),
1531                },
1532            )?;
1533        }
1534
1535        for settings_file in worktree.settings_files {
1536            session.peer.send(
1537                session.connection_id,
1538                proto::UpdateWorktreeSettings {
1539                    project_id: project_id.to_proto(),
1540                    worktree_id: worktree.id,
1541                    path: settings_file.path,
1542                    content: Some(settings_file.content),
1543                },
1544            )?;
1545        }
1546    }
1547
1548    for language_server in &project.language_servers {
1549        session.peer.send(
1550            session.connection_id,
1551            proto::UpdateLanguageServer {
1552                project_id: project_id.to_proto(),
1553                language_server_id: language_server.id,
1554                variant: Some(
1555                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1556                        proto::LspDiskBasedDiagnosticsUpdated {},
1557                    ),
1558                ),
1559            },
1560        )?;
1561    }
1562
1563    Ok(())
1564}
1565
1566async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1567    let sender_id = session.connection_id;
1568    let project_id = ProjectId::from_proto(request.project_id);
1569
1570    let (room, project) = &*session
1571        .db()
1572        .await
1573        .leave_project(project_id, sender_id)
1574        .await?;
1575    tracing::info!(
1576        %project_id,
1577        host_user_id = %project.host_user_id,
1578        host_connection_id = %project.host_connection_id,
1579        "leave project"
1580    );
1581
1582    project_left(&project, &session);
1583    room_updated(&room, &session.peer);
1584
1585    Ok(())
1586}
1587
1588async fn update_project(
1589    request: proto::UpdateProject,
1590    response: Response<proto::UpdateProject>,
1591    session: Session,
1592) -> Result<()> {
1593    let project_id = ProjectId::from_proto(request.project_id);
1594    let (room, guest_connection_ids) = &*session
1595        .db()
1596        .await
1597        .update_project(project_id, session.connection_id, &request.worktrees)
1598        .await?;
1599    broadcast(
1600        Some(session.connection_id),
1601        guest_connection_ids.iter().copied(),
1602        |connection_id| {
1603            session
1604                .peer
1605                .forward_send(session.connection_id, connection_id, request.clone())
1606        },
1607    );
1608    room_updated(&room, &session.peer);
1609    response.send(proto::Ack {})?;
1610
1611    Ok(())
1612}
1613
1614async fn update_worktree(
1615    request: proto::UpdateWorktree,
1616    response: Response<proto::UpdateWorktree>,
1617    session: Session,
1618) -> Result<()> {
1619    let guest_connection_ids = session
1620        .db()
1621        .await
1622        .update_worktree(&request, session.connection_id)
1623        .await?;
1624
1625    broadcast(
1626        Some(session.connection_id),
1627        guest_connection_ids.iter().copied(),
1628        |connection_id| {
1629            session
1630                .peer
1631                .forward_send(session.connection_id, connection_id, request.clone())
1632        },
1633    );
1634    response.send(proto::Ack {})?;
1635    Ok(())
1636}
1637
1638async fn update_diagnostic_summary(
1639    message: proto::UpdateDiagnosticSummary,
1640    session: Session,
1641) -> Result<()> {
1642    let guest_connection_ids = session
1643        .db()
1644        .await
1645        .update_diagnostic_summary(&message, session.connection_id)
1646        .await?;
1647
1648    broadcast(
1649        Some(session.connection_id),
1650        guest_connection_ids.iter().copied(),
1651        |connection_id| {
1652            session
1653                .peer
1654                .forward_send(session.connection_id, connection_id, message.clone())
1655        },
1656    );
1657
1658    Ok(())
1659}
1660
1661async fn update_worktree_settings(
1662    message: proto::UpdateWorktreeSettings,
1663    session: Session,
1664) -> Result<()> {
1665    let guest_connection_ids = session
1666        .db()
1667        .await
1668        .update_worktree_settings(&message, session.connection_id)
1669        .await?;
1670
1671    broadcast(
1672        Some(session.connection_id),
1673        guest_connection_ids.iter().copied(),
1674        |connection_id| {
1675            session
1676                .peer
1677                .forward_send(session.connection_id, connection_id, message.clone())
1678        },
1679    );
1680
1681    Ok(())
1682}
1683
1684async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1685    broadcast_project_message(request.project_id, request, session).await
1686}
1687
1688async fn start_language_server(
1689    request: proto::StartLanguageServer,
1690    session: Session,
1691) -> Result<()> {
1692    let guest_connection_ids = session
1693        .db()
1694        .await
1695        .start_language_server(&request, session.connection_id)
1696        .await?;
1697
1698    broadcast(
1699        Some(session.connection_id),
1700        guest_connection_ids.iter().copied(),
1701        |connection_id| {
1702            session
1703                .peer
1704                .forward_send(session.connection_id, connection_id, request.clone())
1705        },
1706    );
1707    Ok(())
1708}
1709
1710async fn update_language_server(
1711    request: proto::UpdateLanguageServer,
1712    session: Session,
1713) -> Result<()> {
1714    session.executor.record_backtrace();
1715    let project_id = ProjectId::from_proto(request.project_id);
1716    let project_connection_ids = session
1717        .db()
1718        .await
1719        .project_connection_ids(project_id, session.connection_id)
1720        .await?;
1721    broadcast(
1722        Some(session.connection_id),
1723        project_connection_ids.iter().copied(),
1724        |connection_id| {
1725            session
1726                .peer
1727                .forward_send(session.connection_id, connection_id, request.clone())
1728        },
1729    );
1730    Ok(())
1731}
1732
1733async fn forward_project_request<T>(
1734    request: T,
1735    response: Response<T>,
1736    session: Session,
1737) -> Result<()>
1738where
1739    T: EntityMessage + RequestMessage,
1740{
1741    session.executor.record_backtrace();
1742    let project_id = ProjectId::from_proto(request.remote_entity_id());
1743    let host_connection_id = {
1744        let collaborators = session
1745            .db()
1746            .await
1747            .project_collaborators(project_id, session.connection_id)
1748            .await?;
1749        collaborators
1750            .iter()
1751            .find(|collaborator| collaborator.is_host)
1752            .ok_or_else(|| anyhow!("host not found"))?
1753            .connection_id
1754    };
1755
1756    let payload = session
1757        .peer
1758        .forward_request(session.connection_id, host_connection_id, request)
1759        .await?;
1760
1761    response.send(payload)?;
1762    Ok(())
1763}
1764
1765async fn create_buffer_for_peer(
1766    request: proto::CreateBufferForPeer,
1767    session: Session,
1768) -> Result<()> {
1769    session.executor.record_backtrace();
1770    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1771    session
1772        .peer
1773        .forward_send(session.connection_id, peer_id.into(), request)?;
1774    Ok(())
1775}
1776
1777async fn update_buffer(
1778    request: proto::UpdateBuffer,
1779    response: Response<proto::UpdateBuffer>,
1780    session: Session,
1781) -> Result<()> {
1782    session.executor.record_backtrace();
1783    let project_id = ProjectId::from_proto(request.project_id);
1784    let mut guest_connection_ids;
1785    let mut host_connection_id = None;
1786    {
1787        let collaborators = session
1788            .db()
1789            .await
1790            .project_collaborators(project_id, session.connection_id)
1791            .await?;
1792        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1793        for collaborator in collaborators.iter() {
1794            if collaborator.is_host {
1795                host_connection_id = Some(collaborator.connection_id);
1796            } else {
1797                guest_connection_ids.push(collaborator.connection_id);
1798            }
1799        }
1800    }
1801    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1802
1803    session.executor.record_backtrace();
1804    broadcast(
1805        Some(session.connection_id),
1806        guest_connection_ids,
1807        |connection_id| {
1808            session
1809                .peer
1810                .forward_send(session.connection_id, connection_id, request.clone())
1811        },
1812    );
1813    if host_connection_id != session.connection_id {
1814        session
1815            .peer
1816            .forward_request(session.connection_id, host_connection_id, request.clone())
1817            .await?;
1818    }
1819
1820    response.send(proto::Ack {})?;
1821    Ok(())
1822}
1823
1824async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1825    let project_id = ProjectId::from_proto(request.project_id);
1826    let project_connection_ids = session
1827        .db()
1828        .await
1829        .project_connection_ids(project_id, session.connection_id)
1830        .await?;
1831
1832    broadcast(
1833        Some(session.connection_id),
1834        project_connection_ids.iter().copied(),
1835        |connection_id| {
1836            session
1837                .peer
1838                .forward_send(session.connection_id, connection_id, request.clone())
1839        },
1840    );
1841    Ok(())
1842}
1843
1844async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1845    let project_id = ProjectId::from_proto(request.project_id);
1846    let project_connection_ids = session
1847        .db()
1848        .await
1849        .project_connection_ids(project_id, session.connection_id)
1850        .await?;
1851    broadcast(
1852        Some(session.connection_id),
1853        project_connection_ids.iter().copied(),
1854        |connection_id| {
1855            session
1856                .peer
1857                .forward_send(session.connection_id, connection_id, request.clone())
1858        },
1859    );
1860    Ok(())
1861}
1862
1863async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1864    broadcast_project_message(request.project_id, request, session).await
1865}
1866
1867async fn broadcast_project_message<T: EnvelopedMessage>(
1868    project_id: u64,
1869    request: T,
1870    session: Session,
1871) -> Result<()> {
1872    let project_id = ProjectId::from_proto(project_id);
1873    let project_connection_ids = session
1874        .db()
1875        .await
1876        .project_connection_ids(project_id, session.connection_id)
1877        .await?;
1878    broadcast(
1879        Some(session.connection_id),
1880        project_connection_ids.iter().copied(),
1881        |connection_id| {
1882            session
1883                .peer
1884                .forward_send(session.connection_id, connection_id, request.clone())
1885        },
1886    );
1887    Ok(())
1888}
1889
1890async fn follow(
1891    request: proto::Follow,
1892    response: Response<proto::Follow>,
1893    session: Session,
1894) -> Result<()> {
1895    let room_id = RoomId::from_proto(request.room_id);
1896    let project_id = request.project_id.map(ProjectId::from_proto);
1897    let leader_id = request
1898        .leader_id
1899        .ok_or_else(|| anyhow!("invalid leader id"))?
1900        .into();
1901    let follower_id = session.connection_id;
1902
1903    session
1904        .db()
1905        .await
1906        .check_room_participants(room_id, leader_id, session.connection_id)
1907        .await?;
1908
1909    let mut response_payload = session
1910        .peer
1911        .forward_request(session.connection_id, leader_id, request)
1912        .await?;
1913    response_payload
1914        .views
1915        .retain(|view| view.leader_id != Some(follower_id.into()));
1916    response.send(response_payload)?;
1917
1918    if let Some(project_id) = project_id {
1919        let room = session
1920            .db()
1921            .await
1922            .follow(room_id, project_id, leader_id, follower_id)
1923            .await?;
1924        room_updated(&room, &session.peer);
1925    }
1926
1927    Ok(())
1928}
1929
1930async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1931    let room_id = RoomId::from_proto(request.room_id);
1932    let project_id = request.project_id.map(ProjectId::from_proto);
1933    let leader_id = request
1934        .leader_id
1935        .ok_or_else(|| anyhow!("invalid leader id"))?
1936        .into();
1937    let follower_id = session.connection_id;
1938
1939    session
1940        .db()
1941        .await
1942        .check_room_participants(room_id, leader_id, session.connection_id)
1943        .await?;
1944
1945    session
1946        .peer
1947        .forward_send(session.connection_id, leader_id, request)?;
1948
1949    if let Some(project_id) = project_id {
1950        let room = session
1951            .db()
1952            .await
1953            .unfollow(room_id, project_id, leader_id, follower_id)
1954            .await?;
1955        room_updated(&room, &session.peer);
1956    }
1957
1958    Ok(())
1959}
1960
1961async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1962    let room_id = RoomId::from_proto(request.room_id);
1963    let database = session.db.lock().await;
1964
1965    let connection_ids = if let Some(project_id) = request.project_id {
1966        let project_id = ProjectId::from_proto(project_id);
1967        database
1968            .project_connection_ids(project_id, session.connection_id)
1969            .await?
1970    } else {
1971        database
1972            .room_connection_ids(room_id, session.connection_id)
1973            .await?
1974    };
1975
1976    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1977        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1978        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1979        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1980    });
1981    for follower_peer_id in request.follower_ids.iter().copied() {
1982        let follower_connection_id = follower_peer_id.into();
1983        if Some(follower_peer_id) != leader_id && connection_ids.contains(&follower_connection_id) {
1984            session.peer.forward_send(
1985                session.connection_id,
1986                follower_connection_id,
1987                request.clone(),
1988            )?;
1989        }
1990    }
1991    Ok(())
1992}
1993
1994async fn get_users(
1995    request: proto::GetUsers,
1996    response: Response<proto::GetUsers>,
1997    session: Session,
1998) -> Result<()> {
1999    let user_ids = request
2000        .user_ids
2001        .into_iter()
2002        .map(UserId::from_proto)
2003        .collect();
2004    let users = session
2005        .db()
2006        .await
2007        .get_users_by_ids(user_ids)
2008        .await?
2009        .into_iter()
2010        .map(|user| proto::User {
2011            id: user.id.to_proto(),
2012            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2013            github_login: user.github_login,
2014        })
2015        .collect();
2016    response.send(proto::UsersResponse { users })?;
2017    Ok(())
2018}
2019
2020async fn fuzzy_search_users(
2021    request: proto::FuzzySearchUsers,
2022    response: Response<proto::FuzzySearchUsers>,
2023    session: Session,
2024) -> Result<()> {
2025    let query = request.query;
2026    let users = match query.len() {
2027        0 => vec![],
2028        1 | 2 => session
2029            .db()
2030            .await
2031            .get_user_by_github_login(&query)
2032            .await?
2033            .into_iter()
2034            .collect(),
2035        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2036    };
2037    let users = users
2038        .into_iter()
2039        .filter(|user| user.id != session.user_id)
2040        .map(|user| proto::User {
2041            id: user.id.to_proto(),
2042            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2043            github_login: user.github_login,
2044        })
2045        .collect();
2046    response.send(proto::UsersResponse { users })?;
2047    Ok(())
2048}
2049
2050async fn request_contact(
2051    request: proto::RequestContact,
2052    response: Response<proto::RequestContact>,
2053    session: Session,
2054) -> Result<()> {
2055    let requester_id = session.user_id;
2056    let responder_id = UserId::from_proto(request.responder_id);
2057    if requester_id == responder_id {
2058        return Err(anyhow!("cannot add yourself as a contact"))?;
2059    }
2060
2061    session
2062        .db()
2063        .await
2064        .send_contact_request(requester_id, responder_id)
2065        .await?;
2066
2067    // Update outgoing contact requests of requester
2068    let mut update = proto::UpdateContacts::default();
2069    update.outgoing_requests.push(responder_id.to_proto());
2070    for connection_id in session
2071        .connection_pool()
2072        .await
2073        .user_connection_ids(requester_id)
2074    {
2075        session.peer.send(connection_id, update.clone())?;
2076    }
2077
2078    // Update incoming contact requests of responder
2079    let mut update = proto::UpdateContacts::default();
2080    update
2081        .incoming_requests
2082        .push(proto::IncomingContactRequest {
2083            requester_id: requester_id.to_proto(),
2084            should_notify: true,
2085        });
2086    for connection_id in session
2087        .connection_pool()
2088        .await
2089        .user_connection_ids(responder_id)
2090    {
2091        session.peer.send(connection_id, update.clone())?;
2092    }
2093
2094    response.send(proto::Ack {})?;
2095    Ok(())
2096}
2097
2098async fn respond_to_contact_request(
2099    request: proto::RespondToContactRequest,
2100    response: Response<proto::RespondToContactRequest>,
2101    session: Session,
2102) -> Result<()> {
2103    let responder_id = session.user_id;
2104    let requester_id = UserId::from_proto(request.requester_id);
2105    let db = session.db().await;
2106    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2107        db.dismiss_contact_notification(responder_id, requester_id)
2108            .await?;
2109    } else {
2110        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2111
2112        db.respond_to_contact_request(responder_id, requester_id, accept)
2113            .await?;
2114        let requester_busy = db.is_user_busy(requester_id).await?;
2115        let responder_busy = db.is_user_busy(responder_id).await?;
2116
2117        let pool = session.connection_pool().await;
2118        // Update responder with new contact
2119        let mut update = proto::UpdateContacts::default();
2120        if accept {
2121            update
2122                .contacts
2123                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2124        }
2125        update
2126            .remove_incoming_requests
2127            .push(requester_id.to_proto());
2128        for connection_id in pool.user_connection_ids(responder_id) {
2129            session.peer.send(connection_id, update.clone())?;
2130        }
2131
2132        // Update requester with new contact
2133        let mut update = proto::UpdateContacts::default();
2134        if accept {
2135            update
2136                .contacts
2137                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2138        }
2139        update
2140            .remove_outgoing_requests
2141            .push(responder_id.to_proto());
2142        for connection_id in pool.user_connection_ids(requester_id) {
2143            session.peer.send(connection_id, update.clone())?;
2144        }
2145    }
2146
2147    response.send(proto::Ack {})?;
2148    Ok(())
2149}
2150
2151async fn remove_contact(
2152    request: proto::RemoveContact,
2153    response: Response<proto::RemoveContact>,
2154    session: Session,
2155) -> Result<()> {
2156    let requester_id = session.user_id;
2157    let responder_id = UserId::from_proto(request.user_id);
2158    let db = session.db().await;
2159    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2160
2161    let pool = session.connection_pool().await;
2162    // Update outgoing contact requests of requester
2163    let mut update = proto::UpdateContacts::default();
2164    if contact_accepted {
2165        update.remove_contacts.push(responder_id.to_proto());
2166    } else {
2167        update
2168            .remove_outgoing_requests
2169            .push(responder_id.to_proto());
2170    }
2171    for connection_id in pool.user_connection_ids(requester_id) {
2172        session.peer.send(connection_id, update.clone())?;
2173    }
2174
2175    // Update incoming contact requests of responder
2176    let mut update = proto::UpdateContacts::default();
2177    if contact_accepted {
2178        update.remove_contacts.push(requester_id.to_proto());
2179    } else {
2180        update
2181            .remove_incoming_requests
2182            .push(requester_id.to_proto());
2183    }
2184    for connection_id in pool.user_connection_ids(responder_id) {
2185        session.peer.send(connection_id, update.clone())?;
2186    }
2187
2188    response.send(proto::Ack {})?;
2189    Ok(())
2190}
2191
2192async fn create_channel(
2193    request: proto::CreateChannel,
2194    response: Response<proto::CreateChannel>,
2195    session: Session,
2196) -> Result<()> {
2197    let db = session.db().await;
2198    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2199
2200    if let Some(live_kit) = session.live_kit_client.as_ref() {
2201        live_kit.create_room(live_kit_room.clone()).await?;
2202    }
2203
2204    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2205    let id = db
2206        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2207        .await?;
2208
2209    let channel = proto::Channel {
2210        id: id.to_proto(),
2211        name: request.name,
2212    };
2213
2214    response.send(proto::CreateChannelResponse {
2215        channel: Some(channel.clone()),
2216        parent_id: request.parent_id,
2217    })?;
2218
2219    let Some(parent_id) = parent_id else {
2220        return Ok(());
2221    };
2222
2223    let update = proto::UpdateChannels {
2224        channels: vec![channel],
2225        insert_edge: vec![ChannelEdge {
2226            parent_id: parent_id.to_proto(),
2227            channel_id: id.to_proto(),
2228        }],
2229        ..Default::default()
2230    };
2231
2232    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2233
2234    let connection_pool = session.connection_pool().await;
2235    for user_id in user_ids_to_notify {
2236        for connection_id in connection_pool.user_connection_ids(user_id) {
2237            if user_id == session.user_id {
2238                continue;
2239            }
2240            session.peer.send(connection_id, update.clone())?;
2241        }
2242    }
2243
2244    Ok(())
2245}
2246
2247async fn delete_channel(
2248    request: proto::DeleteChannel,
2249    response: Response<proto::DeleteChannel>,
2250    session: Session,
2251) -> Result<()> {
2252    let db = session.db().await;
2253
2254    let channel_id = request.channel_id;
2255    let (removed_channels, member_ids) = db
2256        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2257        .await?;
2258    response.send(proto::Ack {})?;
2259
2260    // Notify members of removed channels
2261    let mut update = proto::UpdateChannels::default();
2262    update
2263        .delete_channels
2264        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2265
2266    let connection_pool = session.connection_pool().await;
2267    for member_id in member_ids {
2268        for connection_id in connection_pool.user_connection_ids(member_id) {
2269            session.peer.send(connection_id, update.clone())?;
2270        }
2271    }
2272
2273    Ok(())
2274}
2275
2276async fn invite_channel_member(
2277    request: proto::InviteChannelMember,
2278    response: Response<proto::InviteChannelMember>,
2279    session: Session,
2280) -> Result<()> {
2281    let db = session.db().await;
2282    let channel_id = ChannelId::from_proto(request.channel_id);
2283    let invitee_id = UserId::from_proto(request.user_id);
2284    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2285        .await?;
2286
2287    let (channel, _) = db
2288        .get_channel(channel_id, session.user_id)
2289        .await?
2290        .ok_or_else(|| anyhow!("channel not found"))?;
2291
2292    let mut update = proto::UpdateChannels::default();
2293    update.channel_invitations.push(proto::Channel {
2294        id: channel.id.to_proto(),
2295        name: channel.name,
2296    });
2297    for connection_id in session
2298        .connection_pool()
2299        .await
2300        .user_connection_ids(invitee_id)
2301    {
2302        session.peer.send(connection_id, update.clone())?;
2303    }
2304
2305    response.send(proto::Ack {})?;
2306    Ok(())
2307}
2308
2309async fn remove_channel_member(
2310    request: proto::RemoveChannelMember,
2311    response: Response<proto::RemoveChannelMember>,
2312    session: Session,
2313) -> Result<()> {
2314    let db = session.db().await;
2315    let channel_id = ChannelId::from_proto(request.channel_id);
2316    let member_id = UserId::from_proto(request.user_id);
2317
2318    db.remove_channel_member(channel_id, member_id, session.user_id)
2319        .await?;
2320
2321    let mut update = proto::UpdateChannels::default();
2322    update.delete_channels.push(channel_id.to_proto());
2323
2324    for connection_id in session
2325        .connection_pool()
2326        .await
2327        .user_connection_ids(member_id)
2328    {
2329        session.peer.send(connection_id, update.clone())?;
2330    }
2331
2332    response.send(proto::Ack {})?;
2333    Ok(())
2334}
2335
2336async fn set_channel_member_admin(
2337    request: proto::SetChannelMemberAdmin,
2338    response: Response<proto::SetChannelMemberAdmin>,
2339    session: Session,
2340) -> Result<()> {
2341    let db = session.db().await;
2342    let channel_id = ChannelId::from_proto(request.channel_id);
2343    let member_id = UserId::from_proto(request.user_id);
2344    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2345        .await?;
2346
2347    let (channel, has_accepted) = db
2348        .get_channel(channel_id, member_id)
2349        .await?
2350        .ok_or_else(|| anyhow!("channel not found"))?;
2351
2352    let mut update = proto::UpdateChannels::default();
2353    if has_accepted {
2354        update.channel_permissions.push(proto::ChannelPermission {
2355            channel_id: channel.id.to_proto(),
2356            is_admin: request.admin,
2357        });
2358    }
2359
2360    for connection_id in session
2361        .connection_pool()
2362        .await
2363        .user_connection_ids(member_id)
2364    {
2365        session.peer.send(connection_id, update.clone())?;
2366    }
2367
2368    response.send(proto::Ack {})?;
2369    Ok(())
2370}
2371
2372async fn rename_channel(
2373    request: proto::RenameChannel,
2374    response: Response<proto::RenameChannel>,
2375    session: Session,
2376) -> Result<()> {
2377    let db = session.db().await;
2378    let channel_id = ChannelId::from_proto(request.channel_id);
2379    let new_name = db
2380        .rename_channel(channel_id, session.user_id, &request.name)
2381        .await?;
2382
2383    let channel = proto::Channel {
2384        id: request.channel_id,
2385        name: new_name,
2386    };
2387    response.send(proto::RenameChannelResponse {
2388        channel: Some(channel.clone()),
2389    })?;
2390    let mut update = proto::UpdateChannels::default();
2391    update.channels.push(channel);
2392
2393    let member_ids = db.get_channel_members(channel_id).await?;
2394
2395    let connection_pool = session.connection_pool().await;
2396    for member_id in member_ids {
2397        for connection_id in connection_pool.user_connection_ids(member_id) {
2398            session.peer.send(connection_id, update.clone())?;
2399        }
2400    }
2401
2402    Ok(())
2403}
2404
2405async fn link_channel(
2406    request: proto::LinkChannel,
2407    response: Response<proto::LinkChannel>,
2408    session: Session,
2409) -> Result<()> {
2410    let db = session.db().await;
2411    let channel_id = ChannelId::from_proto(request.channel_id);
2412    let to = ChannelId::from_proto(request.to);
2413    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2414
2415    let members = db.get_channel_members(to).await?;
2416    let connection_pool = session.connection_pool().await;
2417    let update = proto::UpdateChannels {
2418        channels: channels_to_send
2419            .channels
2420            .into_iter()
2421            .map(|channel| proto::Channel {
2422                id: channel.id.to_proto(),
2423                name: channel.name,
2424            })
2425            .collect(),
2426        insert_edge: channels_to_send.edges,
2427        ..Default::default()
2428    };
2429    for member_id in members {
2430        for connection_id in connection_pool.user_connection_ids(member_id) {
2431            session.peer.send(connection_id, update.clone())?;
2432        }
2433    }
2434
2435    response.send(Ack {})?;
2436
2437    Ok(())
2438}
2439
2440async fn unlink_channel(
2441    request: proto::UnlinkChannel,
2442    response: Response<proto::UnlinkChannel>,
2443    session: Session,
2444) -> Result<()> {
2445    let db = session.db().await;
2446    let channel_id = ChannelId::from_proto(request.channel_id);
2447    let from = ChannelId::from_proto(request.from);
2448
2449    db.unlink_channel(session.user_id, channel_id, from).await?;
2450
2451    let members = db.get_channel_members(from).await?;
2452
2453    let update = proto::UpdateChannels {
2454        delete_edge: vec![proto::ChannelEdge {
2455            channel_id: channel_id.to_proto(),
2456            parent_id: from.to_proto(),
2457        }],
2458        ..Default::default()
2459    };
2460    let connection_pool = session.connection_pool().await;
2461    for member_id in members {
2462        for connection_id in connection_pool.user_connection_ids(member_id) {
2463            session.peer.send(connection_id, update.clone())?;
2464        }
2465    }
2466
2467    response.send(Ack {})?;
2468
2469    Ok(())
2470}
2471
2472async fn move_channel(
2473    request: proto::MoveChannel,
2474    response: Response<proto::MoveChannel>,
2475    session: Session,
2476) -> Result<()> {
2477    let db = session.db().await;
2478    let channel_id = ChannelId::from_proto(request.channel_id);
2479    let from_parent = ChannelId::from_proto(request.from);
2480    let to = ChannelId::from_proto(request.to);
2481
2482    let channels_to_send = db
2483        .move_channel(session.user_id, channel_id, from_parent, to)
2484        .await?;
2485
2486    if channels_to_send.is_empty() {
2487        response.send(Ack {})?;
2488        return Ok(());
2489    }
2490
2491    let members_from = db.get_channel_members(from_parent).await?;
2492    let members_to = db.get_channel_members(to).await?;
2493
2494    let update = proto::UpdateChannels {
2495        delete_edge: vec![proto::ChannelEdge {
2496            channel_id: channel_id.to_proto(),
2497            parent_id: from_parent.to_proto(),
2498        }],
2499        ..Default::default()
2500    };
2501    let connection_pool = session.connection_pool().await;
2502    for member_id in members_from {
2503        for connection_id in connection_pool.user_connection_ids(member_id) {
2504            session.peer.send(connection_id, update.clone())?;
2505        }
2506    }
2507
2508    let update = proto::UpdateChannels {
2509        channels: channels_to_send
2510            .channels
2511            .into_iter()
2512            .map(|channel| proto::Channel {
2513                id: channel.id.to_proto(),
2514                name: channel.name,
2515            })
2516            .collect(),
2517        insert_edge: channels_to_send.edges,
2518        ..Default::default()
2519    };
2520    for member_id in members_to {
2521        for connection_id in connection_pool.user_connection_ids(member_id) {
2522            session.peer.send(connection_id, update.clone())?;
2523        }
2524    }
2525
2526    response.send(Ack {})?;
2527
2528    Ok(())
2529}
2530
2531async fn get_channel_members(
2532    request: proto::GetChannelMembers,
2533    response: Response<proto::GetChannelMembers>,
2534    session: Session,
2535) -> Result<()> {
2536    let db = session.db().await;
2537    let channel_id = ChannelId::from_proto(request.channel_id);
2538    let members = db
2539        .get_channel_member_details(channel_id, session.user_id)
2540        .await?;
2541    response.send(proto::GetChannelMembersResponse { members })?;
2542    Ok(())
2543}
2544
2545async fn respond_to_channel_invite(
2546    request: proto::RespondToChannelInvite,
2547    response: Response<proto::RespondToChannelInvite>,
2548    session: Session,
2549) -> Result<()> {
2550    let db = session.db().await;
2551    let channel_id = ChannelId::from_proto(request.channel_id);
2552    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2553        .await?;
2554
2555    let mut update = proto::UpdateChannels::default();
2556    update
2557        .remove_channel_invitations
2558        .push(channel_id.to_proto());
2559    if request.accept {
2560        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2561        update
2562            .channels
2563            .extend(
2564                result
2565                    .channels
2566                    .channels
2567                    .into_iter()
2568                    .map(|channel| proto::Channel {
2569                        id: channel.id.to_proto(),
2570                        name: channel.name,
2571                    }),
2572            );
2573        update.unseen_channel_messages = result.channel_messages;
2574        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2575        update.insert_edge = result.channels.edges;
2576        update
2577            .channel_participants
2578            .extend(
2579                result
2580                    .channel_participants
2581                    .into_iter()
2582                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2583                        channel_id: channel_id.to_proto(),
2584                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2585                    }),
2586            );
2587        update
2588            .channel_permissions
2589            .extend(
2590                result
2591                    .channels_with_admin_privileges
2592                    .into_iter()
2593                    .map(|channel_id| proto::ChannelPermission {
2594                        channel_id: channel_id.to_proto(),
2595                        is_admin: true,
2596                    }),
2597            );
2598    }
2599    session.peer.send(session.connection_id, update)?;
2600    response.send(proto::Ack {})?;
2601
2602    Ok(())
2603}
2604
2605async fn join_channel(
2606    request: proto::JoinChannel,
2607    response: Response<proto::JoinChannel>,
2608    session: Session,
2609) -> Result<()> {
2610    let channel_id = ChannelId::from_proto(request.channel_id);
2611
2612    let joined_room = {
2613        leave_room_for_session(&session).await?;
2614        let db = session.db().await;
2615
2616        let room_id = db.room_id_for_channel(channel_id).await?;
2617
2618        let joined_room = db
2619            .join_room(room_id, session.user_id, session.connection_id)
2620            .await?;
2621
2622        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2623            let token = live_kit
2624                .room_token(
2625                    &joined_room.room.live_kit_room,
2626                    &session.user_id.to_string(),
2627                )
2628                .trace_err()?;
2629
2630            Some(LiveKitConnectionInfo {
2631                server_url: live_kit.url().into(),
2632                token,
2633            })
2634        });
2635
2636        response.send(proto::JoinRoomResponse {
2637            room: Some(joined_room.room.clone()),
2638            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2639            live_kit_connection_info,
2640        })?;
2641
2642        room_updated(&joined_room.room, &session.peer);
2643
2644        joined_room.into_inner()
2645    };
2646
2647    channel_updated(
2648        channel_id,
2649        &joined_room.room,
2650        &joined_room.channel_members,
2651        &session.peer,
2652        &*session.connection_pool().await,
2653    );
2654
2655    update_user_contacts(session.user_id, &session).await?;
2656
2657    Ok(())
2658}
2659
2660async fn join_channel_buffer(
2661    request: proto::JoinChannelBuffer,
2662    response: Response<proto::JoinChannelBuffer>,
2663    session: Session,
2664) -> Result<()> {
2665    let db = session.db().await;
2666    let channel_id = ChannelId::from_proto(request.channel_id);
2667
2668    let open_response = db
2669        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2670        .await?;
2671
2672    let collaborators = open_response.collaborators.clone();
2673    response.send(open_response)?;
2674
2675    let update = UpdateChannelBufferCollaborators {
2676        channel_id: channel_id.to_proto(),
2677        collaborators: collaborators.clone(),
2678    };
2679    channel_buffer_updated(
2680        session.connection_id,
2681        collaborators
2682            .iter()
2683            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2684        &update,
2685        &session.peer,
2686    );
2687
2688    Ok(())
2689}
2690
2691async fn update_channel_buffer(
2692    request: proto::UpdateChannelBuffer,
2693    session: Session,
2694) -> Result<()> {
2695    let db = session.db().await;
2696    let channel_id = ChannelId::from_proto(request.channel_id);
2697
2698    let (collaborators, non_collaborators, epoch, version) = db
2699        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2700        .await?;
2701
2702    channel_buffer_updated(
2703        session.connection_id,
2704        collaborators,
2705        &proto::UpdateChannelBuffer {
2706            channel_id: channel_id.to_proto(),
2707            operations: request.operations,
2708        },
2709        &session.peer,
2710    );
2711
2712    let pool = &*session.connection_pool().await;
2713
2714    broadcast(
2715        None,
2716        non_collaborators
2717            .iter()
2718            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2719        |peer_id| {
2720            session.peer.send(
2721                peer_id.into(),
2722                proto::UpdateChannels {
2723                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2724                        channel_id: channel_id.to_proto(),
2725                        epoch: epoch as u64,
2726                        version: version.clone(),
2727                    }],
2728                    ..Default::default()
2729                },
2730            )
2731        },
2732    );
2733
2734    Ok(())
2735}
2736
2737async fn rejoin_channel_buffers(
2738    request: proto::RejoinChannelBuffers,
2739    response: Response<proto::RejoinChannelBuffers>,
2740    session: Session,
2741) -> Result<()> {
2742    let db = session.db().await;
2743    let buffers = db
2744        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2745        .await?;
2746
2747    for rejoined_buffer in &buffers {
2748        let collaborators_to_notify = rejoined_buffer
2749            .buffer
2750            .collaborators
2751            .iter()
2752            .filter_map(|c| Some(c.peer_id?.into()));
2753        channel_buffer_updated(
2754            session.connection_id,
2755            collaborators_to_notify,
2756            &proto::UpdateChannelBufferCollaborators {
2757                channel_id: rejoined_buffer.buffer.channel_id,
2758                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2759            },
2760            &session.peer,
2761        );
2762    }
2763
2764    response.send(proto::RejoinChannelBuffersResponse {
2765        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2766    })?;
2767
2768    Ok(())
2769}
2770
2771async fn leave_channel_buffer(
2772    request: proto::LeaveChannelBuffer,
2773    response: Response<proto::LeaveChannelBuffer>,
2774    session: Session,
2775) -> Result<()> {
2776    let db = session.db().await;
2777    let channel_id = ChannelId::from_proto(request.channel_id);
2778
2779    let left_buffer = db
2780        .leave_channel_buffer(channel_id, session.connection_id)
2781        .await?;
2782
2783    response.send(Ack {})?;
2784
2785    channel_buffer_updated(
2786        session.connection_id,
2787        left_buffer.connections,
2788        &proto::UpdateChannelBufferCollaborators {
2789            channel_id: channel_id.to_proto(),
2790            collaborators: left_buffer.collaborators,
2791        },
2792        &session.peer,
2793    );
2794
2795    Ok(())
2796}
2797
2798fn channel_buffer_updated<T: EnvelopedMessage>(
2799    sender_id: ConnectionId,
2800    collaborators: impl IntoIterator<Item = ConnectionId>,
2801    message: &T,
2802    peer: &Peer,
2803) {
2804    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2805        peer.send(peer_id.into(), message.clone())
2806    });
2807}
2808
2809async fn send_channel_message(
2810    request: proto::SendChannelMessage,
2811    response: Response<proto::SendChannelMessage>,
2812    session: Session,
2813) -> Result<()> {
2814    // Validate the message body.
2815    let body = request.body.trim().to_string();
2816    if body.len() > MAX_MESSAGE_LEN {
2817        return Err(anyhow!("message is too long"))?;
2818    }
2819    if body.is_empty() {
2820        return Err(anyhow!("message can't be blank"))?;
2821    }
2822
2823    let timestamp = OffsetDateTime::now_utc();
2824    let nonce = request
2825        .nonce
2826        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2827
2828    let channel_id = ChannelId::from_proto(request.channel_id);
2829    let (message_id, connection_ids, non_participants) = session
2830        .db()
2831        .await
2832        .create_channel_message(
2833            channel_id,
2834            session.user_id,
2835            &body,
2836            timestamp,
2837            nonce.clone().into(),
2838        )
2839        .await?;
2840    let message = proto::ChannelMessage {
2841        sender_id: session.user_id.to_proto(),
2842        id: message_id.to_proto(),
2843        body,
2844        timestamp: timestamp.unix_timestamp() as u64,
2845        nonce: Some(nonce),
2846    };
2847    broadcast(Some(session.connection_id), connection_ids, |connection| {
2848        session.peer.send(
2849            connection,
2850            proto::ChannelMessageSent {
2851                channel_id: channel_id.to_proto(),
2852                message: Some(message.clone()),
2853            },
2854        )
2855    });
2856    response.send(proto::SendChannelMessageResponse {
2857        message: Some(message),
2858    })?;
2859
2860    let pool = &*session.connection_pool().await;
2861    broadcast(
2862        None,
2863        non_participants
2864            .iter()
2865            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2866        |peer_id| {
2867            session.peer.send(
2868                peer_id.into(),
2869                proto::UpdateChannels {
2870                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2871                        channel_id: channel_id.to_proto(),
2872                        message_id: message_id.to_proto(),
2873                    }],
2874                    ..Default::default()
2875                },
2876            )
2877        },
2878    );
2879
2880    Ok(())
2881}
2882
2883async fn remove_channel_message(
2884    request: proto::RemoveChannelMessage,
2885    response: Response<proto::RemoveChannelMessage>,
2886    session: Session,
2887) -> Result<()> {
2888    let channel_id = ChannelId::from_proto(request.channel_id);
2889    let message_id = MessageId::from_proto(request.message_id);
2890    let connection_ids = session
2891        .db()
2892        .await
2893        .remove_channel_message(channel_id, message_id, session.user_id)
2894        .await?;
2895    broadcast(Some(session.connection_id), connection_ids, |connection| {
2896        session.peer.send(connection, request.clone())
2897    });
2898    response.send(proto::Ack {})?;
2899    Ok(())
2900}
2901
2902async fn acknowledge_channel_message(
2903    request: proto::AckChannelMessage,
2904    session: Session,
2905) -> Result<()> {
2906    let channel_id = ChannelId::from_proto(request.channel_id);
2907    let message_id = MessageId::from_proto(request.message_id);
2908    session
2909        .db()
2910        .await
2911        .observe_channel_message(channel_id, session.user_id, message_id)
2912        .await?;
2913    Ok(())
2914}
2915
2916async fn acknowledge_buffer_version(
2917    request: proto::AckBufferOperation,
2918    session: Session,
2919) -> Result<()> {
2920    let buffer_id = BufferId::from_proto(request.buffer_id);
2921    session
2922        .db()
2923        .await
2924        .observe_buffer_version(
2925            buffer_id,
2926            session.user_id,
2927            request.epoch as i32,
2928            &request.version,
2929        )
2930        .await?;
2931    Ok(())
2932}
2933
2934async fn join_channel_chat(
2935    request: proto::JoinChannelChat,
2936    response: Response<proto::JoinChannelChat>,
2937    session: Session,
2938) -> Result<()> {
2939    let channel_id = ChannelId::from_proto(request.channel_id);
2940
2941    let db = session.db().await;
2942    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2943        .await?;
2944    let messages = db
2945        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2946        .await?;
2947    response.send(proto::JoinChannelChatResponse {
2948        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2949        messages,
2950    })?;
2951    Ok(())
2952}
2953
2954async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2955    let channel_id = ChannelId::from_proto(request.channel_id);
2956    session
2957        .db()
2958        .await
2959        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2960        .await?;
2961    Ok(())
2962}
2963
2964async fn get_channel_messages(
2965    request: proto::GetChannelMessages,
2966    response: Response<proto::GetChannelMessages>,
2967    session: Session,
2968) -> Result<()> {
2969    let channel_id = ChannelId::from_proto(request.channel_id);
2970    let messages = session
2971        .db()
2972        .await
2973        .get_channel_messages(
2974            channel_id,
2975            session.user_id,
2976            MESSAGE_COUNT_PER_PAGE,
2977            Some(MessageId::from_proto(request.before_message_id)),
2978        )
2979        .await?;
2980    response.send(proto::GetChannelMessagesResponse {
2981        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2982        messages,
2983    })?;
2984    Ok(())
2985}
2986
2987async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2988    let project_id = ProjectId::from_proto(request.project_id);
2989    let project_connection_ids = session
2990        .db()
2991        .await
2992        .project_connection_ids(project_id, session.connection_id)
2993        .await?;
2994    broadcast(
2995        Some(session.connection_id),
2996        project_connection_ids.iter().copied(),
2997        |connection_id| {
2998            session
2999                .peer
3000                .forward_send(session.connection_id, connection_id, request.clone())
3001        },
3002    );
3003    Ok(())
3004}
3005
3006async fn get_private_user_info(
3007    _request: proto::GetPrivateUserInfo,
3008    response: Response<proto::GetPrivateUserInfo>,
3009    session: Session,
3010) -> Result<()> {
3011    let db = session.db().await;
3012
3013    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3014    let user = db
3015        .get_user_by_id(session.user_id)
3016        .await?
3017        .ok_or_else(|| anyhow!("user not found"))?;
3018    let flags = db.get_user_flags(session.user_id).await?;
3019
3020    response.send(proto::GetPrivateUserInfoResponse {
3021        metrics_id,
3022        staff: user.admin,
3023        flags,
3024    })?;
3025    Ok(())
3026}
3027
3028fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3029    match message {
3030        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3031        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3032        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3033        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3034        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3035            code: frame.code.into(),
3036            reason: frame.reason,
3037        })),
3038    }
3039}
3040
3041fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3042    match message {
3043        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3044        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3045        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3046        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3047        AxumMessage::Close(frame) => {
3048            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3049                code: frame.code.into(),
3050                reason: frame.reason,
3051            }))
3052        }
3053    }
3054}
3055
3056fn build_initial_channels_update(
3057    channels: ChannelsForUser,
3058    channel_invites: Vec<db::Channel>,
3059) -> proto::UpdateChannels {
3060    let mut update = proto::UpdateChannels::default();
3061
3062    for channel in channels.channels.channels {
3063        update.channels.push(proto::Channel {
3064            id: channel.id.to_proto(),
3065            name: channel.name,
3066        });
3067    }
3068
3069    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3070    update.unseen_channel_messages = channels.channel_messages;
3071    update.insert_edge = channels.channels.edges;
3072
3073    for (channel_id, participants) in channels.channel_participants {
3074        update
3075            .channel_participants
3076            .push(proto::ChannelParticipants {
3077                channel_id: channel_id.to_proto(),
3078                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3079            });
3080    }
3081
3082    update
3083        .channel_permissions
3084        .extend(
3085            channels
3086                .channels_with_admin_privileges
3087                .into_iter()
3088                .map(|id| proto::ChannelPermission {
3089                    channel_id: id.to_proto(),
3090                    is_admin: true,
3091                }),
3092        );
3093
3094    for channel in channel_invites {
3095        update.channel_invitations.push(proto::Channel {
3096            id: channel.id.to_proto(),
3097            name: channel.name,
3098        });
3099    }
3100
3101    update
3102}
3103
3104fn build_initial_contacts_update(
3105    contacts: Vec<db::Contact>,
3106    pool: &ConnectionPool,
3107) -> proto::UpdateContacts {
3108    let mut update = proto::UpdateContacts::default();
3109
3110    for contact in contacts {
3111        match contact {
3112            db::Contact::Accepted {
3113                user_id,
3114                should_notify,
3115                busy,
3116            } => {
3117                update
3118                    .contacts
3119                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3120            }
3121            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3122            db::Contact::Incoming {
3123                user_id,
3124                should_notify,
3125            } => update
3126                .incoming_requests
3127                .push(proto::IncomingContactRequest {
3128                    requester_id: user_id.to_proto(),
3129                    should_notify,
3130                }),
3131        }
3132    }
3133
3134    update
3135}
3136
3137fn contact_for_user(
3138    user_id: UserId,
3139    should_notify: bool,
3140    busy: bool,
3141    pool: &ConnectionPool,
3142) -> proto::Contact {
3143    proto::Contact {
3144        user_id: user_id.to_proto(),
3145        online: pool.is_user_online(user_id),
3146        busy,
3147        should_notify,
3148    }
3149}
3150
3151fn room_updated(room: &proto::Room, peer: &Peer) {
3152    broadcast(
3153        None,
3154        room.participants
3155            .iter()
3156            .filter_map(|participant| Some(participant.peer_id?.into())),
3157        |peer_id| {
3158            peer.send(
3159                peer_id.into(),
3160                proto::RoomUpdated {
3161                    room: Some(room.clone()),
3162                },
3163            )
3164        },
3165    );
3166}
3167
3168fn channel_updated(
3169    channel_id: ChannelId,
3170    room: &proto::Room,
3171    channel_members: &[UserId],
3172    peer: &Peer,
3173    pool: &ConnectionPool,
3174) {
3175    let participants = room
3176        .participants
3177        .iter()
3178        .map(|p| p.user_id)
3179        .collect::<Vec<_>>();
3180
3181    broadcast(
3182        None,
3183        channel_members
3184            .iter()
3185            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3186        |peer_id| {
3187            peer.send(
3188                peer_id.into(),
3189                proto::UpdateChannels {
3190                    channel_participants: vec![proto::ChannelParticipants {
3191                        channel_id: channel_id.to_proto(),
3192                        participant_user_ids: participants.clone(),
3193                    }],
3194                    ..Default::default()
3195                },
3196            )
3197        },
3198    );
3199}
3200
3201async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3202    let db = session.db().await;
3203
3204    let contacts = db.get_contacts(user_id).await?;
3205    let busy = db.is_user_busy(user_id).await?;
3206
3207    let pool = session.connection_pool().await;
3208    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3209    for contact in contacts {
3210        if let db::Contact::Accepted {
3211            user_id: contact_user_id,
3212            ..
3213        } = contact
3214        {
3215            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3216                session
3217                    .peer
3218                    .send(
3219                        contact_conn_id,
3220                        proto::UpdateContacts {
3221                            contacts: vec![updated_contact.clone()],
3222                            remove_contacts: Default::default(),
3223                            incoming_requests: Default::default(),
3224                            remove_incoming_requests: Default::default(),
3225                            outgoing_requests: Default::default(),
3226                            remove_outgoing_requests: Default::default(),
3227                        },
3228                    )
3229                    .trace_err();
3230            }
3231        }
3232    }
3233    Ok(())
3234}
3235
3236async fn leave_room_for_session(session: &Session) -> Result<()> {
3237    let mut contacts_to_update = HashSet::default();
3238
3239    let room_id;
3240    let canceled_calls_to_user_ids;
3241    let live_kit_room;
3242    let delete_live_kit_room;
3243    let room;
3244    let channel_members;
3245    let channel_id;
3246
3247    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3248        contacts_to_update.insert(session.user_id);
3249
3250        for project in left_room.left_projects.values() {
3251            project_left(project, session);
3252        }
3253
3254        room_id = RoomId::from_proto(left_room.room.id);
3255        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3256        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3257        delete_live_kit_room = left_room.deleted;
3258        room = mem::take(&mut left_room.room);
3259        channel_members = mem::take(&mut left_room.channel_members);
3260        channel_id = left_room.channel_id;
3261
3262        room_updated(&room, &session.peer);
3263    } else {
3264        return Ok(());
3265    }
3266
3267    if let Some(channel_id) = channel_id {
3268        channel_updated(
3269            channel_id,
3270            &room,
3271            &channel_members,
3272            &session.peer,
3273            &*session.connection_pool().await,
3274        );
3275    }
3276
3277    {
3278        let pool = session.connection_pool().await;
3279        for canceled_user_id in canceled_calls_to_user_ids {
3280            for connection_id in pool.user_connection_ids(canceled_user_id) {
3281                session
3282                    .peer
3283                    .send(
3284                        connection_id,
3285                        proto::CallCanceled {
3286                            room_id: room_id.to_proto(),
3287                        },
3288                    )
3289                    .trace_err();
3290            }
3291            contacts_to_update.insert(canceled_user_id);
3292        }
3293    }
3294
3295    for contact_user_id in contacts_to_update {
3296        update_user_contacts(contact_user_id, &session).await?;
3297    }
3298
3299    if let Some(live_kit) = session.live_kit_client.as_ref() {
3300        live_kit
3301            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3302            .await
3303            .trace_err();
3304
3305        if delete_live_kit_room {
3306            live_kit.delete_room(live_kit_room).await.trace_err();
3307        }
3308    }
3309
3310    Ok(())
3311}
3312
3313async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3314    let left_channel_buffers = session
3315        .db()
3316        .await
3317        .leave_channel_buffers(session.connection_id)
3318        .await?;
3319
3320    for left_buffer in left_channel_buffers {
3321        channel_buffer_updated(
3322            session.connection_id,
3323            left_buffer.connections,
3324            &proto::UpdateChannelBufferCollaborators {
3325                channel_id: left_buffer.channel_id.to_proto(),
3326                collaborators: left_buffer.collaborators,
3327            },
3328            &session.peer,
3329        );
3330    }
3331
3332    Ok(())
3333}
3334
3335fn project_left(project: &db::LeftProject, session: &Session) {
3336    for connection_id in &project.connection_ids {
3337        if project.host_user_id == session.user_id {
3338            session
3339                .peer
3340                .send(
3341                    *connection_id,
3342                    proto::UnshareProject {
3343                        project_id: project.id.to_proto(),
3344                    },
3345                )
3346                .trace_err();
3347        } else {
3348            session
3349                .peer
3350                .send(
3351                    *connection_id,
3352                    proto::RemoveProjectCollaborator {
3353                        project_id: project.id.to_proto(),
3354                        peer_id: Some(session.connection_id.into()),
3355                    },
3356                )
3357                .trace_err();
3358        }
3359    }
3360}
3361
3362pub trait ResultExt {
3363    type Ok;
3364
3365    fn trace_err(self) -> Option<Self::Ok>;
3366}
3367
3368impl<T, E> ResultExt for Result<T, E>
3369where
3370    E: std::fmt::Debug,
3371{
3372    type Ok = T;
3373
3374    fn trace_err(self) -> Option<T> {
3375        match self {
3376            Ok(value) => Some(value),
3377            Err(error) => {
3378                tracing::error!("{:?}", error);
3379                None
3380            }
3381        }
3382    }
3383}