rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, Database, MessageId,
   7        ProjectId, RoomId, ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73
  74lazy_static! {
  75    static ref METRIC_CONNECTIONS: IntGauge =
  76        register_int_gauge!("connections", "number of connections").unwrap();
  77    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  78        "shared_projects",
  79        "number of open projects with one or more guests"
  80    )
  81    .unwrap();
  82}
  83
  84type MessageHandler =
  85    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  86
  87struct Response<R> {
  88    peer: Arc<Peer>,
  89    receipt: Receipt<R>,
  90    responded: Arc<AtomicBool>,
  91}
  92
  93impl<R: RequestMessage> Response<R> {
  94    fn send(self, payload: R::Response) -> Result<()> {
  95        self.responded.store(true, SeqCst);
  96        self.peer.respond(self.receipt, payload)?;
  97        Ok(())
  98    }
  99}
 100
 101#[derive(Clone)]
 102struct Session {
 103    user_id: UserId,
 104    connection_id: ConnectionId,
 105    db: Arc<tokio::sync::Mutex<DbHandle>>,
 106    peer: Arc<Peer>,
 107    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 108    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 109    executor: Executor,
 110}
 111
 112impl Session {
 113    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.db.lock().await;
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        guard
 120    }
 121
 122    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.connection_pool.lock();
 126        ConnectionPoolGuard {
 127            guard,
 128            _not_send: PhantomData,
 129        }
 130    }
 131}
 132
 133impl fmt::Debug for Session {
 134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 135        f.debug_struct("Session")
 136            .field("user_id", &self.user_id)
 137            .field("connection_id", &self.connection_id)
 138            .finish()
 139    }
 140}
 141
 142struct DbHandle(Arc<Database>);
 143
 144impl Deref for DbHandle {
 145    type Target = Database;
 146
 147    fn deref(&self) -> &Self::Target {
 148        self.0.as_ref()
 149    }
 150}
 151
 152pub struct Server {
 153    id: parking_lot::Mutex<ServerId>,
 154    peer: Arc<Peer>,
 155    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 156    app_state: Arc<AppState>,
 157    executor: Executor,
 158    handlers: HashMap<TypeId, MessageHandler>,
 159    teardown: watch::Sender<()>,
 160}
 161
 162pub(crate) struct ConnectionPoolGuard<'a> {
 163    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 164    _not_send: PhantomData<Rc<()>>,
 165}
 166
 167#[derive(Serialize)]
 168pub struct ServerSnapshot<'a> {
 169    peer: &'a Peer,
 170    #[serde(serialize_with = "serialize_deref")]
 171    connection_pool: ConnectionPoolGuard<'a>,
 172}
 173
 174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 175where
 176    S: Serializer,
 177    T: Deref<Target = U>,
 178    U: Serialize,
 179{
 180    Serialize::serialize(value.deref(), serializer)
 181}
 182
 183impl Server {
 184    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 185        let mut server = Self {
 186            id: parking_lot::Mutex::new(id),
 187            peer: Peer::new(id.0 as u32),
 188            app_state,
 189            executor,
 190            connection_pool: Default::default(),
 191            handlers: Default::default(),
 192            teardown: watch::channel(()).0,
 193        };
 194
 195        server
 196            .add_request_handler(ping)
 197            .add_request_handler(create_room)
 198            .add_request_handler(join_room)
 199            .add_request_handler(rejoin_room)
 200            .add_request_handler(leave_room)
 201            .add_request_handler(call)
 202            .add_request_handler(cancel_call)
 203            .add_message_handler(decline_call)
 204            .add_request_handler(update_participant_location)
 205            .add_request_handler(share_project)
 206            .add_message_handler(unshare_project)
 207            .add_request_handler(join_project)
 208            .add_message_handler(leave_project)
 209            .add_request_handler(update_project)
 210            .add_request_handler(update_worktree)
 211            .add_message_handler(start_language_server)
 212            .add_message_handler(update_language_server)
 213            .add_message_handler(update_diagnostic_summary)
 214            .add_message_handler(update_worktree_settings)
 215            .add_message_handler(refresh_inlay_hints)
 216            .add_request_handler(forward_project_request::<proto::GetHover>)
 217            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetReferences>)
 220            .add_request_handler(forward_project_request::<proto::SearchProject>)
 221            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 222            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 226            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 227            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 228            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 229            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 230            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 231            .add_request_handler(forward_project_request::<proto::PerformRename>)
 232            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 233            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 234            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 235            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 241            .add_request_handler(forward_project_request::<proto::InlayHints>)
 242            .add_message_handler(create_buffer_for_peer)
 243            .add_request_handler(update_buffer)
 244            .add_message_handler(update_buffer_file)
 245            .add_message_handler(buffer_reloaded)
 246            .add_message_handler(buffer_saved)
 247            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 248            .add_request_handler(get_users)
 249            .add_request_handler(fuzzy_search_users)
 250            .add_request_handler(request_contact)
 251            .add_request_handler(remove_contact)
 252            .add_request_handler(respond_to_contact_request)
 253            .add_request_handler(create_channel)
 254            .add_request_handler(delete_channel)
 255            .add_request_handler(invite_channel_member)
 256            .add_request_handler(remove_channel_member)
 257            .add_request_handler(set_channel_member_role)
 258            .add_request_handler(set_channel_visibility)
 259            .add_request_handler(rename_channel)
 260            .add_request_handler(join_channel_buffer)
 261            .add_request_handler(leave_channel_buffer)
 262            .add_message_handler(update_channel_buffer)
 263            .add_request_handler(rejoin_channel_buffers)
 264            .add_request_handler(get_channel_members)
 265            .add_request_handler(respond_to_channel_invite)
 266            .add_request_handler(join_channel)
 267            .add_request_handler(join_channel_chat)
 268            .add_message_handler(leave_channel_chat)
 269            .add_request_handler(send_channel_message)
 270            .add_request_handler(remove_channel_message)
 271            .add_request_handler(get_channel_messages)
 272            .add_request_handler(link_channel)
 273            .add_request_handler(unlink_channel)
 274            .add_request_handler(move_channel)
 275            .add_request_handler(follow)
 276            .add_message_handler(unfollow)
 277            .add_message_handler(update_followers)
 278            .add_message_handler(update_diff_base)
 279            .add_request_handler(get_private_user_info)
 280            .add_message_handler(acknowledge_channel_message)
 281            .add_message_handler(acknowledge_buffer_version);
 282
 283        Arc::new(server)
 284    }
 285
 286    pub async fn start(&self) -> Result<()> {
 287        let server_id = *self.id.lock();
 288        let app_state = self.app_state.clone();
 289        let peer = self.peer.clone();
 290        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 291        let pool = self.connection_pool.clone();
 292        let live_kit_client = self.app_state.live_kit_client.clone();
 293
 294        let span = info_span!("start server");
 295        self.executor.spawn_detached(
 296            async move {
 297                tracing::info!("waiting for cleanup timeout");
 298                timeout.await;
 299                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 300                if let Some((room_ids, channel_ids)) = app_state
 301                    .db
 302                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 303                    .await
 304                    .trace_err()
 305                {
 306                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 307                    tracing::info!(
 308                        stale_channel_buffer_count = channel_ids.len(),
 309                        "retrieved stale channel buffers"
 310                    );
 311
 312                    for channel_id in channel_ids {
 313                        if let Some(refreshed_channel_buffer) = app_state
 314                            .db
 315                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 316                            .await
 317                            .trace_err()
 318                        {
 319                            for connection_id in refreshed_channel_buffer.connection_ids {
 320                                peer.send(
 321                                    connection_id,
 322                                    proto::UpdateChannelBufferCollaborators {
 323                                        channel_id: channel_id.to_proto(),
 324                                        collaborators: refreshed_channel_buffer
 325                                            .collaborators
 326                                            .clone(),
 327                                    },
 328                                )
 329                                .trace_err();
 330                            }
 331                        }
 332                    }
 333
 334                    for room_id in room_ids {
 335                        let mut contacts_to_update = HashSet::default();
 336                        let mut canceled_calls_to_user_ids = Vec::new();
 337                        let mut live_kit_room = String::new();
 338                        let mut delete_live_kit_room = false;
 339
 340                        if let Some(mut refreshed_room) = app_state
 341                            .db
 342                            .clear_stale_room_participants(room_id, server_id)
 343                            .await
 344                            .trace_err()
 345                        {
 346                            tracing::info!(
 347                                room_id = room_id.0,
 348                                new_participant_count = refreshed_room.room.participants.len(),
 349                                "refreshed room"
 350                            );
 351                            room_updated(&refreshed_room.room, &peer);
 352                            if let Some(channel_id) = refreshed_room.channel_id {
 353                                channel_updated(
 354                                    channel_id,
 355                                    &refreshed_room.room,
 356                                    &refreshed_room.channel_members,
 357                                    &peer,
 358                                    &*pool.lock(),
 359                                );
 360                            }
 361                            contacts_to_update
 362                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 363                            contacts_to_update
 364                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 365                            canceled_calls_to_user_ids =
 366                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 367                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 368                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 369                        }
 370
 371                        {
 372                            let pool = pool.lock();
 373                            for canceled_user_id in canceled_calls_to_user_ids {
 374                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 375                                    peer.send(
 376                                        connection_id,
 377                                        proto::CallCanceled {
 378                                            room_id: room_id.to_proto(),
 379                                        },
 380                                    )
 381                                    .trace_err();
 382                                }
 383                            }
 384                        }
 385
 386                        for user_id in contacts_to_update {
 387                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 388                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 389                            if let Some((busy, contacts)) = busy.zip(contacts) {
 390                                let pool = pool.lock();
 391                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 392                                for contact in contacts {
 393                                    if let db::Contact::Accepted {
 394                                        user_id: contact_user_id,
 395                                        ..
 396                                    } = contact
 397                                    {
 398                                        for contact_conn_id in
 399                                            pool.user_connection_ids(contact_user_id)
 400                                        {
 401                                            peer.send(
 402                                                contact_conn_id,
 403                                                proto::UpdateContacts {
 404                                                    contacts: vec![updated_contact.clone()],
 405                                                    remove_contacts: Default::default(),
 406                                                    incoming_requests: Default::default(),
 407                                                    remove_incoming_requests: Default::default(),
 408                                                    outgoing_requests: Default::default(),
 409                                                    remove_outgoing_requests: Default::default(),
 410                                                },
 411                                            )
 412                                            .trace_err();
 413                                        }
 414                                    }
 415                                }
 416                            }
 417                        }
 418
 419                        if let Some(live_kit) = live_kit_client.as_ref() {
 420                            if delete_live_kit_room {
 421                                live_kit.delete_room(live_kit_room).await.trace_err();
 422                            }
 423                        }
 424                    }
 425                }
 426
 427                app_state
 428                    .db
 429                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 430                    .await
 431                    .trace_err();
 432            }
 433            .instrument(span),
 434        );
 435        Ok(())
 436    }
 437
 438    pub fn teardown(&self) {
 439        self.peer.teardown();
 440        self.connection_pool.lock().reset();
 441        let _ = self.teardown.send(());
 442    }
 443
 444    #[cfg(test)]
 445    pub fn reset(&self, id: ServerId) {
 446        self.teardown();
 447        *self.id.lock() = id;
 448        self.peer.reset(id.0 as u32);
 449    }
 450
 451    #[cfg(test)]
 452    pub fn id(&self) -> ServerId {
 453        *self.id.lock()
 454    }
 455
 456    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 457    where
 458        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 459        Fut: 'static + Send + Future<Output = Result<()>>,
 460        M: EnvelopedMessage,
 461    {
 462        let prev_handler = self.handlers.insert(
 463            TypeId::of::<M>(),
 464            Box::new(move |envelope, session| {
 465                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 466                let span = info_span!(
 467                    "handle message",
 468                    payload_type = envelope.payload_type_name()
 469                );
 470                span.in_scope(|| {
 471                    tracing::info!(
 472                        payload_type = envelope.payload_type_name(),
 473                        "message received"
 474                    );
 475                });
 476                let start_time = Instant::now();
 477                let future = (handler)(*envelope, session);
 478                async move {
 479                    let result = future.await;
 480                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 481                    match result {
 482                        Err(error) => {
 483                            tracing::error!(%error, ?duration_ms, "error handling message")
 484                        }
 485                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 486                    }
 487                }
 488                .instrument(span)
 489                .boxed()
 490            }),
 491        );
 492        if prev_handler.is_some() {
 493            panic!("registered a handler for the same message twice");
 494        }
 495        self
 496    }
 497
 498    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 499    where
 500        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 501        Fut: 'static + Send + Future<Output = Result<()>>,
 502        M: EnvelopedMessage,
 503    {
 504        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 505        self
 506    }
 507
 508    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 509    where
 510        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 511        Fut: Send + Future<Output = Result<()>>,
 512        M: RequestMessage,
 513    {
 514        let handler = Arc::new(handler);
 515        self.add_handler(move |envelope, session| {
 516            let receipt = envelope.receipt();
 517            let handler = handler.clone();
 518            async move {
 519                let peer = session.peer.clone();
 520                let responded = Arc::new(AtomicBool::default());
 521                let response = Response {
 522                    peer: peer.clone(),
 523                    responded: responded.clone(),
 524                    receipt,
 525                };
 526                match (handler)(envelope.payload, response, session).await {
 527                    Ok(()) => {
 528                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 529                            Ok(())
 530                        } else {
 531                            Err(anyhow!("handler did not send a response"))?
 532                        }
 533                    }
 534                    Err(error) => {
 535                        peer.respond_with_error(
 536                            receipt,
 537                            proto::Error {
 538                                message: error.to_string(),
 539                            },
 540                        )?;
 541                        Err(error)
 542                    }
 543                }
 544            }
 545        })
 546    }
 547
 548    pub fn handle_connection(
 549        self: &Arc<Self>,
 550        connection: Connection,
 551        address: String,
 552        user: User,
 553        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 554        executor: Executor,
 555    ) -> impl Future<Output = Result<()>> {
 556        let this = self.clone();
 557        let user_id = user.id;
 558        let login = user.github_login;
 559        let span = info_span!("handle connection", %user_id, %login, %address);
 560        let mut teardown = self.teardown.subscribe();
 561        async move {
 562            let (connection_id, handle_io, mut incoming_rx) = this
 563                .peer
 564                .add_connection(connection, {
 565                    let executor = executor.clone();
 566                    move |duration| executor.sleep(duration)
 567                });
 568
 569            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 570            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 571            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 572
 573            if let Some(send_connection_id) = send_connection_id.take() {
 574                let _ = send_connection_id.send(connection_id);
 575            }
 576
 577            if !user.connected_once {
 578                this.peer.send(connection_id, proto::ShowContacts {})?;
 579                this.app_state.db.set_user_connected_once(user_id, true).await?;
 580            }
 581
 582            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 583                this.app_state.db.get_contacts(user_id),
 584                this.app_state.db.get_channels_for_user(user_id),
 585                this.app_state.db.get_channel_invites_for_user(user_id)
 586            ).await?;
 587
 588            {
 589                let mut pool = this.connection_pool.lock();
 590                pool.add_connection(connection_id, user_id, user.admin);
 591                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 592                this.peer.send(connection_id, build_initial_channels_update(
 593                    channels_for_user,
 594                    channel_invites
 595                ))?;
 596            }
 597
 598            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 599                this.peer.send(connection_id, incoming_call)?;
 600            }
 601
 602            let session = Session {
 603                user_id,
 604                connection_id,
 605                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 606                peer: this.peer.clone(),
 607                connection_pool: this.connection_pool.clone(),
 608                live_kit_client: this.app_state.live_kit_client.clone(),
 609                executor: executor.clone(),
 610            };
 611            update_user_contacts(user_id, &session).await?;
 612
 613            let handle_io = handle_io.fuse();
 614            futures::pin_mut!(handle_io);
 615
 616            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 617            // This prevents deadlocks when e.g., client A performs a request to client B and
 618            // client B performs a request to client A. If both clients stop processing further
 619            // messages until their respective request completes, they won't have a chance to
 620            // respond to the other client's request and cause a deadlock.
 621            //
 622            // This arrangement ensures we will attempt to process earlier messages first, but fall
 623            // back to processing messages arrived later in the spirit of making progress.
 624            let mut foreground_message_handlers = FuturesUnordered::new();
 625            let concurrent_handlers = Arc::new(Semaphore::new(256));
 626            loop {
 627                let next_message = async {
 628                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 629                    let message = incoming_rx.next().await;
 630                    (permit, message)
 631                }.fuse();
 632                futures::pin_mut!(next_message);
 633                futures::select_biased! {
 634                    _ = teardown.changed().fuse() => return Ok(()),
 635                    result = handle_io => {
 636                        if let Err(error) = result {
 637                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 638                        }
 639                        break;
 640                    }
 641                    _ = foreground_message_handlers.next() => {}
 642                    next_message = next_message => {
 643                        let (permit, message) = next_message;
 644                        if let Some(message) = message {
 645                            let type_name = message.payload_type_name();
 646                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 647                            let span_enter = span.enter();
 648                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 649                                let is_background = message.is_background();
 650                                let handle_message = (handler)(message, session.clone());
 651                                drop(span_enter);
 652
 653                                let handle_message = async move {
 654                                    handle_message.await;
 655                                    drop(permit);
 656                                }.instrument(span);
 657                                if is_background {
 658                                    executor.spawn_detached(handle_message);
 659                                } else {
 660                                    foreground_message_handlers.push(handle_message);
 661                                }
 662                            } else {
 663                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 664                            }
 665                        } else {
 666                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 667                            break;
 668                        }
 669                    }
 670                }
 671            }
 672
 673            drop(foreground_message_handlers);
 674            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 675            if let Err(error) = connection_lost(session, teardown, executor).await {
 676                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 677            }
 678
 679            Ok(())
 680        }.instrument(span)
 681    }
 682
 683    pub async fn invite_code_redeemed(
 684        self: &Arc<Self>,
 685        inviter_id: UserId,
 686        invitee_id: UserId,
 687    ) -> Result<()> {
 688        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 689            if let Some(code) = &user.invite_code {
 690                let pool = self.connection_pool.lock();
 691                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 692                for connection_id in pool.user_connection_ids(inviter_id) {
 693                    self.peer.send(
 694                        connection_id,
 695                        proto::UpdateContacts {
 696                            contacts: vec![invitee_contact.clone()],
 697                            ..Default::default()
 698                        },
 699                    )?;
 700                    self.peer.send(
 701                        connection_id,
 702                        proto::UpdateInviteInfo {
 703                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 704                            count: user.invite_count as u32,
 705                        },
 706                    )?;
 707                }
 708            }
 709        }
 710        Ok(())
 711    }
 712
 713    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 714        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 715            if let Some(invite_code) = &user.invite_code {
 716                let pool = self.connection_pool.lock();
 717                for connection_id in pool.user_connection_ids(user_id) {
 718                    self.peer.send(
 719                        connection_id,
 720                        proto::UpdateInviteInfo {
 721                            url: format!(
 722                                "{}{}",
 723                                self.app_state.config.invite_link_prefix, invite_code
 724                            ),
 725                            count: user.invite_count as u32,
 726                        },
 727                    )?;
 728                }
 729            }
 730        }
 731        Ok(())
 732    }
 733
 734    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 735        ServerSnapshot {
 736            connection_pool: ConnectionPoolGuard {
 737                guard: self.connection_pool.lock(),
 738                _not_send: PhantomData,
 739            },
 740            peer: &self.peer,
 741        }
 742    }
 743}
 744
 745impl<'a> Deref for ConnectionPoolGuard<'a> {
 746    type Target = ConnectionPool;
 747
 748    fn deref(&self) -> &Self::Target {
 749        &*self.guard
 750    }
 751}
 752
 753impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 754    fn deref_mut(&mut self) -> &mut Self::Target {
 755        &mut *self.guard
 756    }
 757}
 758
 759impl<'a> Drop for ConnectionPoolGuard<'a> {
 760    fn drop(&mut self) {
 761        #[cfg(test)]
 762        self.check_invariants();
 763    }
 764}
 765
 766fn broadcast<F>(
 767    sender_id: Option<ConnectionId>,
 768    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 769    mut f: F,
 770) where
 771    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 772{
 773    for receiver_id in receiver_ids {
 774        if Some(receiver_id) != sender_id {
 775            if let Err(error) = f(receiver_id) {
 776                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 777            }
 778        }
 779    }
 780}
 781
 782lazy_static! {
 783    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 784}
 785
 786pub struct ProtocolVersion(u32);
 787
 788impl Header for ProtocolVersion {
 789    fn name() -> &'static HeaderName {
 790        &ZED_PROTOCOL_VERSION
 791    }
 792
 793    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 794    where
 795        Self: Sized,
 796        I: Iterator<Item = &'i axum::http::HeaderValue>,
 797    {
 798        let version = values
 799            .next()
 800            .ok_or_else(axum::headers::Error::invalid)?
 801            .to_str()
 802            .map_err(|_| axum::headers::Error::invalid())?
 803            .parse()
 804            .map_err(|_| axum::headers::Error::invalid())?;
 805        Ok(Self(version))
 806    }
 807
 808    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 809        values.extend([self.0.to_string().parse().unwrap()]);
 810    }
 811}
 812
 813pub fn routes(server: Arc<Server>) -> Router<Body> {
 814    Router::new()
 815        .route("/rpc", get(handle_websocket_request))
 816        .layer(
 817            ServiceBuilder::new()
 818                .layer(Extension(server.app_state.clone()))
 819                .layer(middleware::from_fn(auth::validate_header)),
 820        )
 821        .route("/metrics", get(handle_metrics))
 822        .layer(Extension(server))
 823}
 824
 825pub async fn handle_websocket_request(
 826    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 827    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 828    Extension(server): Extension<Arc<Server>>,
 829    Extension(user): Extension<User>,
 830    ws: WebSocketUpgrade,
 831) -> axum::response::Response {
 832    if protocol_version != rpc::PROTOCOL_VERSION {
 833        return (
 834            StatusCode::UPGRADE_REQUIRED,
 835            "client must be upgraded".to_string(),
 836        )
 837            .into_response();
 838    }
 839    let socket_address = socket_address.to_string();
 840    ws.on_upgrade(move |socket| {
 841        use util::ResultExt;
 842        let socket = socket
 843            .map_ok(to_tungstenite_message)
 844            .err_into()
 845            .with(|message| async move { Ok(to_axum_message(message)) });
 846        let connection = Connection::new(Box::pin(socket));
 847        async move {
 848            server
 849                .handle_connection(connection, socket_address, user, None, Executor::Production)
 850                .await
 851                .log_err();
 852        }
 853    })
 854}
 855
 856pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 857    let connections = server
 858        .connection_pool
 859        .lock()
 860        .connections()
 861        .filter(|connection| !connection.admin)
 862        .count();
 863
 864    METRIC_CONNECTIONS.set(connections as _);
 865
 866    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 867    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 868
 869    let encoder = prometheus::TextEncoder::new();
 870    let metric_families = prometheus::gather();
 871    let encoded_metrics = encoder
 872        .encode_to_string(&metric_families)
 873        .map_err(|err| anyhow!("{}", err))?;
 874    Ok(encoded_metrics)
 875}
 876
 877#[instrument(err, skip(executor))]
 878async fn connection_lost(
 879    session: Session,
 880    mut teardown: watch::Receiver<()>,
 881    executor: Executor,
 882) -> Result<()> {
 883    session.peer.disconnect(session.connection_id);
 884    session
 885        .connection_pool()
 886        .await
 887        .remove_connection(session.connection_id)?;
 888
 889    session
 890        .db()
 891        .await
 892        .connection_lost(session.connection_id)
 893        .await
 894        .trace_err();
 895
 896    futures::select_biased! {
 897        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 898            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 899            leave_room_for_session(&session).await.trace_err();
 900            leave_channel_buffers_for_session(&session)
 901                .await
 902                .trace_err();
 903
 904            if !session
 905                .connection_pool()
 906                .await
 907                .is_user_online(session.user_id)
 908            {
 909                let db = session.db().await;
 910                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 911                    room_updated(&room, &session.peer);
 912                }
 913            }
 914
 915            update_user_contacts(session.user_id, &session).await?;
 916        }
 917        _ = teardown.changed().fuse() => {}
 918    }
 919
 920    Ok(())
 921}
 922
 923async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 924    response.send(proto::Ack {})?;
 925    Ok(())
 926}
 927
 928async fn create_room(
 929    _request: proto::CreateRoom,
 930    response: Response<proto::CreateRoom>,
 931    session: Session,
 932) -> Result<()> {
 933    let live_kit_room = nanoid::nanoid!(30);
 934
 935    let live_kit_connection_info = {
 936        let live_kit_room = live_kit_room.clone();
 937        let live_kit = session.live_kit_client.as_ref();
 938
 939        util::async_iife!({
 940            let live_kit = live_kit?;
 941
 942            let token = live_kit
 943                .room_token(&live_kit_room, &session.user_id.to_string())
 944                .trace_err()?;
 945
 946            Some(proto::LiveKitConnectionInfo {
 947                server_url: live_kit.url().into(),
 948                token,
 949            })
 950        })
 951    }
 952    .await;
 953
 954    let room = session
 955        .db()
 956        .await
 957        .create_room(
 958            session.user_id,
 959            session.connection_id,
 960            &live_kit_room,
 961            RELEASE_CHANNEL_NAME.as_str(),
 962        )
 963        .await?;
 964
 965    response.send(proto::CreateRoomResponse {
 966        room: Some(room.clone()),
 967        live_kit_connection_info,
 968    })?;
 969
 970    update_user_contacts(session.user_id, &session).await?;
 971    Ok(())
 972}
 973
 974async fn join_room(
 975    request: proto::JoinRoom,
 976    response: Response<proto::JoinRoom>,
 977    session: Session,
 978) -> Result<()> {
 979    let room_id = RoomId::from_proto(request.id);
 980    let joined_room = {
 981        let room = session
 982            .db()
 983            .await
 984            .join_room(
 985                room_id,
 986                session.user_id,
 987                session.connection_id,
 988                RELEASE_CHANNEL_NAME.as_str(),
 989            )
 990            .await?;
 991        room_updated(&room.room, &session.peer);
 992        room.into_inner()
 993    };
 994
 995    if let Some(channel_id) = joined_room.channel_id {
 996        channel_updated(
 997            channel_id,
 998            &joined_room.room,
 999            &joined_room.channel_members,
1000            &session.peer,
1001            &*session.connection_pool().await,
1002        )
1003    }
1004
1005    for connection_id in session
1006        .connection_pool()
1007        .await
1008        .user_connection_ids(session.user_id)
1009    {
1010        session
1011            .peer
1012            .send(
1013                connection_id,
1014                proto::CallCanceled {
1015                    room_id: room_id.to_proto(),
1016                },
1017            )
1018            .trace_err();
1019    }
1020
1021    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1022        if let Some(token) = live_kit
1023            .room_token(
1024                &joined_room.room.live_kit_room,
1025                &session.user_id.to_string(),
1026            )
1027            .trace_err()
1028        {
1029            Some(proto::LiveKitConnectionInfo {
1030                server_url: live_kit.url().into(),
1031                token,
1032            })
1033        } else {
1034            None
1035        }
1036    } else {
1037        None
1038    };
1039
1040    response.send(proto::JoinRoomResponse {
1041        room: Some(joined_room.room),
1042        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1043        live_kit_connection_info,
1044    })?;
1045
1046    update_user_contacts(session.user_id, &session).await?;
1047    Ok(())
1048}
1049
1050async fn rejoin_room(
1051    request: proto::RejoinRoom,
1052    response: Response<proto::RejoinRoom>,
1053    session: Session,
1054) -> Result<()> {
1055    let room;
1056    let channel_id;
1057    let channel_members;
1058    {
1059        let mut rejoined_room = session
1060            .db()
1061            .await
1062            .rejoin_room(request, session.user_id, session.connection_id)
1063            .await?;
1064
1065        response.send(proto::RejoinRoomResponse {
1066            room: Some(rejoined_room.room.clone()),
1067            reshared_projects: rejoined_room
1068                .reshared_projects
1069                .iter()
1070                .map(|project| proto::ResharedProject {
1071                    id: project.id.to_proto(),
1072                    collaborators: project
1073                        .collaborators
1074                        .iter()
1075                        .map(|collaborator| collaborator.to_proto())
1076                        .collect(),
1077                })
1078                .collect(),
1079            rejoined_projects: rejoined_room
1080                .rejoined_projects
1081                .iter()
1082                .map(|rejoined_project| proto::RejoinedProject {
1083                    id: rejoined_project.id.to_proto(),
1084                    worktrees: rejoined_project
1085                        .worktrees
1086                        .iter()
1087                        .map(|worktree| proto::WorktreeMetadata {
1088                            id: worktree.id,
1089                            root_name: worktree.root_name.clone(),
1090                            visible: worktree.visible,
1091                            abs_path: worktree.abs_path.clone(),
1092                        })
1093                        .collect(),
1094                    collaborators: rejoined_project
1095                        .collaborators
1096                        .iter()
1097                        .map(|collaborator| collaborator.to_proto())
1098                        .collect(),
1099                    language_servers: rejoined_project.language_servers.clone(),
1100                })
1101                .collect(),
1102        })?;
1103        room_updated(&rejoined_room.room, &session.peer);
1104
1105        for project in &rejoined_room.reshared_projects {
1106            for collaborator in &project.collaborators {
1107                session
1108                    .peer
1109                    .send(
1110                        collaborator.connection_id,
1111                        proto::UpdateProjectCollaborator {
1112                            project_id: project.id.to_proto(),
1113                            old_peer_id: Some(project.old_connection_id.into()),
1114                            new_peer_id: Some(session.connection_id.into()),
1115                        },
1116                    )
1117                    .trace_err();
1118            }
1119
1120            broadcast(
1121                Some(session.connection_id),
1122                project
1123                    .collaborators
1124                    .iter()
1125                    .map(|collaborator| collaborator.connection_id),
1126                |connection_id| {
1127                    session.peer.forward_send(
1128                        session.connection_id,
1129                        connection_id,
1130                        proto::UpdateProject {
1131                            project_id: project.id.to_proto(),
1132                            worktrees: project.worktrees.clone(),
1133                        },
1134                    )
1135                },
1136            );
1137        }
1138
1139        for project in &rejoined_room.rejoined_projects {
1140            for collaborator in &project.collaborators {
1141                session
1142                    .peer
1143                    .send(
1144                        collaborator.connection_id,
1145                        proto::UpdateProjectCollaborator {
1146                            project_id: project.id.to_proto(),
1147                            old_peer_id: Some(project.old_connection_id.into()),
1148                            new_peer_id: Some(session.connection_id.into()),
1149                        },
1150                    )
1151                    .trace_err();
1152            }
1153        }
1154
1155        for project in &mut rejoined_room.rejoined_projects {
1156            for worktree in mem::take(&mut project.worktrees) {
1157                #[cfg(any(test, feature = "test-support"))]
1158                const MAX_CHUNK_SIZE: usize = 2;
1159                #[cfg(not(any(test, feature = "test-support")))]
1160                const MAX_CHUNK_SIZE: usize = 256;
1161
1162                // Stream this worktree's entries.
1163                let message = proto::UpdateWorktree {
1164                    project_id: project.id.to_proto(),
1165                    worktree_id: worktree.id,
1166                    abs_path: worktree.abs_path.clone(),
1167                    root_name: worktree.root_name,
1168                    updated_entries: worktree.updated_entries,
1169                    removed_entries: worktree.removed_entries,
1170                    scan_id: worktree.scan_id,
1171                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1172                    updated_repositories: worktree.updated_repositories,
1173                    removed_repositories: worktree.removed_repositories,
1174                };
1175                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1176                    session.peer.send(session.connection_id, update.clone())?;
1177                }
1178
1179                // Stream this worktree's diagnostics.
1180                for summary in worktree.diagnostic_summaries {
1181                    session.peer.send(
1182                        session.connection_id,
1183                        proto::UpdateDiagnosticSummary {
1184                            project_id: project.id.to_proto(),
1185                            worktree_id: worktree.id,
1186                            summary: Some(summary),
1187                        },
1188                    )?;
1189                }
1190
1191                for settings_file in worktree.settings_files {
1192                    session.peer.send(
1193                        session.connection_id,
1194                        proto::UpdateWorktreeSettings {
1195                            project_id: project.id.to_proto(),
1196                            worktree_id: worktree.id,
1197                            path: settings_file.path,
1198                            content: Some(settings_file.content),
1199                        },
1200                    )?;
1201                }
1202            }
1203
1204            for language_server in &project.language_servers {
1205                session.peer.send(
1206                    session.connection_id,
1207                    proto::UpdateLanguageServer {
1208                        project_id: project.id.to_proto(),
1209                        language_server_id: language_server.id,
1210                        variant: Some(
1211                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1212                                proto::LspDiskBasedDiagnosticsUpdated {},
1213                            ),
1214                        ),
1215                    },
1216                )?;
1217            }
1218        }
1219
1220        let rejoined_room = rejoined_room.into_inner();
1221
1222        room = rejoined_room.room;
1223        channel_id = rejoined_room.channel_id;
1224        channel_members = rejoined_room.channel_members;
1225    }
1226
1227    if let Some(channel_id) = channel_id {
1228        channel_updated(
1229            channel_id,
1230            &room,
1231            &channel_members,
1232            &session.peer,
1233            &*session.connection_pool().await,
1234        );
1235    }
1236
1237    update_user_contacts(session.user_id, &session).await?;
1238    Ok(())
1239}
1240
1241async fn leave_room(
1242    _: proto::LeaveRoom,
1243    response: Response<proto::LeaveRoom>,
1244    session: Session,
1245) -> Result<()> {
1246    leave_room_for_session(&session).await?;
1247    response.send(proto::Ack {})?;
1248    Ok(())
1249}
1250
1251async fn call(
1252    request: proto::Call,
1253    response: Response<proto::Call>,
1254    session: Session,
1255) -> Result<()> {
1256    let room_id = RoomId::from_proto(request.room_id);
1257    let calling_user_id = session.user_id;
1258    let calling_connection_id = session.connection_id;
1259    let called_user_id = UserId::from_proto(request.called_user_id);
1260    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1261    if !session
1262        .db()
1263        .await
1264        .has_contact(calling_user_id, called_user_id)
1265        .await?
1266    {
1267        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1268    }
1269
1270    let incoming_call = {
1271        let (room, incoming_call) = &mut *session
1272            .db()
1273            .await
1274            .call(
1275                room_id,
1276                calling_user_id,
1277                calling_connection_id,
1278                called_user_id,
1279                initial_project_id,
1280            )
1281            .await?;
1282        room_updated(&room, &session.peer);
1283        mem::take(incoming_call)
1284    };
1285    update_user_contacts(called_user_id, &session).await?;
1286
1287    let mut calls = session
1288        .connection_pool()
1289        .await
1290        .user_connection_ids(called_user_id)
1291        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1292        .collect::<FuturesUnordered<_>>();
1293
1294    while let Some(call_response) = calls.next().await {
1295        match call_response.as_ref() {
1296            Ok(_) => {
1297                response.send(proto::Ack {})?;
1298                return Ok(());
1299            }
1300            Err(_) => {
1301                call_response.trace_err();
1302            }
1303        }
1304    }
1305
1306    {
1307        let room = session
1308            .db()
1309            .await
1310            .call_failed(room_id, called_user_id)
1311            .await?;
1312        room_updated(&room, &session.peer);
1313    }
1314    update_user_contacts(called_user_id, &session).await?;
1315
1316    Err(anyhow!("failed to ring user"))?
1317}
1318
1319async fn cancel_call(
1320    request: proto::CancelCall,
1321    response: Response<proto::CancelCall>,
1322    session: Session,
1323) -> Result<()> {
1324    let called_user_id = UserId::from_proto(request.called_user_id);
1325    let room_id = RoomId::from_proto(request.room_id);
1326    {
1327        let room = session
1328            .db()
1329            .await
1330            .cancel_call(room_id, session.connection_id, called_user_id)
1331            .await?;
1332        room_updated(&room, &session.peer);
1333    }
1334
1335    for connection_id in session
1336        .connection_pool()
1337        .await
1338        .user_connection_ids(called_user_id)
1339    {
1340        session
1341            .peer
1342            .send(
1343                connection_id,
1344                proto::CallCanceled {
1345                    room_id: room_id.to_proto(),
1346                },
1347            )
1348            .trace_err();
1349    }
1350    response.send(proto::Ack {})?;
1351
1352    update_user_contacts(called_user_id, &session).await?;
1353    Ok(())
1354}
1355
1356async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1357    let room_id = RoomId::from_proto(message.room_id);
1358    {
1359        let room = session
1360            .db()
1361            .await
1362            .decline_call(Some(room_id), session.user_id)
1363            .await?
1364            .ok_or_else(|| anyhow!("failed to decline call"))?;
1365        room_updated(&room, &session.peer);
1366    }
1367
1368    for connection_id in session
1369        .connection_pool()
1370        .await
1371        .user_connection_ids(session.user_id)
1372    {
1373        session
1374            .peer
1375            .send(
1376                connection_id,
1377                proto::CallCanceled {
1378                    room_id: room_id.to_proto(),
1379                },
1380            )
1381            .trace_err();
1382    }
1383    update_user_contacts(session.user_id, &session).await?;
1384    Ok(())
1385}
1386
1387async fn update_participant_location(
1388    request: proto::UpdateParticipantLocation,
1389    response: Response<proto::UpdateParticipantLocation>,
1390    session: Session,
1391) -> Result<()> {
1392    let room_id = RoomId::from_proto(request.room_id);
1393    let location = request
1394        .location
1395        .ok_or_else(|| anyhow!("invalid location"))?;
1396
1397    let db = session.db().await;
1398    let room = db
1399        .update_room_participant_location(room_id, session.connection_id, location)
1400        .await?;
1401
1402    room_updated(&room, &session.peer);
1403    response.send(proto::Ack {})?;
1404    Ok(())
1405}
1406
1407async fn share_project(
1408    request: proto::ShareProject,
1409    response: Response<proto::ShareProject>,
1410    session: Session,
1411) -> Result<()> {
1412    let (project_id, room) = &*session
1413        .db()
1414        .await
1415        .share_project(
1416            RoomId::from_proto(request.room_id),
1417            session.connection_id,
1418            &request.worktrees,
1419        )
1420        .await?;
1421    response.send(proto::ShareProjectResponse {
1422        project_id: project_id.to_proto(),
1423    })?;
1424    room_updated(&room, &session.peer);
1425
1426    Ok(())
1427}
1428
1429async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1430    let project_id = ProjectId::from_proto(message.project_id);
1431
1432    let (room, guest_connection_ids) = &*session
1433        .db()
1434        .await
1435        .unshare_project(project_id, session.connection_id)
1436        .await?;
1437
1438    broadcast(
1439        Some(session.connection_id),
1440        guest_connection_ids.iter().copied(),
1441        |conn_id| session.peer.send(conn_id, message.clone()),
1442    );
1443    room_updated(&room, &session.peer);
1444
1445    Ok(())
1446}
1447
1448async fn join_project(
1449    request: proto::JoinProject,
1450    response: Response<proto::JoinProject>,
1451    session: Session,
1452) -> Result<()> {
1453    let project_id = ProjectId::from_proto(request.project_id);
1454    let guest_user_id = session.user_id;
1455
1456    tracing::info!(%project_id, "join project");
1457
1458    let (project, replica_id) = &mut *session
1459        .db()
1460        .await
1461        .join_project(project_id, session.connection_id)
1462        .await?;
1463
1464    let collaborators = project
1465        .collaborators
1466        .iter()
1467        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1468        .map(|collaborator| collaborator.to_proto())
1469        .collect::<Vec<_>>();
1470
1471    let worktrees = project
1472        .worktrees
1473        .iter()
1474        .map(|(id, worktree)| proto::WorktreeMetadata {
1475            id: *id,
1476            root_name: worktree.root_name.clone(),
1477            visible: worktree.visible,
1478            abs_path: worktree.abs_path.clone(),
1479        })
1480        .collect::<Vec<_>>();
1481
1482    for collaborator in &collaborators {
1483        session
1484            .peer
1485            .send(
1486                collaborator.peer_id.unwrap().into(),
1487                proto::AddProjectCollaborator {
1488                    project_id: project_id.to_proto(),
1489                    collaborator: Some(proto::Collaborator {
1490                        peer_id: Some(session.connection_id.into()),
1491                        replica_id: replica_id.0 as u32,
1492                        user_id: guest_user_id.to_proto(),
1493                    }),
1494                },
1495            )
1496            .trace_err();
1497    }
1498
1499    // First, we send the metadata associated with each worktree.
1500    response.send(proto::JoinProjectResponse {
1501        worktrees: worktrees.clone(),
1502        replica_id: replica_id.0 as u32,
1503        collaborators: collaborators.clone(),
1504        language_servers: project.language_servers.clone(),
1505    })?;
1506
1507    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1508        #[cfg(any(test, feature = "test-support"))]
1509        const MAX_CHUNK_SIZE: usize = 2;
1510        #[cfg(not(any(test, feature = "test-support")))]
1511        const MAX_CHUNK_SIZE: usize = 256;
1512
1513        // Stream this worktree's entries.
1514        let message = proto::UpdateWorktree {
1515            project_id: project_id.to_proto(),
1516            worktree_id,
1517            abs_path: worktree.abs_path.clone(),
1518            root_name: worktree.root_name,
1519            updated_entries: worktree.entries,
1520            removed_entries: Default::default(),
1521            scan_id: worktree.scan_id,
1522            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1523            updated_repositories: worktree.repository_entries.into_values().collect(),
1524            removed_repositories: Default::default(),
1525        };
1526        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1527            session.peer.send(session.connection_id, update.clone())?;
1528        }
1529
1530        // Stream this worktree's diagnostics.
1531        for summary in worktree.diagnostic_summaries {
1532            session.peer.send(
1533                session.connection_id,
1534                proto::UpdateDiagnosticSummary {
1535                    project_id: project_id.to_proto(),
1536                    worktree_id: worktree.id,
1537                    summary: Some(summary),
1538                },
1539            )?;
1540        }
1541
1542        for settings_file in worktree.settings_files {
1543            session.peer.send(
1544                session.connection_id,
1545                proto::UpdateWorktreeSettings {
1546                    project_id: project_id.to_proto(),
1547                    worktree_id: worktree.id,
1548                    path: settings_file.path,
1549                    content: Some(settings_file.content),
1550                },
1551            )?;
1552        }
1553    }
1554
1555    for language_server in &project.language_servers {
1556        session.peer.send(
1557            session.connection_id,
1558            proto::UpdateLanguageServer {
1559                project_id: project_id.to_proto(),
1560                language_server_id: language_server.id,
1561                variant: Some(
1562                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1563                        proto::LspDiskBasedDiagnosticsUpdated {},
1564                    ),
1565                ),
1566            },
1567        )?;
1568    }
1569
1570    Ok(())
1571}
1572
1573async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1574    let sender_id = session.connection_id;
1575    let project_id = ProjectId::from_proto(request.project_id);
1576
1577    let (room, project) = &*session
1578        .db()
1579        .await
1580        .leave_project(project_id, sender_id)
1581        .await?;
1582    tracing::info!(
1583        %project_id,
1584        host_user_id = %project.host_user_id,
1585        host_connection_id = %project.host_connection_id,
1586        "leave project"
1587    );
1588
1589    project_left(&project, &session);
1590    room_updated(&room, &session.peer);
1591
1592    Ok(())
1593}
1594
1595async fn update_project(
1596    request: proto::UpdateProject,
1597    response: Response<proto::UpdateProject>,
1598    session: Session,
1599) -> Result<()> {
1600    let project_id = ProjectId::from_proto(request.project_id);
1601    let (room, guest_connection_ids) = &*session
1602        .db()
1603        .await
1604        .update_project(project_id, session.connection_id, &request.worktrees)
1605        .await?;
1606    broadcast(
1607        Some(session.connection_id),
1608        guest_connection_ids.iter().copied(),
1609        |connection_id| {
1610            session
1611                .peer
1612                .forward_send(session.connection_id, connection_id, request.clone())
1613        },
1614    );
1615    room_updated(&room, &session.peer);
1616    response.send(proto::Ack {})?;
1617
1618    Ok(())
1619}
1620
1621async fn update_worktree(
1622    request: proto::UpdateWorktree,
1623    response: Response<proto::UpdateWorktree>,
1624    session: Session,
1625) -> Result<()> {
1626    let guest_connection_ids = session
1627        .db()
1628        .await
1629        .update_worktree(&request, session.connection_id)
1630        .await?;
1631
1632    broadcast(
1633        Some(session.connection_id),
1634        guest_connection_ids.iter().copied(),
1635        |connection_id| {
1636            session
1637                .peer
1638                .forward_send(session.connection_id, connection_id, request.clone())
1639        },
1640    );
1641    response.send(proto::Ack {})?;
1642    Ok(())
1643}
1644
1645async fn update_diagnostic_summary(
1646    message: proto::UpdateDiagnosticSummary,
1647    session: Session,
1648) -> Result<()> {
1649    let guest_connection_ids = session
1650        .db()
1651        .await
1652        .update_diagnostic_summary(&message, session.connection_id)
1653        .await?;
1654
1655    broadcast(
1656        Some(session.connection_id),
1657        guest_connection_ids.iter().copied(),
1658        |connection_id| {
1659            session
1660                .peer
1661                .forward_send(session.connection_id, connection_id, message.clone())
1662        },
1663    );
1664
1665    Ok(())
1666}
1667
1668async fn update_worktree_settings(
1669    message: proto::UpdateWorktreeSettings,
1670    session: Session,
1671) -> Result<()> {
1672    let guest_connection_ids = session
1673        .db()
1674        .await
1675        .update_worktree_settings(&message, session.connection_id)
1676        .await?;
1677
1678    broadcast(
1679        Some(session.connection_id),
1680        guest_connection_ids.iter().copied(),
1681        |connection_id| {
1682            session
1683                .peer
1684                .forward_send(session.connection_id, connection_id, message.clone())
1685        },
1686    );
1687
1688    Ok(())
1689}
1690
1691async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1692    broadcast_project_message(request.project_id, request, session).await
1693}
1694
1695async fn start_language_server(
1696    request: proto::StartLanguageServer,
1697    session: Session,
1698) -> Result<()> {
1699    let guest_connection_ids = session
1700        .db()
1701        .await
1702        .start_language_server(&request, session.connection_id)
1703        .await?;
1704
1705    broadcast(
1706        Some(session.connection_id),
1707        guest_connection_ids.iter().copied(),
1708        |connection_id| {
1709            session
1710                .peer
1711                .forward_send(session.connection_id, connection_id, request.clone())
1712        },
1713    );
1714    Ok(())
1715}
1716
1717async fn update_language_server(
1718    request: proto::UpdateLanguageServer,
1719    session: Session,
1720) -> Result<()> {
1721    session.executor.record_backtrace();
1722    let project_id = ProjectId::from_proto(request.project_id);
1723    let project_connection_ids = session
1724        .db()
1725        .await
1726        .project_connection_ids(project_id, session.connection_id)
1727        .await?;
1728    broadcast(
1729        Some(session.connection_id),
1730        project_connection_ids.iter().copied(),
1731        |connection_id| {
1732            session
1733                .peer
1734                .forward_send(session.connection_id, connection_id, request.clone())
1735        },
1736    );
1737    Ok(())
1738}
1739
1740async fn forward_project_request<T>(
1741    request: T,
1742    response: Response<T>,
1743    session: Session,
1744) -> Result<()>
1745where
1746    T: EntityMessage + RequestMessage,
1747{
1748    session.executor.record_backtrace();
1749    let project_id = ProjectId::from_proto(request.remote_entity_id());
1750    let host_connection_id = {
1751        let collaborators = session
1752            .db()
1753            .await
1754            .project_collaborators(project_id, session.connection_id)
1755            .await?;
1756        collaborators
1757            .iter()
1758            .find(|collaborator| collaborator.is_host)
1759            .ok_or_else(|| anyhow!("host not found"))?
1760            .connection_id
1761    };
1762
1763    let payload = session
1764        .peer
1765        .forward_request(session.connection_id, host_connection_id, request)
1766        .await?;
1767
1768    response.send(payload)?;
1769    Ok(())
1770}
1771
1772async fn create_buffer_for_peer(
1773    request: proto::CreateBufferForPeer,
1774    session: Session,
1775) -> Result<()> {
1776    session.executor.record_backtrace();
1777    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1778    session
1779        .peer
1780        .forward_send(session.connection_id, peer_id.into(), request)?;
1781    Ok(())
1782}
1783
1784async fn update_buffer(
1785    request: proto::UpdateBuffer,
1786    response: Response<proto::UpdateBuffer>,
1787    session: Session,
1788) -> Result<()> {
1789    session.executor.record_backtrace();
1790    let project_id = ProjectId::from_proto(request.project_id);
1791    let mut guest_connection_ids;
1792    let mut host_connection_id = None;
1793    {
1794        let collaborators = session
1795            .db()
1796            .await
1797            .project_collaborators(project_id, session.connection_id)
1798            .await?;
1799        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1800        for collaborator in collaborators.iter() {
1801            if collaborator.is_host {
1802                host_connection_id = Some(collaborator.connection_id);
1803            } else {
1804                guest_connection_ids.push(collaborator.connection_id);
1805            }
1806        }
1807    }
1808    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1809
1810    session.executor.record_backtrace();
1811    broadcast(
1812        Some(session.connection_id),
1813        guest_connection_ids,
1814        |connection_id| {
1815            session
1816                .peer
1817                .forward_send(session.connection_id, connection_id, request.clone())
1818        },
1819    );
1820    if host_connection_id != session.connection_id {
1821        session
1822            .peer
1823            .forward_request(session.connection_id, host_connection_id, request.clone())
1824            .await?;
1825    }
1826
1827    response.send(proto::Ack {})?;
1828    Ok(())
1829}
1830
1831async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1832    let project_id = ProjectId::from_proto(request.project_id);
1833    let project_connection_ids = session
1834        .db()
1835        .await
1836        .project_connection_ids(project_id, session.connection_id)
1837        .await?;
1838
1839    broadcast(
1840        Some(session.connection_id),
1841        project_connection_ids.iter().copied(),
1842        |connection_id| {
1843            session
1844                .peer
1845                .forward_send(session.connection_id, connection_id, request.clone())
1846        },
1847    );
1848    Ok(())
1849}
1850
1851async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1852    let project_id = ProjectId::from_proto(request.project_id);
1853    let project_connection_ids = session
1854        .db()
1855        .await
1856        .project_connection_ids(project_id, session.connection_id)
1857        .await?;
1858    broadcast(
1859        Some(session.connection_id),
1860        project_connection_ids.iter().copied(),
1861        |connection_id| {
1862            session
1863                .peer
1864                .forward_send(session.connection_id, connection_id, request.clone())
1865        },
1866    );
1867    Ok(())
1868}
1869
1870async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1871    broadcast_project_message(request.project_id, request, session).await
1872}
1873
1874async fn broadcast_project_message<T: EnvelopedMessage>(
1875    project_id: u64,
1876    request: T,
1877    session: Session,
1878) -> Result<()> {
1879    let project_id = ProjectId::from_proto(project_id);
1880    let project_connection_ids = session
1881        .db()
1882        .await
1883        .project_connection_ids(project_id, session.connection_id)
1884        .await?;
1885    broadcast(
1886        Some(session.connection_id),
1887        project_connection_ids.iter().copied(),
1888        |connection_id| {
1889            session
1890                .peer
1891                .forward_send(session.connection_id, connection_id, request.clone())
1892        },
1893    );
1894    Ok(())
1895}
1896
1897async fn follow(
1898    request: proto::Follow,
1899    response: Response<proto::Follow>,
1900    session: Session,
1901) -> Result<()> {
1902    let room_id = RoomId::from_proto(request.room_id);
1903    let project_id = request.project_id.map(ProjectId::from_proto);
1904    let leader_id = request
1905        .leader_id
1906        .ok_or_else(|| anyhow!("invalid leader id"))?
1907        .into();
1908    let follower_id = session.connection_id;
1909
1910    session
1911        .db()
1912        .await
1913        .check_room_participants(room_id, leader_id, session.connection_id)
1914        .await?;
1915
1916    let response_payload = session
1917        .peer
1918        .forward_request(session.connection_id, leader_id, request)
1919        .await?;
1920    response.send(response_payload)?;
1921
1922    if let Some(project_id) = project_id {
1923        let room = session
1924            .db()
1925            .await
1926            .follow(room_id, project_id, leader_id, follower_id)
1927            .await?;
1928        room_updated(&room, &session.peer);
1929    }
1930
1931    Ok(())
1932}
1933
1934async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1935    let room_id = RoomId::from_proto(request.room_id);
1936    let project_id = request.project_id.map(ProjectId::from_proto);
1937    let leader_id = request
1938        .leader_id
1939        .ok_or_else(|| anyhow!("invalid leader id"))?
1940        .into();
1941    let follower_id = session.connection_id;
1942
1943    session
1944        .db()
1945        .await
1946        .check_room_participants(room_id, leader_id, session.connection_id)
1947        .await?;
1948
1949    session
1950        .peer
1951        .forward_send(session.connection_id, leader_id, request)?;
1952
1953    if let Some(project_id) = project_id {
1954        let room = session
1955            .db()
1956            .await
1957            .unfollow(room_id, project_id, leader_id, follower_id)
1958            .await?;
1959        room_updated(&room, &session.peer);
1960    }
1961
1962    Ok(())
1963}
1964
1965async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1966    let room_id = RoomId::from_proto(request.room_id);
1967    let database = session.db.lock().await;
1968
1969    let connection_ids = if let Some(project_id) = request.project_id {
1970        let project_id = ProjectId::from_proto(project_id);
1971        database
1972            .project_connection_ids(project_id, session.connection_id)
1973            .await?
1974    } else {
1975        database
1976            .room_connection_ids(room_id, session.connection_id)
1977            .await?
1978    };
1979
1980    // For now, don't send view update messages back to that view's current leader.
1981    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1982        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1983        _ => None,
1984    });
1985
1986    for follower_peer_id in request.follower_ids.iter().copied() {
1987        let follower_connection_id = follower_peer_id.into();
1988        if Some(follower_peer_id) != connection_id_to_omit
1989            && connection_ids.contains(&follower_connection_id)
1990        {
1991            session.peer.forward_send(
1992                session.connection_id,
1993                follower_connection_id,
1994                request.clone(),
1995            )?;
1996        }
1997    }
1998    Ok(())
1999}
2000
2001async fn get_users(
2002    request: proto::GetUsers,
2003    response: Response<proto::GetUsers>,
2004    session: Session,
2005) -> Result<()> {
2006    let user_ids = request
2007        .user_ids
2008        .into_iter()
2009        .map(UserId::from_proto)
2010        .collect();
2011    let users = session
2012        .db()
2013        .await
2014        .get_users_by_ids(user_ids)
2015        .await?
2016        .into_iter()
2017        .map(|user| proto::User {
2018            id: user.id.to_proto(),
2019            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2020            github_login: user.github_login,
2021        })
2022        .collect();
2023    response.send(proto::UsersResponse { users })?;
2024    Ok(())
2025}
2026
2027async fn fuzzy_search_users(
2028    request: proto::FuzzySearchUsers,
2029    response: Response<proto::FuzzySearchUsers>,
2030    session: Session,
2031) -> Result<()> {
2032    let query = request.query;
2033    let users = match query.len() {
2034        0 => vec![],
2035        1 | 2 => session
2036            .db()
2037            .await
2038            .get_user_by_github_login(&query)
2039            .await?
2040            .into_iter()
2041            .collect(),
2042        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2043    };
2044    let users = users
2045        .into_iter()
2046        .filter(|user| user.id != session.user_id)
2047        .map(|user| proto::User {
2048            id: user.id.to_proto(),
2049            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2050            github_login: user.github_login,
2051        })
2052        .collect();
2053    response.send(proto::UsersResponse { users })?;
2054    Ok(())
2055}
2056
2057async fn request_contact(
2058    request: proto::RequestContact,
2059    response: Response<proto::RequestContact>,
2060    session: Session,
2061) -> Result<()> {
2062    let requester_id = session.user_id;
2063    let responder_id = UserId::from_proto(request.responder_id);
2064    if requester_id == responder_id {
2065        return Err(anyhow!("cannot add yourself as a contact"))?;
2066    }
2067
2068    session
2069        .db()
2070        .await
2071        .send_contact_request(requester_id, responder_id)
2072        .await?;
2073
2074    // Update outgoing contact requests of requester
2075    let mut update = proto::UpdateContacts::default();
2076    update.outgoing_requests.push(responder_id.to_proto());
2077    for connection_id in session
2078        .connection_pool()
2079        .await
2080        .user_connection_ids(requester_id)
2081    {
2082        session.peer.send(connection_id, update.clone())?;
2083    }
2084
2085    // Update incoming contact requests of responder
2086    let mut update = proto::UpdateContacts::default();
2087    update
2088        .incoming_requests
2089        .push(proto::IncomingContactRequest {
2090            requester_id: requester_id.to_proto(),
2091            should_notify: true,
2092        });
2093    for connection_id in session
2094        .connection_pool()
2095        .await
2096        .user_connection_ids(responder_id)
2097    {
2098        session.peer.send(connection_id, update.clone())?;
2099    }
2100
2101    response.send(proto::Ack {})?;
2102    Ok(())
2103}
2104
2105async fn respond_to_contact_request(
2106    request: proto::RespondToContactRequest,
2107    response: Response<proto::RespondToContactRequest>,
2108    session: Session,
2109) -> Result<()> {
2110    let responder_id = session.user_id;
2111    let requester_id = UserId::from_proto(request.requester_id);
2112    let db = session.db().await;
2113    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2114        db.dismiss_contact_notification(responder_id, requester_id)
2115            .await?;
2116    } else {
2117        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2118
2119        db.respond_to_contact_request(responder_id, requester_id, accept)
2120            .await?;
2121        let requester_busy = db.is_user_busy(requester_id).await?;
2122        let responder_busy = db.is_user_busy(responder_id).await?;
2123
2124        let pool = session.connection_pool().await;
2125        // Update responder with new contact
2126        let mut update = proto::UpdateContacts::default();
2127        if accept {
2128            update
2129                .contacts
2130                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2131        }
2132        update
2133            .remove_incoming_requests
2134            .push(requester_id.to_proto());
2135        for connection_id in pool.user_connection_ids(responder_id) {
2136            session.peer.send(connection_id, update.clone())?;
2137        }
2138
2139        // Update requester with new contact
2140        let mut update = proto::UpdateContacts::default();
2141        if accept {
2142            update
2143                .contacts
2144                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2145        }
2146        update
2147            .remove_outgoing_requests
2148            .push(responder_id.to_proto());
2149        for connection_id in pool.user_connection_ids(requester_id) {
2150            session.peer.send(connection_id, update.clone())?;
2151        }
2152    }
2153
2154    response.send(proto::Ack {})?;
2155    Ok(())
2156}
2157
2158async fn remove_contact(
2159    request: proto::RemoveContact,
2160    response: Response<proto::RemoveContact>,
2161    session: Session,
2162) -> Result<()> {
2163    let requester_id = session.user_id;
2164    let responder_id = UserId::from_proto(request.user_id);
2165    let db = session.db().await;
2166    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2167
2168    let pool = session.connection_pool().await;
2169    // Update outgoing contact requests of requester
2170    let mut update = proto::UpdateContacts::default();
2171    if contact_accepted {
2172        update.remove_contacts.push(responder_id.to_proto());
2173    } else {
2174        update
2175            .remove_outgoing_requests
2176            .push(responder_id.to_proto());
2177    }
2178    for connection_id in pool.user_connection_ids(requester_id) {
2179        session.peer.send(connection_id, update.clone())?;
2180    }
2181
2182    // Update incoming contact requests of responder
2183    let mut update = proto::UpdateContacts::default();
2184    if contact_accepted {
2185        update.remove_contacts.push(requester_id.to_proto());
2186    } else {
2187        update
2188            .remove_incoming_requests
2189            .push(requester_id.to_proto());
2190    }
2191    for connection_id in pool.user_connection_ids(responder_id) {
2192        session.peer.send(connection_id, update.clone())?;
2193    }
2194
2195    response.send(proto::Ack {})?;
2196    Ok(())
2197}
2198
2199async fn create_channel(
2200    request: proto::CreateChannel,
2201    response: Response<proto::CreateChannel>,
2202    session: Session,
2203) -> Result<()> {
2204    let db = session.db().await;
2205
2206    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2207    let id = db
2208        .create_channel(&request.name, parent_id, session.user_id)
2209        .await?;
2210
2211    let channel = proto::Channel {
2212        id: id.to_proto(),
2213        name: request.name,
2214        visibility: proto::ChannelVisibility::Members as i32,
2215    };
2216
2217    response.send(proto::CreateChannelResponse {
2218        channel: Some(channel.clone()),
2219        parent_id: request.parent_id,
2220    })?;
2221
2222    let Some(parent_id) = parent_id else {
2223        return Ok(());
2224    };
2225
2226    let update = proto::UpdateChannels {
2227        channels: vec![channel],
2228        insert_edge: vec![ChannelEdge {
2229            parent_id: parent_id.to_proto(),
2230            channel_id: id.to_proto(),
2231        }],
2232        ..Default::default()
2233    };
2234
2235    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2236
2237    let connection_pool = session.connection_pool().await;
2238    for user_id in user_ids_to_notify {
2239        for connection_id in connection_pool.user_connection_ids(user_id) {
2240            if user_id == session.user_id {
2241                continue;
2242            }
2243            session.peer.send(connection_id, update.clone())?;
2244        }
2245    }
2246
2247    Ok(())
2248}
2249
2250async fn delete_channel(
2251    request: proto::DeleteChannel,
2252    response: Response<proto::DeleteChannel>,
2253    session: Session,
2254) -> Result<()> {
2255    let db = session.db().await;
2256
2257    let channel_id = request.channel_id;
2258    let (removed_channels, member_ids) = db
2259        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2260        .await?;
2261    response.send(proto::Ack {})?;
2262
2263    // Notify members of removed channels
2264    let mut update = proto::UpdateChannels::default();
2265    update
2266        .delete_channels
2267        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2268
2269    let connection_pool = session.connection_pool().await;
2270    for member_id in member_ids {
2271        for connection_id in connection_pool.user_connection_ids(member_id) {
2272            session.peer.send(connection_id, update.clone())?;
2273        }
2274    }
2275
2276    Ok(())
2277}
2278
2279async fn invite_channel_member(
2280    request: proto::InviteChannelMember,
2281    response: Response<proto::InviteChannelMember>,
2282    session: Session,
2283) -> Result<()> {
2284    let db = session.db().await;
2285    let channel_id = ChannelId::from_proto(request.channel_id);
2286    let invitee_id = UserId::from_proto(request.user_id);
2287    db.invite_channel_member(
2288        channel_id,
2289        invitee_id,
2290        session.user_id,
2291        request.role().into(),
2292    )
2293    .await?;
2294
2295    let (channel, _) = db
2296        .get_channel(channel_id, session.user_id)
2297        .await?
2298        .ok_or_else(|| anyhow!("channel not found"))?;
2299
2300    let mut update = proto::UpdateChannels::default();
2301    update.channel_invitations.push(proto::Channel {
2302        id: channel.id.to_proto(),
2303        visibility: channel.visibility.into(),
2304        name: channel.name,
2305    });
2306    for connection_id in session
2307        .connection_pool()
2308        .await
2309        .user_connection_ids(invitee_id)
2310    {
2311        session.peer.send(connection_id, update.clone())?;
2312    }
2313
2314    response.send(proto::Ack {})?;
2315    Ok(())
2316}
2317
2318async fn remove_channel_member(
2319    request: proto::RemoveChannelMember,
2320    response: Response<proto::RemoveChannelMember>,
2321    session: Session,
2322) -> Result<()> {
2323    let db = session.db().await;
2324    let channel_id = ChannelId::from_proto(request.channel_id);
2325    let member_id = UserId::from_proto(request.user_id);
2326
2327    db.remove_channel_member(channel_id, member_id, session.user_id)
2328        .await?;
2329
2330    let mut update = proto::UpdateChannels::default();
2331    update.delete_channels.push(channel_id.to_proto());
2332
2333    for connection_id in session
2334        .connection_pool()
2335        .await
2336        .user_connection_ids(member_id)
2337    {
2338        session.peer.send(connection_id, update.clone())?;
2339    }
2340
2341    response.send(proto::Ack {})?;
2342    Ok(())
2343}
2344
2345async fn set_channel_visibility(
2346    request: proto::SetChannelVisibility,
2347    response: Response<proto::SetChannelVisibility>,
2348    session: Session,
2349) -> Result<()> {
2350    let db = session.db().await;
2351    let channel_id = ChannelId::from_proto(request.channel_id);
2352    let visibility = request.visibility().into();
2353
2354    let channel = db
2355        .set_channel_visibility(channel_id, visibility, session.user_id)
2356        .await?;
2357
2358    let mut update = proto::UpdateChannels::default();
2359    update.channels.push(proto::Channel {
2360        id: channel.id.to_proto(),
2361        name: channel.name,
2362        visibility: channel.visibility.into(),
2363    });
2364
2365    let member_ids = db.get_channel_members(channel_id).await?;
2366
2367    let connection_pool = session.connection_pool().await;
2368    for member_id in member_ids {
2369        for connection_id in connection_pool.user_connection_ids(member_id) {
2370            session.peer.send(connection_id, update.clone())?;
2371        }
2372    }
2373
2374    response.send(proto::Ack {})?;
2375    Ok(())
2376}
2377
2378async fn set_channel_member_role(
2379    request: proto::SetChannelMemberRole,
2380    response: Response<proto::SetChannelMemberRole>,
2381    session: Session,
2382) -> Result<()> {
2383    let db = session.db().await;
2384    let channel_id = ChannelId::from_proto(request.channel_id);
2385    let member_id = UserId::from_proto(request.user_id);
2386    db.set_channel_member_role(
2387        channel_id,
2388        session.user_id,
2389        member_id,
2390        request.role().into(),
2391    )
2392    .await?;
2393
2394    let (channel, has_accepted) = db
2395        .get_channel(channel_id, member_id)
2396        .await?
2397        .ok_or_else(|| anyhow!("channel not found"))?;
2398
2399    let mut update = proto::UpdateChannels::default();
2400    if has_accepted {
2401        update.channel_permissions.push(proto::ChannelPermission {
2402            channel_id: channel.id.to_proto(),
2403            role: request.role,
2404        });
2405    }
2406
2407    for connection_id in session
2408        .connection_pool()
2409        .await
2410        .user_connection_ids(member_id)
2411    {
2412        session.peer.send(connection_id, update.clone())?;
2413    }
2414
2415    response.send(proto::Ack {})?;
2416    Ok(())
2417}
2418
2419async fn rename_channel(
2420    request: proto::RenameChannel,
2421    response: Response<proto::RenameChannel>,
2422    session: Session,
2423) -> Result<()> {
2424    let db = session.db().await;
2425    let channel_id = ChannelId::from_proto(request.channel_id);
2426    let channel = db
2427        .rename_channel(channel_id, session.user_id, &request.name)
2428        .await?;
2429
2430    let channel = proto::Channel {
2431        id: channel.id.to_proto(),
2432        name: channel.name,
2433        visibility: channel.visibility.into(),
2434    };
2435    response.send(proto::RenameChannelResponse {
2436        channel: Some(channel.clone()),
2437    })?;
2438    let mut update = proto::UpdateChannels::default();
2439    update.channels.push(channel);
2440
2441    let member_ids = db.get_channel_members(channel_id).await?;
2442
2443    let connection_pool = session.connection_pool().await;
2444    for member_id in member_ids {
2445        for connection_id in connection_pool.user_connection_ids(member_id) {
2446            session.peer.send(connection_id, update.clone())?;
2447        }
2448    }
2449
2450    Ok(())
2451}
2452
2453async fn link_channel(
2454    request: proto::LinkChannel,
2455    response: Response<proto::LinkChannel>,
2456    session: Session,
2457) -> Result<()> {
2458    let db = session.db().await;
2459    let channel_id = ChannelId::from_proto(request.channel_id);
2460    let to = ChannelId::from_proto(request.to);
2461    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2462
2463    let members = db.get_channel_members(to).await?;
2464    let connection_pool = session.connection_pool().await;
2465    let update = proto::UpdateChannels {
2466        channels: channels_to_send
2467            .channels
2468            .into_iter()
2469            .map(|channel| proto::Channel {
2470                id: channel.id.to_proto(),
2471                visibility: channel.visibility.into(),
2472                name: channel.name,
2473            })
2474            .collect(),
2475        insert_edge: channels_to_send.edges,
2476        ..Default::default()
2477    };
2478    for member_id in members {
2479        for connection_id in connection_pool.user_connection_ids(member_id) {
2480            session.peer.send(connection_id, update.clone())?;
2481        }
2482    }
2483
2484    response.send(Ack {})?;
2485
2486    Ok(())
2487}
2488
2489async fn unlink_channel(
2490    request: proto::UnlinkChannel,
2491    response: Response<proto::UnlinkChannel>,
2492    session: Session,
2493) -> Result<()> {
2494    let db = session.db().await;
2495    let channel_id = ChannelId::from_proto(request.channel_id);
2496    let from = ChannelId::from_proto(request.from);
2497
2498    db.unlink_channel(session.user_id, channel_id, from).await?;
2499
2500    let members = db.get_channel_members(from).await?;
2501
2502    let update = proto::UpdateChannels {
2503        delete_edge: vec![proto::ChannelEdge {
2504            channel_id: channel_id.to_proto(),
2505            parent_id: from.to_proto(),
2506        }],
2507        ..Default::default()
2508    };
2509    let connection_pool = session.connection_pool().await;
2510    for member_id in members {
2511        for connection_id in connection_pool.user_connection_ids(member_id) {
2512            session.peer.send(connection_id, update.clone())?;
2513        }
2514    }
2515
2516    response.send(Ack {})?;
2517
2518    Ok(())
2519}
2520
2521async fn move_channel(
2522    request: proto::MoveChannel,
2523    response: Response<proto::MoveChannel>,
2524    session: Session,
2525) -> Result<()> {
2526    let db = session.db().await;
2527    let channel_id = ChannelId::from_proto(request.channel_id);
2528    let from_parent = ChannelId::from_proto(request.from);
2529    let to = ChannelId::from_proto(request.to);
2530
2531    let channels_to_send = db
2532        .move_channel(session.user_id, channel_id, from_parent, to)
2533        .await?;
2534
2535    if channels_to_send.is_empty() {
2536        response.send(Ack {})?;
2537        return Ok(());
2538    }
2539
2540    let members_from = db.get_channel_members(from_parent).await?;
2541    let members_to = db.get_channel_members(to).await?;
2542
2543    let update = proto::UpdateChannels {
2544        delete_edge: vec![proto::ChannelEdge {
2545            channel_id: channel_id.to_proto(),
2546            parent_id: from_parent.to_proto(),
2547        }],
2548        ..Default::default()
2549    };
2550    let connection_pool = session.connection_pool().await;
2551    for member_id in members_from {
2552        for connection_id in connection_pool.user_connection_ids(member_id) {
2553            session.peer.send(connection_id, update.clone())?;
2554        }
2555    }
2556
2557    let update = proto::UpdateChannels {
2558        channels: channels_to_send
2559            .channels
2560            .into_iter()
2561            .map(|channel| proto::Channel {
2562                id: channel.id.to_proto(),
2563                visibility: channel.visibility.into(),
2564                name: channel.name,
2565            })
2566            .collect(),
2567        insert_edge: channels_to_send.edges,
2568        ..Default::default()
2569    };
2570    for member_id in members_to {
2571        for connection_id in connection_pool.user_connection_ids(member_id) {
2572            session.peer.send(connection_id, update.clone())?;
2573        }
2574    }
2575
2576    response.send(Ack {})?;
2577
2578    Ok(())
2579}
2580
2581async fn get_channel_members(
2582    request: proto::GetChannelMembers,
2583    response: Response<proto::GetChannelMembers>,
2584    session: Session,
2585) -> Result<()> {
2586    let db = session.db().await;
2587    let channel_id = ChannelId::from_proto(request.channel_id);
2588    let members = db
2589        .get_channel_participant_details(channel_id, session.user_id)
2590        .await?;
2591    response.send(proto::GetChannelMembersResponse { members })?;
2592    Ok(())
2593}
2594
2595async fn respond_to_channel_invite(
2596    request: proto::RespondToChannelInvite,
2597    response: Response<proto::RespondToChannelInvite>,
2598    session: Session,
2599) -> Result<()> {
2600    let db = session.db().await;
2601    let channel_id = ChannelId::from_proto(request.channel_id);
2602    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2603        .await?;
2604
2605    let mut update = proto::UpdateChannels::default();
2606    update
2607        .remove_channel_invitations
2608        .push(channel_id.to_proto());
2609    if request.accept {
2610        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2611        update
2612            .channels
2613            .extend(
2614                result
2615                    .channels
2616                    .channels
2617                    .into_iter()
2618                    .map(|channel| proto::Channel {
2619                        id: channel.id.to_proto(),
2620                        visibility: channel.visibility.into(),
2621                        name: channel.name,
2622                    }),
2623            );
2624        update.unseen_channel_messages = result.channel_messages;
2625        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2626        update.insert_edge = result.channels.edges;
2627        update
2628            .channel_participants
2629            .extend(
2630                result
2631                    .channel_participants
2632                    .into_iter()
2633                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2634                        channel_id: channel_id.to_proto(),
2635                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2636                    }),
2637            );
2638        update
2639            .channel_permissions
2640            .extend(
2641                result
2642                    .channels_with_admin_privileges
2643                    .into_iter()
2644                    .map(|channel_id| proto::ChannelPermission {
2645                        channel_id: channel_id.to_proto(),
2646                        role: proto::ChannelRole::Admin.into(),
2647                    }),
2648            );
2649    }
2650    session.peer.send(session.connection_id, update)?;
2651    response.send(proto::Ack {})?;
2652
2653    Ok(())
2654}
2655
2656async fn join_channel(
2657    request: proto::JoinChannel,
2658    response: Response<proto::JoinChannel>,
2659    session: Session,
2660) -> Result<()> {
2661    let channel_id = ChannelId::from_proto(request.channel_id);
2662    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2663
2664    let joined_room = {
2665        leave_room_for_session(&session).await?;
2666        let db = session.db().await;
2667
2668        let room_id = db
2669            .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME)
2670            .await?;
2671
2672        let joined_room = db
2673            .join_room(
2674                room_id,
2675                session.user_id,
2676                session.connection_id,
2677                RELEASE_CHANNEL_NAME.as_str(),
2678            )
2679            .await?;
2680
2681        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2682            let token = live_kit
2683                .room_token(
2684                    &joined_room.room.live_kit_room,
2685                    &session.user_id.to_string(),
2686                )
2687                .trace_err()?;
2688
2689            Some(LiveKitConnectionInfo {
2690                server_url: live_kit.url().into(),
2691                token,
2692            })
2693        });
2694
2695        response.send(proto::JoinRoomResponse {
2696            room: Some(joined_room.room.clone()),
2697            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2698            live_kit_connection_info,
2699        })?;
2700
2701        room_updated(&joined_room.room, &session.peer);
2702
2703        joined_room.into_inner()
2704    };
2705
2706    channel_updated(
2707        channel_id,
2708        &joined_room.room,
2709        &joined_room.channel_members,
2710        &session.peer,
2711        &*session.connection_pool().await,
2712    );
2713
2714    update_user_contacts(session.user_id, &session).await?;
2715
2716    Ok(())
2717}
2718
2719async fn join_channel_buffer(
2720    request: proto::JoinChannelBuffer,
2721    response: Response<proto::JoinChannelBuffer>,
2722    session: Session,
2723) -> Result<()> {
2724    let db = session.db().await;
2725    let channel_id = ChannelId::from_proto(request.channel_id);
2726
2727    let open_response = db
2728        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2729        .await?;
2730
2731    let collaborators = open_response.collaborators.clone();
2732    response.send(open_response)?;
2733
2734    let update = UpdateChannelBufferCollaborators {
2735        channel_id: channel_id.to_proto(),
2736        collaborators: collaborators.clone(),
2737    };
2738    channel_buffer_updated(
2739        session.connection_id,
2740        collaborators
2741            .iter()
2742            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2743        &update,
2744        &session.peer,
2745    );
2746
2747    Ok(())
2748}
2749
2750async fn update_channel_buffer(
2751    request: proto::UpdateChannelBuffer,
2752    session: Session,
2753) -> Result<()> {
2754    let db = session.db().await;
2755    let channel_id = ChannelId::from_proto(request.channel_id);
2756
2757    let (collaborators, non_collaborators, epoch, version) = db
2758        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2759        .await?;
2760
2761    channel_buffer_updated(
2762        session.connection_id,
2763        collaborators,
2764        &proto::UpdateChannelBuffer {
2765            channel_id: channel_id.to_proto(),
2766            operations: request.operations,
2767        },
2768        &session.peer,
2769    );
2770
2771    let pool = &*session.connection_pool().await;
2772
2773    broadcast(
2774        None,
2775        non_collaborators
2776            .iter()
2777            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2778        |peer_id| {
2779            session.peer.send(
2780                peer_id.into(),
2781                proto::UpdateChannels {
2782                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2783                        channel_id: channel_id.to_proto(),
2784                        epoch: epoch as u64,
2785                        version: version.clone(),
2786                    }],
2787                    ..Default::default()
2788                },
2789            )
2790        },
2791    );
2792
2793    Ok(())
2794}
2795
2796async fn rejoin_channel_buffers(
2797    request: proto::RejoinChannelBuffers,
2798    response: Response<proto::RejoinChannelBuffers>,
2799    session: Session,
2800) -> Result<()> {
2801    let db = session.db().await;
2802    let buffers = db
2803        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2804        .await?;
2805
2806    for rejoined_buffer in &buffers {
2807        let collaborators_to_notify = rejoined_buffer
2808            .buffer
2809            .collaborators
2810            .iter()
2811            .filter_map(|c| Some(c.peer_id?.into()));
2812        channel_buffer_updated(
2813            session.connection_id,
2814            collaborators_to_notify,
2815            &proto::UpdateChannelBufferCollaborators {
2816                channel_id: rejoined_buffer.buffer.channel_id,
2817                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2818            },
2819            &session.peer,
2820        );
2821    }
2822
2823    response.send(proto::RejoinChannelBuffersResponse {
2824        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2825    })?;
2826
2827    Ok(())
2828}
2829
2830async fn leave_channel_buffer(
2831    request: proto::LeaveChannelBuffer,
2832    response: Response<proto::LeaveChannelBuffer>,
2833    session: Session,
2834) -> Result<()> {
2835    let db = session.db().await;
2836    let channel_id = ChannelId::from_proto(request.channel_id);
2837
2838    let left_buffer = db
2839        .leave_channel_buffer(channel_id, session.connection_id)
2840        .await?;
2841
2842    response.send(Ack {})?;
2843
2844    channel_buffer_updated(
2845        session.connection_id,
2846        left_buffer.connections,
2847        &proto::UpdateChannelBufferCollaborators {
2848            channel_id: channel_id.to_proto(),
2849            collaborators: left_buffer.collaborators,
2850        },
2851        &session.peer,
2852    );
2853
2854    Ok(())
2855}
2856
2857fn channel_buffer_updated<T: EnvelopedMessage>(
2858    sender_id: ConnectionId,
2859    collaborators: impl IntoIterator<Item = ConnectionId>,
2860    message: &T,
2861    peer: &Peer,
2862) {
2863    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2864        peer.send(peer_id.into(), message.clone())
2865    });
2866}
2867
2868async fn send_channel_message(
2869    request: proto::SendChannelMessage,
2870    response: Response<proto::SendChannelMessage>,
2871    session: Session,
2872) -> Result<()> {
2873    // Validate the message body.
2874    let body = request.body.trim().to_string();
2875    if body.len() > MAX_MESSAGE_LEN {
2876        return Err(anyhow!("message is too long"))?;
2877    }
2878    if body.is_empty() {
2879        return Err(anyhow!("message can't be blank"))?;
2880    }
2881
2882    let timestamp = OffsetDateTime::now_utc();
2883    let nonce = request
2884        .nonce
2885        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2886
2887    let channel_id = ChannelId::from_proto(request.channel_id);
2888    let (message_id, connection_ids, non_participants) = session
2889        .db()
2890        .await
2891        .create_channel_message(
2892            channel_id,
2893            session.user_id,
2894            &body,
2895            timestamp,
2896            nonce.clone().into(),
2897        )
2898        .await?;
2899    let message = proto::ChannelMessage {
2900        sender_id: session.user_id.to_proto(),
2901        id: message_id.to_proto(),
2902        body,
2903        timestamp: timestamp.unix_timestamp() as u64,
2904        nonce: Some(nonce),
2905    };
2906    broadcast(Some(session.connection_id), connection_ids, |connection| {
2907        session.peer.send(
2908            connection,
2909            proto::ChannelMessageSent {
2910                channel_id: channel_id.to_proto(),
2911                message: Some(message.clone()),
2912            },
2913        )
2914    });
2915    response.send(proto::SendChannelMessageResponse {
2916        message: Some(message),
2917    })?;
2918
2919    let pool = &*session.connection_pool().await;
2920    broadcast(
2921        None,
2922        non_participants
2923            .iter()
2924            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2925        |peer_id| {
2926            session.peer.send(
2927                peer_id.into(),
2928                proto::UpdateChannels {
2929                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2930                        channel_id: channel_id.to_proto(),
2931                        message_id: message_id.to_proto(),
2932                    }],
2933                    ..Default::default()
2934                },
2935            )
2936        },
2937    );
2938
2939    Ok(())
2940}
2941
2942async fn remove_channel_message(
2943    request: proto::RemoveChannelMessage,
2944    response: Response<proto::RemoveChannelMessage>,
2945    session: Session,
2946) -> Result<()> {
2947    let channel_id = ChannelId::from_proto(request.channel_id);
2948    let message_id = MessageId::from_proto(request.message_id);
2949    let connection_ids = session
2950        .db()
2951        .await
2952        .remove_channel_message(channel_id, message_id, session.user_id)
2953        .await?;
2954    broadcast(Some(session.connection_id), connection_ids, |connection| {
2955        session.peer.send(connection, request.clone())
2956    });
2957    response.send(proto::Ack {})?;
2958    Ok(())
2959}
2960
2961async fn acknowledge_channel_message(
2962    request: proto::AckChannelMessage,
2963    session: Session,
2964) -> Result<()> {
2965    let channel_id = ChannelId::from_proto(request.channel_id);
2966    let message_id = MessageId::from_proto(request.message_id);
2967    session
2968        .db()
2969        .await
2970        .observe_channel_message(channel_id, session.user_id, message_id)
2971        .await?;
2972    Ok(())
2973}
2974
2975async fn acknowledge_buffer_version(
2976    request: proto::AckBufferOperation,
2977    session: Session,
2978) -> Result<()> {
2979    let buffer_id = BufferId::from_proto(request.buffer_id);
2980    session
2981        .db()
2982        .await
2983        .observe_buffer_version(
2984            buffer_id,
2985            session.user_id,
2986            request.epoch as i32,
2987            &request.version,
2988        )
2989        .await?;
2990    Ok(())
2991}
2992
2993async fn join_channel_chat(
2994    request: proto::JoinChannelChat,
2995    response: Response<proto::JoinChannelChat>,
2996    session: Session,
2997) -> Result<()> {
2998    let channel_id = ChannelId::from_proto(request.channel_id);
2999
3000    let db = session.db().await;
3001    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3002        .await?;
3003    let messages = db
3004        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3005        .await?;
3006    response.send(proto::JoinChannelChatResponse {
3007        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3008        messages,
3009    })?;
3010    Ok(())
3011}
3012
3013async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3014    let channel_id = ChannelId::from_proto(request.channel_id);
3015    session
3016        .db()
3017        .await
3018        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3019        .await?;
3020    Ok(())
3021}
3022
3023async fn get_channel_messages(
3024    request: proto::GetChannelMessages,
3025    response: Response<proto::GetChannelMessages>,
3026    session: Session,
3027) -> Result<()> {
3028    let channel_id = ChannelId::from_proto(request.channel_id);
3029    let messages = session
3030        .db()
3031        .await
3032        .get_channel_messages(
3033            channel_id,
3034            session.user_id,
3035            MESSAGE_COUNT_PER_PAGE,
3036            Some(MessageId::from_proto(request.before_message_id)),
3037        )
3038        .await?;
3039    response.send(proto::GetChannelMessagesResponse {
3040        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3041        messages,
3042    })?;
3043    Ok(())
3044}
3045
3046async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3047    let project_id = ProjectId::from_proto(request.project_id);
3048    let project_connection_ids = session
3049        .db()
3050        .await
3051        .project_connection_ids(project_id, session.connection_id)
3052        .await?;
3053    broadcast(
3054        Some(session.connection_id),
3055        project_connection_ids.iter().copied(),
3056        |connection_id| {
3057            session
3058                .peer
3059                .forward_send(session.connection_id, connection_id, request.clone())
3060        },
3061    );
3062    Ok(())
3063}
3064
3065async fn get_private_user_info(
3066    _request: proto::GetPrivateUserInfo,
3067    response: Response<proto::GetPrivateUserInfo>,
3068    session: Session,
3069) -> Result<()> {
3070    let db = session.db().await;
3071
3072    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3073    let user = db
3074        .get_user_by_id(session.user_id)
3075        .await?
3076        .ok_or_else(|| anyhow!("user not found"))?;
3077    let flags = db.get_user_flags(session.user_id).await?;
3078
3079    response.send(proto::GetPrivateUserInfoResponse {
3080        metrics_id,
3081        staff: user.admin,
3082        flags,
3083    })?;
3084    Ok(())
3085}
3086
3087fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3088    match message {
3089        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3090        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3091        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3092        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3093        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3094            code: frame.code.into(),
3095            reason: frame.reason,
3096        })),
3097    }
3098}
3099
3100fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3101    match message {
3102        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3103        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3104        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3105        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3106        AxumMessage::Close(frame) => {
3107            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3108                code: frame.code.into(),
3109                reason: frame.reason,
3110            }))
3111        }
3112    }
3113}
3114
3115fn build_initial_channels_update(
3116    channels: ChannelsForUser,
3117    channel_invites: Vec<db::Channel>,
3118) -> proto::UpdateChannels {
3119    let mut update = proto::UpdateChannels::default();
3120
3121    for channel in channels.channels.channels {
3122        update.channels.push(proto::Channel {
3123            id: channel.id.to_proto(),
3124            name: channel.name,
3125            visibility: channel.visibility.into(),
3126        });
3127    }
3128
3129    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3130    update.unseen_channel_messages = channels.channel_messages;
3131    update.insert_edge = channels.channels.edges;
3132
3133    for (channel_id, participants) in channels.channel_participants {
3134        update
3135            .channel_participants
3136            .push(proto::ChannelParticipants {
3137                channel_id: channel_id.to_proto(),
3138                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3139            });
3140    }
3141
3142    update
3143        .channel_permissions
3144        .extend(
3145            channels
3146                .channels_with_admin_privileges
3147                .into_iter()
3148                .map(|id| proto::ChannelPermission {
3149                    channel_id: id.to_proto(),
3150                    role: proto::ChannelRole::Admin.into(),
3151                }),
3152        );
3153
3154    for channel in channel_invites {
3155        update.channel_invitations.push(proto::Channel {
3156            id: channel.id.to_proto(),
3157            name: channel.name,
3158            // TODO: Visibility
3159            visibility: ChannelVisibility::Public.into(),
3160        });
3161    }
3162
3163    update
3164}
3165
3166fn build_initial_contacts_update(
3167    contacts: Vec<db::Contact>,
3168    pool: &ConnectionPool,
3169) -> proto::UpdateContacts {
3170    let mut update = proto::UpdateContacts::default();
3171
3172    for contact in contacts {
3173        match contact {
3174            db::Contact::Accepted {
3175                user_id,
3176                should_notify,
3177                busy,
3178            } => {
3179                update
3180                    .contacts
3181                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3182            }
3183            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3184            db::Contact::Incoming {
3185                user_id,
3186                should_notify,
3187            } => update
3188                .incoming_requests
3189                .push(proto::IncomingContactRequest {
3190                    requester_id: user_id.to_proto(),
3191                    should_notify,
3192                }),
3193        }
3194    }
3195
3196    update
3197}
3198
3199fn contact_for_user(
3200    user_id: UserId,
3201    should_notify: bool,
3202    busy: bool,
3203    pool: &ConnectionPool,
3204) -> proto::Contact {
3205    proto::Contact {
3206        user_id: user_id.to_proto(),
3207        online: pool.is_user_online(user_id),
3208        busy,
3209        should_notify,
3210    }
3211}
3212
3213fn room_updated(room: &proto::Room, peer: &Peer) {
3214    broadcast(
3215        None,
3216        room.participants
3217            .iter()
3218            .filter_map(|participant| Some(participant.peer_id?.into())),
3219        |peer_id| {
3220            peer.send(
3221                peer_id.into(),
3222                proto::RoomUpdated {
3223                    room: Some(room.clone()),
3224                },
3225            )
3226        },
3227    );
3228}
3229
3230fn channel_updated(
3231    channel_id: ChannelId,
3232    room: &proto::Room,
3233    channel_members: &[UserId],
3234    peer: &Peer,
3235    pool: &ConnectionPool,
3236) {
3237    let participants = room
3238        .participants
3239        .iter()
3240        .map(|p| p.user_id)
3241        .collect::<Vec<_>>();
3242
3243    broadcast(
3244        None,
3245        channel_members
3246            .iter()
3247            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3248        |peer_id| {
3249            peer.send(
3250                peer_id.into(),
3251                proto::UpdateChannels {
3252                    channel_participants: vec![proto::ChannelParticipants {
3253                        channel_id: channel_id.to_proto(),
3254                        participant_user_ids: participants.clone(),
3255                    }],
3256                    ..Default::default()
3257                },
3258            )
3259        },
3260    );
3261}
3262
3263async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3264    let db = session.db().await;
3265
3266    let contacts = db.get_contacts(user_id).await?;
3267    let busy = db.is_user_busy(user_id).await?;
3268
3269    let pool = session.connection_pool().await;
3270    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3271    for contact in contacts {
3272        if let db::Contact::Accepted {
3273            user_id: contact_user_id,
3274            ..
3275        } = contact
3276        {
3277            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3278                session
3279                    .peer
3280                    .send(
3281                        contact_conn_id,
3282                        proto::UpdateContacts {
3283                            contacts: vec![updated_contact.clone()],
3284                            remove_contacts: Default::default(),
3285                            incoming_requests: Default::default(),
3286                            remove_incoming_requests: Default::default(),
3287                            outgoing_requests: Default::default(),
3288                            remove_outgoing_requests: Default::default(),
3289                        },
3290                    )
3291                    .trace_err();
3292            }
3293        }
3294    }
3295    Ok(())
3296}
3297
3298async fn leave_room_for_session(session: &Session) -> Result<()> {
3299    let mut contacts_to_update = HashSet::default();
3300
3301    let room_id;
3302    let canceled_calls_to_user_ids;
3303    let live_kit_room;
3304    let delete_live_kit_room;
3305    let room;
3306    let channel_members;
3307    let channel_id;
3308
3309    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3310        contacts_to_update.insert(session.user_id);
3311
3312        for project in left_room.left_projects.values() {
3313            project_left(project, session);
3314        }
3315
3316        room_id = RoomId::from_proto(left_room.room.id);
3317        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3318        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3319        delete_live_kit_room = left_room.deleted;
3320        room = mem::take(&mut left_room.room);
3321        channel_members = mem::take(&mut left_room.channel_members);
3322        channel_id = left_room.channel_id;
3323
3324        room_updated(&room, &session.peer);
3325    } else {
3326        return Ok(());
3327    }
3328
3329    if let Some(channel_id) = channel_id {
3330        channel_updated(
3331            channel_id,
3332            &room,
3333            &channel_members,
3334            &session.peer,
3335            &*session.connection_pool().await,
3336        );
3337    }
3338
3339    {
3340        let pool = session.connection_pool().await;
3341        for canceled_user_id in canceled_calls_to_user_ids {
3342            for connection_id in pool.user_connection_ids(canceled_user_id) {
3343                session
3344                    .peer
3345                    .send(
3346                        connection_id,
3347                        proto::CallCanceled {
3348                            room_id: room_id.to_proto(),
3349                        },
3350                    )
3351                    .trace_err();
3352            }
3353            contacts_to_update.insert(canceled_user_id);
3354        }
3355    }
3356
3357    for contact_user_id in contacts_to_update {
3358        update_user_contacts(contact_user_id, &session).await?;
3359    }
3360
3361    if let Some(live_kit) = session.live_kit_client.as_ref() {
3362        live_kit
3363            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3364            .await
3365            .trace_err();
3366
3367        if delete_live_kit_room {
3368            live_kit.delete_room(live_kit_room).await.trace_err();
3369        }
3370    }
3371
3372    Ok(())
3373}
3374
3375async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3376    let left_channel_buffers = session
3377        .db()
3378        .await
3379        .leave_channel_buffers(session.connection_id)
3380        .await?;
3381
3382    for left_buffer in left_channel_buffers {
3383        channel_buffer_updated(
3384            session.connection_id,
3385            left_buffer.connections,
3386            &proto::UpdateChannelBufferCollaborators {
3387                channel_id: left_buffer.channel_id.to_proto(),
3388                collaborators: left_buffer.collaborators,
3389            },
3390            &session.peer,
3391        );
3392    }
3393
3394    Ok(())
3395}
3396
3397fn project_left(project: &db::LeftProject, session: &Session) {
3398    for connection_id in &project.connection_ids {
3399        if project.host_user_id == session.user_id {
3400            session
3401                .peer
3402                .send(
3403                    *connection_id,
3404                    proto::UnshareProject {
3405                        project_id: project.id.to_proto(),
3406                    },
3407                )
3408                .trace_err();
3409        } else {
3410            session
3411                .peer
3412                .send(
3413                    *connection_id,
3414                    proto::RemoveProjectCollaborator {
3415                        project_id: project.id.to_proto(),
3416                        peer_id: Some(session.connection_id.into()),
3417                    },
3418                )
3419                .trace_err();
3420        }
3421    }
3422}
3423
3424pub trait ResultExt {
3425    type Ok;
3426
3427    fn trace_err(self) -> Option<Self::Ok>;
3428}
3429
3430impl<T, E> ResultExt for Result<T, E>
3431where
3432    E: std::fmt::Debug,
3433{
3434    type Ok = T;
3435
3436    fn trace_err(self) -> Option<T> {
3437        match self {
3438            Ok(value) => Some(value),
3439            Err(error) => {
3440                tracing::error!("{:?}", error);
3441                None
3442            }
3443        }
3444    }
3445}