rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  72
  73const MESSAGE_COUNT_PER_PAGE: usize = 100;
  74const MAX_MESSAGE_LEN: usize = 1024;
  75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  76
  77lazy_static! {
  78    static ref METRIC_CONNECTIONS: IntGauge =
  79        register_int_gauge!("connections", "number of connections").unwrap();
  80    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  81        "shared_projects",
  82        "number of open projects with one or more guests"
  83    )
  84    .unwrap();
  85}
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct Session {
 106    zed_environment: Arc<str>,
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(call)
 206            .add_request_handler(cancel_call)
 207            .add_message_handler(decline_call)
 208            .add_request_handler(update_participant_location)
 209            .add_request_handler(share_project)
 210            .add_message_handler(unshare_project)
 211            .add_request_handler(join_project)
 212            .add_message_handler(leave_project)
 213            .add_request_handler(update_project)
 214            .add_request_handler(update_worktree)
 215            .add_message_handler(start_language_server)
 216            .add_message_handler(update_language_server)
 217            .add_message_handler(update_diagnostic_summary)
 218            .add_message_handler(update_worktree_settings)
 219            .add_message_handler(refresh_inlay_hints)
 220            .add_request_handler(forward_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 230            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 232            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 233            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 234            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 235            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 236            .add_request_handler(forward_project_request::<proto::PerformRename>)
 237            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 238            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 239            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 240            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 243            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 244            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 245            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 246            .add_request_handler(forward_project_request::<proto::InlayHints>)
 247            .add_message_handler(create_buffer_for_peer)
 248            .add_request_handler(update_buffer)
 249            .add_message_handler(update_buffer_file)
 250            .add_message_handler(buffer_reloaded)
 251            .add_message_handler(buffer_saved)
 252            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 253            .add_request_handler(get_users)
 254            .add_request_handler(fuzzy_search_users)
 255            .add_request_handler(request_contact)
 256            .add_request_handler(remove_contact)
 257            .add_request_handler(respond_to_contact_request)
 258            .add_request_handler(create_channel)
 259            .add_request_handler(delete_channel)
 260            .add_request_handler(invite_channel_member)
 261            .add_request_handler(remove_channel_member)
 262            .add_request_handler(set_channel_member_role)
 263            .add_request_handler(set_channel_visibility)
 264            .add_request_handler(rename_channel)
 265            .add_request_handler(join_channel_buffer)
 266            .add_request_handler(leave_channel_buffer)
 267            .add_message_handler(update_channel_buffer)
 268            .add_request_handler(rejoin_channel_buffers)
 269            .add_request_handler(get_channel_members)
 270            .add_request_handler(respond_to_channel_invite)
 271            .add_request_handler(join_channel)
 272            .add_request_handler(join_channel_chat)
 273            .add_message_handler(leave_channel_chat)
 274            .add_request_handler(send_channel_message)
 275            .add_request_handler(remove_channel_message)
 276            .add_request_handler(get_channel_messages)
 277            .add_request_handler(get_channel_messages_by_id)
 278            .add_request_handler(get_notifications)
 279            .add_request_handler(mark_notification_as_read)
 280            .add_request_handler(move_channel)
 281            .add_request_handler(follow)
 282            .add_message_handler(unfollow)
 283            .add_message_handler(update_followers)
 284            .add_message_handler(update_diff_base)
 285            .add_request_handler(get_private_user_info)
 286            .add_message_handler(acknowledge_channel_message)
 287            .add_message_handler(acknowledge_buffer_version);
 288
 289        Arc::new(server)
 290    }
 291
 292    pub async fn start(&self) -> Result<()> {
 293        let server_id = *self.id.lock();
 294        let app_state = self.app_state.clone();
 295        let peer = self.peer.clone();
 296        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 297        let pool = self.connection_pool.clone();
 298        let live_kit_client = self.app_state.live_kit_client.clone();
 299
 300        let span = info_span!("start server");
 301        self.executor.spawn_detached(
 302            async move {
 303                tracing::info!("waiting for cleanup timeout");
 304                timeout.await;
 305                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 306                if let Some((room_ids, channel_ids)) = app_state
 307                    .db
 308                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 309                    .await
 310                    .trace_err()
 311                {
 312                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 313                    tracing::info!(
 314                        stale_channel_buffer_count = channel_ids.len(),
 315                        "retrieved stale channel buffers"
 316                    );
 317
 318                    for channel_id in channel_ids {
 319                        if let Some(refreshed_channel_buffer) = app_state
 320                            .db
 321                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 322                            .await
 323                            .trace_err()
 324                        {
 325                            for connection_id in refreshed_channel_buffer.connection_ids {
 326                                peer.send(
 327                                    connection_id,
 328                                    proto::UpdateChannelBufferCollaborators {
 329                                        channel_id: channel_id.to_proto(),
 330                                        collaborators: refreshed_channel_buffer
 331                                            .collaborators
 332                                            .clone(),
 333                                    },
 334                                )
 335                                .trace_err();
 336                            }
 337                        }
 338                    }
 339
 340                    for room_id in room_ids {
 341                        let mut contacts_to_update = HashSet::default();
 342                        let mut canceled_calls_to_user_ids = Vec::new();
 343                        let mut live_kit_room = String::new();
 344                        let mut delete_live_kit_room = false;
 345
 346                        if let Some(mut refreshed_room) = app_state
 347                            .db
 348                            .clear_stale_room_participants(room_id, server_id)
 349                            .await
 350                            .trace_err()
 351                        {
 352                            tracing::info!(
 353                                room_id = room_id.0,
 354                                new_participant_count = refreshed_room.room.participants.len(),
 355                                "refreshed room"
 356                            );
 357                            room_updated(&refreshed_room.room, &peer);
 358                            if let Some(channel_id) = refreshed_room.channel_id {
 359                                channel_updated(
 360                                    channel_id,
 361                                    &refreshed_room.room,
 362                                    &refreshed_room.channel_members,
 363                                    &peer,
 364                                    &*pool.lock(),
 365                                );
 366                            }
 367                            contacts_to_update
 368                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 369                            contacts_to_update
 370                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 371                            canceled_calls_to_user_ids =
 372                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 373                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 374                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 375                        }
 376
 377                        {
 378                            let pool = pool.lock();
 379                            for canceled_user_id in canceled_calls_to_user_ids {
 380                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 381                                    peer.send(
 382                                        connection_id,
 383                                        proto::CallCanceled {
 384                                            room_id: room_id.to_proto(),
 385                                        },
 386                                    )
 387                                    .trace_err();
 388                                }
 389                            }
 390                        }
 391
 392                        for user_id in contacts_to_update {
 393                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 394                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 395                            if let Some((busy, contacts)) = busy.zip(contacts) {
 396                                let pool = pool.lock();
 397                                let updated_contact = contact_for_user(user_id, busy, &pool);
 398                                for contact in contacts {
 399                                    if let db::Contact::Accepted {
 400                                        user_id: contact_user_id,
 401                                        ..
 402                                    } = contact
 403                                    {
 404                                        for contact_conn_id in
 405                                            pool.user_connection_ids(contact_user_id)
 406                                        {
 407                                            peer.send(
 408                                                contact_conn_id,
 409                                                proto::UpdateContacts {
 410                                                    contacts: vec![updated_contact.clone()],
 411                                                    remove_contacts: Default::default(),
 412                                                    incoming_requests: Default::default(),
 413                                                    remove_incoming_requests: Default::default(),
 414                                                    outgoing_requests: Default::default(),
 415                                                    remove_outgoing_requests: Default::default(),
 416                                                },
 417                                            )
 418                                            .trace_err();
 419                                        }
 420                                    }
 421                                }
 422                            }
 423                        }
 424
 425                        if let Some(live_kit) = live_kit_client.as_ref() {
 426                            if delete_live_kit_room {
 427                                live_kit.delete_room(live_kit_room).await.trace_err();
 428                            }
 429                        }
 430                    }
 431                }
 432
 433                app_state
 434                    .db
 435                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 436                    .await
 437                    .trace_err();
 438            }
 439            .instrument(span),
 440        );
 441        Ok(())
 442    }
 443
 444    pub fn teardown(&self) {
 445        self.peer.teardown();
 446        self.connection_pool.lock().reset();
 447        let _ = self.teardown.send(());
 448    }
 449
 450    #[cfg(test)]
 451    pub fn reset(&self, id: ServerId) {
 452        self.teardown();
 453        *self.id.lock() = id;
 454        self.peer.reset(id.0 as u32);
 455    }
 456
 457    #[cfg(test)]
 458    pub fn id(&self) -> ServerId {
 459        *self.id.lock()
 460    }
 461
 462    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 463    where
 464        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 465        Fut: 'static + Send + Future<Output = Result<()>>,
 466        M: EnvelopedMessage,
 467    {
 468        let prev_handler = self.handlers.insert(
 469            TypeId::of::<M>(),
 470            Box::new(move |envelope, session| {
 471                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 472                let span = info_span!(
 473                    "handle message",
 474                    payload_type = envelope.payload_type_name()
 475                );
 476                span.in_scope(|| {
 477                    tracing::info!(
 478                        payload_type = envelope.payload_type_name(),
 479                        "message received"
 480                    );
 481                });
 482                let start_time = Instant::now();
 483                let future = (handler)(*envelope, session);
 484                async move {
 485                    let result = future.await;
 486                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 487                    match result {
 488                        Err(error) => {
 489                            tracing::error!(%error, ?duration_ms, "error handling message")
 490                        }
 491                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 492                    }
 493                }
 494                .instrument(span)
 495                .boxed()
 496            }),
 497        );
 498        if prev_handler.is_some() {
 499            panic!("registered a handler for the same message twice");
 500        }
 501        self
 502    }
 503
 504    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 505    where
 506        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 507        Fut: 'static + Send + Future<Output = Result<()>>,
 508        M: EnvelopedMessage,
 509    {
 510        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 511        self
 512    }
 513
 514    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 515    where
 516        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 517        Fut: Send + Future<Output = Result<()>>,
 518        M: RequestMessage,
 519    {
 520        let handler = Arc::new(handler);
 521        self.add_handler(move |envelope, session| {
 522            let receipt = envelope.receipt();
 523            let handler = handler.clone();
 524            async move {
 525                let peer = session.peer.clone();
 526                let responded = Arc::new(AtomicBool::default());
 527                let response = Response {
 528                    peer: peer.clone(),
 529                    responded: responded.clone(),
 530                    receipt,
 531                };
 532                match (handler)(envelope.payload, response, session).await {
 533                    Ok(()) => {
 534                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 535                            Ok(())
 536                        } else {
 537                            Err(anyhow!("handler did not send a response"))?
 538                        }
 539                    }
 540                    Err(error) => {
 541                        peer.respond_with_error(
 542                            receipt,
 543                            proto::Error {
 544                                message: error.to_string(),
 545                            },
 546                        )?;
 547                        Err(error)
 548                    }
 549                }
 550            }
 551        })
 552    }
 553
 554    pub fn handle_connection(
 555        self: &Arc<Self>,
 556        connection: Connection,
 557        address: String,
 558        user: User,
 559        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 560        executor: Executor,
 561    ) -> impl Future<Output = Result<()>> {
 562        let this = self.clone();
 563        let user_id = user.id;
 564        let login = user.github_login;
 565        let span = info_span!("handle connection", %user_id, %login, %address);
 566        let mut teardown = self.teardown.subscribe();
 567        async move {
 568            let (connection_id, handle_io, mut incoming_rx) = this
 569                .peer
 570                .add_connection(connection, {
 571                    let executor = executor.clone();
 572                    move |duration| executor.sleep(duration)
 573                });
 574
 575            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 576            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 577            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 578
 579            if let Some(send_connection_id) = send_connection_id.take() {
 580                let _ = send_connection_id.send(connection_id);
 581            }
 582
 583            if !user.connected_once {
 584                this.peer.send(connection_id, proto::ShowContacts {})?;
 585                this.app_state.db.set_user_connected_once(user_id, true).await?;
 586            }
 587
 588            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 589                this.app_state.db.get_contacts(user_id),
 590                this.app_state.db.get_channels_for_user(user_id),
 591                this.app_state.db.get_channel_invites_for_user(user_id),
 592            ).await?;
 593
 594            {
 595                let mut pool = this.connection_pool.lock();
 596                pool.add_connection(connection_id, user_id, user.admin);
 597                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 598                this.peer.send(connection_id, build_channels_update(
 599                    channels_for_user,
 600                    channel_invites
 601                ))?;
 602            }
 603
 604            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 605                this.peer.send(connection_id, incoming_call)?;
 606            }
 607
 608            let session = Session {
 609                user_id,
 610                connection_id,
 611                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 612                zed_environment: this.app_state.config.zed_environment.clone(),
 613                peer: this.peer.clone(),
 614                connection_pool: this.connection_pool.clone(),
 615                live_kit_client: this.app_state.live_kit_client.clone(),
 616                _executor: executor.clone()
 617            };
 618            update_user_contacts(user_id, &session).await?;
 619
 620            let handle_io = handle_io.fuse();
 621            futures::pin_mut!(handle_io);
 622
 623            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 624            // This prevents deadlocks when e.g., client A performs a request to client B and
 625            // client B performs a request to client A. If both clients stop processing further
 626            // messages until their respective request completes, they won't have a chance to
 627            // respond to the other client's request and cause a deadlock.
 628            //
 629            // This arrangement ensures we will attempt to process earlier messages first, but fall
 630            // back to processing messages arrived later in the spirit of making progress.
 631            let mut foreground_message_handlers = FuturesUnordered::new();
 632            let concurrent_handlers = Arc::new(Semaphore::new(256));
 633            loop {
 634                let next_message = async {
 635                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 636                    let message = incoming_rx.next().await;
 637                    (permit, message)
 638                }.fuse();
 639                futures::pin_mut!(next_message);
 640                futures::select_biased! {
 641                    _ = teardown.changed().fuse() => return Ok(()),
 642                    result = handle_io => {
 643                        if let Err(error) = result {
 644                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 645                        }
 646                        break;
 647                    }
 648                    _ = foreground_message_handlers.next() => {}
 649                    next_message = next_message => {
 650                        let (permit, message) = next_message;
 651                        if let Some(message) = message {
 652                            let type_name = message.payload_type_name();
 653                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 654                            let span_enter = span.enter();
 655                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 656                                let is_background = message.is_background();
 657                                let handle_message = (handler)(message, session.clone());
 658                                drop(span_enter);
 659
 660                                let handle_message = async move {
 661                                    handle_message.await;
 662                                    drop(permit);
 663                                }.instrument(span);
 664                                if is_background {
 665                                    executor.spawn_detached(handle_message);
 666                                } else {
 667                                    foreground_message_handlers.push(handle_message);
 668                                }
 669                            } else {
 670                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 671                            }
 672                        } else {
 673                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 674                            break;
 675                        }
 676                    }
 677                }
 678            }
 679
 680            drop(foreground_message_handlers);
 681            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 682            if let Err(error) = connection_lost(session, teardown, executor).await {
 683                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 684            }
 685
 686            Ok(())
 687        }.instrument(span)
 688    }
 689
 690    pub async fn invite_code_redeemed(
 691        self: &Arc<Self>,
 692        inviter_id: UserId,
 693        invitee_id: UserId,
 694    ) -> Result<()> {
 695        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 696            if let Some(code) = &user.invite_code {
 697                let pool = self.connection_pool.lock();
 698                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 699                for connection_id in pool.user_connection_ids(inviter_id) {
 700                    self.peer.send(
 701                        connection_id,
 702                        proto::UpdateContacts {
 703                            contacts: vec![invitee_contact.clone()],
 704                            ..Default::default()
 705                        },
 706                    )?;
 707                    self.peer.send(
 708                        connection_id,
 709                        proto::UpdateInviteInfo {
 710                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 711                            count: user.invite_count as u32,
 712                        },
 713                    )?;
 714                }
 715            }
 716        }
 717        Ok(())
 718    }
 719
 720    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 721        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 722            if let Some(invite_code) = &user.invite_code {
 723                let pool = self.connection_pool.lock();
 724                for connection_id in pool.user_connection_ids(user_id) {
 725                    self.peer.send(
 726                        connection_id,
 727                        proto::UpdateInviteInfo {
 728                            url: format!(
 729                                "{}{}",
 730                                self.app_state.config.invite_link_prefix, invite_code
 731                            ),
 732                            count: user.invite_count as u32,
 733                        },
 734                    )?;
 735                }
 736            }
 737        }
 738        Ok(())
 739    }
 740
 741    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 742        ServerSnapshot {
 743            connection_pool: ConnectionPoolGuard {
 744                guard: self.connection_pool.lock(),
 745                _not_send: PhantomData,
 746            },
 747            peer: &self.peer,
 748        }
 749    }
 750}
 751
 752impl<'a> Deref for ConnectionPoolGuard<'a> {
 753    type Target = ConnectionPool;
 754
 755    fn deref(&self) -> &Self::Target {
 756        &*self.guard
 757    }
 758}
 759
 760impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 761    fn deref_mut(&mut self) -> &mut Self::Target {
 762        &mut *self.guard
 763    }
 764}
 765
 766impl<'a> Drop for ConnectionPoolGuard<'a> {
 767    fn drop(&mut self) {
 768        #[cfg(test)]
 769        self.check_invariants();
 770    }
 771}
 772
 773fn broadcast<F>(
 774    sender_id: Option<ConnectionId>,
 775    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 776    mut f: F,
 777) where
 778    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 779{
 780    for receiver_id in receiver_ids {
 781        if Some(receiver_id) != sender_id {
 782            if let Err(error) = f(receiver_id) {
 783                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 784            }
 785        }
 786    }
 787}
 788
 789lazy_static! {
 790    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 791}
 792
 793pub struct ProtocolVersion(u32);
 794
 795impl Header for ProtocolVersion {
 796    fn name() -> &'static HeaderName {
 797        &ZED_PROTOCOL_VERSION
 798    }
 799
 800    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 801    where
 802        Self: Sized,
 803        I: Iterator<Item = &'i axum::http::HeaderValue>,
 804    {
 805        let version = values
 806            .next()
 807            .ok_or_else(axum::headers::Error::invalid)?
 808            .to_str()
 809            .map_err(|_| axum::headers::Error::invalid())?
 810            .parse()
 811            .map_err(|_| axum::headers::Error::invalid())?;
 812        Ok(Self(version))
 813    }
 814
 815    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 816        values.extend([self.0.to_string().parse().unwrap()]);
 817    }
 818}
 819
 820pub fn routes(server: Arc<Server>) -> Router<Body> {
 821    Router::new()
 822        .route("/rpc", get(handle_websocket_request))
 823        .layer(
 824            ServiceBuilder::new()
 825                .layer(Extension(server.app_state.clone()))
 826                .layer(middleware::from_fn(auth::validate_header)),
 827        )
 828        .route("/metrics", get(handle_metrics))
 829        .layer(Extension(server))
 830}
 831
 832pub async fn handle_websocket_request(
 833    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 834    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 835    Extension(server): Extension<Arc<Server>>,
 836    Extension(user): Extension<User>,
 837    ws: WebSocketUpgrade,
 838) -> axum::response::Response {
 839    if protocol_version != rpc::PROTOCOL_VERSION {
 840        return (
 841            StatusCode::UPGRADE_REQUIRED,
 842            "client must be upgraded".to_string(),
 843        )
 844            .into_response();
 845    }
 846    let socket_address = socket_address.to_string();
 847    ws.on_upgrade(move |socket| {
 848        use util::ResultExt;
 849        let socket = socket
 850            .map_ok(to_tungstenite_message)
 851            .err_into()
 852            .with(|message| async move { Ok(to_axum_message(message)) });
 853        let connection = Connection::new(Box::pin(socket));
 854        async move {
 855            server
 856                .handle_connection(connection, socket_address, user, None, Executor::Production)
 857                .await
 858                .log_err();
 859        }
 860    })
 861}
 862
 863pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 864    let connections = server
 865        .connection_pool
 866        .lock()
 867        .connections()
 868        .filter(|connection| !connection.admin)
 869        .count();
 870
 871    METRIC_CONNECTIONS.set(connections as _);
 872
 873    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 874    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 875
 876    let encoder = prometheus::TextEncoder::new();
 877    let metric_families = prometheus::gather();
 878    let encoded_metrics = encoder
 879        .encode_to_string(&metric_families)
 880        .map_err(|err| anyhow!("{}", err))?;
 881    Ok(encoded_metrics)
 882}
 883
 884#[instrument(err, skip(executor))]
 885async fn connection_lost(
 886    session: Session,
 887    mut teardown: watch::Receiver<()>,
 888    executor: Executor,
 889) -> Result<()> {
 890    session.peer.disconnect(session.connection_id);
 891    session
 892        .connection_pool()
 893        .await
 894        .remove_connection(session.connection_id)?;
 895
 896    session
 897        .db()
 898        .await
 899        .connection_lost(session.connection_id)
 900        .await
 901        .trace_err();
 902
 903    futures::select_biased! {
 904        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 905            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 906            leave_room_for_session(&session).await.trace_err();
 907            leave_channel_buffers_for_session(&session)
 908                .await
 909                .trace_err();
 910
 911            if !session
 912                .connection_pool()
 913                .await
 914                .is_user_online(session.user_id)
 915            {
 916                let db = session.db().await;
 917                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 918                    room_updated(&room, &session.peer);
 919                }
 920            }
 921
 922            update_user_contacts(session.user_id, &session).await?;
 923        }
 924        _ = teardown.changed().fuse() => {}
 925    }
 926
 927    Ok(())
 928}
 929
 930async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 931    response.send(proto::Ack {})?;
 932    Ok(())
 933}
 934
 935async fn create_room(
 936    _request: proto::CreateRoom,
 937    response: Response<proto::CreateRoom>,
 938    session: Session,
 939) -> Result<()> {
 940    let live_kit_room = nanoid::nanoid!(30);
 941
 942    let live_kit_connection_info = {
 943        let live_kit_room = live_kit_room.clone();
 944        let live_kit = session.live_kit_client.as_ref();
 945
 946        util::async_maybe!({
 947            let live_kit = live_kit?;
 948
 949            let token = live_kit
 950                .room_token(&live_kit_room, &session.user_id.to_string())
 951                .trace_err()?;
 952
 953            Some(proto::LiveKitConnectionInfo {
 954                server_url: live_kit.url().into(),
 955                token,
 956                can_publish: true,
 957            })
 958        })
 959    }
 960    .await;
 961
 962    let room = session
 963        .db()
 964        .await
 965        .create_room(
 966            session.user_id,
 967            session.connection_id,
 968            &live_kit_room,
 969            &session.zed_environment,
 970        )
 971        .await?;
 972
 973    response.send(proto::CreateRoomResponse {
 974        room: Some(room.clone()),
 975        live_kit_connection_info,
 976    })?;
 977
 978    update_user_contacts(session.user_id, &session).await?;
 979    Ok(())
 980}
 981
 982async fn join_room(
 983    request: proto::JoinRoom,
 984    response: Response<proto::JoinRoom>,
 985    session: Session,
 986) -> Result<()> {
 987    let room_id = RoomId::from_proto(request.id);
 988
 989    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 990
 991    if let Some(channel_id) = channel_id {
 992        return join_channel_internal(channel_id, Box::new(response), session).await;
 993    }
 994
 995    let joined_room = {
 996        let room = session
 997            .db()
 998            .await
 999            .join_room(
1000                room_id,
1001                session.user_id,
1002                session.connection_id,
1003                session.zed_environment.as_ref(),
1004            )
1005            .await?;
1006        room_updated(&room.room, &session.peer);
1007        room.into_inner()
1008    };
1009
1010    for connection_id in session
1011        .connection_pool()
1012        .await
1013        .user_connection_ids(session.user_id)
1014    {
1015        session
1016            .peer
1017            .send(
1018                connection_id,
1019                proto::CallCanceled {
1020                    room_id: room_id.to_proto(),
1021                },
1022            )
1023            .trace_err();
1024    }
1025
1026    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1027        if let Some(token) = live_kit
1028            .room_token(
1029                &joined_room.room.live_kit_room,
1030                &session.user_id.to_string(),
1031            )
1032            .trace_err()
1033        {
1034            Some(proto::LiveKitConnectionInfo {
1035                server_url: live_kit.url().into(),
1036                token,
1037                can_publish: true,
1038            })
1039        } else {
1040            None
1041        }
1042    } else {
1043        None
1044    };
1045
1046    response.send(proto::JoinRoomResponse {
1047        room: Some(joined_room.room),
1048        channel_id: None,
1049        live_kit_connection_info,
1050    })?;
1051
1052    update_user_contacts(session.user_id, &session).await?;
1053    Ok(())
1054}
1055
1056async fn rejoin_room(
1057    request: proto::RejoinRoom,
1058    response: Response<proto::RejoinRoom>,
1059    session: Session,
1060) -> Result<()> {
1061    let room;
1062    let channel_id;
1063    let channel_members;
1064    {
1065        let mut rejoined_room = session
1066            .db()
1067            .await
1068            .rejoin_room(request, session.user_id, session.connection_id)
1069            .await?;
1070
1071        response.send(proto::RejoinRoomResponse {
1072            room: Some(rejoined_room.room.clone()),
1073            reshared_projects: rejoined_room
1074                .reshared_projects
1075                .iter()
1076                .map(|project| proto::ResharedProject {
1077                    id: project.id.to_proto(),
1078                    collaborators: project
1079                        .collaborators
1080                        .iter()
1081                        .map(|collaborator| collaborator.to_proto())
1082                        .collect(),
1083                })
1084                .collect(),
1085            rejoined_projects: rejoined_room
1086                .rejoined_projects
1087                .iter()
1088                .map(|rejoined_project| proto::RejoinedProject {
1089                    id: rejoined_project.id.to_proto(),
1090                    worktrees: rejoined_project
1091                        .worktrees
1092                        .iter()
1093                        .map(|worktree| proto::WorktreeMetadata {
1094                            id: worktree.id,
1095                            root_name: worktree.root_name.clone(),
1096                            visible: worktree.visible,
1097                            abs_path: worktree.abs_path.clone(),
1098                        })
1099                        .collect(),
1100                    collaborators: rejoined_project
1101                        .collaborators
1102                        .iter()
1103                        .map(|collaborator| collaborator.to_proto())
1104                        .collect(),
1105                    language_servers: rejoined_project.language_servers.clone(),
1106                })
1107                .collect(),
1108        })?;
1109        room_updated(&rejoined_room.room, &session.peer);
1110
1111        for project in &rejoined_room.reshared_projects {
1112            for collaborator in &project.collaborators {
1113                session
1114                    .peer
1115                    .send(
1116                        collaborator.connection_id,
1117                        proto::UpdateProjectCollaborator {
1118                            project_id: project.id.to_proto(),
1119                            old_peer_id: Some(project.old_connection_id.into()),
1120                            new_peer_id: Some(session.connection_id.into()),
1121                        },
1122                    )
1123                    .trace_err();
1124            }
1125
1126            broadcast(
1127                Some(session.connection_id),
1128                project
1129                    .collaborators
1130                    .iter()
1131                    .map(|collaborator| collaborator.connection_id),
1132                |connection_id| {
1133                    session.peer.forward_send(
1134                        session.connection_id,
1135                        connection_id,
1136                        proto::UpdateProject {
1137                            project_id: project.id.to_proto(),
1138                            worktrees: project.worktrees.clone(),
1139                        },
1140                    )
1141                },
1142            );
1143        }
1144
1145        for project in &rejoined_room.rejoined_projects {
1146            for collaborator in &project.collaborators {
1147                session
1148                    .peer
1149                    .send(
1150                        collaborator.connection_id,
1151                        proto::UpdateProjectCollaborator {
1152                            project_id: project.id.to_proto(),
1153                            old_peer_id: Some(project.old_connection_id.into()),
1154                            new_peer_id: Some(session.connection_id.into()),
1155                        },
1156                    )
1157                    .trace_err();
1158            }
1159        }
1160
1161        for project in &mut rejoined_room.rejoined_projects {
1162            for worktree in mem::take(&mut project.worktrees) {
1163                #[cfg(any(test, feature = "test-support"))]
1164                const MAX_CHUNK_SIZE: usize = 2;
1165                #[cfg(not(any(test, feature = "test-support")))]
1166                const MAX_CHUNK_SIZE: usize = 256;
1167
1168                // Stream this worktree's entries.
1169                let message = proto::UpdateWorktree {
1170                    project_id: project.id.to_proto(),
1171                    worktree_id: worktree.id,
1172                    abs_path: worktree.abs_path.clone(),
1173                    root_name: worktree.root_name,
1174                    updated_entries: worktree.updated_entries,
1175                    removed_entries: worktree.removed_entries,
1176                    scan_id: worktree.scan_id,
1177                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1178                    updated_repositories: worktree.updated_repositories,
1179                    removed_repositories: worktree.removed_repositories,
1180                };
1181                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1182                    session.peer.send(session.connection_id, update.clone())?;
1183                }
1184
1185                // Stream this worktree's diagnostics.
1186                for summary in worktree.diagnostic_summaries {
1187                    session.peer.send(
1188                        session.connection_id,
1189                        proto::UpdateDiagnosticSummary {
1190                            project_id: project.id.to_proto(),
1191                            worktree_id: worktree.id,
1192                            summary: Some(summary),
1193                        },
1194                    )?;
1195                }
1196
1197                for settings_file in worktree.settings_files {
1198                    session.peer.send(
1199                        session.connection_id,
1200                        proto::UpdateWorktreeSettings {
1201                            project_id: project.id.to_proto(),
1202                            worktree_id: worktree.id,
1203                            path: settings_file.path,
1204                            content: Some(settings_file.content),
1205                        },
1206                    )?;
1207                }
1208            }
1209
1210            for language_server in &project.language_servers {
1211                session.peer.send(
1212                    session.connection_id,
1213                    proto::UpdateLanguageServer {
1214                        project_id: project.id.to_proto(),
1215                        language_server_id: language_server.id,
1216                        variant: Some(
1217                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1218                                proto::LspDiskBasedDiagnosticsUpdated {},
1219                            ),
1220                        ),
1221                    },
1222                )?;
1223            }
1224        }
1225
1226        let rejoined_room = rejoined_room.into_inner();
1227
1228        room = rejoined_room.room;
1229        channel_id = rejoined_room.channel_id;
1230        channel_members = rejoined_room.channel_members;
1231    }
1232
1233    if let Some(channel_id) = channel_id {
1234        channel_updated(
1235            channel_id,
1236            &room,
1237            &channel_members,
1238            &session.peer,
1239            &*session.connection_pool().await,
1240        );
1241    }
1242
1243    update_user_contacts(session.user_id, &session).await?;
1244    Ok(())
1245}
1246
1247async fn leave_room(
1248    _: proto::LeaveRoom,
1249    response: Response<proto::LeaveRoom>,
1250    session: Session,
1251) -> Result<()> {
1252    leave_room_for_session(&session).await?;
1253    response.send(proto::Ack {})?;
1254    Ok(())
1255}
1256
1257async fn call(
1258    request: proto::Call,
1259    response: Response<proto::Call>,
1260    session: Session,
1261) -> Result<()> {
1262    let room_id = RoomId::from_proto(request.room_id);
1263    let calling_user_id = session.user_id;
1264    let calling_connection_id = session.connection_id;
1265    let called_user_id = UserId::from_proto(request.called_user_id);
1266    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1267    if !session
1268        .db()
1269        .await
1270        .has_contact(calling_user_id, called_user_id)
1271        .await?
1272    {
1273        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1274    }
1275
1276    let incoming_call = {
1277        let (room, incoming_call) = &mut *session
1278            .db()
1279            .await
1280            .call(
1281                room_id,
1282                calling_user_id,
1283                calling_connection_id,
1284                called_user_id,
1285                initial_project_id,
1286            )
1287            .await?;
1288        room_updated(&room, &session.peer);
1289        mem::take(incoming_call)
1290    };
1291    update_user_contacts(called_user_id, &session).await?;
1292
1293    let mut calls = session
1294        .connection_pool()
1295        .await
1296        .user_connection_ids(called_user_id)
1297        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1298        .collect::<FuturesUnordered<_>>();
1299
1300    while let Some(call_response) = calls.next().await {
1301        match call_response.as_ref() {
1302            Ok(_) => {
1303                response.send(proto::Ack {})?;
1304                return Ok(());
1305            }
1306            Err(_) => {
1307                call_response.trace_err();
1308            }
1309        }
1310    }
1311
1312    {
1313        let room = session
1314            .db()
1315            .await
1316            .call_failed(room_id, called_user_id)
1317            .await?;
1318        room_updated(&room, &session.peer);
1319    }
1320    update_user_contacts(called_user_id, &session).await?;
1321
1322    Err(anyhow!("failed to ring user"))?
1323}
1324
1325async fn cancel_call(
1326    request: proto::CancelCall,
1327    response: Response<proto::CancelCall>,
1328    session: Session,
1329) -> Result<()> {
1330    let called_user_id = UserId::from_proto(request.called_user_id);
1331    let room_id = RoomId::from_proto(request.room_id);
1332    {
1333        let room = session
1334            .db()
1335            .await
1336            .cancel_call(room_id, session.connection_id, called_user_id)
1337            .await?;
1338        room_updated(&room, &session.peer);
1339    }
1340
1341    for connection_id in session
1342        .connection_pool()
1343        .await
1344        .user_connection_ids(called_user_id)
1345    {
1346        session
1347            .peer
1348            .send(
1349                connection_id,
1350                proto::CallCanceled {
1351                    room_id: room_id.to_proto(),
1352                },
1353            )
1354            .trace_err();
1355    }
1356    response.send(proto::Ack {})?;
1357
1358    update_user_contacts(called_user_id, &session).await?;
1359    Ok(())
1360}
1361
1362async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1363    let room_id = RoomId::from_proto(message.room_id);
1364    {
1365        let room = session
1366            .db()
1367            .await
1368            .decline_call(Some(room_id), session.user_id)
1369            .await?
1370            .ok_or_else(|| anyhow!("failed to decline call"))?;
1371        room_updated(&room, &session.peer);
1372    }
1373
1374    for connection_id in session
1375        .connection_pool()
1376        .await
1377        .user_connection_ids(session.user_id)
1378    {
1379        session
1380            .peer
1381            .send(
1382                connection_id,
1383                proto::CallCanceled {
1384                    room_id: room_id.to_proto(),
1385                },
1386            )
1387            .trace_err();
1388    }
1389    update_user_contacts(session.user_id, &session).await?;
1390    Ok(())
1391}
1392
1393async fn update_participant_location(
1394    request: proto::UpdateParticipantLocation,
1395    response: Response<proto::UpdateParticipantLocation>,
1396    session: Session,
1397) -> Result<()> {
1398    let room_id = RoomId::from_proto(request.room_id);
1399    let location = request
1400        .location
1401        .ok_or_else(|| anyhow!("invalid location"))?;
1402
1403    let db = session.db().await;
1404    let room = db
1405        .update_room_participant_location(room_id, session.connection_id, location)
1406        .await?;
1407
1408    room_updated(&room, &session.peer);
1409    response.send(proto::Ack {})?;
1410    Ok(())
1411}
1412
1413async fn share_project(
1414    request: proto::ShareProject,
1415    response: Response<proto::ShareProject>,
1416    session: Session,
1417) -> Result<()> {
1418    let (project_id, room) = &*session
1419        .db()
1420        .await
1421        .share_project(
1422            RoomId::from_proto(request.room_id),
1423            session.connection_id,
1424            &request.worktrees,
1425        )
1426        .await?;
1427    response.send(proto::ShareProjectResponse {
1428        project_id: project_id.to_proto(),
1429    })?;
1430    room_updated(&room, &session.peer);
1431
1432    Ok(())
1433}
1434
1435async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1436    let project_id = ProjectId::from_proto(message.project_id);
1437
1438    let (room, guest_connection_ids) = &*session
1439        .db()
1440        .await
1441        .unshare_project(project_id, session.connection_id)
1442        .await?;
1443
1444    broadcast(
1445        Some(session.connection_id),
1446        guest_connection_ids.iter().copied(),
1447        |conn_id| session.peer.send(conn_id, message.clone()),
1448    );
1449    room_updated(&room, &session.peer);
1450
1451    Ok(())
1452}
1453
1454async fn join_project(
1455    request: proto::JoinProject,
1456    response: Response<proto::JoinProject>,
1457    session: Session,
1458) -> Result<()> {
1459    let project_id = ProjectId::from_proto(request.project_id);
1460    let guest_user_id = session.user_id;
1461
1462    tracing::info!(%project_id, "join project");
1463
1464    let (project, replica_id) = &mut *session
1465        .db()
1466        .await
1467        .join_project(project_id, session.connection_id)
1468        .await?;
1469
1470    let collaborators = project
1471        .collaborators
1472        .iter()
1473        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1474        .map(|collaborator| collaborator.to_proto())
1475        .collect::<Vec<_>>();
1476
1477    let worktrees = project
1478        .worktrees
1479        .iter()
1480        .map(|(id, worktree)| proto::WorktreeMetadata {
1481            id: *id,
1482            root_name: worktree.root_name.clone(),
1483            visible: worktree.visible,
1484            abs_path: worktree.abs_path.clone(),
1485        })
1486        .collect::<Vec<_>>();
1487
1488    for collaborator in &collaborators {
1489        session
1490            .peer
1491            .send(
1492                collaborator.peer_id.unwrap().into(),
1493                proto::AddProjectCollaborator {
1494                    project_id: project_id.to_proto(),
1495                    collaborator: Some(proto::Collaborator {
1496                        peer_id: Some(session.connection_id.into()),
1497                        replica_id: replica_id.0 as u32,
1498                        user_id: guest_user_id.to_proto(),
1499                    }),
1500                },
1501            )
1502            .trace_err();
1503    }
1504
1505    // First, we send the metadata associated with each worktree.
1506    response.send(proto::JoinProjectResponse {
1507        worktrees: worktrees.clone(),
1508        replica_id: replica_id.0 as u32,
1509        collaborators: collaborators.clone(),
1510        language_servers: project.language_servers.clone(),
1511    })?;
1512
1513    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1514        #[cfg(any(test, feature = "test-support"))]
1515        const MAX_CHUNK_SIZE: usize = 2;
1516        #[cfg(not(any(test, feature = "test-support")))]
1517        const MAX_CHUNK_SIZE: usize = 256;
1518
1519        // Stream this worktree's entries.
1520        let message = proto::UpdateWorktree {
1521            project_id: project_id.to_proto(),
1522            worktree_id,
1523            abs_path: worktree.abs_path.clone(),
1524            root_name: worktree.root_name,
1525            updated_entries: worktree.entries,
1526            removed_entries: Default::default(),
1527            scan_id: worktree.scan_id,
1528            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1529            updated_repositories: worktree.repository_entries.into_values().collect(),
1530            removed_repositories: Default::default(),
1531        };
1532        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1533            session.peer.send(session.connection_id, update.clone())?;
1534        }
1535
1536        // Stream this worktree's diagnostics.
1537        for summary in worktree.diagnostic_summaries {
1538            session.peer.send(
1539                session.connection_id,
1540                proto::UpdateDiagnosticSummary {
1541                    project_id: project_id.to_proto(),
1542                    worktree_id: worktree.id,
1543                    summary: Some(summary),
1544                },
1545            )?;
1546        }
1547
1548        for settings_file in worktree.settings_files {
1549            session.peer.send(
1550                session.connection_id,
1551                proto::UpdateWorktreeSettings {
1552                    project_id: project_id.to_proto(),
1553                    worktree_id: worktree.id,
1554                    path: settings_file.path,
1555                    content: Some(settings_file.content),
1556                },
1557            )?;
1558        }
1559    }
1560
1561    for language_server in &project.language_servers {
1562        session.peer.send(
1563            session.connection_id,
1564            proto::UpdateLanguageServer {
1565                project_id: project_id.to_proto(),
1566                language_server_id: language_server.id,
1567                variant: Some(
1568                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1569                        proto::LspDiskBasedDiagnosticsUpdated {},
1570                    ),
1571                ),
1572            },
1573        )?;
1574    }
1575
1576    Ok(())
1577}
1578
1579async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1580    let sender_id = session.connection_id;
1581    let project_id = ProjectId::from_proto(request.project_id);
1582
1583    let (room, project) = &*session
1584        .db()
1585        .await
1586        .leave_project(project_id, sender_id)
1587        .await?;
1588    tracing::info!(
1589        %project_id,
1590        host_user_id = %project.host_user_id,
1591        host_connection_id = %project.host_connection_id,
1592        "leave project"
1593    );
1594
1595    project_left(&project, &session);
1596    room_updated(&room, &session.peer);
1597
1598    Ok(())
1599}
1600
1601async fn update_project(
1602    request: proto::UpdateProject,
1603    response: Response<proto::UpdateProject>,
1604    session: Session,
1605) -> Result<()> {
1606    let project_id = ProjectId::from_proto(request.project_id);
1607    let (room, guest_connection_ids) = &*session
1608        .db()
1609        .await
1610        .update_project(project_id, session.connection_id, &request.worktrees)
1611        .await?;
1612    broadcast(
1613        Some(session.connection_id),
1614        guest_connection_ids.iter().copied(),
1615        |connection_id| {
1616            session
1617                .peer
1618                .forward_send(session.connection_id, connection_id, request.clone())
1619        },
1620    );
1621    room_updated(&room, &session.peer);
1622    response.send(proto::Ack {})?;
1623
1624    Ok(())
1625}
1626
1627async fn update_worktree(
1628    request: proto::UpdateWorktree,
1629    response: Response<proto::UpdateWorktree>,
1630    session: Session,
1631) -> Result<()> {
1632    let guest_connection_ids = session
1633        .db()
1634        .await
1635        .update_worktree(&request, session.connection_id)
1636        .await?;
1637
1638    broadcast(
1639        Some(session.connection_id),
1640        guest_connection_ids.iter().copied(),
1641        |connection_id| {
1642            session
1643                .peer
1644                .forward_send(session.connection_id, connection_id, request.clone())
1645        },
1646    );
1647    response.send(proto::Ack {})?;
1648    Ok(())
1649}
1650
1651async fn update_diagnostic_summary(
1652    message: proto::UpdateDiagnosticSummary,
1653    session: Session,
1654) -> Result<()> {
1655    let guest_connection_ids = session
1656        .db()
1657        .await
1658        .update_diagnostic_summary(&message, session.connection_id)
1659        .await?;
1660
1661    broadcast(
1662        Some(session.connection_id),
1663        guest_connection_ids.iter().copied(),
1664        |connection_id| {
1665            session
1666                .peer
1667                .forward_send(session.connection_id, connection_id, message.clone())
1668        },
1669    );
1670
1671    Ok(())
1672}
1673
1674async fn update_worktree_settings(
1675    message: proto::UpdateWorktreeSettings,
1676    session: Session,
1677) -> Result<()> {
1678    let guest_connection_ids = session
1679        .db()
1680        .await
1681        .update_worktree_settings(&message, session.connection_id)
1682        .await?;
1683
1684    broadcast(
1685        Some(session.connection_id),
1686        guest_connection_ids.iter().copied(),
1687        |connection_id| {
1688            session
1689                .peer
1690                .forward_send(session.connection_id, connection_id, message.clone())
1691        },
1692    );
1693
1694    Ok(())
1695}
1696
1697async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1698    broadcast_project_message(request.project_id, request, session).await
1699}
1700
1701async fn start_language_server(
1702    request: proto::StartLanguageServer,
1703    session: Session,
1704) -> Result<()> {
1705    let guest_connection_ids = session
1706        .db()
1707        .await
1708        .start_language_server(&request, session.connection_id)
1709        .await?;
1710
1711    broadcast(
1712        Some(session.connection_id),
1713        guest_connection_ids.iter().copied(),
1714        |connection_id| {
1715            session
1716                .peer
1717                .forward_send(session.connection_id, connection_id, request.clone())
1718        },
1719    );
1720    Ok(())
1721}
1722
1723async fn update_language_server(
1724    request: proto::UpdateLanguageServer,
1725    session: Session,
1726) -> Result<()> {
1727    let project_id = ProjectId::from_proto(request.project_id);
1728    let project_connection_ids = session
1729        .db()
1730        .await
1731        .project_connection_ids(project_id, session.connection_id)
1732        .await?;
1733    broadcast(
1734        Some(session.connection_id),
1735        project_connection_ids.iter().copied(),
1736        |connection_id| {
1737            session
1738                .peer
1739                .forward_send(session.connection_id, connection_id, request.clone())
1740        },
1741    );
1742    Ok(())
1743}
1744
1745async fn forward_project_request<T>(
1746    request: T,
1747    response: Response<T>,
1748    session: Session,
1749) -> Result<()>
1750where
1751    T: EntityMessage + RequestMessage,
1752{
1753    let project_id = ProjectId::from_proto(request.remote_entity_id());
1754    let host_connection_id = {
1755        let collaborators = session
1756            .db()
1757            .await
1758            .project_collaborators(project_id, session.connection_id)
1759            .await?;
1760        collaborators
1761            .iter()
1762            .find(|collaborator| collaborator.is_host)
1763            .ok_or_else(|| anyhow!("host not found"))?
1764            .connection_id
1765    };
1766
1767    let payload = session
1768        .peer
1769        .forward_request(session.connection_id, host_connection_id, request)
1770        .await?;
1771
1772    response.send(payload)?;
1773    Ok(())
1774}
1775
1776async fn create_buffer_for_peer(
1777    request: proto::CreateBufferForPeer,
1778    session: Session,
1779) -> Result<()> {
1780    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1781    session
1782        .peer
1783        .forward_send(session.connection_id, peer_id.into(), request)?;
1784    Ok(())
1785}
1786
1787async fn update_buffer(
1788    request: proto::UpdateBuffer,
1789    response: Response<proto::UpdateBuffer>,
1790    session: Session,
1791) -> Result<()> {
1792    let project_id = ProjectId::from_proto(request.project_id);
1793    let mut guest_connection_ids;
1794    let mut host_connection_id = None;
1795    {
1796        let collaborators = session
1797            .db()
1798            .await
1799            .project_collaborators(project_id, session.connection_id)
1800            .await?;
1801        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1802        for collaborator in collaborators.iter() {
1803            if collaborator.is_host {
1804                host_connection_id = Some(collaborator.connection_id);
1805            } else {
1806                guest_connection_ids.push(collaborator.connection_id);
1807            }
1808        }
1809    }
1810    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1811
1812    broadcast(
1813        Some(session.connection_id),
1814        guest_connection_ids,
1815        |connection_id| {
1816            session
1817                .peer
1818                .forward_send(session.connection_id, connection_id, request.clone())
1819        },
1820    );
1821    if host_connection_id != session.connection_id {
1822        session
1823            .peer
1824            .forward_request(session.connection_id, host_connection_id, request.clone())
1825            .await?;
1826    }
1827
1828    response.send(proto::Ack {})?;
1829    Ok(())
1830}
1831
1832async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1833    let project_id = ProjectId::from_proto(request.project_id);
1834    let project_connection_ids = session
1835        .db()
1836        .await
1837        .project_connection_ids(project_id, session.connection_id)
1838        .await?;
1839
1840    broadcast(
1841        Some(session.connection_id),
1842        project_connection_ids.iter().copied(),
1843        |connection_id| {
1844            session
1845                .peer
1846                .forward_send(session.connection_id, connection_id, request.clone())
1847        },
1848    );
1849    Ok(())
1850}
1851
1852async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1853    let project_id = ProjectId::from_proto(request.project_id);
1854    let project_connection_ids = session
1855        .db()
1856        .await
1857        .project_connection_ids(project_id, session.connection_id)
1858        .await?;
1859    broadcast(
1860        Some(session.connection_id),
1861        project_connection_ids.iter().copied(),
1862        |connection_id| {
1863            session
1864                .peer
1865                .forward_send(session.connection_id, connection_id, request.clone())
1866        },
1867    );
1868    Ok(())
1869}
1870
1871async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1872    broadcast_project_message(request.project_id, request, session).await
1873}
1874
1875async fn broadcast_project_message<T: EnvelopedMessage>(
1876    project_id: u64,
1877    request: T,
1878    session: Session,
1879) -> Result<()> {
1880    let project_id = ProjectId::from_proto(project_id);
1881    let project_connection_ids = session
1882        .db()
1883        .await
1884        .project_connection_ids(project_id, session.connection_id)
1885        .await?;
1886    broadcast(
1887        Some(session.connection_id),
1888        project_connection_ids.iter().copied(),
1889        |connection_id| {
1890            session
1891                .peer
1892                .forward_send(session.connection_id, connection_id, request.clone())
1893        },
1894    );
1895    Ok(())
1896}
1897
1898async fn follow(
1899    request: proto::Follow,
1900    response: Response<proto::Follow>,
1901    session: Session,
1902) -> Result<()> {
1903    let room_id = RoomId::from_proto(request.room_id);
1904    let project_id = request.project_id.map(ProjectId::from_proto);
1905    let leader_id = request
1906        .leader_id
1907        .ok_or_else(|| anyhow!("invalid leader id"))?
1908        .into();
1909    let follower_id = session.connection_id;
1910
1911    session
1912        .db()
1913        .await
1914        .check_room_participants(room_id, leader_id, session.connection_id)
1915        .await?;
1916
1917    let response_payload = session
1918        .peer
1919        .forward_request(session.connection_id, leader_id, request)
1920        .await?;
1921    response.send(response_payload)?;
1922
1923    if let Some(project_id) = project_id {
1924        let room = session
1925            .db()
1926            .await
1927            .follow(room_id, project_id, leader_id, follower_id)
1928            .await?;
1929        room_updated(&room, &session.peer);
1930    }
1931
1932    Ok(())
1933}
1934
1935async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1936    let room_id = RoomId::from_proto(request.room_id);
1937    let project_id = request.project_id.map(ProjectId::from_proto);
1938    let leader_id = request
1939        .leader_id
1940        .ok_or_else(|| anyhow!("invalid leader id"))?
1941        .into();
1942    let follower_id = session.connection_id;
1943
1944    session
1945        .db()
1946        .await
1947        .check_room_participants(room_id, leader_id, session.connection_id)
1948        .await?;
1949
1950    session
1951        .peer
1952        .forward_send(session.connection_id, leader_id, request)?;
1953
1954    if let Some(project_id) = project_id {
1955        let room = session
1956            .db()
1957            .await
1958            .unfollow(room_id, project_id, leader_id, follower_id)
1959            .await?;
1960        room_updated(&room, &session.peer);
1961    }
1962
1963    Ok(())
1964}
1965
1966async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1967    let room_id = RoomId::from_proto(request.room_id);
1968    let database = session.db.lock().await;
1969
1970    let connection_ids = if let Some(project_id) = request.project_id {
1971        let project_id = ProjectId::from_proto(project_id);
1972        database
1973            .project_connection_ids(project_id, session.connection_id)
1974            .await?
1975    } else {
1976        database
1977            .room_connection_ids(room_id, session.connection_id)
1978            .await?
1979    };
1980
1981    // For now, don't send view update messages back to that view's current leader.
1982    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1983        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1984        _ => None,
1985    });
1986
1987    for follower_peer_id in request.follower_ids.iter().copied() {
1988        let follower_connection_id = follower_peer_id.into();
1989        if Some(follower_peer_id) != connection_id_to_omit
1990            && connection_ids.contains(&follower_connection_id)
1991        {
1992            session.peer.forward_send(
1993                session.connection_id,
1994                follower_connection_id,
1995                request.clone(),
1996            )?;
1997        }
1998    }
1999    Ok(())
2000}
2001
2002async fn get_users(
2003    request: proto::GetUsers,
2004    response: Response<proto::GetUsers>,
2005    session: Session,
2006) -> Result<()> {
2007    let user_ids = request
2008        .user_ids
2009        .into_iter()
2010        .map(UserId::from_proto)
2011        .collect();
2012    let users = session
2013        .db()
2014        .await
2015        .get_users_by_ids(user_ids)
2016        .await?
2017        .into_iter()
2018        .map(|user| proto::User {
2019            id: user.id.to_proto(),
2020            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2021            github_login: user.github_login,
2022        })
2023        .collect();
2024    response.send(proto::UsersResponse { users })?;
2025    Ok(())
2026}
2027
2028async fn fuzzy_search_users(
2029    request: proto::FuzzySearchUsers,
2030    response: Response<proto::FuzzySearchUsers>,
2031    session: Session,
2032) -> Result<()> {
2033    let query = request.query;
2034    let users = match query.len() {
2035        0 => vec![],
2036        1 | 2 => session
2037            .db()
2038            .await
2039            .get_user_by_github_login(&query)
2040            .await?
2041            .into_iter()
2042            .collect(),
2043        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2044    };
2045    let users = users
2046        .into_iter()
2047        .filter(|user| user.id != session.user_id)
2048        .map(|user| proto::User {
2049            id: user.id.to_proto(),
2050            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2051            github_login: user.github_login,
2052        })
2053        .collect();
2054    response.send(proto::UsersResponse { users })?;
2055    Ok(())
2056}
2057
2058async fn request_contact(
2059    request: proto::RequestContact,
2060    response: Response<proto::RequestContact>,
2061    session: Session,
2062) -> Result<()> {
2063    let requester_id = session.user_id;
2064    let responder_id = UserId::from_proto(request.responder_id);
2065    if requester_id == responder_id {
2066        return Err(anyhow!("cannot add yourself as a contact"))?;
2067    }
2068
2069    let notifications = session
2070        .db()
2071        .await
2072        .send_contact_request(requester_id, responder_id)
2073        .await?;
2074
2075    // Update outgoing contact requests of requester
2076    let mut update = proto::UpdateContacts::default();
2077    update.outgoing_requests.push(responder_id.to_proto());
2078    for connection_id in session
2079        .connection_pool()
2080        .await
2081        .user_connection_ids(requester_id)
2082    {
2083        session.peer.send(connection_id, update.clone())?;
2084    }
2085
2086    // Update incoming contact requests of responder
2087    let mut update = proto::UpdateContacts::default();
2088    update
2089        .incoming_requests
2090        .push(proto::IncomingContactRequest {
2091            requester_id: requester_id.to_proto(),
2092        });
2093    let connection_pool = session.connection_pool().await;
2094    for connection_id in connection_pool.user_connection_ids(responder_id) {
2095        session.peer.send(connection_id, update.clone())?;
2096    }
2097
2098    send_notifications(&*connection_pool, &session.peer, notifications);
2099
2100    response.send(proto::Ack {})?;
2101    Ok(())
2102}
2103
2104async fn respond_to_contact_request(
2105    request: proto::RespondToContactRequest,
2106    response: Response<proto::RespondToContactRequest>,
2107    session: Session,
2108) -> Result<()> {
2109    let responder_id = session.user_id;
2110    let requester_id = UserId::from_proto(request.requester_id);
2111    let db = session.db().await;
2112    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2113        db.dismiss_contact_notification(responder_id, requester_id)
2114            .await?;
2115    } else {
2116        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2117
2118        let notifications = db
2119            .respond_to_contact_request(responder_id, requester_id, accept)
2120            .await?;
2121        let requester_busy = db.is_user_busy(requester_id).await?;
2122        let responder_busy = db.is_user_busy(responder_id).await?;
2123
2124        let pool = session.connection_pool().await;
2125        // Update responder with new contact
2126        let mut update = proto::UpdateContacts::default();
2127        if accept {
2128            update
2129                .contacts
2130                .push(contact_for_user(requester_id, requester_busy, &pool));
2131        }
2132        update
2133            .remove_incoming_requests
2134            .push(requester_id.to_proto());
2135        for connection_id in pool.user_connection_ids(responder_id) {
2136            session.peer.send(connection_id, update.clone())?;
2137        }
2138
2139        // Update requester with new contact
2140        let mut update = proto::UpdateContacts::default();
2141        if accept {
2142            update
2143                .contacts
2144                .push(contact_for_user(responder_id, responder_busy, &pool));
2145        }
2146        update
2147            .remove_outgoing_requests
2148            .push(responder_id.to_proto());
2149
2150        for connection_id in pool.user_connection_ids(requester_id) {
2151            session.peer.send(connection_id, update.clone())?;
2152        }
2153
2154        send_notifications(&*pool, &session.peer, notifications);
2155    }
2156
2157    response.send(proto::Ack {})?;
2158    Ok(())
2159}
2160
2161async fn remove_contact(
2162    request: proto::RemoveContact,
2163    response: Response<proto::RemoveContact>,
2164    session: Session,
2165) -> Result<()> {
2166    let requester_id = session.user_id;
2167    let responder_id = UserId::from_proto(request.user_id);
2168    let db = session.db().await;
2169    let (contact_accepted, deleted_notification_id) =
2170        db.remove_contact(requester_id, responder_id).await?;
2171
2172    let pool = session.connection_pool().await;
2173    // Update outgoing contact requests of requester
2174    let mut update = proto::UpdateContacts::default();
2175    if contact_accepted {
2176        update.remove_contacts.push(responder_id.to_proto());
2177    } else {
2178        update
2179            .remove_outgoing_requests
2180            .push(responder_id.to_proto());
2181    }
2182    for connection_id in pool.user_connection_ids(requester_id) {
2183        session.peer.send(connection_id, update.clone())?;
2184    }
2185
2186    // Update incoming contact requests of responder
2187    let mut update = proto::UpdateContacts::default();
2188    if contact_accepted {
2189        update.remove_contacts.push(requester_id.to_proto());
2190    } else {
2191        update
2192            .remove_incoming_requests
2193            .push(requester_id.to_proto());
2194    }
2195    for connection_id in pool.user_connection_ids(responder_id) {
2196        session.peer.send(connection_id, update.clone())?;
2197        if let Some(notification_id) = deleted_notification_id {
2198            session.peer.send(
2199                connection_id,
2200                proto::DeleteNotification {
2201                    notification_id: notification_id.to_proto(),
2202                },
2203            )?;
2204        }
2205    }
2206
2207    response.send(proto::Ack {})?;
2208    Ok(())
2209}
2210
2211async fn create_channel(
2212    request: proto::CreateChannel,
2213    response: Response<proto::CreateChannel>,
2214    session: Session,
2215) -> Result<()> {
2216    let db = session.db().await;
2217
2218    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2219    let CreateChannelResult {
2220        channel,
2221        participants_to_update,
2222    } = db
2223        .create_channel(&request.name, parent_id, session.user_id)
2224        .await?;
2225
2226    response.send(proto::CreateChannelResponse {
2227        channel: Some(channel.to_proto()),
2228        parent_id: request.parent_id,
2229    })?;
2230
2231    let connection_pool = session.connection_pool().await;
2232    for (user_id, channels) in participants_to_update {
2233        let update = build_channels_update(channels, vec![]);
2234        for connection_id in connection_pool.user_connection_ids(user_id) {
2235            if user_id == session.user_id {
2236                continue;
2237            }
2238            session.peer.send(connection_id, update.clone())?;
2239        }
2240    }
2241
2242    Ok(())
2243}
2244
2245async fn delete_channel(
2246    request: proto::DeleteChannel,
2247    response: Response<proto::DeleteChannel>,
2248    session: Session,
2249) -> Result<()> {
2250    let db = session.db().await;
2251
2252    let channel_id = request.channel_id;
2253    let (removed_channels, member_ids) = db
2254        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2255        .await?;
2256    response.send(proto::Ack {})?;
2257
2258    // Notify members of removed channels
2259    let mut update = proto::UpdateChannels::default();
2260    update
2261        .delete_channels
2262        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2263
2264    let connection_pool = session.connection_pool().await;
2265    for member_id in member_ids {
2266        for connection_id in connection_pool.user_connection_ids(member_id) {
2267            session.peer.send(connection_id, update.clone())?;
2268        }
2269    }
2270
2271    Ok(())
2272}
2273
2274async fn invite_channel_member(
2275    request: proto::InviteChannelMember,
2276    response: Response<proto::InviteChannelMember>,
2277    session: Session,
2278) -> Result<()> {
2279    let db = session.db().await;
2280    let channel_id = ChannelId::from_proto(request.channel_id);
2281    let invitee_id = UserId::from_proto(request.user_id);
2282    let InviteMemberResult {
2283        channel,
2284        notifications,
2285    } = db
2286        .invite_channel_member(
2287            channel_id,
2288            invitee_id,
2289            session.user_id,
2290            request.role().into(),
2291        )
2292        .await?;
2293
2294    let update = proto::UpdateChannels {
2295        channel_invitations: vec![channel.to_proto()],
2296        ..Default::default()
2297    };
2298
2299    let connection_pool = session.connection_pool().await;
2300    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2301        session.peer.send(connection_id, update.clone())?;
2302    }
2303
2304    send_notifications(&*connection_pool, &session.peer, notifications);
2305
2306    response.send(proto::Ack {})?;
2307    Ok(())
2308}
2309
2310async fn remove_channel_member(
2311    request: proto::RemoveChannelMember,
2312    response: Response<proto::RemoveChannelMember>,
2313    session: Session,
2314) -> Result<()> {
2315    let db = session.db().await;
2316    let channel_id = ChannelId::from_proto(request.channel_id);
2317    let member_id = UserId::from_proto(request.user_id);
2318
2319    let RemoveChannelMemberResult {
2320        membership_update,
2321        notification_id,
2322    } = db
2323        .remove_channel_member(channel_id, member_id, session.user_id)
2324        .await?;
2325
2326    let connection_pool = &session.connection_pool().await;
2327    notify_membership_updated(
2328        &connection_pool,
2329        membership_update,
2330        member_id,
2331        &session.peer,
2332    );
2333    for connection_id in connection_pool.user_connection_ids(member_id) {
2334        if let Some(notification_id) = notification_id {
2335            session
2336                .peer
2337                .send(
2338                    connection_id,
2339                    proto::DeleteNotification {
2340                        notification_id: notification_id.to_proto(),
2341                    },
2342                )
2343                .trace_err();
2344        }
2345    }
2346
2347    response.send(proto::Ack {})?;
2348    Ok(())
2349}
2350
2351async fn set_channel_visibility(
2352    request: proto::SetChannelVisibility,
2353    response: Response<proto::SetChannelVisibility>,
2354    session: Session,
2355) -> Result<()> {
2356    let db = session.db().await;
2357    let channel_id = ChannelId::from_proto(request.channel_id);
2358    let visibility = request.visibility().into();
2359
2360    let SetChannelVisibilityResult {
2361        participants_to_update,
2362        participants_to_remove,
2363        channels_to_remove,
2364    } = db
2365        .set_channel_visibility(channel_id, visibility, session.user_id)
2366        .await?;
2367
2368    let connection_pool = session.connection_pool().await;
2369    for (user_id, channels) in participants_to_update {
2370        let update = build_channels_update(channels, vec![]);
2371        for connection_id in connection_pool.user_connection_ids(user_id) {
2372            session.peer.send(connection_id, update.clone())?;
2373        }
2374    }
2375    for user_id in participants_to_remove {
2376        let update = proto::UpdateChannels {
2377            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2378            ..Default::default()
2379        };
2380        for connection_id in connection_pool.user_connection_ids(user_id) {
2381            session.peer.send(connection_id, update.clone())?;
2382        }
2383    }
2384
2385    response.send(proto::Ack {})?;
2386    Ok(())
2387}
2388
2389async fn set_channel_member_role(
2390    request: proto::SetChannelMemberRole,
2391    response: Response<proto::SetChannelMemberRole>,
2392    session: Session,
2393) -> Result<()> {
2394    let db = session.db().await;
2395    let channel_id = ChannelId::from_proto(request.channel_id);
2396    let member_id = UserId::from_proto(request.user_id);
2397    let result = db
2398        .set_channel_member_role(
2399            channel_id,
2400            session.user_id,
2401            member_id,
2402            request.role().into(),
2403        )
2404        .await?;
2405
2406    match result {
2407        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2408            let connection_pool = session.connection_pool().await;
2409            notify_membership_updated(
2410                &connection_pool,
2411                membership_update,
2412                member_id,
2413                &session.peer,
2414            )
2415        }
2416        db::SetMemberRoleResult::InviteUpdated(channel) => {
2417            let update = proto::UpdateChannels {
2418                channel_invitations: vec![channel.to_proto()],
2419                ..Default::default()
2420            };
2421
2422            for connection_id in session
2423                .connection_pool()
2424                .await
2425                .user_connection_ids(member_id)
2426            {
2427                session.peer.send(connection_id, update.clone())?;
2428            }
2429        }
2430    }
2431
2432    response.send(proto::Ack {})?;
2433    Ok(())
2434}
2435
2436async fn rename_channel(
2437    request: proto::RenameChannel,
2438    response: Response<proto::RenameChannel>,
2439    session: Session,
2440) -> Result<()> {
2441    let db = session.db().await;
2442    let channel_id = ChannelId::from_proto(request.channel_id);
2443    let RenameChannelResult {
2444        channel,
2445        participants_to_update,
2446    } = db
2447        .rename_channel(channel_id, session.user_id, &request.name)
2448        .await?;
2449
2450    response.send(proto::RenameChannelResponse {
2451        channel: Some(channel.to_proto()),
2452    })?;
2453
2454    let connection_pool = session.connection_pool().await;
2455    for (user_id, channel) in participants_to_update {
2456        for connection_id in connection_pool.user_connection_ids(user_id) {
2457            let update = proto::UpdateChannels {
2458                channels: vec![channel.to_proto()],
2459                ..Default::default()
2460            };
2461
2462            session.peer.send(connection_id, update.clone())?;
2463        }
2464    }
2465
2466    Ok(())
2467}
2468
2469async fn move_channel(
2470    request: proto::MoveChannel,
2471    response: Response<proto::MoveChannel>,
2472    session: Session,
2473) -> Result<()> {
2474    let channel_id = ChannelId::from_proto(request.channel_id);
2475    let to = request.to.map(ChannelId::from_proto);
2476
2477    let result = session
2478        .db()
2479        .await
2480        .move_channel(channel_id, to, session.user_id)
2481        .await?;
2482
2483    notify_channel_moved(result, session).await?;
2484
2485    response.send(Ack {})?;
2486    Ok(())
2487}
2488
2489async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2490    let Some(MoveChannelResult {
2491        participants_to_remove,
2492        participants_to_update,
2493        moved_channels,
2494    }) = result
2495    else {
2496        return Ok(());
2497    };
2498    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2499
2500    let connection_pool = session.connection_pool().await;
2501    for (user_id, channels) in participants_to_update {
2502        let mut update = build_channels_update(channels, vec![]);
2503        update.delete_channels = moved_channels.clone();
2504        for connection_id in connection_pool.user_connection_ids(user_id) {
2505            session.peer.send(connection_id, update.clone())?;
2506        }
2507    }
2508
2509    for user_id in participants_to_remove {
2510        let update = proto::UpdateChannels {
2511            delete_channels: moved_channels.clone(),
2512            ..Default::default()
2513        };
2514        for connection_id in connection_pool.user_connection_ids(user_id) {
2515            session.peer.send(connection_id, update.clone())?;
2516        }
2517    }
2518    Ok(())
2519}
2520
2521async fn get_channel_members(
2522    request: proto::GetChannelMembers,
2523    response: Response<proto::GetChannelMembers>,
2524    session: Session,
2525) -> Result<()> {
2526    let db = session.db().await;
2527    let channel_id = ChannelId::from_proto(request.channel_id);
2528    let members = db
2529        .get_channel_participant_details(channel_id, session.user_id)
2530        .await?;
2531    response.send(proto::GetChannelMembersResponse { members })?;
2532    Ok(())
2533}
2534
2535async fn respond_to_channel_invite(
2536    request: proto::RespondToChannelInvite,
2537    response: Response<proto::RespondToChannelInvite>,
2538    session: Session,
2539) -> Result<()> {
2540    let db = session.db().await;
2541    let channel_id = ChannelId::from_proto(request.channel_id);
2542    let RespondToChannelInvite {
2543        membership_update,
2544        notifications,
2545    } = db
2546        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2547        .await?;
2548
2549    let connection_pool = session.connection_pool().await;
2550    if let Some(membership_update) = membership_update {
2551        notify_membership_updated(
2552            &connection_pool,
2553            membership_update,
2554            session.user_id,
2555            &session.peer,
2556        );
2557    } else {
2558        let update = proto::UpdateChannels {
2559            remove_channel_invitations: vec![channel_id.to_proto()],
2560            ..Default::default()
2561        };
2562
2563        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2564            session.peer.send(connection_id, update.clone())?;
2565        }
2566    };
2567
2568    send_notifications(&*connection_pool, &session.peer, notifications);
2569
2570    response.send(proto::Ack {})?;
2571
2572    Ok(())
2573}
2574
2575async fn join_channel(
2576    request: proto::JoinChannel,
2577    response: Response<proto::JoinChannel>,
2578    session: Session,
2579) -> Result<()> {
2580    let channel_id = ChannelId::from_proto(request.channel_id);
2581    join_channel_internal(channel_id, Box::new(response), session).await
2582}
2583
2584trait JoinChannelInternalResponse {
2585    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2586}
2587impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2588    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2589        Response::<proto::JoinChannel>::send(self, result)
2590    }
2591}
2592impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2593    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2594        Response::<proto::JoinRoom>::send(self, result)
2595    }
2596}
2597
2598async fn join_channel_internal(
2599    channel_id: ChannelId,
2600    response: Box<impl JoinChannelInternalResponse>,
2601    session: Session,
2602) -> Result<()> {
2603    let joined_room = {
2604        leave_room_for_session(&session).await?;
2605        let db = session.db().await;
2606
2607        let (joined_room, membership_updated, role) = db
2608            .join_channel(
2609                channel_id,
2610                session.user_id,
2611                session.connection_id,
2612                session.zed_environment.as_ref(),
2613            )
2614            .await?;
2615
2616        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2617            let (can_publish, token) = if role == ChannelRole::Guest {
2618                (
2619                    false,
2620                    live_kit
2621                        .guest_token(
2622                            &joined_room.room.live_kit_room,
2623                            &session.user_id.to_string(),
2624                        )
2625                        .trace_err()?,
2626                )
2627            } else {
2628                (
2629                    true,
2630                    live_kit
2631                        .room_token(
2632                            &joined_room.room.live_kit_room,
2633                            &session.user_id.to_string(),
2634                        )
2635                        .trace_err()?,
2636                )
2637            };
2638
2639            Some(LiveKitConnectionInfo {
2640                server_url: live_kit.url().into(),
2641                token,
2642                can_publish,
2643            })
2644        });
2645
2646        response.send(proto::JoinRoomResponse {
2647            room: Some(joined_room.room.clone()),
2648            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2649            live_kit_connection_info,
2650        })?;
2651
2652        let connection_pool = session.connection_pool().await;
2653        if let Some(membership_updated) = membership_updated {
2654            notify_membership_updated(
2655                &connection_pool,
2656                membership_updated,
2657                session.user_id,
2658                &session.peer,
2659            );
2660        }
2661
2662        room_updated(&joined_room.room, &session.peer);
2663
2664        joined_room
2665    };
2666
2667    channel_updated(
2668        channel_id,
2669        &joined_room.room,
2670        &joined_room.channel_members,
2671        &session.peer,
2672        &*session.connection_pool().await,
2673    );
2674
2675    update_user_contacts(session.user_id, &session).await?;
2676    Ok(())
2677}
2678
2679async fn join_channel_buffer(
2680    request: proto::JoinChannelBuffer,
2681    response: Response<proto::JoinChannelBuffer>,
2682    session: Session,
2683) -> Result<()> {
2684    let db = session.db().await;
2685    let channel_id = ChannelId::from_proto(request.channel_id);
2686
2687    let open_response = db
2688        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2689        .await?;
2690
2691    let collaborators = open_response.collaborators.clone();
2692    response.send(open_response)?;
2693
2694    let update = UpdateChannelBufferCollaborators {
2695        channel_id: channel_id.to_proto(),
2696        collaborators: collaborators.clone(),
2697    };
2698    channel_buffer_updated(
2699        session.connection_id,
2700        collaborators
2701            .iter()
2702            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2703        &update,
2704        &session.peer,
2705    );
2706
2707    Ok(())
2708}
2709
2710async fn update_channel_buffer(
2711    request: proto::UpdateChannelBuffer,
2712    session: Session,
2713) -> Result<()> {
2714    let db = session.db().await;
2715    let channel_id = ChannelId::from_proto(request.channel_id);
2716
2717    let (collaborators, non_collaborators, epoch, version) = db
2718        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2719        .await?;
2720
2721    channel_buffer_updated(
2722        session.connection_id,
2723        collaborators,
2724        &proto::UpdateChannelBuffer {
2725            channel_id: channel_id.to_proto(),
2726            operations: request.operations,
2727        },
2728        &session.peer,
2729    );
2730
2731    let pool = &*session.connection_pool().await;
2732
2733    broadcast(
2734        None,
2735        non_collaborators
2736            .iter()
2737            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2738        |peer_id| {
2739            session.peer.send(
2740                peer_id.into(),
2741                proto::UpdateChannels {
2742                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2743                        channel_id: channel_id.to_proto(),
2744                        epoch: epoch as u64,
2745                        version: version.clone(),
2746                    }],
2747                    ..Default::default()
2748                },
2749            )
2750        },
2751    );
2752
2753    Ok(())
2754}
2755
2756async fn rejoin_channel_buffers(
2757    request: proto::RejoinChannelBuffers,
2758    response: Response<proto::RejoinChannelBuffers>,
2759    session: Session,
2760) -> Result<()> {
2761    let db = session.db().await;
2762    let buffers = db
2763        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2764        .await?;
2765
2766    for rejoined_buffer in &buffers {
2767        let collaborators_to_notify = rejoined_buffer
2768            .buffer
2769            .collaborators
2770            .iter()
2771            .filter_map(|c| Some(c.peer_id?.into()));
2772        channel_buffer_updated(
2773            session.connection_id,
2774            collaborators_to_notify,
2775            &proto::UpdateChannelBufferCollaborators {
2776                channel_id: rejoined_buffer.buffer.channel_id,
2777                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2778            },
2779            &session.peer,
2780        );
2781    }
2782
2783    response.send(proto::RejoinChannelBuffersResponse {
2784        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2785    })?;
2786
2787    Ok(())
2788}
2789
2790async fn leave_channel_buffer(
2791    request: proto::LeaveChannelBuffer,
2792    response: Response<proto::LeaveChannelBuffer>,
2793    session: Session,
2794) -> Result<()> {
2795    let db = session.db().await;
2796    let channel_id = ChannelId::from_proto(request.channel_id);
2797
2798    let left_buffer = db
2799        .leave_channel_buffer(channel_id, session.connection_id)
2800        .await?;
2801
2802    response.send(Ack {})?;
2803
2804    channel_buffer_updated(
2805        session.connection_id,
2806        left_buffer.connections,
2807        &proto::UpdateChannelBufferCollaborators {
2808            channel_id: channel_id.to_proto(),
2809            collaborators: left_buffer.collaborators,
2810        },
2811        &session.peer,
2812    );
2813
2814    Ok(())
2815}
2816
2817fn channel_buffer_updated<T: EnvelopedMessage>(
2818    sender_id: ConnectionId,
2819    collaborators: impl IntoIterator<Item = ConnectionId>,
2820    message: &T,
2821    peer: &Peer,
2822) {
2823    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2824        peer.send(peer_id.into(), message.clone())
2825    });
2826}
2827
2828fn send_notifications(
2829    connection_pool: &ConnectionPool,
2830    peer: &Peer,
2831    notifications: db::NotificationBatch,
2832) {
2833    for (user_id, notification) in notifications {
2834        for connection_id in connection_pool.user_connection_ids(user_id) {
2835            if let Err(error) = peer.send(
2836                connection_id,
2837                proto::AddNotification {
2838                    notification: Some(notification.clone()),
2839                },
2840            ) {
2841                tracing::error!(
2842                    "failed to send notification to {:?} {}",
2843                    connection_id,
2844                    error
2845                );
2846            }
2847        }
2848    }
2849}
2850
2851async fn send_channel_message(
2852    request: proto::SendChannelMessage,
2853    response: Response<proto::SendChannelMessage>,
2854    session: Session,
2855) -> Result<()> {
2856    // Validate the message body.
2857    let body = request.body.trim().to_string();
2858    if body.len() > MAX_MESSAGE_LEN {
2859        return Err(anyhow!("message is too long"))?;
2860    }
2861    if body.is_empty() {
2862        return Err(anyhow!("message can't be blank"))?;
2863    }
2864
2865    // TODO: adjust mentions if body is trimmed
2866
2867    let timestamp = OffsetDateTime::now_utc();
2868    let nonce = request
2869        .nonce
2870        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2871
2872    let channel_id = ChannelId::from_proto(request.channel_id);
2873    let CreatedChannelMessage {
2874        message_id,
2875        participant_connection_ids,
2876        channel_members,
2877        notifications,
2878    } = session
2879        .db()
2880        .await
2881        .create_channel_message(
2882            channel_id,
2883            session.user_id,
2884            &body,
2885            &request.mentions,
2886            timestamp,
2887            nonce.clone().into(),
2888        )
2889        .await?;
2890    let message = proto::ChannelMessage {
2891        sender_id: session.user_id.to_proto(),
2892        id: message_id.to_proto(),
2893        body,
2894        mentions: request.mentions,
2895        timestamp: timestamp.unix_timestamp() as u64,
2896        nonce: Some(nonce),
2897    };
2898    broadcast(
2899        Some(session.connection_id),
2900        participant_connection_ids,
2901        |connection| {
2902            session.peer.send(
2903                connection,
2904                proto::ChannelMessageSent {
2905                    channel_id: channel_id.to_proto(),
2906                    message: Some(message.clone()),
2907                },
2908            )
2909        },
2910    );
2911    response.send(proto::SendChannelMessageResponse {
2912        message: Some(message),
2913    })?;
2914
2915    let pool = &*session.connection_pool().await;
2916    broadcast(
2917        None,
2918        channel_members
2919            .iter()
2920            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2921        |peer_id| {
2922            session.peer.send(
2923                peer_id.into(),
2924                proto::UpdateChannels {
2925                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2926                        channel_id: channel_id.to_proto(),
2927                        message_id: message_id.to_proto(),
2928                    }],
2929                    ..Default::default()
2930                },
2931            )
2932        },
2933    );
2934    send_notifications(pool, &session.peer, notifications);
2935
2936    Ok(())
2937}
2938
2939async fn remove_channel_message(
2940    request: proto::RemoveChannelMessage,
2941    response: Response<proto::RemoveChannelMessage>,
2942    session: Session,
2943) -> Result<()> {
2944    let channel_id = ChannelId::from_proto(request.channel_id);
2945    let message_id = MessageId::from_proto(request.message_id);
2946    let connection_ids = session
2947        .db()
2948        .await
2949        .remove_channel_message(channel_id, message_id, session.user_id)
2950        .await?;
2951    broadcast(Some(session.connection_id), connection_ids, |connection| {
2952        session.peer.send(connection, request.clone())
2953    });
2954    response.send(proto::Ack {})?;
2955    Ok(())
2956}
2957
2958async fn acknowledge_channel_message(
2959    request: proto::AckChannelMessage,
2960    session: Session,
2961) -> Result<()> {
2962    let channel_id = ChannelId::from_proto(request.channel_id);
2963    let message_id = MessageId::from_proto(request.message_id);
2964    let notifications = session
2965        .db()
2966        .await
2967        .observe_channel_message(channel_id, session.user_id, message_id)
2968        .await?;
2969    send_notifications(
2970        &*session.connection_pool().await,
2971        &session.peer,
2972        notifications,
2973    );
2974    Ok(())
2975}
2976
2977async fn acknowledge_buffer_version(
2978    request: proto::AckBufferOperation,
2979    session: Session,
2980) -> Result<()> {
2981    let buffer_id = BufferId::from_proto(request.buffer_id);
2982    session
2983        .db()
2984        .await
2985        .observe_buffer_version(
2986            buffer_id,
2987            session.user_id,
2988            request.epoch as i32,
2989            &request.version,
2990        )
2991        .await?;
2992    Ok(())
2993}
2994
2995async fn join_channel_chat(
2996    request: proto::JoinChannelChat,
2997    response: Response<proto::JoinChannelChat>,
2998    session: Session,
2999) -> Result<()> {
3000    let channel_id = ChannelId::from_proto(request.channel_id);
3001
3002    let db = session.db().await;
3003    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3004        .await?;
3005    let messages = db
3006        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3007        .await?;
3008    response.send(proto::JoinChannelChatResponse {
3009        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3010        messages,
3011    })?;
3012    Ok(())
3013}
3014
3015async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3016    let channel_id = ChannelId::from_proto(request.channel_id);
3017    session
3018        .db()
3019        .await
3020        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3021        .await?;
3022    Ok(())
3023}
3024
3025async fn get_channel_messages(
3026    request: proto::GetChannelMessages,
3027    response: Response<proto::GetChannelMessages>,
3028    session: Session,
3029) -> Result<()> {
3030    let channel_id = ChannelId::from_proto(request.channel_id);
3031    let messages = session
3032        .db()
3033        .await
3034        .get_channel_messages(
3035            channel_id,
3036            session.user_id,
3037            MESSAGE_COUNT_PER_PAGE,
3038            Some(MessageId::from_proto(request.before_message_id)),
3039        )
3040        .await?;
3041    response.send(proto::GetChannelMessagesResponse {
3042        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3043        messages,
3044    })?;
3045    Ok(())
3046}
3047
3048async fn get_channel_messages_by_id(
3049    request: proto::GetChannelMessagesById,
3050    response: Response<proto::GetChannelMessagesById>,
3051    session: Session,
3052) -> Result<()> {
3053    let message_ids = request
3054        .message_ids
3055        .iter()
3056        .map(|id| MessageId::from_proto(*id))
3057        .collect::<Vec<_>>();
3058    let messages = session
3059        .db()
3060        .await
3061        .get_channel_messages_by_id(session.user_id, &message_ids)
3062        .await?;
3063    response.send(proto::GetChannelMessagesResponse {
3064        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3065        messages,
3066    })?;
3067    Ok(())
3068}
3069
3070async fn get_notifications(
3071    request: proto::GetNotifications,
3072    response: Response<proto::GetNotifications>,
3073    session: Session,
3074) -> Result<()> {
3075    let notifications = session
3076        .db()
3077        .await
3078        .get_notifications(
3079            session.user_id,
3080            NOTIFICATION_COUNT_PER_PAGE,
3081            request
3082                .before_id
3083                .map(|id| db::NotificationId::from_proto(id)),
3084        )
3085        .await?;
3086    response.send(proto::GetNotificationsResponse {
3087        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3088        notifications,
3089    })?;
3090    Ok(())
3091}
3092
3093async fn mark_notification_as_read(
3094    request: proto::MarkNotificationRead,
3095    response: Response<proto::MarkNotificationRead>,
3096    session: Session,
3097) -> Result<()> {
3098    let database = &session.db().await;
3099    let notifications = database
3100        .mark_notification_as_read_by_id(
3101            session.user_id,
3102            NotificationId::from_proto(request.notification_id),
3103        )
3104        .await?;
3105    send_notifications(
3106        &*session.connection_pool().await,
3107        &session.peer,
3108        notifications,
3109    );
3110    response.send(proto::Ack {})?;
3111    Ok(())
3112}
3113
3114async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3115    let project_id = ProjectId::from_proto(request.project_id);
3116    let project_connection_ids = session
3117        .db()
3118        .await
3119        .project_connection_ids(project_id, session.connection_id)
3120        .await?;
3121    broadcast(
3122        Some(session.connection_id),
3123        project_connection_ids.iter().copied(),
3124        |connection_id| {
3125            session
3126                .peer
3127                .forward_send(session.connection_id, connection_id, request.clone())
3128        },
3129    );
3130    Ok(())
3131}
3132
3133async fn get_private_user_info(
3134    _request: proto::GetPrivateUserInfo,
3135    response: Response<proto::GetPrivateUserInfo>,
3136    session: Session,
3137) -> Result<()> {
3138    let db = session.db().await;
3139
3140    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3141    let user = db
3142        .get_user_by_id(session.user_id)
3143        .await?
3144        .ok_or_else(|| anyhow!("user not found"))?;
3145    let flags = db.get_user_flags(session.user_id).await?;
3146
3147    response.send(proto::GetPrivateUserInfoResponse {
3148        metrics_id,
3149        staff: user.admin,
3150        flags,
3151    })?;
3152    Ok(())
3153}
3154
3155fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3156    match message {
3157        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3158        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3159        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3160        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3161        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3162            code: frame.code.into(),
3163            reason: frame.reason,
3164        })),
3165    }
3166}
3167
3168fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3169    match message {
3170        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3171        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3172        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3173        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3174        AxumMessage::Close(frame) => {
3175            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3176                code: frame.code.into(),
3177                reason: frame.reason,
3178            }))
3179        }
3180    }
3181}
3182
3183fn notify_membership_updated(
3184    connection_pool: &ConnectionPool,
3185    result: MembershipUpdated,
3186    user_id: UserId,
3187    peer: &Peer,
3188) {
3189    let mut update = build_channels_update(result.new_channels, vec![]);
3190    update.delete_channels = result
3191        .removed_channels
3192        .into_iter()
3193        .map(|id| id.to_proto())
3194        .collect();
3195    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3196
3197    for connection_id in connection_pool.user_connection_ids(user_id) {
3198        peer.send(connection_id, update.clone()).trace_err();
3199    }
3200}
3201
3202fn build_channels_update(
3203    channels: ChannelsForUser,
3204    channel_invites: Vec<db::Channel>,
3205) -> proto::UpdateChannels {
3206    let mut update = proto::UpdateChannels::default();
3207
3208    for channel in channels.channels {
3209        update.channels.push(channel.to_proto());
3210    }
3211
3212    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3213    update.unseen_channel_messages = channels.channel_messages;
3214
3215    for (channel_id, participants) in channels.channel_participants {
3216        update
3217            .channel_participants
3218            .push(proto::ChannelParticipants {
3219                channel_id: channel_id.to_proto(),
3220                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3221            });
3222    }
3223
3224    for channel in channel_invites {
3225        update.channel_invitations.push(channel.to_proto());
3226    }
3227
3228    update
3229}
3230
3231fn build_initial_contacts_update(
3232    contacts: Vec<db::Contact>,
3233    pool: &ConnectionPool,
3234) -> proto::UpdateContacts {
3235    let mut update = proto::UpdateContacts::default();
3236
3237    for contact in contacts {
3238        match contact {
3239            db::Contact::Accepted { user_id, busy } => {
3240                update.contacts.push(contact_for_user(user_id, busy, &pool));
3241            }
3242            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3243            db::Contact::Incoming { user_id } => {
3244                update
3245                    .incoming_requests
3246                    .push(proto::IncomingContactRequest {
3247                        requester_id: user_id.to_proto(),
3248                    })
3249            }
3250        }
3251    }
3252
3253    update
3254}
3255
3256fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3257    proto::Contact {
3258        user_id: user_id.to_proto(),
3259        online: pool.is_user_online(user_id),
3260        busy,
3261    }
3262}
3263
3264fn room_updated(room: &proto::Room, peer: &Peer) {
3265    broadcast(
3266        None,
3267        room.participants
3268            .iter()
3269            .filter_map(|participant| Some(participant.peer_id?.into())),
3270        |peer_id| {
3271            peer.send(
3272                peer_id.into(),
3273                proto::RoomUpdated {
3274                    room: Some(room.clone()),
3275                },
3276            )
3277        },
3278    );
3279}
3280
3281fn channel_updated(
3282    channel_id: ChannelId,
3283    room: &proto::Room,
3284    channel_members: &[UserId],
3285    peer: &Peer,
3286    pool: &ConnectionPool,
3287) {
3288    let participants = room
3289        .participants
3290        .iter()
3291        .map(|p| p.user_id)
3292        .collect::<Vec<_>>();
3293
3294    broadcast(
3295        None,
3296        channel_members
3297            .iter()
3298            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3299        |peer_id| {
3300            peer.send(
3301                peer_id.into(),
3302                proto::UpdateChannels {
3303                    channel_participants: vec![proto::ChannelParticipants {
3304                        channel_id: channel_id.to_proto(),
3305                        participant_user_ids: participants.clone(),
3306                    }],
3307                    ..Default::default()
3308                },
3309            )
3310        },
3311    );
3312}
3313
3314async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3315    let db = session.db().await;
3316
3317    let contacts = db.get_contacts(user_id).await?;
3318    let busy = db.is_user_busy(user_id).await?;
3319
3320    let pool = session.connection_pool().await;
3321    let updated_contact = contact_for_user(user_id, busy, &pool);
3322    for contact in contacts {
3323        if let db::Contact::Accepted {
3324            user_id: contact_user_id,
3325            ..
3326        } = contact
3327        {
3328            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3329                session
3330                    .peer
3331                    .send(
3332                        contact_conn_id,
3333                        proto::UpdateContacts {
3334                            contacts: vec![updated_contact.clone()],
3335                            remove_contacts: Default::default(),
3336                            incoming_requests: Default::default(),
3337                            remove_incoming_requests: Default::default(),
3338                            outgoing_requests: Default::default(),
3339                            remove_outgoing_requests: Default::default(),
3340                        },
3341                    )
3342                    .trace_err();
3343            }
3344        }
3345    }
3346    Ok(())
3347}
3348
3349async fn leave_room_for_session(session: &Session) -> Result<()> {
3350    let mut contacts_to_update = HashSet::default();
3351
3352    let room_id;
3353    let canceled_calls_to_user_ids;
3354    let live_kit_room;
3355    let delete_live_kit_room;
3356    let room;
3357    let channel_members;
3358    let channel_id;
3359
3360    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3361        contacts_to_update.insert(session.user_id);
3362
3363        for project in left_room.left_projects.values() {
3364            project_left(project, session);
3365        }
3366
3367        room_id = RoomId::from_proto(left_room.room.id);
3368        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3369        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3370        delete_live_kit_room = left_room.deleted;
3371        room = mem::take(&mut left_room.room);
3372        channel_members = mem::take(&mut left_room.channel_members);
3373        channel_id = left_room.channel_id;
3374
3375        room_updated(&room, &session.peer);
3376    } else {
3377        return Ok(());
3378    }
3379
3380    if let Some(channel_id) = channel_id {
3381        channel_updated(
3382            channel_id,
3383            &room,
3384            &channel_members,
3385            &session.peer,
3386            &*session.connection_pool().await,
3387        );
3388    }
3389
3390    {
3391        let pool = session.connection_pool().await;
3392        for canceled_user_id in canceled_calls_to_user_ids {
3393            for connection_id in pool.user_connection_ids(canceled_user_id) {
3394                session
3395                    .peer
3396                    .send(
3397                        connection_id,
3398                        proto::CallCanceled {
3399                            room_id: room_id.to_proto(),
3400                        },
3401                    )
3402                    .trace_err();
3403            }
3404            contacts_to_update.insert(canceled_user_id);
3405        }
3406    }
3407
3408    for contact_user_id in contacts_to_update {
3409        update_user_contacts(contact_user_id, &session).await?;
3410    }
3411
3412    if let Some(live_kit) = session.live_kit_client.as_ref() {
3413        live_kit
3414            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3415            .await
3416            .trace_err();
3417
3418        if delete_live_kit_room {
3419            live_kit.delete_room(live_kit_room).await.trace_err();
3420        }
3421    }
3422
3423    Ok(())
3424}
3425
3426async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3427    let left_channel_buffers = session
3428        .db()
3429        .await
3430        .leave_channel_buffers(session.connection_id)
3431        .await?;
3432
3433    for left_buffer in left_channel_buffers {
3434        channel_buffer_updated(
3435            session.connection_id,
3436            left_buffer.connections,
3437            &proto::UpdateChannelBufferCollaborators {
3438                channel_id: left_buffer.channel_id.to_proto(),
3439                collaborators: left_buffer.collaborators,
3440            },
3441            &session.peer,
3442        );
3443    }
3444
3445    Ok(())
3446}
3447
3448fn project_left(project: &db::LeftProject, session: &Session) {
3449    for connection_id in &project.connection_ids {
3450        if project.host_user_id == session.user_id {
3451            session
3452                .peer
3453                .send(
3454                    *connection_id,
3455                    proto::UnshareProject {
3456                        project_id: project.id.to_proto(),
3457                    },
3458                )
3459                .trace_err();
3460        } else {
3461            session
3462                .peer
3463                .send(
3464                    *connection_id,
3465                    proto::RemoveProjectCollaborator {
3466                        project_id: project.id.to_proto(),
3467                        peer_id: Some(session.connection_id.into()),
3468                    },
3469                )
3470                .trace_err();
3471        }
3472    }
3473}
3474
3475pub trait ResultExt {
3476    type Ok;
3477
3478    fn trace_err(self) -> Option<Self::Ok>;
3479}
3480
3481impl<T, E> ResultExt for Result<T, E>
3482where
3483    E: std::fmt::Debug,
3484{
3485    type Ok = T;
3486
3487    fn trace_err(self) -> Option<T> {
3488        match self {
3489            Ok(value) => Some(value),
3490            Err(error) => {
3491                tracing::error!("{:?}", error);
3492                None
3493            }
3494        }
3495    }
3496}