rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
   7        ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66
  67pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  68pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  69
  70const MESSAGE_COUNT_PER_PAGE: usize = 100;
  71const MAX_MESSAGE_LEN: usize = 1024;
  72
  73lazy_static! {
  74    static ref METRIC_CONNECTIONS: IntGauge =
  75        register_int_gauge!("connections", "number of connections").unwrap();
  76    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  77        "shared_projects",
  78        "number of open projects with one or more guests"
  79    )
  80    .unwrap();
  81}
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100#[derive(Clone)]
 101struct Session {
 102    user_id: UserId,
 103    connection_id: ConnectionId,
 104    db: Arc<tokio::sync::Mutex<DbHandle>>,
 105    peer: Arc<Peer>,
 106    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 107    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 108    executor: Executor,
 109}
 110
 111impl Session {
 112    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        let guard = self.db.lock().await;
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        guard
 119    }
 120
 121    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 122        #[cfg(test)]
 123        tokio::task::yield_now().await;
 124        let guard = self.connection_pool.lock();
 125        ConnectionPoolGuard {
 126            guard,
 127            _not_send: PhantomData,
 128        }
 129    }
 130}
 131
 132impl fmt::Debug for Session {
 133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 134        f.debug_struct("Session")
 135            .field("user_id", &self.user_id)
 136            .field("connection_id", &self.connection_id)
 137            .finish()
 138    }
 139}
 140
 141struct DbHandle(Arc<Database>);
 142
 143impl Deref for DbHandle {
 144    type Target = Database;
 145
 146    fn deref(&self) -> &Self::Target {
 147        self.0.as_ref()
 148    }
 149}
 150
 151pub struct Server {
 152    id: parking_lot::Mutex<ServerId>,
 153    peer: Arc<Peer>,
 154    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 155    app_state: Arc<AppState>,
 156    executor: Executor,
 157    handlers: HashMap<TypeId, MessageHandler>,
 158    teardown: watch::Sender<()>,
 159}
 160
 161pub(crate) struct ConnectionPoolGuard<'a> {
 162    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 163    _not_send: PhantomData<Rc<()>>,
 164}
 165
 166#[derive(Serialize)]
 167pub struct ServerSnapshot<'a> {
 168    peer: &'a Peer,
 169    #[serde(serialize_with = "serialize_deref")]
 170    connection_pool: ConnectionPoolGuard<'a>,
 171}
 172
 173pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 174where
 175    S: Serializer,
 176    T: Deref<Target = U>,
 177    U: Serialize,
 178{
 179    Serialize::serialize(value.deref(), serializer)
 180}
 181
 182impl Server {
 183    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 184        let mut server = Self {
 185            id: parking_lot::Mutex::new(id),
 186            peer: Peer::new(id.0 as u32),
 187            app_state,
 188            executor,
 189            connection_pool: Default::default(),
 190            handlers: Default::default(),
 191            teardown: watch::channel(()).0,
 192        };
 193
 194        server
 195            .add_request_handler(ping)
 196            .add_request_handler(create_room)
 197            .add_request_handler(join_room)
 198            .add_request_handler(rejoin_room)
 199            .add_request_handler(leave_room)
 200            .add_request_handler(call)
 201            .add_request_handler(cancel_call)
 202            .add_message_handler(decline_call)
 203            .add_request_handler(update_participant_location)
 204            .add_request_handler(share_project)
 205            .add_message_handler(unshare_project)
 206            .add_request_handler(join_project)
 207            .add_message_handler(leave_project)
 208            .add_request_handler(update_project)
 209            .add_request_handler(update_worktree)
 210            .add_message_handler(start_language_server)
 211            .add_message_handler(update_language_server)
 212            .add_message_handler(update_diagnostic_summary)
 213            .add_message_handler(update_worktree_settings)
 214            .add_message_handler(refresh_inlay_hints)
 215            .add_request_handler(forward_project_request::<proto::GetHover>)
 216            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 217            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetReferences>)
 219            .add_request_handler(forward_project_request::<proto::SearchProject>)
 220            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 221            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 222            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 225            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 226            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 227            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 228            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 229            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 230            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 231            .add_request_handler(forward_project_request::<proto::PerformRename>)
 232            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 233            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 234            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 235            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 241            .add_request_handler(forward_project_request::<proto::InlayHints>)
 242            .add_message_handler(create_buffer_for_peer)
 243            .add_request_handler(update_buffer)
 244            .add_message_handler(update_buffer_file)
 245            .add_message_handler(buffer_reloaded)
 246            .add_message_handler(buffer_saved)
 247            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 248            .add_request_handler(get_users)
 249            .add_request_handler(fuzzy_search_users)
 250            .add_request_handler(request_contact)
 251            .add_request_handler(remove_contact)
 252            .add_request_handler(respond_to_contact_request)
 253            .add_request_handler(create_channel)
 254            .add_request_handler(delete_channel)
 255            .add_request_handler(invite_channel_member)
 256            .add_request_handler(remove_channel_member)
 257            .add_request_handler(set_channel_member_admin)
 258            .add_request_handler(rename_channel)
 259            .add_request_handler(join_channel_buffer)
 260            .add_request_handler(leave_channel_buffer)
 261            .add_message_handler(update_channel_buffer)
 262            .add_request_handler(rejoin_channel_buffers)
 263            .add_request_handler(get_channel_members)
 264            .add_request_handler(respond_to_channel_invite)
 265            .add_request_handler(join_channel)
 266            .add_request_handler(join_channel_chat)
 267            .add_message_handler(leave_channel_chat)
 268            .add_request_handler(send_channel_message)
 269            .add_request_handler(remove_channel_message)
 270            .add_request_handler(get_channel_messages)
 271            .add_request_handler(link_channel)
 272            .add_request_handler(unlink_channel)
 273            .add_request_handler(move_channel)
 274            .add_request_handler(follow)
 275            .add_message_handler(unfollow)
 276            .add_message_handler(update_followers)
 277            .add_message_handler(update_diff_base)
 278            .add_request_handler(get_private_user_info)
 279            .add_message_handler(acknowledge_channel_message)
 280            .add_message_handler(acknowledge_buffer_version);
 281
 282        Arc::new(server)
 283    }
 284
 285    pub async fn start(&self) -> Result<()> {
 286        let server_id = *self.id.lock();
 287        let app_state = self.app_state.clone();
 288        let peer = self.peer.clone();
 289        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 290        let pool = self.connection_pool.clone();
 291        let live_kit_client = self.app_state.live_kit_client.clone();
 292
 293        let span = info_span!("start server");
 294        self.executor.spawn_detached(
 295            async move {
 296                tracing::info!("waiting for cleanup timeout");
 297                timeout.await;
 298                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 299                if let Some((room_ids, channel_ids)) = app_state
 300                    .db
 301                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 302                    .await
 303                    .trace_err()
 304                {
 305                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 306                    tracing::info!(
 307                        stale_channel_buffer_count = channel_ids.len(),
 308                        "retrieved stale channel buffers"
 309                    );
 310
 311                    for channel_id in channel_ids {
 312                        if let Some(refreshed_channel_buffer) = app_state
 313                            .db
 314                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 315                            .await
 316                            .trace_err()
 317                        {
 318                            for connection_id in refreshed_channel_buffer.connection_ids {
 319                                peer.send(
 320                                    connection_id,
 321                                    proto::UpdateChannelBufferCollaborators {
 322                                        channel_id: channel_id.to_proto(),
 323                                        collaborators: refreshed_channel_buffer
 324                                            .collaborators
 325                                            .clone(),
 326                                    },
 327                                )
 328                                .trace_err();
 329                            }
 330                        }
 331                    }
 332
 333                    for room_id in room_ids {
 334                        let mut contacts_to_update = HashSet::default();
 335                        let mut canceled_calls_to_user_ids = Vec::new();
 336                        let mut live_kit_room = String::new();
 337                        let mut delete_live_kit_room = false;
 338
 339                        if let Some(mut refreshed_room) = app_state
 340                            .db
 341                            .clear_stale_room_participants(room_id, server_id)
 342                            .await
 343                            .trace_err()
 344                        {
 345                            tracing::info!(
 346                                room_id = room_id.0,
 347                                new_participant_count = refreshed_room.room.participants.len(),
 348                                "refreshed room"
 349                            );
 350                            room_updated(&refreshed_room.room, &peer);
 351                            if let Some(channel_id) = refreshed_room.channel_id {
 352                                channel_updated(
 353                                    channel_id,
 354                                    &refreshed_room.room,
 355                                    &refreshed_room.channel_members,
 356                                    &peer,
 357                                    &*pool.lock(),
 358                                );
 359                            }
 360                            contacts_to_update
 361                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 362                            contacts_to_update
 363                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 364                            canceled_calls_to_user_ids =
 365                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 366                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 367                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 368                        }
 369
 370                        {
 371                            let pool = pool.lock();
 372                            for canceled_user_id in canceled_calls_to_user_ids {
 373                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 374                                    peer.send(
 375                                        connection_id,
 376                                        proto::CallCanceled {
 377                                            room_id: room_id.to_proto(),
 378                                        },
 379                                    )
 380                                    .trace_err();
 381                                }
 382                            }
 383                        }
 384
 385                        for user_id in contacts_to_update {
 386                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 387                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 388                            if let Some((busy, contacts)) = busy.zip(contacts) {
 389                                let pool = pool.lock();
 390                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 391                                for contact in contacts {
 392                                    if let db::Contact::Accepted {
 393                                        user_id: contact_user_id,
 394                                        ..
 395                                    } = contact
 396                                    {
 397                                        for contact_conn_id in
 398                                            pool.user_connection_ids(contact_user_id)
 399                                        {
 400                                            peer.send(
 401                                                contact_conn_id,
 402                                                proto::UpdateContacts {
 403                                                    contacts: vec![updated_contact.clone()],
 404                                                    remove_contacts: Default::default(),
 405                                                    incoming_requests: Default::default(),
 406                                                    remove_incoming_requests: Default::default(),
 407                                                    outgoing_requests: Default::default(),
 408                                                    remove_outgoing_requests: Default::default(),
 409                                                },
 410                                            )
 411                                            .trace_err();
 412                                        }
 413                                    }
 414                                }
 415                            }
 416                        }
 417
 418                        if let Some(live_kit) = live_kit_client.as_ref() {
 419                            if delete_live_kit_room {
 420                                live_kit.delete_room(live_kit_room).await.trace_err();
 421                            }
 422                        }
 423                    }
 424                }
 425
 426                app_state
 427                    .db
 428                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 429                    .await
 430                    .trace_err();
 431            }
 432            .instrument(span),
 433        );
 434        Ok(())
 435    }
 436
 437    pub fn teardown(&self) {
 438        self.peer.teardown();
 439        self.connection_pool.lock().reset();
 440        let _ = self.teardown.send(());
 441    }
 442
 443    #[cfg(test)]
 444    pub fn reset(&self, id: ServerId) {
 445        self.teardown();
 446        *self.id.lock() = id;
 447        self.peer.reset(id.0 as u32);
 448    }
 449
 450    #[cfg(test)]
 451    pub fn id(&self) -> ServerId {
 452        *self.id.lock()
 453    }
 454
 455    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 456    where
 457        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 458        Fut: 'static + Send + Future<Output = Result<()>>,
 459        M: EnvelopedMessage,
 460    {
 461        let prev_handler = self.handlers.insert(
 462            TypeId::of::<M>(),
 463            Box::new(move |envelope, session| {
 464                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 465                let span = info_span!(
 466                    "handle message",
 467                    payload_type = envelope.payload_type_name()
 468                );
 469                span.in_scope(|| {
 470                    tracing::info!(
 471                        payload_type = envelope.payload_type_name(),
 472                        "message received"
 473                    );
 474                });
 475                let start_time = Instant::now();
 476                let future = (handler)(*envelope, session);
 477                async move {
 478                    let result = future.await;
 479                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 480                    match result {
 481                        Err(error) => {
 482                            tracing::error!(%error, ?duration_ms, "error handling message")
 483                        }
 484                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 485                    }
 486                }
 487                .instrument(span)
 488                .boxed()
 489            }),
 490        );
 491        if prev_handler.is_some() {
 492            panic!("registered a handler for the same message twice");
 493        }
 494        self
 495    }
 496
 497    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 498    where
 499        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 500        Fut: 'static + Send + Future<Output = Result<()>>,
 501        M: EnvelopedMessage,
 502    {
 503        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 504        self
 505    }
 506
 507    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 510        Fut: Send + Future<Output = Result<()>>,
 511        M: RequestMessage,
 512    {
 513        let handler = Arc::new(handler);
 514        self.add_handler(move |envelope, session| {
 515            let receipt = envelope.receipt();
 516            let handler = handler.clone();
 517            async move {
 518                let peer = session.peer.clone();
 519                let responded = Arc::new(AtomicBool::default());
 520                let response = Response {
 521                    peer: peer.clone(),
 522                    responded: responded.clone(),
 523                    receipt,
 524                };
 525                match (handler)(envelope.payload, response, session).await {
 526                    Ok(()) => {
 527                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 528                            Ok(())
 529                        } else {
 530                            Err(anyhow!("handler did not send a response"))?
 531                        }
 532                    }
 533                    Err(error) => {
 534                        peer.respond_with_error(
 535                            receipt,
 536                            proto::Error {
 537                                message: error.to_string(),
 538                            },
 539                        )?;
 540                        Err(error)
 541                    }
 542                }
 543            }
 544        })
 545    }
 546
 547    pub fn handle_connection(
 548        self: &Arc<Self>,
 549        connection: Connection,
 550        address: String,
 551        user: User,
 552        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 553        executor: Executor,
 554    ) -> impl Future<Output = Result<()>> {
 555        let this = self.clone();
 556        let user_id = user.id;
 557        let login = user.github_login;
 558        let span = info_span!("handle connection", %user_id, %login, %address);
 559        let mut teardown = self.teardown.subscribe();
 560        async move {
 561            let (connection_id, handle_io, mut incoming_rx) = this
 562                .peer
 563                .add_connection(connection, {
 564                    let executor = executor.clone();
 565                    move |duration| executor.sleep(duration)
 566                });
 567
 568            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 569            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 570            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 571
 572            if let Some(send_connection_id) = send_connection_id.take() {
 573                let _ = send_connection_id.send(connection_id);
 574            }
 575
 576            if !user.connected_once {
 577                this.peer.send(connection_id, proto::ShowContacts {})?;
 578                this.app_state.db.set_user_connected_once(user_id, true).await?;
 579            }
 580
 581            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 582                this.app_state.db.get_contacts(user_id),
 583                this.app_state.db.get_channels_for_user(user_id),
 584                this.app_state.db.get_channel_invites_for_user(user_id)
 585            ).await?;
 586
 587            {
 588                let mut pool = this.connection_pool.lock();
 589                pool.add_connection(connection_id, user_id, user.admin);
 590                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 591                this.peer.send(connection_id, build_initial_channels_update(
 592                    channels_for_user,
 593                    channel_invites
 594                ))?;
 595            }
 596
 597            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 598                this.peer.send(connection_id, incoming_call)?;
 599            }
 600
 601            let session = Session {
 602                user_id,
 603                connection_id,
 604                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 605                peer: this.peer.clone(),
 606                connection_pool: this.connection_pool.clone(),
 607                live_kit_client: this.app_state.live_kit_client.clone(),
 608                executor: executor.clone(),
 609            };
 610            update_user_contacts(user_id, &session).await?;
 611
 612            let handle_io = handle_io.fuse();
 613            futures::pin_mut!(handle_io);
 614
 615            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 616            // This prevents deadlocks when e.g., client A performs a request to client B and
 617            // client B performs a request to client A. If both clients stop processing further
 618            // messages until their respective request completes, they won't have a chance to
 619            // respond to the other client's request and cause a deadlock.
 620            //
 621            // This arrangement ensures we will attempt to process earlier messages first, but fall
 622            // back to processing messages arrived later in the spirit of making progress.
 623            let mut foreground_message_handlers = FuturesUnordered::new();
 624            let concurrent_handlers = Arc::new(Semaphore::new(256));
 625            loop {
 626                let next_message = async {
 627                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 628                    let message = incoming_rx.next().await;
 629                    (permit, message)
 630                }.fuse();
 631                futures::pin_mut!(next_message);
 632                futures::select_biased! {
 633                    _ = teardown.changed().fuse() => return Ok(()),
 634                    result = handle_io => {
 635                        if let Err(error) = result {
 636                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 637                        }
 638                        break;
 639                    }
 640                    _ = foreground_message_handlers.next() => {}
 641                    next_message = next_message => {
 642                        let (permit, message) = next_message;
 643                        if let Some(message) = message {
 644                            let type_name = message.payload_type_name();
 645                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 646                            let span_enter = span.enter();
 647                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 648                                let is_background = message.is_background();
 649                                let handle_message = (handler)(message, session.clone());
 650                                drop(span_enter);
 651
 652                                let handle_message = async move {
 653                                    handle_message.await;
 654                                    drop(permit);
 655                                }.instrument(span);
 656                                if is_background {
 657                                    executor.spawn_detached(handle_message);
 658                                } else {
 659                                    foreground_message_handlers.push(handle_message);
 660                                }
 661                            } else {
 662                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 663                            }
 664                        } else {
 665                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 666                            break;
 667                        }
 668                    }
 669                }
 670            }
 671
 672            drop(foreground_message_handlers);
 673            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 674            if let Err(error) = connection_lost(session, teardown, executor).await {
 675                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 676            }
 677
 678            Ok(())
 679        }.instrument(span)
 680    }
 681
 682    pub async fn invite_code_redeemed(
 683        self: &Arc<Self>,
 684        inviter_id: UserId,
 685        invitee_id: UserId,
 686    ) -> Result<()> {
 687        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 688            if let Some(code) = &user.invite_code {
 689                let pool = self.connection_pool.lock();
 690                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 691                for connection_id in pool.user_connection_ids(inviter_id) {
 692                    self.peer.send(
 693                        connection_id,
 694                        proto::UpdateContacts {
 695                            contacts: vec![invitee_contact.clone()],
 696                            ..Default::default()
 697                        },
 698                    )?;
 699                    self.peer.send(
 700                        connection_id,
 701                        proto::UpdateInviteInfo {
 702                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 703                            count: user.invite_count as u32,
 704                        },
 705                    )?;
 706                }
 707            }
 708        }
 709        Ok(())
 710    }
 711
 712    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 713        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 714            if let Some(invite_code) = &user.invite_code {
 715                let pool = self.connection_pool.lock();
 716                for connection_id in pool.user_connection_ids(user_id) {
 717                    self.peer.send(
 718                        connection_id,
 719                        proto::UpdateInviteInfo {
 720                            url: format!(
 721                                "{}{}",
 722                                self.app_state.config.invite_link_prefix, invite_code
 723                            ),
 724                            count: user.invite_count as u32,
 725                        },
 726                    )?;
 727                }
 728            }
 729        }
 730        Ok(())
 731    }
 732
 733    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 734        ServerSnapshot {
 735            connection_pool: ConnectionPoolGuard {
 736                guard: self.connection_pool.lock(),
 737                _not_send: PhantomData,
 738            },
 739            peer: &self.peer,
 740        }
 741    }
 742}
 743
 744impl<'a> Deref for ConnectionPoolGuard<'a> {
 745    type Target = ConnectionPool;
 746
 747    fn deref(&self) -> &Self::Target {
 748        &*self.guard
 749    }
 750}
 751
 752impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 753    fn deref_mut(&mut self) -> &mut Self::Target {
 754        &mut *self.guard
 755    }
 756}
 757
 758impl<'a> Drop for ConnectionPoolGuard<'a> {
 759    fn drop(&mut self) {
 760        #[cfg(test)]
 761        self.check_invariants();
 762    }
 763}
 764
 765fn broadcast<F>(
 766    sender_id: Option<ConnectionId>,
 767    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 768    mut f: F,
 769) where
 770    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 771{
 772    for receiver_id in receiver_ids {
 773        if Some(receiver_id) != sender_id {
 774            if let Err(error) = f(receiver_id) {
 775                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 776            }
 777        }
 778    }
 779}
 780
 781lazy_static! {
 782    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 783}
 784
 785pub struct ProtocolVersion(u32);
 786
 787impl Header for ProtocolVersion {
 788    fn name() -> &'static HeaderName {
 789        &ZED_PROTOCOL_VERSION
 790    }
 791
 792    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 793    where
 794        Self: Sized,
 795        I: Iterator<Item = &'i axum::http::HeaderValue>,
 796    {
 797        let version = values
 798            .next()
 799            .ok_or_else(axum::headers::Error::invalid)?
 800            .to_str()
 801            .map_err(|_| axum::headers::Error::invalid())?
 802            .parse()
 803            .map_err(|_| axum::headers::Error::invalid())?;
 804        Ok(Self(version))
 805    }
 806
 807    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 808        values.extend([self.0.to_string().parse().unwrap()]);
 809    }
 810}
 811
 812pub fn routes(server: Arc<Server>) -> Router<Body> {
 813    Router::new()
 814        .route("/rpc", get(handle_websocket_request))
 815        .layer(
 816            ServiceBuilder::new()
 817                .layer(Extension(server.app_state.clone()))
 818                .layer(middleware::from_fn(auth::validate_header)),
 819        )
 820        .route("/metrics", get(handle_metrics))
 821        .layer(Extension(server))
 822}
 823
 824pub async fn handle_websocket_request(
 825    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 826    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 827    Extension(server): Extension<Arc<Server>>,
 828    Extension(user): Extension<User>,
 829    ws: WebSocketUpgrade,
 830) -> axum::response::Response {
 831    if protocol_version != rpc::PROTOCOL_VERSION {
 832        return (
 833            StatusCode::UPGRADE_REQUIRED,
 834            "client must be upgraded".to_string(),
 835        )
 836            .into_response();
 837    }
 838    let socket_address = socket_address.to_string();
 839    ws.on_upgrade(move |socket| {
 840        use util::ResultExt;
 841        let socket = socket
 842            .map_ok(to_tungstenite_message)
 843            .err_into()
 844            .with(|message| async move { Ok(to_axum_message(message)) });
 845        let connection = Connection::new(Box::pin(socket));
 846        async move {
 847            server
 848                .handle_connection(connection, socket_address, user, None, Executor::Production)
 849                .await
 850                .log_err();
 851        }
 852    })
 853}
 854
 855pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 856    let connections = server
 857        .connection_pool
 858        .lock()
 859        .connections()
 860        .filter(|connection| !connection.admin)
 861        .count();
 862
 863    METRIC_CONNECTIONS.set(connections as _);
 864
 865    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 866    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 867
 868    let encoder = prometheus::TextEncoder::new();
 869    let metric_families = prometheus::gather();
 870    let encoded_metrics = encoder
 871        .encode_to_string(&metric_families)
 872        .map_err(|err| anyhow!("{}", err))?;
 873    Ok(encoded_metrics)
 874}
 875
 876#[instrument(err, skip(executor))]
 877async fn connection_lost(
 878    session: Session,
 879    mut teardown: watch::Receiver<()>,
 880    executor: Executor,
 881) -> Result<()> {
 882    session.peer.disconnect(session.connection_id);
 883    session
 884        .connection_pool()
 885        .await
 886        .remove_connection(session.connection_id)?;
 887
 888    session
 889        .db()
 890        .await
 891        .connection_lost(session.connection_id)
 892        .await
 893        .trace_err();
 894
 895    futures::select_biased! {
 896        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 897            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 898            leave_room_for_session(&session).await.trace_err();
 899            leave_channel_buffers_for_session(&session)
 900                .await
 901                .trace_err();
 902
 903            if !session
 904                .connection_pool()
 905                .await
 906                .is_user_online(session.user_id)
 907            {
 908                let db = session.db().await;
 909                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 910                    room_updated(&room, &session.peer);
 911                }
 912            }
 913
 914            update_user_contacts(session.user_id, &session).await?;
 915        }
 916        _ = teardown.changed().fuse() => {}
 917    }
 918
 919    Ok(())
 920}
 921
 922async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 923    response.send(proto::Ack {})?;
 924    Ok(())
 925}
 926
 927async fn create_room(
 928    _request: proto::CreateRoom,
 929    response: Response<proto::CreateRoom>,
 930    session: Session,
 931) -> Result<()> {
 932    let live_kit_room = nanoid::nanoid!(30);
 933
 934    let live_kit_connection_info = {
 935        let live_kit_room = live_kit_room.clone();
 936        let live_kit = session.live_kit_client.as_ref();
 937
 938        util::async_iife!({
 939            let live_kit = live_kit?;
 940
 941            live_kit
 942                .create_room(live_kit_room.clone())
 943                .await
 944                .trace_err()?;
 945
 946            let token = live_kit
 947                .room_token(&live_kit_room, &session.user_id.to_string())
 948                .trace_err()?;
 949
 950            Some(proto::LiveKitConnectionInfo {
 951                server_url: live_kit.url().into(),
 952                token,
 953            })
 954        })
 955    }
 956    .await;
 957
 958    let room = session
 959        .db()
 960        .await
 961        .create_room(session.user_id, session.connection_id, &live_kit_room)
 962        .await?;
 963
 964    response.send(proto::CreateRoomResponse {
 965        room: Some(room.clone()),
 966        live_kit_connection_info,
 967    })?;
 968
 969    update_user_contacts(session.user_id, &session).await?;
 970    Ok(())
 971}
 972
 973async fn join_room(
 974    request: proto::JoinRoom,
 975    response: Response<proto::JoinRoom>,
 976    session: Session,
 977) -> Result<()> {
 978    let room_id = RoomId::from_proto(request.id);
 979    let joined_room = {
 980        let room = session
 981            .db()
 982            .await
 983            .join_room(room_id, session.user_id, session.connection_id)
 984            .await?;
 985        room_updated(&room.room, &session.peer);
 986        room.into_inner()
 987    };
 988
 989    if let Some(channel_id) = joined_room.channel_id {
 990        channel_updated(
 991            channel_id,
 992            &joined_room.room,
 993            &joined_room.channel_members,
 994            &session.peer,
 995            &*session.connection_pool().await,
 996        )
 997    }
 998
 999    for connection_id in session
1000        .connection_pool()
1001        .await
1002        .user_connection_ids(session.user_id)
1003    {
1004        session
1005            .peer
1006            .send(
1007                connection_id,
1008                proto::CallCanceled {
1009                    room_id: room_id.to_proto(),
1010                },
1011            )
1012            .trace_err();
1013    }
1014
1015    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1016        if let Some(token) = live_kit
1017            .room_token(
1018                &joined_room.room.live_kit_room,
1019                &session.user_id.to_string(),
1020            )
1021            .trace_err()
1022        {
1023            Some(proto::LiveKitConnectionInfo {
1024                server_url: live_kit.url().into(),
1025                token,
1026            })
1027        } else {
1028            None
1029        }
1030    } else {
1031        None
1032    };
1033
1034    response.send(proto::JoinRoomResponse {
1035        room: Some(joined_room.room),
1036        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1037        live_kit_connection_info,
1038    })?;
1039
1040    update_user_contacts(session.user_id, &session).await?;
1041    Ok(())
1042}
1043
1044async fn rejoin_room(
1045    request: proto::RejoinRoom,
1046    response: Response<proto::RejoinRoom>,
1047    session: Session,
1048) -> Result<()> {
1049    let room;
1050    let channel_id;
1051    let channel_members;
1052    {
1053        let mut rejoined_room = session
1054            .db()
1055            .await
1056            .rejoin_room(request, session.user_id, session.connection_id)
1057            .await?;
1058
1059        response.send(proto::RejoinRoomResponse {
1060            room: Some(rejoined_room.room.clone()),
1061            reshared_projects: rejoined_room
1062                .reshared_projects
1063                .iter()
1064                .map(|project| proto::ResharedProject {
1065                    id: project.id.to_proto(),
1066                    collaborators: project
1067                        .collaborators
1068                        .iter()
1069                        .map(|collaborator| collaborator.to_proto())
1070                        .collect(),
1071                })
1072                .collect(),
1073            rejoined_projects: rejoined_room
1074                .rejoined_projects
1075                .iter()
1076                .map(|rejoined_project| proto::RejoinedProject {
1077                    id: rejoined_project.id.to_proto(),
1078                    worktrees: rejoined_project
1079                        .worktrees
1080                        .iter()
1081                        .map(|worktree| proto::WorktreeMetadata {
1082                            id: worktree.id,
1083                            root_name: worktree.root_name.clone(),
1084                            visible: worktree.visible,
1085                            abs_path: worktree.abs_path.clone(),
1086                        })
1087                        .collect(),
1088                    collaborators: rejoined_project
1089                        .collaborators
1090                        .iter()
1091                        .map(|collaborator| collaborator.to_proto())
1092                        .collect(),
1093                    language_servers: rejoined_project.language_servers.clone(),
1094                })
1095                .collect(),
1096        })?;
1097        room_updated(&rejoined_room.room, &session.peer);
1098
1099        for project in &rejoined_room.reshared_projects {
1100            for collaborator in &project.collaborators {
1101                session
1102                    .peer
1103                    .send(
1104                        collaborator.connection_id,
1105                        proto::UpdateProjectCollaborator {
1106                            project_id: project.id.to_proto(),
1107                            old_peer_id: Some(project.old_connection_id.into()),
1108                            new_peer_id: Some(session.connection_id.into()),
1109                        },
1110                    )
1111                    .trace_err();
1112            }
1113
1114            broadcast(
1115                Some(session.connection_id),
1116                project
1117                    .collaborators
1118                    .iter()
1119                    .map(|collaborator| collaborator.connection_id),
1120                |connection_id| {
1121                    session.peer.forward_send(
1122                        session.connection_id,
1123                        connection_id,
1124                        proto::UpdateProject {
1125                            project_id: project.id.to_proto(),
1126                            worktrees: project.worktrees.clone(),
1127                        },
1128                    )
1129                },
1130            );
1131        }
1132
1133        for project in &rejoined_room.rejoined_projects {
1134            for collaborator in &project.collaborators {
1135                session
1136                    .peer
1137                    .send(
1138                        collaborator.connection_id,
1139                        proto::UpdateProjectCollaborator {
1140                            project_id: project.id.to_proto(),
1141                            old_peer_id: Some(project.old_connection_id.into()),
1142                            new_peer_id: Some(session.connection_id.into()),
1143                        },
1144                    )
1145                    .trace_err();
1146            }
1147        }
1148
1149        for project in &mut rejoined_room.rejoined_projects {
1150            for worktree in mem::take(&mut project.worktrees) {
1151                #[cfg(any(test, feature = "test-support"))]
1152                const MAX_CHUNK_SIZE: usize = 2;
1153                #[cfg(not(any(test, feature = "test-support")))]
1154                const MAX_CHUNK_SIZE: usize = 256;
1155
1156                // Stream this worktree's entries.
1157                let message = proto::UpdateWorktree {
1158                    project_id: project.id.to_proto(),
1159                    worktree_id: worktree.id,
1160                    abs_path: worktree.abs_path.clone(),
1161                    root_name: worktree.root_name,
1162                    updated_entries: worktree.updated_entries,
1163                    removed_entries: worktree.removed_entries,
1164                    scan_id: worktree.scan_id,
1165                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1166                    updated_repositories: worktree.updated_repositories,
1167                    removed_repositories: worktree.removed_repositories,
1168                };
1169                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1170                    session.peer.send(session.connection_id, update.clone())?;
1171                }
1172
1173                // Stream this worktree's diagnostics.
1174                for summary in worktree.diagnostic_summaries {
1175                    session.peer.send(
1176                        session.connection_id,
1177                        proto::UpdateDiagnosticSummary {
1178                            project_id: project.id.to_proto(),
1179                            worktree_id: worktree.id,
1180                            summary: Some(summary),
1181                        },
1182                    )?;
1183                }
1184
1185                for settings_file in worktree.settings_files {
1186                    session.peer.send(
1187                        session.connection_id,
1188                        proto::UpdateWorktreeSettings {
1189                            project_id: project.id.to_proto(),
1190                            worktree_id: worktree.id,
1191                            path: settings_file.path,
1192                            content: Some(settings_file.content),
1193                        },
1194                    )?;
1195                }
1196            }
1197
1198            for language_server in &project.language_servers {
1199                session.peer.send(
1200                    session.connection_id,
1201                    proto::UpdateLanguageServer {
1202                        project_id: project.id.to_proto(),
1203                        language_server_id: language_server.id,
1204                        variant: Some(
1205                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1206                                proto::LspDiskBasedDiagnosticsUpdated {},
1207                            ),
1208                        ),
1209                    },
1210                )?;
1211            }
1212        }
1213
1214        let rejoined_room = rejoined_room.into_inner();
1215
1216        room = rejoined_room.room;
1217        channel_id = rejoined_room.channel_id;
1218        channel_members = rejoined_room.channel_members;
1219    }
1220
1221    if let Some(channel_id) = channel_id {
1222        channel_updated(
1223            channel_id,
1224            &room,
1225            &channel_members,
1226            &session.peer,
1227            &*session.connection_pool().await,
1228        );
1229    }
1230
1231    update_user_contacts(session.user_id, &session).await?;
1232    Ok(())
1233}
1234
1235async fn leave_room(
1236    _: proto::LeaveRoom,
1237    response: Response<proto::LeaveRoom>,
1238    session: Session,
1239) -> Result<()> {
1240    leave_room_for_session(&session).await?;
1241    response.send(proto::Ack {})?;
1242    Ok(())
1243}
1244
1245async fn call(
1246    request: proto::Call,
1247    response: Response<proto::Call>,
1248    session: Session,
1249) -> Result<()> {
1250    let room_id = RoomId::from_proto(request.room_id);
1251    let calling_user_id = session.user_id;
1252    let calling_connection_id = session.connection_id;
1253    let called_user_id = UserId::from_proto(request.called_user_id);
1254    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1255    if !session
1256        .db()
1257        .await
1258        .has_contact(calling_user_id, called_user_id)
1259        .await?
1260    {
1261        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1262    }
1263
1264    let incoming_call = {
1265        let (room, incoming_call) = &mut *session
1266            .db()
1267            .await
1268            .call(
1269                room_id,
1270                calling_user_id,
1271                calling_connection_id,
1272                called_user_id,
1273                initial_project_id,
1274            )
1275            .await?;
1276        room_updated(&room, &session.peer);
1277        mem::take(incoming_call)
1278    };
1279    update_user_contacts(called_user_id, &session).await?;
1280
1281    let mut calls = session
1282        .connection_pool()
1283        .await
1284        .user_connection_ids(called_user_id)
1285        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1286        .collect::<FuturesUnordered<_>>();
1287
1288    while let Some(call_response) = calls.next().await {
1289        match call_response.as_ref() {
1290            Ok(_) => {
1291                response.send(proto::Ack {})?;
1292                return Ok(());
1293            }
1294            Err(_) => {
1295                call_response.trace_err();
1296            }
1297        }
1298    }
1299
1300    {
1301        let room = session
1302            .db()
1303            .await
1304            .call_failed(room_id, called_user_id)
1305            .await?;
1306        room_updated(&room, &session.peer);
1307    }
1308    update_user_contacts(called_user_id, &session).await?;
1309
1310    Err(anyhow!("failed to ring user"))?
1311}
1312
1313async fn cancel_call(
1314    request: proto::CancelCall,
1315    response: Response<proto::CancelCall>,
1316    session: Session,
1317) -> Result<()> {
1318    let called_user_id = UserId::from_proto(request.called_user_id);
1319    let room_id = RoomId::from_proto(request.room_id);
1320    {
1321        let room = session
1322            .db()
1323            .await
1324            .cancel_call(room_id, session.connection_id, called_user_id)
1325            .await?;
1326        room_updated(&room, &session.peer);
1327    }
1328
1329    for connection_id in session
1330        .connection_pool()
1331        .await
1332        .user_connection_ids(called_user_id)
1333    {
1334        session
1335            .peer
1336            .send(
1337                connection_id,
1338                proto::CallCanceled {
1339                    room_id: room_id.to_proto(),
1340                },
1341            )
1342            .trace_err();
1343    }
1344    response.send(proto::Ack {})?;
1345
1346    update_user_contacts(called_user_id, &session).await?;
1347    Ok(())
1348}
1349
1350async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1351    let room_id = RoomId::from_proto(message.room_id);
1352    {
1353        let room = session
1354            .db()
1355            .await
1356            .decline_call(Some(room_id), session.user_id)
1357            .await?
1358            .ok_or_else(|| anyhow!("failed to decline call"))?;
1359        room_updated(&room, &session.peer);
1360    }
1361
1362    for connection_id in session
1363        .connection_pool()
1364        .await
1365        .user_connection_ids(session.user_id)
1366    {
1367        session
1368            .peer
1369            .send(
1370                connection_id,
1371                proto::CallCanceled {
1372                    room_id: room_id.to_proto(),
1373                },
1374            )
1375            .trace_err();
1376    }
1377    update_user_contacts(session.user_id, &session).await?;
1378    Ok(())
1379}
1380
1381async fn update_participant_location(
1382    request: proto::UpdateParticipantLocation,
1383    response: Response<proto::UpdateParticipantLocation>,
1384    session: Session,
1385) -> Result<()> {
1386    let room_id = RoomId::from_proto(request.room_id);
1387    let location = request
1388        .location
1389        .ok_or_else(|| anyhow!("invalid location"))?;
1390
1391    let db = session.db().await;
1392    let room = db
1393        .update_room_participant_location(room_id, session.connection_id, location)
1394        .await?;
1395
1396    room_updated(&room, &session.peer);
1397    response.send(proto::Ack {})?;
1398    Ok(())
1399}
1400
1401async fn share_project(
1402    request: proto::ShareProject,
1403    response: Response<proto::ShareProject>,
1404    session: Session,
1405) -> Result<()> {
1406    let (project_id, room) = &*session
1407        .db()
1408        .await
1409        .share_project(
1410            RoomId::from_proto(request.room_id),
1411            session.connection_id,
1412            &request.worktrees,
1413        )
1414        .await?;
1415    response.send(proto::ShareProjectResponse {
1416        project_id: project_id.to_proto(),
1417    })?;
1418    room_updated(&room, &session.peer);
1419
1420    Ok(())
1421}
1422
1423async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1424    let project_id = ProjectId::from_proto(message.project_id);
1425
1426    let (room, guest_connection_ids) = &*session
1427        .db()
1428        .await
1429        .unshare_project(project_id, session.connection_id)
1430        .await?;
1431
1432    broadcast(
1433        Some(session.connection_id),
1434        guest_connection_ids.iter().copied(),
1435        |conn_id| session.peer.send(conn_id, message.clone()),
1436    );
1437    room_updated(&room, &session.peer);
1438
1439    Ok(())
1440}
1441
1442async fn join_project(
1443    request: proto::JoinProject,
1444    response: Response<proto::JoinProject>,
1445    session: Session,
1446) -> Result<()> {
1447    let project_id = ProjectId::from_proto(request.project_id);
1448    let guest_user_id = session.user_id;
1449
1450    tracing::info!(%project_id, "join project");
1451
1452    let (project, replica_id) = &mut *session
1453        .db()
1454        .await
1455        .join_project(project_id, session.connection_id)
1456        .await?;
1457
1458    let collaborators = project
1459        .collaborators
1460        .iter()
1461        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1462        .map(|collaborator| collaborator.to_proto())
1463        .collect::<Vec<_>>();
1464
1465    let worktrees = project
1466        .worktrees
1467        .iter()
1468        .map(|(id, worktree)| proto::WorktreeMetadata {
1469            id: *id,
1470            root_name: worktree.root_name.clone(),
1471            visible: worktree.visible,
1472            abs_path: worktree.abs_path.clone(),
1473        })
1474        .collect::<Vec<_>>();
1475
1476    for collaborator in &collaborators {
1477        session
1478            .peer
1479            .send(
1480                collaborator.peer_id.unwrap().into(),
1481                proto::AddProjectCollaborator {
1482                    project_id: project_id.to_proto(),
1483                    collaborator: Some(proto::Collaborator {
1484                        peer_id: Some(session.connection_id.into()),
1485                        replica_id: replica_id.0 as u32,
1486                        user_id: guest_user_id.to_proto(),
1487                    }),
1488                },
1489            )
1490            .trace_err();
1491    }
1492
1493    // First, we send the metadata associated with each worktree.
1494    response.send(proto::JoinProjectResponse {
1495        worktrees: worktrees.clone(),
1496        replica_id: replica_id.0 as u32,
1497        collaborators: collaborators.clone(),
1498        language_servers: project.language_servers.clone(),
1499    })?;
1500
1501    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1502        #[cfg(any(test, feature = "test-support"))]
1503        const MAX_CHUNK_SIZE: usize = 2;
1504        #[cfg(not(any(test, feature = "test-support")))]
1505        const MAX_CHUNK_SIZE: usize = 256;
1506
1507        // Stream this worktree's entries.
1508        let message = proto::UpdateWorktree {
1509            project_id: project_id.to_proto(),
1510            worktree_id,
1511            abs_path: worktree.abs_path.clone(),
1512            root_name: worktree.root_name,
1513            updated_entries: worktree.entries,
1514            removed_entries: Default::default(),
1515            scan_id: worktree.scan_id,
1516            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1517            updated_repositories: worktree.repository_entries.into_values().collect(),
1518            removed_repositories: Default::default(),
1519        };
1520        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1521            session.peer.send(session.connection_id, update.clone())?;
1522        }
1523
1524        // Stream this worktree's diagnostics.
1525        for summary in worktree.diagnostic_summaries {
1526            session.peer.send(
1527                session.connection_id,
1528                proto::UpdateDiagnosticSummary {
1529                    project_id: project_id.to_proto(),
1530                    worktree_id: worktree.id,
1531                    summary: Some(summary),
1532                },
1533            )?;
1534        }
1535
1536        for settings_file in worktree.settings_files {
1537            session.peer.send(
1538                session.connection_id,
1539                proto::UpdateWorktreeSettings {
1540                    project_id: project_id.to_proto(),
1541                    worktree_id: worktree.id,
1542                    path: settings_file.path,
1543                    content: Some(settings_file.content),
1544                },
1545            )?;
1546        }
1547    }
1548
1549    for language_server in &project.language_servers {
1550        session.peer.send(
1551            session.connection_id,
1552            proto::UpdateLanguageServer {
1553                project_id: project_id.to_proto(),
1554                language_server_id: language_server.id,
1555                variant: Some(
1556                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1557                        proto::LspDiskBasedDiagnosticsUpdated {},
1558                    ),
1559                ),
1560            },
1561        )?;
1562    }
1563
1564    Ok(())
1565}
1566
1567async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1568    let sender_id = session.connection_id;
1569    let project_id = ProjectId::from_proto(request.project_id);
1570
1571    let (room, project) = &*session
1572        .db()
1573        .await
1574        .leave_project(project_id, sender_id)
1575        .await?;
1576    tracing::info!(
1577        %project_id,
1578        host_user_id = %project.host_user_id,
1579        host_connection_id = %project.host_connection_id,
1580        "leave project"
1581    );
1582
1583    project_left(&project, &session);
1584    room_updated(&room, &session.peer);
1585
1586    Ok(())
1587}
1588
1589async fn update_project(
1590    request: proto::UpdateProject,
1591    response: Response<proto::UpdateProject>,
1592    session: Session,
1593) -> Result<()> {
1594    let project_id = ProjectId::from_proto(request.project_id);
1595    let (room, guest_connection_ids) = &*session
1596        .db()
1597        .await
1598        .update_project(project_id, session.connection_id, &request.worktrees)
1599        .await?;
1600    broadcast(
1601        Some(session.connection_id),
1602        guest_connection_ids.iter().copied(),
1603        |connection_id| {
1604            session
1605                .peer
1606                .forward_send(session.connection_id, connection_id, request.clone())
1607        },
1608    );
1609    room_updated(&room, &session.peer);
1610    response.send(proto::Ack {})?;
1611
1612    Ok(())
1613}
1614
1615async fn update_worktree(
1616    request: proto::UpdateWorktree,
1617    response: Response<proto::UpdateWorktree>,
1618    session: Session,
1619) -> Result<()> {
1620    let guest_connection_ids = session
1621        .db()
1622        .await
1623        .update_worktree(&request, session.connection_id)
1624        .await?;
1625
1626    broadcast(
1627        Some(session.connection_id),
1628        guest_connection_ids.iter().copied(),
1629        |connection_id| {
1630            session
1631                .peer
1632                .forward_send(session.connection_id, connection_id, request.clone())
1633        },
1634    );
1635    response.send(proto::Ack {})?;
1636    Ok(())
1637}
1638
1639async fn update_diagnostic_summary(
1640    message: proto::UpdateDiagnosticSummary,
1641    session: Session,
1642) -> Result<()> {
1643    let guest_connection_ids = session
1644        .db()
1645        .await
1646        .update_diagnostic_summary(&message, session.connection_id)
1647        .await?;
1648
1649    broadcast(
1650        Some(session.connection_id),
1651        guest_connection_ids.iter().copied(),
1652        |connection_id| {
1653            session
1654                .peer
1655                .forward_send(session.connection_id, connection_id, message.clone())
1656        },
1657    );
1658
1659    Ok(())
1660}
1661
1662async fn update_worktree_settings(
1663    message: proto::UpdateWorktreeSettings,
1664    session: Session,
1665) -> Result<()> {
1666    let guest_connection_ids = session
1667        .db()
1668        .await
1669        .update_worktree_settings(&message, session.connection_id)
1670        .await?;
1671
1672    broadcast(
1673        Some(session.connection_id),
1674        guest_connection_ids.iter().copied(),
1675        |connection_id| {
1676            session
1677                .peer
1678                .forward_send(session.connection_id, connection_id, message.clone())
1679        },
1680    );
1681
1682    Ok(())
1683}
1684
1685async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1686    broadcast_project_message(request.project_id, request, session).await
1687}
1688
1689async fn start_language_server(
1690    request: proto::StartLanguageServer,
1691    session: Session,
1692) -> Result<()> {
1693    let guest_connection_ids = session
1694        .db()
1695        .await
1696        .start_language_server(&request, session.connection_id)
1697        .await?;
1698
1699    broadcast(
1700        Some(session.connection_id),
1701        guest_connection_ids.iter().copied(),
1702        |connection_id| {
1703            session
1704                .peer
1705                .forward_send(session.connection_id, connection_id, request.clone())
1706        },
1707    );
1708    Ok(())
1709}
1710
1711async fn update_language_server(
1712    request: proto::UpdateLanguageServer,
1713    session: Session,
1714) -> Result<()> {
1715    session.executor.record_backtrace();
1716    let project_id = ProjectId::from_proto(request.project_id);
1717    let project_connection_ids = session
1718        .db()
1719        .await
1720        .project_connection_ids(project_id, session.connection_id)
1721        .await?;
1722    broadcast(
1723        Some(session.connection_id),
1724        project_connection_ids.iter().copied(),
1725        |connection_id| {
1726            session
1727                .peer
1728                .forward_send(session.connection_id, connection_id, request.clone())
1729        },
1730    );
1731    Ok(())
1732}
1733
1734async fn forward_project_request<T>(
1735    request: T,
1736    response: Response<T>,
1737    session: Session,
1738) -> Result<()>
1739where
1740    T: EntityMessage + RequestMessage,
1741{
1742    session.executor.record_backtrace();
1743    let project_id = ProjectId::from_proto(request.remote_entity_id());
1744    let host_connection_id = {
1745        let collaborators = session
1746            .db()
1747            .await
1748            .project_collaborators(project_id, session.connection_id)
1749            .await?;
1750        collaborators
1751            .iter()
1752            .find(|collaborator| collaborator.is_host)
1753            .ok_or_else(|| anyhow!("host not found"))?
1754            .connection_id
1755    };
1756
1757    let payload = session
1758        .peer
1759        .forward_request(session.connection_id, host_connection_id, request)
1760        .await?;
1761
1762    response.send(payload)?;
1763    Ok(())
1764}
1765
1766async fn create_buffer_for_peer(
1767    request: proto::CreateBufferForPeer,
1768    session: Session,
1769) -> Result<()> {
1770    session.executor.record_backtrace();
1771    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1772    session
1773        .peer
1774        .forward_send(session.connection_id, peer_id.into(), request)?;
1775    Ok(())
1776}
1777
1778async fn update_buffer(
1779    request: proto::UpdateBuffer,
1780    response: Response<proto::UpdateBuffer>,
1781    session: Session,
1782) -> Result<()> {
1783    session.executor.record_backtrace();
1784    let project_id = ProjectId::from_proto(request.project_id);
1785    let mut guest_connection_ids;
1786    let mut host_connection_id = None;
1787    {
1788        let collaborators = session
1789            .db()
1790            .await
1791            .project_collaborators(project_id, session.connection_id)
1792            .await?;
1793        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1794        for collaborator in collaborators.iter() {
1795            if collaborator.is_host {
1796                host_connection_id = Some(collaborator.connection_id);
1797            } else {
1798                guest_connection_ids.push(collaborator.connection_id);
1799            }
1800        }
1801    }
1802    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1803
1804    session.executor.record_backtrace();
1805    broadcast(
1806        Some(session.connection_id),
1807        guest_connection_ids,
1808        |connection_id| {
1809            session
1810                .peer
1811                .forward_send(session.connection_id, connection_id, request.clone())
1812        },
1813    );
1814    if host_connection_id != session.connection_id {
1815        session
1816            .peer
1817            .forward_request(session.connection_id, host_connection_id, request.clone())
1818            .await?;
1819    }
1820
1821    response.send(proto::Ack {})?;
1822    Ok(())
1823}
1824
1825async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1826    let project_id = ProjectId::from_proto(request.project_id);
1827    let project_connection_ids = session
1828        .db()
1829        .await
1830        .project_connection_ids(project_id, session.connection_id)
1831        .await?;
1832
1833    broadcast(
1834        Some(session.connection_id),
1835        project_connection_ids.iter().copied(),
1836        |connection_id| {
1837            session
1838                .peer
1839                .forward_send(session.connection_id, connection_id, request.clone())
1840        },
1841    );
1842    Ok(())
1843}
1844
1845async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1846    let project_id = ProjectId::from_proto(request.project_id);
1847    let project_connection_ids = session
1848        .db()
1849        .await
1850        .project_connection_ids(project_id, session.connection_id)
1851        .await?;
1852    broadcast(
1853        Some(session.connection_id),
1854        project_connection_ids.iter().copied(),
1855        |connection_id| {
1856            session
1857                .peer
1858                .forward_send(session.connection_id, connection_id, request.clone())
1859        },
1860    );
1861    Ok(())
1862}
1863
1864async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1865    broadcast_project_message(request.project_id, request, session).await
1866}
1867
1868async fn broadcast_project_message<T: EnvelopedMessage>(
1869    project_id: u64,
1870    request: T,
1871    session: Session,
1872) -> Result<()> {
1873    let project_id = ProjectId::from_proto(project_id);
1874    let project_connection_ids = session
1875        .db()
1876        .await
1877        .project_connection_ids(project_id, session.connection_id)
1878        .await?;
1879    broadcast(
1880        Some(session.connection_id),
1881        project_connection_ids.iter().copied(),
1882        |connection_id| {
1883            session
1884                .peer
1885                .forward_send(session.connection_id, connection_id, request.clone())
1886        },
1887    );
1888    Ok(())
1889}
1890
1891async fn follow(
1892    request: proto::Follow,
1893    response: Response<proto::Follow>,
1894    session: Session,
1895) -> Result<()> {
1896    let room_id = RoomId::from_proto(request.room_id);
1897    let project_id = request.project_id.map(ProjectId::from_proto);
1898    let leader_id = request
1899        .leader_id
1900        .ok_or_else(|| anyhow!("invalid leader id"))?
1901        .into();
1902    let follower_id = session.connection_id;
1903
1904    session
1905        .db()
1906        .await
1907        .check_room_participants(room_id, leader_id, session.connection_id)
1908        .await?;
1909
1910    let response_payload = session
1911        .peer
1912        .forward_request(session.connection_id, leader_id, request)
1913        .await?;
1914    response.send(response_payload)?;
1915
1916    if let Some(project_id) = project_id {
1917        let room = session
1918            .db()
1919            .await
1920            .follow(room_id, project_id, leader_id, follower_id)
1921            .await?;
1922        room_updated(&room, &session.peer);
1923    }
1924
1925    Ok(())
1926}
1927
1928async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1929    let room_id = RoomId::from_proto(request.room_id);
1930    let project_id = request.project_id.map(ProjectId::from_proto);
1931    let leader_id = request
1932        .leader_id
1933        .ok_or_else(|| anyhow!("invalid leader id"))?
1934        .into();
1935    let follower_id = session.connection_id;
1936
1937    session
1938        .db()
1939        .await
1940        .check_room_participants(room_id, leader_id, session.connection_id)
1941        .await?;
1942
1943    session
1944        .peer
1945        .forward_send(session.connection_id, leader_id, request)?;
1946
1947    if let Some(project_id) = project_id {
1948        let room = session
1949            .db()
1950            .await
1951            .unfollow(room_id, project_id, leader_id, follower_id)
1952            .await?;
1953        room_updated(&room, &session.peer);
1954    }
1955
1956    Ok(())
1957}
1958
1959async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1960    let room_id = RoomId::from_proto(request.room_id);
1961    let database = session.db.lock().await;
1962
1963    let connection_ids = if let Some(project_id) = request.project_id {
1964        let project_id = ProjectId::from_proto(project_id);
1965        database
1966            .project_connection_ids(project_id, session.connection_id)
1967            .await?
1968    } else {
1969        database
1970            .room_connection_ids(room_id, session.connection_id)
1971            .await?
1972    };
1973
1974    // For now, don't send view update messages back to that view's current leader.
1975    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1976        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1977        _ => None,
1978    });
1979
1980    for follower_peer_id in request.follower_ids.iter().copied() {
1981        let follower_connection_id = follower_peer_id.into();
1982        if Some(follower_peer_id) != connection_id_to_omit
1983            && connection_ids.contains(&follower_connection_id)
1984        {
1985            session.peer.forward_send(
1986                session.connection_id,
1987                follower_connection_id,
1988                request.clone(),
1989            )?;
1990        }
1991    }
1992    Ok(())
1993}
1994
1995async fn get_users(
1996    request: proto::GetUsers,
1997    response: Response<proto::GetUsers>,
1998    session: Session,
1999) -> Result<()> {
2000    let user_ids = request
2001        .user_ids
2002        .into_iter()
2003        .map(UserId::from_proto)
2004        .collect();
2005    let users = session
2006        .db()
2007        .await
2008        .get_users_by_ids(user_ids)
2009        .await?
2010        .into_iter()
2011        .map(|user| proto::User {
2012            id: user.id.to_proto(),
2013            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2014            github_login: user.github_login,
2015        })
2016        .collect();
2017    response.send(proto::UsersResponse { users })?;
2018    Ok(())
2019}
2020
2021async fn fuzzy_search_users(
2022    request: proto::FuzzySearchUsers,
2023    response: Response<proto::FuzzySearchUsers>,
2024    session: Session,
2025) -> Result<()> {
2026    let query = request.query;
2027    let users = match query.len() {
2028        0 => vec![],
2029        1 | 2 => session
2030            .db()
2031            .await
2032            .get_user_by_github_login(&query)
2033            .await?
2034            .into_iter()
2035            .collect(),
2036        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2037    };
2038    let users = users
2039        .into_iter()
2040        .filter(|user| user.id != session.user_id)
2041        .map(|user| proto::User {
2042            id: user.id.to_proto(),
2043            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2044            github_login: user.github_login,
2045        })
2046        .collect();
2047    response.send(proto::UsersResponse { users })?;
2048    Ok(())
2049}
2050
2051async fn request_contact(
2052    request: proto::RequestContact,
2053    response: Response<proto::RequestContact>,
2054    session: Session,
2055) -> Result<()> {
2056    let requester_id = session.user_id;
2057    let responder_id = UserId::from_proto(request.responder_id);
2058    if requester_id == responder_id {
2059        return Err(anyhow!("cannot add yourself as a contact"))?;
2060    }
2061
2062    session
2063        .db()
2064        .await
2065        .send_contact_request(requester_id, responder_id)
2066        .await?;
2067
2068    // Update outgoing contact requests of requester
2069    let mut update = proto::UpdateContacts::default();
2070    update.outgoing_requests.push(responder_id.to_proto());
2071    for connection_id in session
2072        .connection_pool()
2073        .await
2074        .user_connection_ids(requester_id)
2075    {
2076        session.peer.send(connection_id, update.clone())?;
2077    }
2078
2079    // Update incoming contact requests of responder
2080    let mut update = proto::UpdateContacts::default();
2081    update
2082        .incoming_requests
2083        .push(proto::IncomingContactRequest {
2084            requester_id: requester_id.to_proto(),
2085            should_notify: true,
2086        });
2087    for connection_id in session
2088        .connection_pool()
2089        .await
2090        .user_connection_ids(responder_id)
2091    {
2092        session.peer.send(connection_id, update.clone())?;
2093    }
2094
2095    response.send(proto::Ack {})?;
2096    Ok(())
2097}
2098
2099async fn respond_to_contact_request(
2100    request: proto::RespondToContactRequest,
2101    response: Response<proto::RespondToContactRequest>,
2102    session: Session,
2103) -> Result<()> {
2104    let responder_id = session.user_id;
2105    let requester_id = UserId::from_proto(request.requester_id);
2106    let db = session.db().await;
2107    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2108        db.dismiss_contact_notification(responder_id, requester_id)
2109            .await?;
2110    } else {
2111        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2112
2113        db.respond_to_contact_request(responder_id, requester_id, accept)
2114            .await?;
2115        let requester_busy = db.is_user_busy(requester_id).await?;
2116        let responder_busy = db.is_user_busy(responder_id).await?;
2117
2118        let pool = session.connection_pool().await;
2119        // Update responder with new contact
2120        let mut update = proto::UpdateContacts::default();
2121        if accept {
2122            update
2123                .contacts
2124                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2125        }
2126        update
2127            .remove_incoming_requests
2128            .push(requester_id.to_proto());
2129        for connection_id in pool.user_connection_ids(responder_id) {
2130            session.peer.send(connection_id, update.clone())?;
2131        }
2132
2133        // Update requester with new contact
2134        let mut update = proto::UpdateContacts::default();
2135        if accept {
2136            update
2137                .contacts
2138                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2139        }
2140        update
2141            .remove_outgoing_requests
2142            .push(responder_id.to_proto());
2143        for connection_id in pool.user_connection_ids(requester_id) {
2144            session.peer.send(connection_id, update.clone())?;
2145        }
2146    }
2147
2148    response.send(proto::Ack {})?;
2149    Ok(())
2150}
2151
2152async fn remove_contact(
2153    request: proto::RemoveContact,
2154    response: Response<proto::RemoveContact>,
2155    session: Session,
2156) -> Result<()> {
2157    let requester_id = session.user_id;
2158    let responder_id = UserId::from_proto(request.user_id);
2159    let db = session.db().await;
2160    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2161
2162    let pool = session.connection_pool().await;
2163    // Update outgoing contact requests of requester
2164    let mut update = proto::UpdateContacts::default();
2165    if contact_accepted {
2166        update.remove_contacts.push(responder_id.to_proto());
2167    } else {
2168        update
2169            .remove_outgoing_requests
2170            .push(responder_id.to_proto());
2171    }
2172    for connection_id in pool.user_connection_ids(requester_id) {
2173        session.peer.send(connection_id, update.clone())?;
2174    }
2175
2176    // Update incoming contact requests of responder
2177    let mut update = proto::UpdateContacts::default();
2178    if contact_accepted {
2179        update.remove_contacts.push(requester_id.to_proto());
2180    } else {
2181        update
2182            .remove_incoming_requests
2183            .push(requester_id.to_proto());
2184    }
2185    for connection_id in pool.user_connection_ids(responder_id) {
2186        session.peer.send(connection_id, update.clone())?;
2187    }
2188
2189    response.send(proto::Ack {})?;
2190    Ok(())
2191}
2192
2193async fn create_channel(
2194    request: proto::CreateChannel,
2195    response: Response<proto::CreateChannel>,
2196    session: Session,
2197) -> Result<()> {
2198    let db = session.db().await;
2199    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2200
2201    if let Some(live_kit) = session.live_kit_client.as_ref() {
2202        live_kit.create_room(live_kit_room.clone()).await?;
2203    }
2204
2205    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2206    let id = db
2207        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2208        .await?;
2209
2210    let channel = proto::Channel {
2211        id: id.to_proto(),
2212        name: request.name,
2213    };
2214
2215    response.send(proto::CreateChannelResponse {
2216        channel: Some(channel.clone()),
2217        parent_id: request.parent_id,
2218    })?;
2219
2220    let Some(parent_id) = parent_id else {
2221        return Ok(());
2222    };
2223
2224    let update = proto::UpdateChannels {
2225        channels: vec![channel],
2226        insert_edge: vec![ChannelEdge {
2227            parent_id: parent_id.to_proto(),
2228            channel_id: id.to_proto(),
2229        }],
2230        ..Default::default()
2231    };
2232
2233    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2234
2235    let connection_pool = session.connection_pool().await;
2236    for user_id in user_ids_to_notify {
2237        for connection_id in connection_pool.user_connection_ids(user_id) {
2238            if user_id == session.user_id {
2239                continue;
2240            }
2241            session.peer.send(connection_id, update.clone())?;
2242        }
2243    }
2244
2245    Ok(())
2246}
2247
2248async fn delete_channel(
2249    request: proto::DeleteChannel,
2250    response: Response<proto::DeleteChannel>,
2251    session: Session,
2252) -> Result<()> {
2253    let db = session.db().await;
2254
2255    let channel_id = request.channel_id;
2256    let (removed_channels, member_ids) = db
2257        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2258        .await?;
2259    response.send(proto::Ack {})?;
2260
2261    // Notify members of removed channels
2262    let mut update = proto::UpdateChannels::default();
2263    update
2264        .delete_channels
2265        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2266
2267    let connection_pool = session.connection_pool().await;
2268    for member_id in member_ids {
2269        for connection_id in connection_pool.user_connection_ids(member_id) {
2270            session.peer.send(connection_id, update.clone())?;
2271        }
2272    }
2273
2274    Ok(())
2275}
2276
2277async fn invite_channel_member(
2278    request: proto::InviteChannelMember,
2279    response: Response<proto::InviteChannelMember>,
2280    session: Session,
2281) -> Result<()> {
2282    let db = session.db().await;
2283    let channel_id = ChannelId::from_proto(request.channel_id);
2284    let invitee_id = UserId::from_proto(request.user_id);
2285    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2286        .await?;
2287
2288    let (channel, _) = db
2289        .get_channel(channel_id, session.user_id)
2290        .await?
2291        .ok_or_else(|| anyhow!("channel not found"))?;
2292
2293    let mut update = proto::UpdateChannels::default();
2294    update.channel_invitations.push(proto::Channel {
2295        id: channel.id.to_proto(),
2296        name: channel.name,
2297    });
2298    for connection_id in session
2299        .connection_pool()
2300        .await
2301        .user_connection_ids(invitee_id)
2302    {
2303        session.peer.send(connection_id, update.clone())?;
2304    }
2305
2306    response.send(proto::Ack {})?;
2307    Ok(())
2308}
2309
2310async fn remove_channel_member(
2311    request: proto::RemoveChannelMember,
2312    response: Response<proto::RemoveChannelMember>,
2313    session: Session,
2314) -> Result<()> {
2315    let db = session.db().await;
2316    let channel_id = ChannelId::from_proto(request.channel_id);
2317    let member_id = UserId::from_proto(request.user_id);
2318
2319    db.remove_channel_member(channel_id, member_id, session.user_id)
2320        .await?;
2321
2322    let mut update = proto::UpdateChannels::default();
2323    update.delete_channels.push(channel_id.to_proto());
2324
2325    for connection_id in session
2326        .connection_pool()
2327        .await
2328        .user_connection_ids(member_id)
2329    {
2330        session.peer.send(connection_id, update.clone())?;
2331    }
2332
2333    response.send(proto::Ack {})?;
2334    Ok(())
2335}
2336
2337async fn set_channel_member_admin(
2338    request: proto::SetChannelMemberAdmin,
2339    response: Response<proto::SetChannelMemberAdmin>,
2340    session: Session,
2341) -> Result<()> {
2342    let db = session.db().await;
2343    let channel_id = ChannelId::from_proto(request.channel_id);
2344    let member_id = UserId::from_proto(request.user_id);
2345    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2346        .await?;
2347
2348    let (channel, has_accepted) = db
2349        .get_channel(channel_id, member_id)
2350        .await?
2351        .ok_or_else(|| anyhow!("channel not found"))?;
2352
2353    let mut update = proto::UpdateChannels::default();
2354    if has_accepted {
2355        update.channel_permissions.push(proto::ChannelPermission {
2356            channel_id: channel.id.to_proto(),
2357            is_admin: request.admin,
2358        });
2359    }
2360
2361    for connection_id in session
2362        .connection_pool()
2363        .await
2364        .user_connection_ids(member_id)
2365    {
2366        session.peer.send(connection_id, update.clone())?;
2367    }
2368
2369    response.send(proto::Ack {})?;
2370    Ok(())
2371}
2372
2373async fn rename_channel(
2374    request: proto::RenameChannel,
2375    response: Response<proto::RenameChannel>,
2376    session: Session,
2377) -> Result<()> {
2378    let db = session.db().await;
2379    let channel_id = ChannelId::from_proto(request.channel_id);
2380    let new_name = db
2381        .rename_channel(channel_id, session.user_id, &request.name)
2382        .await?;
2383
2384    let channel = proto::Channel {
2385        id: request.channel_id,
2386        name: new_name,
2387    };
2388    response.send(proto::RenameChannelResponse {
2389        channel: Some(channel.clone()),
2390    })?;
2391    let mut update = proto::UpdateChannels::default();
2392    update.channels.push(channel);
2393
2394    let member_ids = db.get_channel_members(channel_id).await?;
2395
2396    let connection_pool = session.connection_pool().await;
2397    for member_id in member_ids {
2398        for connection_id in connection_pool.user_connection_ids(member_id) {
2399            session.peer.send(connection_id, update.clone())?;
2400        }
2401    }
2402
2403    Ok(())
2404}
2405
2406async fn link_channel(
2407    request: proto::LinkChannel,
2408    response: Response<proto::LinkChannel>,
2409    session: Session,
2410) -> Result<()> {
2411    let db = session.db().await;
2412    let channel_id = ChannelId::from_proto(request.channel_id);
2413    let to = ChannelId::from_proto(request.to);
2414    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2415
2416    let members = db.get_channel_members(to).await?;
2417    let connection_pool = session.connection_pool().await;
2418    let update = proto::UpdateChannels {
2419        channels: channels_to_send
2420            .channels
2421            .into_iter()
2422            .map(|channel| proto::Channel {
2423                id: channel.id.to_proto(),
2424                name: channel.name,
2425            })
2426            .collect(),
2427        insert_edge: channels_to_send.edges,
2428        ..Default::default()
2429    };
2430    for member_id in members {
2431        for connection_id in connection_pool.user_connection_ids(member_id) {
2432            session.peer.send(connection_id, update.clone())?;
2433        }
2434    }
2435
2436    response.send(Ack {})?;
2437
2438    Ok(())
2439}
2440
2441async fn unlink_channel(
2442    request: proto::UnlinkChannel,
2443    response: Response<proto::UnlinkChannel>,
2444    session: Session,
2445) -> Result<()> {
2446    let db = session.db().await;
2447    let channel_id = ChannelId::from_proto(request.channel_id);
2448    let from = ChannelId::from_proto(request.from);
2449
2450    db.unlink_channel(session.user_id, channel_id, from).await?;
2451
2452    let members = db.get_channel_members(from).await?;
2453
2454    let update = proto::UpdateChannels {
2455        delete_edge: vec![proto::ChannelEdge {
2456            channel_id: channel_id.to_proto(),
2457            parent_id: from.to_proto(),
2458        }],
2459        ..Default::default()
2460    };
2461    let connection_pool = session.connection_pool().await;
2462    for member_id in members {
2463        for connection_id in connection_pool.user_connection_ids(member_id) {
2464            session.peer.send(connection_id, update.clone())?;
2465        }
2466    }
2467
2468    response.send(Ack {})?;
2469
2470    Ok(())
2471}
2472
2473async fn move_channel(
2474    request: proto::MoveChannel,
2475    response: Response<proto::MoveChannel>,
2476    session: Session,
2477) -> Result<()> {
2478    let db = session.db().await;
2479    let channel_id = ChannelId::from_proto(request.channel_id);
2480    let from_parent = ChannelId::from_proto(request.from);
2481    let to = ChannelId::from_proto(request.to);
2482
2483    let channels_to_send = db
2484        .move_channel(session.user_id, channel_id, from_parent, to)
2485        .await?;
2486
2487    if channels_to_send.is_empty() {
2488        response.send(Ack {})?;
2489        return Ok(());
2490    }
2491
2492    let members_from = db.get_channel_members(from_parent).await?;
2493    let members_to = db.get_channel_members(to).await?;
2494
2495    let update = proto::UpdateChannels {
2496        delete_edge: vec![proto::ChannelEdge {
2497            channel_id: channel_id.to_proto(),
2498            parent_id: from_parent.to_proto(),
2499        }],
2500        ..Default::default()
2501    };
2502    let connection_pool = session.connection_pool().await;
2503    for member_id in members_from {
2504        for connection_id in connection_pool.user_connection_ids(member_id) {
2505            session.peer.send(connection_id, update.clone())?;
2506        }
2507    }
2508
2509    let update = proto::UpdateChannels {
2510        channels: channels_to_send
2511            .channels
2512            .into_iter()
2513            .map(|channel| proto::Channel {
2514                id: channel.id.to_proto(),
2515                name: channel.name,
2516            })
2517            .collect(),
2518        insert_edge: channels_to_send.edges,
2519        ..Default::default()
2520    };
2521    for member_id in members_to {
2522        for connection_id in connection_pool.user_connection_ids(member_id) {
2523            session.peer.send(connection_id, update.clone())?;
2524        }
2525    }
2526
2527    response.send(Ack {})?;
2528
2529    Ok(())
2530}
2531
2532async fn get_channel_members(
2533    request: proto::GetChannelMembers,
2534    response: Response<proto::GetChannelMembers>,
2535    session: Session,
2536) -> Result<()> {
2537    let db = session.db().await;
2538    let channel_id = ChannelId::from_proto(request.channel_id);
2539    let members = db
2540        .get_channel_member_details(channel_id, session.user_id)
2541        .await?;
2542    response.send(proto::GetChannelMembersResponse { members })?;
2543    Ok(())
2544}
2545
2546async fn respond_to_channel_invite(
2547    request: proto::RespondToChannelInvite,
2548    response: Response<proto::RespondToChannelInvite>,
2549    session: Session,
2550) -> Result<()> {
2551    let db = session.db().await;
2552    let channel_id = ChannelId::from_proto(request.channel_id);
2553    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2554        .await?;
2555
2556    let mut update = proto::UpdateChannels::default();
2557    update
2558        .remove_channel_invitations
2559        .push(channel_id.to_proto());
2560    if request.accept {
2561        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2562        update
2563            .channels
2564            .extend(
2565                result
2566                    .channels
2567                    .channels
2568                    .into_iter()
2569                    .map(|channel| proto::Channel {
2570                        id: channel.id.to_proto(),
2571                        name: channel.name,
2572                    }),
2573            );
2574        update.unseen_channel_messages = result.channel_messages;
2575        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2576        update.insert_edge = result.channels.edges;
2577        update
2578            .channel_participants
2579            .extend(
2580                result
2581                    .channel_participants
2582                    .into_iter()
2583                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2584                        channel_id: channel_id.to_proto(),
2585                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2586                    }),
2587            );
2588        update
2589            .channel_permissions
2590            .extend(
2591                result
2592                    .channels_with_admin_privileges
2593                    .into_iter()
2594                    .map(|channel_id| proto::ChannelPermission {
2595                        channel_id: channel_id.to_proto(),
2596                        is_admin: true,
2597                    }),
2598            );
2599    }
2600    session.peer.send(session.connection_id, update)?;
2601    response.send(proto::Ack {})?;
2602
2603    Ok(())
2604}
2605
2606async fn join_channel(
2607    request: proto::JoinChannel,
2608    response: Response<proto::JoinChannel>,
2609    session: Session,
2610) -> Result<()> {
2611    let channel_id = ChannelId::from_proto(request.channel_id);
2612
2613    let joined_room = {
2614        leave_room_for_session(&session).await?;
2615        let db = session.db().await;
2616
2617        let room_id = db.room_id_for_channel(channel_id).await?;
2618
2619        let joined_room = db
2620            .join_room(room_id, session.user_id, session.connection_id)
2621            .await?;
2622
2623        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2624            let token = live_kit
2625                .room_token(
2626                    &joined_room.room.live_kit_room,
2627                    &session.user_id.to_string(),
2628                )
2629                .trace_err()?;
2630
2631            Some(LiveKitConnectionInfo {
2632                server_url: live_kit.url().into(),
2633                token,
2634            })
2635        });
2636
2637        response.send(proto::JoinRoomResponse {
2638            room: Some(joined_room.room.clone()),
2639            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2640            live_kit_connection_info,
2641        })?;
2642
2643        room_updated(&joined_room.room, &session.peer);
2644
2645        joined_room.into_inner()
2646    };
2647
2648    channel_updated(
2649        channel_id,
2650        &joined_room.room,
2651        &joined_room.channel_members,
2652        &session.peer,
2653        &*session.connection_pool().await,
2654    );
2655
2656    update_user_contacts(session.user_id, &session).await?;
2657
2658    Ok(())
2659}
2660
2661async fn join_channel_buffer(
2662    request: proto::JoinChannelBuffer,
2663    response: Response<proto::JoinChannelBuffer>,
2664    session: Session,
2665) -> Result<()> {
2666    let db = session.db().await;
2667    let channel_id = ChannelId::from_proto(request.channel_id);
2668
2669    let open_response = db
2670        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2671        .await?;
2672
2673    let collaborators = open_response.collaborators.clone();
2674    response.send(open_response)?;
2675
2676    let update = UpdateChannelBufferCollaborators {
2677        channel_id: channel_id.to_proto(),
2678        collaborators: collaborators.clone(),
2679    };
2680    channel_buffer_updated(
2681        session.connection_id,
2682        collaborators
2683            .iter()
2684            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2685        &update,
2686        &session.peer,
2687    );
2688
2689    Ok(())
2690}
2691
2692async fn update_channel_buffer(
2693    request: proto::UpdateChannelBuffer,
2694    session: Session,
2695) -> Result<()> {
2696    let db = session.db().await;
2697    let channel_id = ChannelId::from_proto(request.channel_id);
2698
2699    let (collaborators, non_collaborators, epoch, version) = db
2700        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2701        .await?;
2702
2703    channel_buffer_updated(
2704        session.connection_id,
2705        collaborators,
2706        &proto::UpdateChannelBuffer {
2707            channel_id: channel_id.to_proto(),
2708            operations: request.operations,
2709        },
2710        &session.peer,
2711    );
2712
2713    let pool = &*session.connection_pool().await;
2714
2715    broadcast(
2716        None,
2717        non_collaborators
2718            .iter()
2719            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2720        |peer_id| {
2721            session.peer.send(
2722                peer_id.into(),
2723                proto::UpdateChannels {
2724                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2725                        channel_id: channel_id.to_proto(),
2726                        epoch: epoch as u64,
2727                        version: version.clone(),
2728                    }],
2729                    ..Default::default()
2730                },
2731            )
2732        },
2733    );
2734
2735    Ok(())
2736}
2737
2738async fn rejoin_channel_buffers(
2739    request: proto::RejoinChannelBuffers,
2740    response: Response<proto::RejoinChannelBuffers>,
2741    session: Session,
2742) -> Result<()> {
2743    let db = session.db().await;
2744    let buffers = db
2745        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2746        .await?;
2747
2748    for rejoined_buffer in &buffers {
2749        let collaborators_to_notify = rejoined_buffer
2750            .buffer
2751            .collaborators
2752            .iter()
2753            .filter_map(|c| Some(c.peer_id?.into()));
2754        channel_buffer_updated(
2755            session.connection_id,
2756            collaborators_to_notify,
2757            &proto::UpdateChannelBufferCollaborators {
2758                channel_id: rejoined_buffer.buffer.channel_id,
2759                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2760            },
2761            &session.peer,
2762        );
2763    }
2764
2765    response.send(proto::RejoinChannelBuffersResponse {
2766        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2767    })?;
2768
2769    Ok(())
2770}
2771
2772async fn leave_channel_buffer(
2773    request: proto::LeaveChannelBuffer,
2774    response: Response<proto::LeaveChannelBuffer>,
2775    session: Session,
2776) -> Result<()> {
2777    let db = session.db().await;
2778    let channel_id = ChannelId::from_proto(request.channel_id);
2779
2780    let left_buffer = db
2781        .leave_channel_buffer(channel_id, session.connection_id)
2782        .await?;
2783
2784    response.send(Ack {})?;
2785
2786    channel_buffer_updated(
2787        session.connection_id,
2788        left_buffer.connections,
2789        &proto::UpdateChannelBufferCollaborators {
2790            channel_id: channel_id.to_proto(),
2791            collaborators: left_buffer.collaborators,
2792        },
2793        &session.peer,
2794    );
2795
2796    Ok(())
2797}
2798
2799fn channel_buffer_updated<T: EnvelopedMessage>(
2800    sender_id: ConnectionId,
2801    collaborators: impl IntoIterator<Item = ConnectionId>,
2802    message: &T,
2803    peer: &Peer,
2804) {
2805    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2806        peer.send(peer_id.into(), message.clone())
2807    });
2808}
2809
2810async fn send_channel_message(
2811    request: proto::SendChannelMessage,
2812    response: Response<proto::SendChannelMessage>,
2813    session: Session,
2814) -> Result<()> {
2815    // Validate the message body.
2816    let body = request.body.trim().to_string();
2817    if body.len() > MAX_MESSAGE_LEN {
2818        return Err(anyhow!("message is too long"))?;
2819    }
2820    if body.is_empty() {
2821        return Err(anyhow!("message can't be blank"))?;
2822    }
2823
2824    let timestamp = OffsetDateTime::now_utc();
2825    let nonce = request
2826        .nonce
2827        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2828
2829    let channel_id = ChannelId::from_proto(request.channel_id);
2830    let (message_id, connection_ids, non_participants) = session
2831        .db()
2832        .await
2833        .create_channel_message(
2834            channel_id,
2835            session.user_id,
2836            &body,
2837            timestamp,
2838            nonce.clone().into(),
2839        )
2840        .await?;
2841    let message = proto::ChannelMessage {
2842        sender_id: session.user_id.to_proto(),
2843        id: message_id.to_proto(),
2844        body,
2845        timestamp: timestamp.unix_timestamp() as u64,
2846        nonce: Some(nonce),
2847    };
2848    broadcast(Some(session.connection_id), connection_ids, |connection| {
2849        session.peer.send(
2850            connection,
2851            proto::ChannelMessageSent {
2852                channel_id: channel_id.to_proto(),
2853                message: Some(message.clone()),
2854            },
2855        )
2856    });
2857    response.send(proto::SendChannelMessageResponse {
2858        message: Some(message),
2859    })?;
2860
2861    let pool = &*session.connection_pool().await;
2862    broadcast(
2863        None,
2864        non_participants
2865            .iter()
2866            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2867        |peer_id| {
2868            session.peer.send(
2869                peer_id.into(),
2870                proto::UpdateChannels {
2871                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2872                        channel_id: channel_id.to_proto(),
2873                        message_id: message_id.to_proto(),
2874                    }],
2875                    ..Default::default()
2876                },
2877            )
2878        },
2879    );
2880
2881    Ok(())
2882}
2883
2884async fn remove_channel_message(
2885    request: proto::RemoveChannelMessage,
2886    response: Response<proto::RemoveChannelMessage>,
2887    session: Session,
2888) -> Result<()> {
2889    let channel_id = ChannelId::from_proto(request.channel_id);
2890    let message_id = MessageId::from_proto(request.message_id);
2891    let connection_ids = session
2892        .db()
2893        .await
2894        .remove_channel_message(channel_id, message_id, session.user_id)
2895        .await?;
2896    broadcast(Some(session.connection_id), connection_ids, |connection| {
2897        session.peer.send(connection, request.clone())
2898    });
2899    response.send(proto::Ack {})?;
2900    Ok(())
2901}
2902
2903async fn acknowledge_channel_message(
2904    request: proto::AckChannelMessage,
2905    session: Session,
2906) -> Result<()> {
2907    let channel_id = ChannelId::from_proto(request.channel_id);
2908    let message_id = MessageId::from_proto(request.message_id);
2909    session
2910        .db()
2911        .await
2912        .observe_channel_message(channel_id, session.user_id, message_id)
2913        .await?;
2914    Ok(())
2915}
2916
2917async fn acknowledge_buffer_version(
2918    request: proto::AckBufferOperation,
2919    session: Session,
2920) -> Result<()> {
2921    let buffer_id = BufferId::from_proto(request.buffer_id);
2922    session
2923        .db()
2924        .await
2925        .observe_buffer_version(
2926            buffer_id,
2927            session.user_id,
2928            request.epoch as i32,
2929            &request.version,
2930        )
2931        .await?;
2932    Ok(())
2933}
2934
2935async fn join_channel_chat(
2936    request: proto::JoinChannelChat,
2937    response: Response<proto::JoinChannelChat>,
2938    session: Session,
2939) -> Result<()> {
2940    let channel_id = ChannelId::from_proto(request.channel_id);
2941
2942    let db = session.db().await;
2943    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2944        .await?;
2945    let messages = db
2946        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2947        .await?;
2948    response.send(proto::JoinChannelChatResponse {
2949        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2950        messages,
2951    })?;
2952    Ok(())
2953}
2954
2955async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2956    let channel_id = ChannelId::from_proto(request.channel_id);
2957    session
2958        .db()
2959        .await
2960        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2961        .await?;
2962    Ok(())
2963}
2964
2965async fn get_channel_messages(
2966    request: proto::GetChannelMessages,
2967    response: Response<proto::GetChannelMessages>,
2968    session: Session,
2969) -> Result<()> {
2970    let channel_id = ChannelId::from_proto(request.channel_id);
2971    let messages = session
2972        .db()
2973        .await
2974        .get_channel_messages(
2975            channel_id,
2976            session.user_id,
2977            MESSAGE_COUNT_PER_PAGE,
2978            Some(MessageId::from_proto(request.before_message_id)),
2979        )
2980        .await?;
2981    response.send(proto::GetChannelMessagesResponse {
2982        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2983        messages,
2984    })?;
2985    Ok(())
2986}
2987
2988async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2989    let project_id = ProjectId::from_proto(request.project_id);
2990    let project_connection_ids = session
2991        .db()
2992        .await
2993        .project_connection_ids(project_id, session.connection_id)
2994        .await?;
2995    broadcast(
2996        Some(session.connection_id),
2997        project_connection_ids.iter().copied(),
2998        |connection_id| {
2999            session
3000                .peer
3001                .forward_send(session.connection_id, connection_id, request.clone())
3002        },
3003    );
3004    Ok(())
3005}
3006
3007async fn get_private_user_info(
3008    _request: proto::GetPrivateUserInfo,
3009    response: Response<proto::GetPrivateUserInfo>,
3010    session: Session,
3011) -> Result<()> {
3012    let db = session.db().await;
3013
3014    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3015    let user = db
3016        .get_user_by_id(session.user_id)
3017        .await?
3018        .ok_or_else(|| anyhow!("user not found"))?;
3019    let flags = db.get_user_flags(session.user_id).await?;
3020
3021    response.send(proto::GetPrivateUserInfoResponse {
3022        metrics_id,
3023        staff: user.admin,
3024        flags,
3025    })?;
3026    Ok(())
3027}
3028
3029fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3030    match message {
3031        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3032        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3033        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3034        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3035        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3036            code: frame.code.into(),
3037            reason: frame.reason,
3038        })),
3039    }
3040}
3041
3042fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3043    match message {
3044        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3045        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3046        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3047        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3048        AxumMessage::Close(frame) => {
3049            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3050                code: frame.code.into(),
3051                reason: frame.reason,
3052            }))
3053        }
3054    }
3055}
3056
3057fn build_initial_channels_update(
3058    channels: ChannelsForUser,
3059    channel_invites: Vec<db::Channel>,
3060) -> proto::UpdateChannels {
3061    let mut update = proto::UpdateChannels::default();
3062
3063    for channel in channels.channels.channels {
3064        update.channels.push(proto::Channel {
3065            id: channel.id.to_proto(),
3066            name: channel.name,
3067        });
3068    }
3069
3070    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3071    update.unseen_channel_messages = channels.channel_messages;
3072    update.insert_edge = channels.channels.edges;
3073
3074    for (channel_id, participants) in channels.channel_participants {
3075        update
3076            .channel_participants
3077            .push(proto::ChannelParticipants {
3078                channel_id: channel_id.to_proto(),
3079                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3080            });
3081    }
3082
3083    update
3084        .channel_permissions
3085        .extend(
3086            channels
3087                .channels_with_admin_privileges
3088                .into_iter()
3089                .map(|id| proto::ChannelPermission {
3090                    channel_id: id.to_proto(),
3091                    is_admin: true,
3092                }),
3093        );
3094
3095    for channel in channel_invites {
3096        update.channel_invitations.push(proto::Channel {
3097            id: channel.id.to_proto(),
3098            name: channel.name,
3099        });
3100    }
3101
3102    update
3103}
3104
3105fn build_initial_contacts_update(
3106    contacts: Vec<db::Contact>,
3107    pool: &ConnectionPool,
3108) -> proto::UpdateContacts {
3109    let mut update = proto::UpdateContacts::default();
3110
3111    for contact in contacts {
3112        match contact {
3113            db::Contact::Accepted {
3114                user_id,
3115                should_notify,
3116                busy,
3117            } => {
3118                update
3119                    .contacts
3120                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3121            }
3122            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3123            db::Contact::Incoming {
3124                user_id,
3125                should_notify,
3126            } => update
3127                .incoming_requests
3128                .push(proto::IncomingContactRequest {
3129                    requester_id: user_id.to_proto(),
3130                    should_notify,
3131                }),
3132        }
3133    }
3134
3135    update
3136}
3137
3138fn contact_for_user(
3139    user_id: UserId,
3140    should_notify: bool,
3141    busy: bool,
3142    pool: &ConnectionPool,
3143) -> proto::Contact {
3144    proto::Contact {
3145        user_id: user_id.to_proto(),
3146        online: pool.is_user_online(user_id),
3147        busy,
3148        should_notify,
3149    }
3150}
3151
3152fn room_updated(room: &proto::Room, peer: &Peer) {
3153    broadcast(
3154        None,
3155        room.participants
3156            .iter()
3157            .filter_map(|participant| Some(participant.peer_id?.into())),
3158        |peer_id| {
3159            peer.send(
3160                peer_id.into(),
3161                proto::RoomUpdated {
3162                    room: Some(room.clone()),
3163                },
3164            )
3165        },
3166    );
3167}
3168
3169fn channel_updated(
3170    channel_id: ChannelId,
3171    room: &proto::Room,
3172    channel_members: &[UserId],
3173    peer: &Peer,
3174    pool: &ConnectionPool,
3175) {
3176    let participants = room
3177        .participants
3178        .iter()
3179        .map(|p| p.user_id)
3180        .collect::<Vec<_>>();
3181
3182    broadcast(
3183        None,
3184        channel_members
3185            .iter()
3186            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3187        |peer_id| {
3188            peer.send(
3189                peer_id.into(),
3190                proto::UpdateChannels {
3191                    channel_participants: vec![proto::ChannelParticipants {
3192                        channel_id: channel_id.to_proto(),
3193                        participant_user_ids: participants.clone(),
3194                    }],
3195                    ..Default::default()
3196                },
3197            )
3198        },
3199    );
3200}
3201
3202async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3203    let db = session.db().await;
3204
3205    let contacts = db.get_contacts(user_id).await?;
3206    let busy = db.is_user_busy(user_id).await?;
3207
3208    let pool = session.connection_pool().await;
3209    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3210    for contact in contacts {
3211        if let db::Contact::Accepted {
3212            user_id: contact_user_id,
3213            ..
3214        } = contact
3215        {
3216            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3217                session
3218                    .peer
3219                    .send(
3220                        contact_conn_id,
3221                        proto::UpdateContacts {
3222                            contacts: vec![updated_contact.clone()],
3223                            remove_contacts: Default::default(),
3224                            incoming_requests: Default::default(),
3225                            remove_incoming_requests: Default::default(),
3226                            outgoing_requests: Default::default(),
3227                            remove_outgoing_requests: Default::default(),
3228                        },
3229                    )
3230                    .trace_err();
3231            }
3232        }
3233    }
3234    Ok(())
3235}
3236
3237async fn leave_room_for_session(session: &Session) -> Result<()> {
3238    let mut contacts_to_update = HashSet::default();
3239
3240    let room_id;
3241    let canceled_calls_to_user_ids;
3242    let live_kit_room;
3243    let delete_live_kit_room;
3244    let room;
3245    let channel_members;
3246    let channel_id;
3247
3248    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3249        contacts_to_update.insert(session.user_id);
3250
3251        for project in left_room.left_projects.values() {
3252            project_left(project, session);
3253        }
3254
3255        room_id = RoomId::from_proto(left_room.room.id);
3256        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3257        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3258        delete_live_kit_room = left_room.deleted;
3259        room = mem::take(&mut left_room.room);
3260        channel_members = mem::take(&mut left_room.channel_members);
3261        channel_id = left_room.channel_id;
3262
3263        room_updated(&room, &session.peer);
3264    } else {
3265        return Ok(());
3266    }
3267
3268    if let Some(channel_id) = channel_id {
3269        channel_updated(
3270            channel_id,
3271            &room,
3272            &channel_members,
3273            &session.peer,
3274            &*session.connection_pool().await,
3275        );
3276    }
3277
3278    {
3279        let pool = session.connection_pool().await;
3280        for canceled_user_id in canceled_calls_to_user_ids {
3281            for connection_id in pool.user_connection_ids(canceled_user_id) {
3282                session
3283                    .peer
3284                    .send(
3285                        connection_id,
3286                        proto::CallCanceled {
3287                            room_id: room_id.to_proto(),
3288                        },
3289                    )
3290                    .trace_err();
3291            }
3292            contacts_to_update.insert(canceled_user_id);
3293        }
3294    }
3295
3296    for contact_user_id in contacts_to_update {
3297        update_user_contacts(contact_user_id, &session).await?;
3298    }
3299
3300    if let Some(live_kit) = session.live_kit_client.as_ref() {
3301        live_kit
3302            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3303            .await
3304            .trace_err();
3305
3306        if delete_live_kit_room {
3307            live_kit.delete_room(live_kit_room).await.trace_err();
3308        }
3309    }
3310
3311    Ok(())
3312}
3313
3314async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3315    let left_channel_buffers = session
3316        .db()
3317        .await
3318        .leave_channel_buffers(session.connection_id)
3319        .await?;
3320
3321    for left_buffer in left_channel_buffers {
3322        channel_buffer_updated(
3323            session.connection_id,
3324            left_buffer.connections,
3325            &proto::UpdateChannelBufferCollaborators {
3326                channel_id: left_buffer.channel_id.to_proto(),
3327                collaborators: left_buffer.collaborators,
3328            },
3329            &session.peer,
3330        );
3331    }
3332
3333    Ok(())
3334}
3335
3336fn project_left(project: &db::LeftProject, session: &Session) {
3337    for connection_id in &project.connection_ids {
3338        if project.host_user_id == session.user_id {
3339            session
3340                .peer
3341                .send(
3342                    *connection_id,
3343                    proto::UnshareProject {
3344                        project_id: project.id.to_proto(),
3345                    },
3346                )
3347                .trace_err();
3348        } else {
3349            session
3350                .peer
3351                .send(
3352                    *connection_id,
3353                    proto::RemoveProjectCollaborator {
3354                        project_id: project.id.to_proto(),
3355                        peer_id: Some(session.connection_id.into()),
3356                    },
3357                )
3358                .trace_err();
3359        }
3360    }
3361}
3362
3363pub trait ResultExt {
3364    type Ok;
3365
3366    fn trace_err(self) -> Option<Self::Ok>;
3367}
3368
3369impl<T, E> ResultExt for Result<T, E>
3370where
3371    E: std::fmt::Debug,
3372{
3373    type Ok = T;
3374
3375    fn trace_err(self) -> Option<T> {
3376        match self {
3377            Ok(value) => Some(value),
3378            Err(error) => {
3379                tracing::error!("{:?}", error);
3380                None
3381            }
3382        }
3383    }
3384}