rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  72
  73const MESSAGE_COUNT_PER_PAGE: usize = 100;
  74const MAX_MESSAGE_LEN: usize = 1024;
  75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  76
  77lazy_static! {
  78    static ref METRIC_CONNECTIONS: IntGauge =
  79        register_int_gauge!("connections", "number of connections").unwrap();
  80    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  81        "shared_projects",
  82        "number of open projects with one or more guests"
  83    )
  84    .unwrap();
  85}
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct Session {
 106    zed_environment: Arc<str>,
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(set_room_participant_role)
 206            .add_request_handler(call)
 207            .add_request_handler(cancel_call)
 208            .add_message_handler(decline_call)
 209            .add_request_handler(update_participant_location)
 210            .add_request_handler(share_project)
 211            .add_message_handler(unshare_project)
 212            .add_request_handler(join_project)
 213            .add_message_handler(leave_project)
 214            .add_request_handler(update_project)
 215            .add_request_handler(update_worktree)
 216            .add_message_handler(start_language_server)
 217            .add_message_handler(update_language_server)
 218            .add_message_handler(update_diagnostic_summary)
 219            .add_message_handler(update_worktree_settings)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 230            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 231            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 232            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 233            .add_request_handler(
 234                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 235            )
 236            .add_request_handler(
 237                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 238            )
 239            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 240            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 243            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 251            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 252            .add_message_handler(create_buffer_for_peer)
 253            .add_request_handler(update_buffer)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 258            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 259            .add_request_handler(get_users)
 260            .add_request_handler(fuzzy_search_users)
 261            .add_request_handler(request_contact)
 262            .add_request_handler(remove_contact)
 263            .add_request_handler(respond_to_contact_request)
 264            .add_request_handler(create_channel)
 265            .add_request_handler(delete_channel)
 266            .add_request_handler(invite_channel_member)
 267            .add_request_handler(remove_channel_member)
 268            .add_request_handler(set_channel_member_role)
 269            .add_request_handler(set_channel_visibility)
 270            .add_request_handler(rename_channel)
 271            .add_request_handler(join_channel_buffer)
 272            .add_request_handler(leave_channel_buffer)
 273            .add_message_handler(update_channel_buffer)
 274            .add_request_handler(rejoin_channel_buffers)
 275            .add_request_handler(get_channel_members)
 276            .add_request_handler(respond_to_channel_invite)
 277            .add_request_handler(join_channel)
 278            .add_request_handler(join_channel_chat)
 279            .add_message_handler(leave_channel_chat)
 280            .add_request_handler(send_channel_message)
 281            .add_request_handler(remove_channel_message)
 282            .add_request_handler(get_channel_messages)
 283            .add_request_handler(get_channel_messages_by_id)
 284            .add_request_handler(get_notifications)
 285            .add_request_handler(mark_notification_as_read)
 286            .add_request_handler(move_channel)
 287            .add_request_handler(follow)
 288            .add_message_handler(unfollow)
 289            .add_message_handler(update_followers)
 290            .add_request_handler(get_private_user_info)
 291            .add_message_handler(acknowledge_channel_message)
 292            .add_message_handler(acknowledge_buffer_version);
 293
 294        Arc::new(server)
 295    }
 296
 297    pub async fn start(&self) -> Result<()> {
 298        let server_id = *self.id.lock();
 299        let app_state = self.app_state.clone();
 300        let peer = self.peer.clone();
 301        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 302        let pool = self.connection_pool.clone();
 303        let live_kit_client = self.app_state.live_kit_client.clone();
 304
 305        let span = info_span!("start server");
 306        self.executor.spawn_detached(
 307            async move {
 308                tracing::info!("waiting for cleanup timeout");
 309                timeout.await;
 310                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 311                if let Some((room_ids, channel_ids)) = app_state
 312                    .db
 313                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 314                    .await
 315                    .trace_err()
 316                {
 317                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 318                    tracing::info!(
 319                        stale_channel_buffer_count = channel_ids.len(),
 320                        "retrieved stale channel buffers"
 321                    );
 322
 323                    for channel_id in channel_ids {
 324                        if let Some(refreshed_channel_buffer) = app_state
 325                            .db
 326                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 327                            .await
 328                            .trace_err()
 329                        {
 330                            for connection_id in refreshed_channel_buffer.connection_ids {
 331                                peer.send(
 332                                    connection_id,
 333                                    proto::UpdateChannelBufferCollaborators {
 334                                        channel_id: channel_id.to_proto(),
 335                                        collaborators: refreshed_channel_buffer
 336                                            .collaborators
 337                                            .clone(),
 338                                    },
 339                                )
 340                                .trace_err();
 341                            }
 342                        }
 343                    }
 344
 345                    for room_id in room_ids {
 346                        let mut contacts_to_update = HashSet::default();
 347                        let mut canceled_calls_to_user_ids = Vec::new();
 348                        let mut live_kit_room = String::new();
 349                        let mut delete_live_kit_room = false;
 350
 351                        if let Some(mut refreshed_room) = app_state
 352                            .db
 353                            .clear_stale_room_participants(room_id, server_id)
 354                            .await
 355                            .trace_err()
 356                        {
 357                            tracing::info!(
 358                                room_id = room_id.0,
 359                                new_participant_count = refreshed_room.room.participants.len(),
 360                                "refreshed room"
 361                            );
 362                            room_updated(&refreshed_room.room, &peer);
 363                            if let Some(channel_id) = refreshed_room.channel_id {
 364                                channel_updated(
 365                                    channel_id,
 366                                    &refreshed_room.room,
 367                                    &refreshed_room.channel_members,
 368                                    &peer,
 369                                    &*pool.lock(),
 370                                );
 371                            }
 372                            contacts_to_update
 373                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 374                            contacts_to_update
 375                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 376                            canceled_calls_to_user_ids =
 377                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 378                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 379                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 380                        }
 381
 382                        {
 383                            let pool = pool.lock();
 384                            for canceled_user_id in canceled_calls_to_user_ids {
 385                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 386                                    peer.send(
 387                                        connection_id,
 388                                        proto::CallCanceled {
 389                                            room_id: room_id.to_proto(),
 390                                        },
 391                                    )
 392                                    .trace_err();
 393                                }
 394                            }
 395                        }
 396
 397                        for user_id in contacts_to_update {
 398                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 399                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 400                            if let Some((busy, contacts)) = busy.zip(contacts) {
 401                                let pool = pool.lock();
 402                                let updated_contact = contact_for_user(user_id, busy, &pool);
 403                                for contact in contacts {
 404                                    if let db::Contact::Accepted {
 405                                        user_id: contact_user_id,
 406                                        ..
 407                                    } = contact
 408                                    {
 409                                        for contact_conn_id in
 410                                            pool.user_connection_ids(contact_user_id)
 411                                        {
 412                                            peer.send(
 413                                                contact_conn_id,
 414                                                proto::UpdateContacts {
 415                                                    contacts: vec![updated_contact.clone()],
 416                                                    remove_contacts: Default::default(),
 417                                                    incoming_requests: Default::default(),
 418                                                    remove_incoming_requests: Default::default(),
 419                                                    outgoing_requests: Default::default(),
 420                                                    remove_outgoing_requests: Default::default(),
 421                                                },
 422                                            )
 423                                            .trace_err();
 424                                        }
 425                                    }
 426                                }
 427                            }
 428                        }
 429
 430                        if let Some(live_kit) = live_kit_client.as_ref() {
 431                            if delete_live_kit_room {
 432                                live_kit.delete_room(live_kit_room).await.trace_err();
 433                            }
 434                        }
 435                    }
 436                }
 437
 438                app_state
 439                    .db
 440                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 441                    .await
 442                    .trace_err();
 443            }
 444            .instrument(span),
 445        );
 446        Ok(())
 447    }
 448
 449    pub fn teardown(&self) {
 450        self.peer.teardown();
 451        self.connection_pool.lock().reset();
 452        let _ = self.teardown.send(());
 453    }
 454
 455    #[cfg(test)]
 456    pub fn reset(&self, id: ServerId) {
 457        self.teardown();
 458        *self.id.lock() = id;
 459        self.peer.reset(id.0 as u32);
 460    }
 461
 462    #[cfg(test)]
 463    pub fn id(&self) -> ServerId {
 464        *self.id.lock()
 465    }
 466
 467    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 468    where
 469        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 470        Fut: 'static + Send + Future<Output = Result<()>>,
 471        M: EnvelopedMessage,
 472    {
 473        let prev_handler = self.handlers.insert(
 474            TypeId::of::<M>(),
 475            Box::new(move |envelope, session| {
 476                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 477                let span = info_span!(
 478                    "handle message",
 479                    payload_type = envelope.payload_type_name()
 480                );
 481                span.in_scope(|| {
 482                    tracing::info!(
 483                        payload_type = envelope.payload_type_name(),
 484                        "message received"
 485                    );
 486                });
 487                let start_time = Instant::now();
 488                let future = (handler)(*envelope, session);
 489                async move {
 490                    let result = future.await;
 491                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 492                    match result {
 493                        Err(error) => {
 494                            tracing::error!(%error, ?duration_ms, "error handling message")
 495                        }
 496                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 497                    }
 498                }
 499                .instrument(span)
 500                .boxed()
 501            }),
 502        );
 503        if prev_handler.is_some() {
 504            panic!("registered a handler for the same message twice");
 505        }
 506        self
 507    }
 508
 509    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 512        Fut: 'static + Send + Future<Output = Result<()>>,
 513        M: EnvelopedMessage,
 514    {
 515        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 516        self
 517    }
 518
 519    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 520    where
 521        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 522        Fut: Send + Future<Output = Result<()>>,
 523        M: RequestMessage,
 524    {
 525        let handler = Arc::new(handler);
 526        self.add_handler(move |envelope, session| {
 527            let receipt = envelope.receipt();
 528            let handler = handler.clone();
 529            async move {
 530                let peer = session.peer.clone();
 531                let responded = Arc::new(AtomicBool::default());
 532                let response = Response {
 533                    peer: peer.clone(),
 534                    responded: responded.clone(),
 535                    receipt,
 536                };
 537                match (handler)(envelope.payload, response, session).await {
 538                    Ok(()) => {
 539                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 540                            Ok(())
 541                        } else {
 542                            Err(anyhow!("handler did not send a response"))?
 543                        }
 544                    }
 545                    Err(error) => {
 546                        peer.respond_with_error(
 547                            receipt,
 548                            proto::Error {
 549                                message: error.to_string(),
 550                            },
 551                        )?;
 552                        Err(error)
 553                    }
 554                }
 555            }
 556        })
 557    }
 558
 559    pub fn handle_connection(
 560        self: &Arc<Self>,
 561        connection: Connection,
 562        address: String,
 563        user: User,
 564        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 565        executor: Executor,
 566    ) -> impl Future<Output = Result<()>> {
 567        let this = self.clone();
 568        let user_id = user.id;
 569        let login = user.github_login;
 570        let span = info_span!("handle connection", %user_id, %login, %address);
 571        let mut teardown = self.teardown.subscribe();
 572        async move {
 573            let (connection_id, handle_io, mut incoming_rx) = this
 574                .peer
 575                .add_connection(connection, {
 576                    let executor = executor.clone();
 577                    move |duration| executor.sleep(duration)
 578                });
 579
 580            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 581            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 582            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 583
 584            if let Some(send_connection_id) = send_connection_id.take() {
 585                let _ = send_connection_id.send(connection_id);
 586            }
 587
 588            if !user.connected_once {
 589                this.peer.send(connection_id, proto::ShowContacts {})?;
 590                this.app_state.db.set_user_connected_once(user_id, true).await?;
 591            }
 592
 593            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 594                this.app_state.db.get_contacts(user_id),
 595                this.app_state.db.get_channels_for_user(user_id),
 596                this.app_state.db.get_channel_invites_for_user(user_id),
 597            ).await?;
 598
 599            {
 600                let mut pool = this.connection_pool.lock();
 601                pool.add_connection(connection_id, user_id, user.admin);
 602                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 603                this.peer.send(connection_id, build_channels_update(
 604                    channels_for_user,
 605                    channel_invites
 606                ))?;
 607            }
 608
 609            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 610                this.peer.send(connection_id, incoming_call)?;
 611            }
 612
 613            let session = Session {
 614                user_id,
 615                connection_id,
 616                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 617                zed_environment: this.app_state.config.zed_environment.clone(),
 618                peer: this.peer.clone(),
 619                connection_pool: this.connection_pool.clone(),
 620                live_kit_client: this.app_state.live_kit_client.clone(),
 621                _executor: executor.clone()
 622            };
 623            update_user_contacts(user_id, &session).await?;
 624
 625            let handle_io = handle_io.fuse();
 626            futures::pin_mut!(handle_io);
 627
 628            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 629            // This prevents deadlocks when e.g., client A performs a request to client B and
 630            // client B performs a request to client A. If both clients stop processing further
 631            // messages until their respective request completes, they won't have a chance to
 632            // respond to the other client's request and cause a deadlock.
 633            //
 634            // This arrangement ensures we will attempt to process earlier messages first, but fall
 635            // back to processing messages arrived later in the spirit of making progress.
 636            let mut foreground_message_handlers = FuturesUnordered::new();
 637            let concurrent_handlers = Arc::new(Semaphore::new(256));
 638            loop {
 639                let next_message = async {
 640                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 641                    let message = incoming_rx.next().await;
 642                    (permit, message)
 643                }.fuse();
 644                futures::pin_mut!(next_message);
 645                futures::select_biased! {
 646                    _ = teardown.changed().fuse() => return Ok(()),
 647                    result = handle_io => {
 648                        if let Err(error) = result {
 649                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 650                        }
 651                        break;
 652                    }
 653                    _ = foreground_message_handlers.next() => {}
 654                    next_message = next_message => {
 655                        let (permit, message) = next_message;
 656                        if let Some(message) = message {
 657                            let type_name = message.payload_type_name();
 658                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 659                            let span_enter = span.enter();
 660                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 661                                let is_background = message.is_background();
 662                                let handle_message = (handler)(message, session.clone());
 663                                drop(span_enter);
 664
 665                                let handle_message = async move {
 666                                    handle_message.await;
 667                                    drop(permit);
 668                                }.instrument(span);
 669                                if is_background {
 670                                    executor.spawn_detached(handle_message);
 671                                } else {
 672                                    foreground_message_handlers.push(handle_message);
 673                                }
 674                            } else {
 675                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 676                            }
 677                        } else {
 678                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 679                            break;
 680                        }
 681                    }
 682                }
 683            }
 684
 685            drop(foreground_message_handlers);
 686            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 687            if let Err(error) = connection_lost(session, teardown, executor).await {
 688                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 689            }
 690
 691            Ok(())
 692        }.instrument(span)
 693    }
 694
 695    pub async fn invite_code_redeemed(
 696        self: &Arc<Self>,
 697        inviter_id: UserId,
 698        invitee_id: UserId,
 699    ) -> Result<()> {
 700        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 701            if let Some(code) = &user.invite_code {
 702                let pool = self.connection_pool.lock();
 703                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 704                for connection_id in pool.user_connection_ids(inviter_id) {
 705                    self.peer.send(
 706                        connection_id,
 707                        proto::UpdateContacts {
 708                            contacts: vec![invitee_contact.clone()],
 709                            ..Default::default()
 710                        },
 711                    )?;
 712                    self.peer.send(
 713                        connection_id,
 714                        proto::UpdateInviteInfo {
 715                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 716                            count: user.invite_count as u32,
 717                        },
 718                    )?;
 719                }
 720            }
 721        }
 722        Ok(())
 723    }
 724
 725    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 726        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 727            if let Some(invite_code) = &user.invite_code {
 728                let pool = self.connection_pool.lock();
 729                for connection_id in pool.user_connection_ids(user_id) {
 730                    self.peer.send(
 731                        connection_id,
 732                        proto::UpdateInviteInfo {
 733                            url: format!(
 734                                "{}{}",
 735                                self.app_state.config.invite_link_prefix, invite_code
 736                            ),
 737                            count: user.invite_count as u32,
 738                        },
 739                    )?;
 740                }
 741            }
 742        }
 743        Ok(())
 744    }
 745
 746    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 747        ServerSnapshot {
 748            connection_pool: ConnectionPoolGuard {
 749                guard: self.connection_pool.lock(),
 750                _not_send: PhantomData,
 751            },
 752            peer: &self.peer,
 753        }
 754    }
 755}
 756
 757impl<'a> Deref for ConnectionPoolGuard<'a> {
 758    type Target = ConnectionPool;
 759
 760    fn deref(&self) -> &Self::Target {
 761        &*self.guard
 762    }
 763}
 764
 765impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 766    fn deref_mut(&mut self) -> &mut Self::Target {
 767        &mut *self.guard
 768    }
 769}
 770
 771impl<'a> Drop for ConnectionPoolGuard<'a> {
 772    fn drop(&mut self) {
 773        #[cfg(test)]
 774        self.check_invariants();
 775    }
 776}
 777
 778fn broadcast<F>(
 779    sender_id: Option<ConnectionId>,
 780    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 781    mut f: F,
 782) where
 783    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 784{
 785    for receiver_id in receiver_ids {
 786        if Some(receiver_id) != sender_id {
 787            if let Err(error) = f(receiver_id) {
 788                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 789            }
 790        }
 791    }
 792}
 793
 794lazy_static! {
 795    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 796}
 797
 798pub struct ProtocolVersion(u32);
 799
 800impl Header for ProtocolVersion {
 801    fn name() -> &'static HeaderName {
 802        &ZED_PROTOCOL_VERSION
 803    }
 804
 805    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 806    where
 807        Self: Sized,
 808        I: Iterator<Item = &'i axum::http::HeaderValue>,
 809    {
 810        let version = values
 811            .next()
 812            .ok_or_else(axum::headers::Error::invalid)?
 813            .to_str()
 814            .map_err(|_| axum::headers::Error::invalid())?
 815            .parse()
 816            .map_err(|_| axum::headers::Error::invalid())?;
 817        Ok(Self(version))
 818    }
 819
 820    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 821        values.extend([self.0.to_string().parse().unwrap()]);
 822    }
 823}
 824
 825pub fn routes(server: Arc<Server>) -> Router<Body> {
 826    Router::new()
 827        .route("/rpc", get(handle_websocket_request))
 828        .layer(
 829            ServiceBuilder::new()
 830                .layer(Extension(server.app_state.clone()))
 831                .layer(middleware::from_fn(auth::validate_header)),
 832        )
 833        .route("/metrics", get(handle_metrics))
 834        .layer(Extension(server))
 835}
 836
 837pub async fn handle_websocket_request(
 838    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 839    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 840    Extension(server): Extension<Arc<Server>>,
 841    Extension(user): Extension<User>,
 842    ws: WebSocketUpgrade,
 843) -> axum::response::Response {
 844    if protocol_version != rpc::PROTOCOL_VERSION {
 845        return (
 846            StatusCode::UPGRADE_REQUIRED,
 847            "client must be upgraded".to_string(),
 848        )
 849            .into_response();
 850    }
 851    let socket_address = socket_address.to_string();
 852    ws.on_upgrade(move |socket| {
 853        use util::ResultExt;
 854        let socket = socket
 855            .map_ok(to_tungstenite_message)
 856            .err_into()
 857            .with(|message| async move { Ok(to_axum_message(message)) });
 858        let connection = Connection::new(Box::pin(socket));
 859        async move {
 860            server
 861                .handle_connection(connection, socket_address, user, None, Executor::Production)
 862                .await
 863                .log_err();
 864        }
 865    })
 866}
 867
 868pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 869    let connections = server
 870        .connection_pool
 871        .lock()
 872        .connections()
 873        .filter(|connection| !connection.admin)
 874        .count();
 875
 876    METRIC_CONNECTIONS.set(connections as _);
 877
 878    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 879    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 880
 881    let encoder = prometheus::TextEncoder::new();
 882    let metric_families = prometheus::gather();
 883    let encoded_metrics = encoder
 884        .encode_to_string(&metric_families)
 885        .map_err(|err| anyhow!("{}", err))?;
 886    Ok(encoded_metrics)
 887}
 888
 889#[instrument(err, skip(executor))]
 890async fn connection_lost(
 891    session: Session,
 892    mut teardown: watch::Receiver<()>,
 893    executor: Executor,
 894) -> Result<()> {
 895    session.peer.disconnect(session.connection_id);
 896    session
 897        .connection_pool()
 898        .await
 899        .remove_connection(session.connection_id)?;
 900
 901    session
 902        .db()
 903        .await
 904        .connection_lost(session.connection_id)
 905        .await
 906        .trace_err();
 907
 908    futures::select_biased! {
 909        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 910            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 911            leave_room_for_session(&session).await.trace_err();
 912            leave_channel_buffers_for_session(&session)
 913                .await
 914                .trace_err();
 915
 916            if !session
 917                .connection_pool()
 918                .await
 919                .is_user_online(session.user_id)
 920            {
 921                let db = session.db().await;
 922                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 923                    room_updated(&room, &session.peer);
 924                }
 925            }
 926
 927            update_user_contacts(session.user_id, &session).await?;
 928        }
 929        _ = teardown.changed().fuse() => {}
 930    }
 931
 932    Ok(())
 933}
 934
 935async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 936    response.send(proto::Ack {})?;
 937    Ok(())
 938}
 939
 940async fn create_room(
 941    _request: proto::CreateRoom,
 942    response: Response<proto::CreateRoom>,
 943    session: Session,
 944) -> Result<()> {
 945    let live_kit_room = nanoid::nanoid!(30);
 946
 947    let live_kit_connection_info = {
 948        let live_kit_room = live_kit_room.clone();
 949        let live_kit = session.live_kit_client.as_ref();
 950
 951        util::async_maybe!({
 952            let live_kit = live_kit?;
 953
 954            let token = live_kit
 955                .room_token(&live_kit_room, &session.user_id.to_string())
 956                .trace_err()?;
 957
 958            Some(proto::LiveKitConnectionInfo {
 959                server_url: live_kit.url().into(),
 960                token,
 961                can_publish: true,
 962            })
 963        })
 964    }
 965    .await;
 966
 967    let room = session
 968        .db()
 969        .await
 970        .create_room(
 971            session.user_id,
 972            session.connection_id,
 973            &live_kit_room,
 974            &session.zed_environment,
 975        )
 976        .await?;
 977
 978    response.send(proto::CreateRoomResponse {
 979        room: Some(room.clone()),
 980        live_kit_connection_info,
 981    })?;
 982
 983    update_user_contacts(session.user_id, &session).await?;
 984    Ok(())
 985}
 986
 987async fn join_room(
 988    request: proto::JoinRoom,
 989    response: Response<proto::JoinRoom>,
 990    session: Session,
 991) -> Result<()> {
 992    let room_id = RoomId::from_proto(request.id);
 993
 994    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 995
 996    if let Some(channel_id) = channel_id {
 997        return join_channel_internal(channel_id, Box::new(response), session).await;
 998    }
 999
1000    let joined_room = {
1001        let room = session
1002            .db()
1003            .await
1004            .join_room(
1005                room_id,
1006                session.user_id,
1007                session.connection_id,
1008                session.zed_environment.as_ref(),
1009            )
1010            .await?;
1011        room_updated(&room.room, &session.peer);
1012        room.into_inner()
1013    };
1014
1015    for connection_id in session
1016        .connection_pool()
1017        .await
1018        .user_connection_ids(session.user_id)
1019    {
1020        session
1021            .peer
1022            .send(
1023                connection_id,
1024                proto::CallCanceled {
1025                    room_id: room_id.to_proto(),
1026                },
1027            )
1028            .trace_err();
1029    }
1030
1031    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1032        if let Some(token) = live_kit
1033            .room_token(
1034                &joined_room.room.live_kit_room,
1035                &session.user_id.to_string(),
1036            )
1037            .trace_err()
1038        {
1039            Some(proto::LiveKitConnectionInfo {
1040                server_url: live_kit.url().into(),
1041                token,
1042                can_publish: true,
1043            })
1044        } else {
1045            None
1046        }
1047    } else {
1048        None
1049    };
1050
1051    response.send(proto::JoinRoomResponse {
1052        room: Some(joined_room.room),
1053        channel_id: None,
1054        live_kit_connection_info,
1055    })?;
1056
1057    update_user_contacts(session.user_id, &session).await?;
1058    Ok(())
1059}
1060
1061async fn rejoin_room(
1062    request: proto::RejoinRoom,
1063    response: Response<proto::RejoinRoom>,
1064    session: Session,
1065) -> Result<()> {
1066    let room;
1067    let channel_id;
1068    let channel_members;
1069    {
1070        let mut rejoined_room = session
1071            .db()
1072            .await
1073            .rejoin_room(request, session.user_id, session.connection_id)
1074            .await?;
1075
1076        response.send(proto::RejoinRoomResponse {
1077            room: Some(rejoined_room.room.clone()),
1078            reshared_projects: rejoined_room
1079                .reshared_projects
1080                .iter()
1081                .map(|project| proto::ResharedProject {
1082                    id: project.id.to_proto(),
1083                    collaborators: project
1084                        .collaborators
1085                        .iter()
1086                        .map(|collaborator| collaborator.to_proto())
1087                        .collect(),
1088                })
1089                .collect(),
1090            rejoined_projects: rejoined_room
1091                .rejoined_projects
1092                .iter()
1093                .map(|rejoined_project| proto::RejoinedProject {
1094                    id: rejoined_project.id.to_proto(),
1095                    worktrees: rejoined_project
1096                        .worktrees
1097                        .iter()
1098                        .map(|worktree| proto::WorktreeMetadata {
1099                            id: worktree.id,
1100                            root_name: worktree.root_name.clone(),
1101                            visible: worktree.visible,
1102                            abs_path: worktree.abs_path.clone(),
1103                        })
1104                        .collect(),
1105                    collaborators: rejoined_project
1106                        .collaborators
1107                        .iter()
1108                        .map(|collaborator| collaborator.to_proto())
1109                        .collect(),
1110                    language_servers: rejoined_project.language_servers.clone(),
1111                })
1112                .collect(),
1113        })?;
1114        room_updated(&rejoined_room.room, &session.peer);
1115
1116        for project in &rejoined_room.reshared_projects {
1117            for collaborator in &project.collaborators {
1118                session
1119                    .peer
1120                    .send(
1121                        collaborator.connection_id,
1122                        proto::UpdateProjectCollaborator {
1123                            project_id: project.id.to_proto(),
1124                            old_peer_id: Some(project.old_connection_id.into()),
1125                            new_peer_id: Some(session.connection_id.into()),
1126                        },
1127                    )
1128                    .trace_err();
1129            }
1130
1131            broadcast(
1132                Some(session.connection_id),
1133                project
1134                    .collaborators
1135                    .iter()
1136                    .map(|collaborator| collaborator.connection_id),
1137                |connection_id| {
1138                    session.peer.forward_send(
1139                        session.connection_id,
1140                        connection_id,
1141                        proto::UpdateProject {
1142                            project_id: project.id.to_proto(),
1143                            worktrees: project.worktrees.clone(),
1144                        },
1145                    )
1146                },
1147            );
1148        }
1149
1150        for project in &rejoined_room.rejoined_projects {
1151            for collaborator in &project.collaborators {
1152                session
1153                    .peer
1154                    .send(
1155                        collaborator.connection_id,
1156                        proto::UpdateProjectCollaborator {
1157                            project_id: project.id.to_proto(),
1158                            old_peer_id: Some(project.old_connection_id.into()),
1159                            new_peer_id: Some(session.connection_id.into()),
1160                        },
1161                    )
1162                    .trace_err();
1163            }
1164        }
1165
1166        for project in &mut rejoined_room.rejoined_projects {
1167            for worktree in mem::take(&mut project.worktrees) {
1168                #[cfg(any(test, feature = "test-support"))]
1169                const MAX_CHUNK_SIZE: usize = 2;
1170                #[cfg(not(any(test, feature = "test-support")))]
1171                const MAX_CHUNK_SIZE: usize = 256;
1172
1173                // Stream this worktree's entries.
1174                let message = proto::UpdateWorktree {
1175                    project_id: project.id.to_proto(),
1176                    worktree_id: worktree.id,
1177                    abs_path: worktree.abs_path.clone(),
1178                    root_name: worktree.root_name,
1179                    updated_entries: worktree.updated_entries,
1180                    removed_entries: worktree.removed_entries,
1181                    scan_id: worktree.scan_id,
1182                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1183                    updated_repositories: worktree.updated_repositories,
1184                    removed_repositories: worktree.removed_repositories,
1185                };
1186                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1187                    session.peer.send(session.connection_id, update.clone())?;
1188                }
1189
1190                // Stream this worktree's diagnostics.
1191                for summary in worktree.diagnostic_summaries {
1192                    session.peer.send(
1193                        session.connection_id,
1194                        proto::UpdateDiagnosticSummary {
1195                            project_id: project.id.to_proto(),
1196                            worktree_id: worktree.id,
1197                            summary: Some(summary),
1198                        },
1199                    )?;
1200                }
1201
1202                for settings_file in worktree.settings_files {
1203                    session.peer.send(
1204                        session.connection_id,
1205                        proto::UpdateWorktreeSettings {
1206                            project_id: project.id.to_proto(),
1207                            worktree_id: worktree.id,
1208                            path: settings_file.path,
1209                            content: Some(settings_file.content),
1210                        },
1211                    )?;
1212                }
1213            }
1214
1215            for language_server in &project.language_servers {
1216                session.peer.send(
1217                    session.connection_id,
1218                    proto::UpdateLanguageServer {
1219                        project_id: project.id.to_proto(),
1220                        language_server_id: language_server.id,
1221                        variant: Some(
1222                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1223                                proto::LspDiskBasedDiagnosticsUpdated {},
1224                            ),
1225                        ),
1226                    },
1227                )?;
1228            }
1229        }
1230
1231        let rejoined_room = rejoined_room.into_inner();
1232
1233        room = rejoined_room.room;
1234        channel_id = rejoined_room.channel_id;
1235        channel_members = rejoined_room.channel_members;
1236    }
1237
1238    if let Some(channel_id) = channel_id {
1239        channel_updated(
1240            channel_id,
1241            &room,
1242            &channel_members,
1243            &session.peer,
1244            &*session.connection_pool().await,
1245        );
1246    }
1247
1248    update_user_contacts(session.user_id, &session).await?;
1249    Ok(())
1250}
1251
1252async fn leave_room(
1253    _: proto::LeaveRoom,
1254    response: Response<proto::LeaveRoom>,
1255    session: Session,
1256) -> Result<()> {
1257    leave_room_for_session(&session).await?;
1258    response.send(proto::Ack {})?;
1259    Ok(())
1260}
1261
1262async fn set_room_participant_role(
1263    request: proto::SetRoomParticipantRole,
1264    response: Response<proto::SetRoomParticipantRole>,
1265    session: Session,
1266) -> Result<()> {
1267    let (live_kit_room, can_publish) = {
1268        let room = session
1269            .db()
1270            .await
1271            .set_room_participant_role(
1272                session.user_id,
1273                RoomId::from_proto(request.room_id),
1274                UserId::from_proto(request.user_id),
1275                ChannelRole::from(request.role()),
1276            )
1277            .await?;
1278
1279        let live_kit_room = room.live_kit_room.clone();
1280        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1281        room_updated(&room, &session.peer);
1282        (live_kit_room, can_publish)
1283    };
1284
1285    if let Some(live_kit) = session.live_kit_client.as_ref() {
1286        live_kit
1287            .update_participant(
1288                live_kit_room.clone(),
1289                request.user_id.to_string(),
1290                live_kit_server::proto::ParticipantPermission {
1291                    can_subscribe: true,
1292                    can_publish,
1293                    can_publish_data: can_publish,
1294                    hidden: false,
1295                    recorder: false,
1296                },
1297            )
1298            .await
1299            .trace_err();
1300    }
1301
1302    response.send(proto::Ack {})?;
1303    Ok(())
1304}
1305
1306async fn call(
1307    request: proto::Call,
1308    response: Response<proto::Call>,
1309    session: Session,
1310) -> Result<()> {
1311    let room_id = RoomId::from_proto(request.room_id);
1312    let calling_user_id = session.user_id;
1313    let calling_connection_id = session.connection_id;
1314    let called_user_id = UserId::from_proto(request.called_user_id);
1315    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1316    if !session
1317        .db()
1318        .await
1319        .has_contact(calling_user_id, called_user_id)
1320        .await?
1321    {
1322        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1323    }
1324
1325    let incoming_call = {
1326        let (room, incoming_call) = &mut *session
1327            .db()
1328            .await
1329            .call(
1330                room_id,
1331                calling_user_id,
1332                calling_connection_id,
1333                called_user_id,
1334                initial_project_id,
1335            )
1336            .await?;
1337        room_updated(&room, &session.peer);
1338        mem::take(incoming_call)
1339    };
1340    update_user_contacts(called_user_id, &session).await?;
1341
1342    let mut calls = session
1343        .connection_pool()
1344        .await
1345        .user_connection_ids(called_user_id)
1346        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1347        .collect::<FuturesUnordered<_>>();
1348
1349    while let Some(call_response) = calls.next().await {
1350        match call_response.as_ref() {
1351            Ok(_) => {
1352                response.send(proto::Ack {})?;
1353                return Ok(());
1354            }
1355            Err(_) => {
1356                call_response.trace_err();
1357            }
1358        }
1359    }
1360
1361    {
1362        let room = session
1363            .db()
1364            .await
1365            .call_failed(room_id, called_user_id)
1366            .await?;
1367        room_updated(&room, &session.peer);
1368    }
1369    update_user_contacts(called_user_id, &session).await?;
1370
1371    Err(anyhow!("failed to ring user"))?
1372}
1373
1374async fn cancel_call(
1375    request: proto::CancelCall,
1376    response: Response<proto::CancelCall>,
1377    session: Session,
1378) -> Result<()> {
1379    let called_user_id = UserId::from_proto(request.called_user_id);
1380    let room_id = RoomId::from_proto(request.room_id);
1381    {
1382        let room = session
1383            .db()
1384            .await
1385            .cancel_call(room_id, session.connection_id, called_user_id)
1386            .await?;
1387        room_updated(&room, &session.peer);
1388    }
1389
1390    for connection_id in session
1391        .connection_pool()
1392        .await
1393        .user_connection_ids(called_user_id)
1394    {
1395        session
1396            .peer
1397            .send(
1398                connection_id,
1399                proto::CallCanceled {
1400                    room_id: room_id.to_proto(),
1401                },
1402            )
1403            .trace_err();
1404    }
1405    response.send(proto::Ack {})?;
1406
1407    update_user_contacts(called_user_id, &session).await?;
1408    Ok(())
1409}
1410
1411async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1412    let room_id = RoomId::from_proto(message.room_id);
1413    {
1414        let room = session
1415            .db()
1416            .await
1417            .decline_call(Some(room_id), session.user_id)
1418            .await?
1419            .ok_or_else(|| anyhow!("failed to decline call"))?;
1420        room_updated(&room, &session.peer);
1421    }
1422
1423    for connection_id in session
1424        .connection_pool()
1425        .await
1426        .user_connection_ids(session.user_id)
1427    {
1428        session
1429            .peer
1430            .send(
1431                connection_id,
1432                proto::CallCanceled {
1433                    room_id: room_id.to_proto(),
1434                },
1435            )
1436            .trace_err();
1437    }
1438    update_user_contacts(session.user_id, &session).await?;
1439    Ok(())
1440}
1441
1442async fn update_participant_location(
1443    request: proto::UpdateParticipantLocation,
1444    response: Response<proto::UpdateParticipantLocation>,
1445    session: Session,
1446) -> Result<()> {
1447    let room_id = RoomId::from_proto(request.room_id);
1448    let location = request
1449        .location
1450        .ok_or_else(|| anyhow!("invalid location"))?;
1451
1452    let db = session.db().await;
1453    let room = db
1454        .update_room_participant_location(room_id, session.connection_id, location)
1455        .await?;
1456
1457    room_updated(&room, &session.peer);
1458    response.send(proto::Ack {})?;
1459    Ok(())
1460}
1461
1462async fn share_project(
1463    request: proto::ShareProject,
1464    response: Response<proto::ShareProject>,
1465    session: Session,
1466) -> Result<()> {
1467    let (project_id, room) = &*session
1468        .db()
1469        .await
1470        .share_project(
1471            RoomId::from_proto(request.room_id),
1472            session.connection_id,
1473            &request.worktrees,
1474        )
1475        .await?;
1476    response.send(proto::ShareProjectResponse {
1477        project_id: project_id.to_proto(),
1478    })?;
1479    room_updated(&room, &session.peer);
1480
1481    Ok(())
1482}
1483
1484async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1485    let project_id = ProjectId::from_proto(message.project_id);
1486
1487    let (room, guest_connection_ids) = &*session
1488        .db()
1489        .await
1490        .unshare_project(project_id, session.connection_id)
1491        .await?;
1492
1493    broadcast(
1494        Some(session.connection_id),
1495        guest_connection_ids.iter().copied(),
1496        |conn_id| session.peer.send(conn_id, message.clone()),
1497    );
1498    room_updated(&room, &session.peer);
1499
1500    Ok(())
1501}
1502
1503async fn join_project(
1504    request: proto::JoinProject,
1505    response: Response<proto::JoinProject>,
1506    session: Session,
1507) -> Result<()> {
1508    let project_id = ProjectId::from_proto(request.project_id);
1509    let guest_user_id = session.user_id;
1510
1511    tracing::info!(%project_id, "join project");
1512
1513    let (project, replica_id) = &mut *session
1514        .db()
1515        .await
1516        .join_project(project_id, session.connection_id)
1517        .await?;
1518
1519    let collaborators = project
1520        .collaborators
1521        .iter()
1522        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1523        .map(|collaborator| collaborator.to_proto())
1524        .collect::<Vec<_>>();
1525
1526    let worktrees = project
1527        .worktrees
1528        .iter()
1529        .map(|(id, worktree)| proto::WorktreeMetadata {
1530            id: *id,
1531            root_name: worktree.root_name.clone(),
1532            visible: worktree.visible,
1533            abs_path: worktree.abs_path.clone(),
1534        })
1535        .collect::<Vec<_>>();
1536
1537    for collaborator in &collaborators {
1538        session
1539            .peer
1540            .send(
1541                collaborator.peer_id.unwrap().into(),
1542                proto::AddProjectCollaborator {
1543                    project_id: project_id.to_proto(),
1544                    collaborator: Some(proto::Collaborator {
1545                        peer_id: Some(session.connection_id.into()),
1546                        replica_id: replica_id.0 as u32,
1547                        user_id: guest_user_id.to_proto(),
1548                    }),
1549                },
1550            )
1551            .trace_err();
1552    }
1553
1554    // First, we send the metadata associated with each worktree.
1555    response.send(proto::JoinProjectResponse {
1556        worktrees: worktrees.clone(),
1557        replica_id: replica_id.0 as u32,
1558        collaborators: collaborators.clone(),
1559        language_servers: project.language_servers.clone(),
1560    })?;
1561
1562    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1563        #[cfg(any(test, feature = "test-support"))]
1564        const MAX_CHUNK_SIZE: usize = 2;
1565        #[cfg(not(any(test, feature = "test-support")))]
1566        const MAX_CHUNK_SIZE: usize = 256;
1567
1568        // Stream this worktree's entries.
1569        let message = proto::UpdateWorktree {
1570            project_id: project_id.to_proto(),
1571            worktree_id,
1572            abs_path: worktree.abs_path.clone(),
1573            root_name: worktree.root_name,
1574            updated_entries: worktree.entries,
1575            removed_entries: Default::default(),
1576            scan_id: worktree.scan_id,
1577            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1578            updated_repositories: worktree.repository_entries.into_values().collect(),
1579            removed_repositories: Default::default(),
1580        };
1581        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1582            session.peer.send(session.connection_id, update.clone())?;
1583        }
1584
1585        // Stream this worktree's diagnostics.
1586        for summary in worktree.diagnostic_summaries {
1587            session.peer.send(
1588                session.connection_id,
1589                proto::UpdateDiagnosticSummary {
1590                    project_id: project_id.to_proto(),
1591                    worktree_id: worktree.id,
1592                    summary: Some(summary),
1593                },
1594            )?;
1595        }
1596
1597        for settings_file in worktree.settings_files {
1598            session.peer.send(
1599                session.connection_id,
1600                proto::UpdateWorktreeSettings {
1601                    project_id: project_id.to_proto(),
1602                    worktree_id: worktree.id,
1603                    path: settings_file.path,
1604                    content: Some(settings_file.content),
1605                },
1606            )?;
1607        }
1608    }
1609
1610    for language_server in &project.language_servers {
1611        session.peer.send(
1612            session.connection_id,
1613            proto::UpdateLanguageServer {
1614                project_id: project_id.to_proto(),
1615                language_server_id: language_server.id,
1616                variant: Some(
1617                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1618                        proto::LspDiskBasedDiagnosticsUpdated {},
1619                    ),
1620                ),
1621            },
1622        )?;
1623    }
1624
1625    Ok(())
1626}
1627
1628async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1629    let sender_id = session.connection_id;
1630    let project_id = ProjectId::from_proto(request.project_id);
1631
1632    let (room, project) = &*session
1633        .db()
1634        .await
1635        .leave_project(project_id, sender_id)
1636        .await?;
1637    tracing::info!(
1638        %project_id,
1639        host_user_id = %project.host_user_id,
1640        host_connection_id = %project.host_connection_id,
1641        "leave project"
1642    );
1643
1644    project_left(&project, &session);
1645    room_updated(&room, &session.peer);
1646
1647    Ok(())
1648}
1649
1650async fn update_project(
1651    request: proto::UpdateProject,
1652    response: Response<proto::UpdateProject>,
1653    session: Session,
1654) -> Result<()> {
1655    let project_id = ProjectId::from_proto(request.project_id);
1656    let (room, guest_connection_ids) = &*session
1657        .db()
1658        .await
1659        .update_project(project_id, session.connection_id, &request.worktrees)
1660        .await?;
1661    broadcast(
1662        Some(session.connection_id),
1663        guest_connection_ids.iter().copied(),
1664        |connection_id| {
1665            session
1666                .peer
1667                .forward_send(session.connection_id, connection_id, request.clone())
1668        },
1669    );
1670    room_updated(&room, &session.peer);
1671    response.send(proto::Ack {})?;
1672
1673    Ok(())
1674}
1675
1676async fn update_worktree(
1677    request: proto::UpdateWorktree,
1678    response: Response<proto::UpdateWorktree>,
1679    session: Session,
1680) -> Result<()> {
1681    let guest_connection_ids = session
1682        .db()
1683        .await
1684        .update_worktree(&request, session.connection_id)
1685        .await?;
1686
1687    broadcast(
1688        Some(session.connection_id),
1689        guest_connection_ids.iter().copied(),
1690        |connection_id| {
1691            session
1692                .peer
1693                .forward_send(session.connection_id, connection_id, request.clone())
1694        },
1695    );
1696    response.send(proto::Ack {})?;
1697    Ok(())
1698}
1699
1700async fn update_diagnostic_summary(
1701    message: proto::UpdateDiagnosticSummary,
1702    session: Session,
1703) -> Result<()> {
1704    let guest_connection_ids = session
1705        .db()
1706        .await
1707        .update_diagnostic_summary(&message, session.connection_id)
1708        .await?;
1709
1710    broadcast(
1711        Some(session.connection_id),
1712        guest_connection_ids.iter().copied(),
1713        |connection_id| {
1714            session
1715                .peer
1716                .forward_send(session.connection_id, connection_id, message.clone())
1717        },
1718    );
1719
1720    Ok(())
1721}
1722
1723async fn update_worktree_settings(
1724    message: proto::UpdateWorktreeSettings,
1725    session: Session,
1726) -> Result<()> {
1727    let guest_connection_ids = session
1728        .db()
1729        .await
1730        .update_worktree_settings(&message, session.connection_id)
1731        .await?;
1732
1733    broadcast(
1734        Some(session.connection_id),
1735        guest_connection_ids.iter().copied(),
1736        |connection_id| {
1737            session
1738                .peer
1739                .forward_send(session.connection_id, connection_id, message.clone())
1740        },
1741    );
1742
1743    Ok(())
1744}
1745
1746async fn start_language_server(
1747    request: proto::StartLanguageServer,
1748    session: Session,
1749) -> Result<()> {
1750    let guest_connection_ids = session
1751        .db()
1752        .await
1753        .start_language_server(&request, session.connection_id)
1754        .await?;
1755
1756    broadcast(
1757        Some(session.connection_id),
1758        guest_connection_ids.iter().copied(),
1759        |connection_id| {
1760            session
1761                .peer
1762                .forward_send(session.connection_id, connection_id, request.clone())
1763        },
1764    );
1765    Ok(())
1766}
1767
1768async fn update_language_server(
1769    request: proto::UpdateLanguageServer,
1770    session: Session,
1771) -> Result<()> {
1772    let project_id = ProjectId::from_proto(request.project_id);
1773    let project_connection_ids = session
1774        .db()
1775        .await
1776        .project_connection_ids(project_id, session.connection_id)
1777        .await?;
1778    broadcast(
1779        Some(session.connection_id),
1780        project_connection_ids.iter().copied(),
1781        |connection_id| {
1782            session
1783                .peer
1784                .forward_send(session.connection_id, connection_id, request.clone())
1785        },
1786    );
1787    Ok(())
1788}
1789
1790async fn forward_read_only_project_request<T>(
1791    request: T,
1792    response: Response<T>,
1793    session: Session,
1794) -> Result<()>
1795where
1796    T: EntityMessage + RequestMessage,
1797{
1798    let project_id = ProjectId::from_proto(request.remote_entity_id());
1799    let host_connection_id = session
1800        .db()
1801        .await
1802        .host_for_read_only_project_request(project_id, session.connection_id)
1803        .await?;
1804    let payload = session
1805        .peer
1806        .forward_request(session.connection_id, host_connection_id, request)
1807        .await?;
1808    response.send(payload)?;
1809    Ok(())
1810}
1811
1812async fn forward_mutating_project_request<T>(
1813    request: T,
1814    response: Response<T>,
1815    session: Session,
1816) -> Result<()>
1817where
1818    T: EntityMessage + RequestMessage,
1819{
1820    let project_id = ProjectId::from_proto(request.remote_entity_id());
1821    let host_connection_id = session
1822        .db()
1823        .await
1824        .host_for_mutating_project_request(project_id, session.connection_id)
1825        .await?;
1826    let payload = session
1827        .peer
1828        .forward_request(session.connection_id, host_connection_id, request)
1829        .await?;
1830    response.send(payload)?;
1831    Ok(())
1832}
1833
1834async fn create_buffer_for_peer(
1835    request: proto::CreateBufferForPeer,
1836    session: Session,
1837) -> Result<()> {
1838    session
1839        .db()
1840        .await
1841        .check_user_is_project_host(
1842            ProjectId::from_proto(request.project_id),
1843            session.connection_id,
1844        )
1845        .await?;
1846    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1847    session
1848        .peer
1849        .forward_send(session.connection_id, peer_id.into(), request)?;
1850    Ok(())
1851}
1852
1853async fn update_buffer(
1854    request: proto::UpdateBuffer,
1855    response: Response<proto::UpdateBuffer>,
1856    session: Session,
1857) -> Result<()> {
1858    let project_id = ProjectId::from_proto(request.project_id);
1859    let mut guest_connection_ids;
1860    let mut host_connection_id = None;
1861
1862    {
1863        let collaborators = session
1864            .db()
1865            .await
1866            .project_collaborators_for_buffer_update(project_id, session.connection_id)
1867            .await?;
1868        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1869        for collaborator in collaborators.iter() {
1870            if collaborator.is_host {
1871                host_connection_id = Some(collaborator.connection_id);
1872            } else {
1873                guest_connection_ids.push(collaborator.connection_id);
1874            }
1875        }
1876    }
1877    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1878
1879    broadcast(
1880        Some(session.connection_id),
1881        guest_connection_ids,
1882        |connection_id| {
1883            session
1884                .peer
1885                .forward_send(session.connection_id, connection_id, request.clone())
1886        },
1887    );
1888    if host_connection_id != session.connection_id {
1889        session
1890            .peer
1891            .forward_request(session.connection_id, host_connection_id, request.clone())
1892            .await?;
1893    }
1894
1895    response.send(proto::Ack {})?;
1896    Ok(())
1897}
1898
1899async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1900    request: T,
1901    session: Session,
1902) -> Result<()> {
1903    let project_id = ProjectId::from_proto(request.remote_entity_id());
1904    let project_connection_ids = session
1905        .db()
1906        .await
1907        .project_connection_ids(project_id, session.connection_id)
1908        .await?;
1909
1910    broadcast(
1911        Some(session.connection_id),
1912        project_connection_ids.iter().copied(),
1913        |connection_id| {
1914            session
1915                .peer
1916                .forward_send(session.connection_id, connection_id, request.clone())
1917        },
1918    );
1919    Ok(())
1920}
1921
1922async fn follow(
1923    request: proto::Follow,
1924    response: Response<proto::Follow>,
1925    session: Session,
1926) -> Result<()> {
1927    let room_id = RoomId::from_proto(request.room_id);
1928    let project_id = request.project_id.map(ProjectId::from_proto);
1929    let leader_id = request
1930        .leader_id
1931        .ok_or_else(|| anyhow!("invalid leader id"))?
1932        .into();
1933    let follower_id = session.connection_id;
1934
1935    session
1936        .db()
1937        .await
1938        .check_room_participants(room_id, leader_id, session.connection_id)
1939        .await?;
1940
1941    let response_payload = session
1942        .peer
1943        .forward_request(session.connection_id, leader_id, request)
1944        .await?;
1945    response.send(response_payload)?;
1946
1947    if let Some(project_id) = project_id {
1948        let room = session
1949            .db()
1950            .await
1951            .follow(room_id, project_id, leader_id, follower_id)
1952            .await?;
1953        room_updated(&room, &session.peer);
1954    }
1955
1956    Ok(())
1957}
1958
1959async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1960    let room_id = RoomId::from_proto(request.room_id);
1961    let project_id = request.project_id.map(ProjectId::from_proto);
1962    let leader_id = request
1963        .leader_id
1964        .ok_or_else(|| anyhow!("invalid leader id"))?
1965        .into();
1966    let follower_id = session.connection_id;
1967
1968    session
1969        .db()
1970        .await
1971        .check_room_participants(room_id, leader_id, session.connection_id)
1972        .await?;
1973
1974    session
1975        .peer
1976        .forward_send(session.connection_id, leader_id, request)?;
1977
1978    if let Some(project_id) = project_id {
1979        let room = session
1980            .db()
1981            .await
1982            .unfollow(room_id, project_id, leader_id, follower_id)
1983            .await?;
1984        room_updated(&room, &session.peer);
1985    }
1986
1987    Ok(())
1988}
1989
1990async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1991    let room_id = RoomId::from_proto(request.room_id);
1992    let database = session.db.lock().await;
1993
1994    let connection_ids = if let Some(project_id) = request.project_id {
1995        let project_id = ProjectId::from_proto(project_id);
1996        database
1997            .project_connection_ids(project_id, session.connection_id)
1998            .await?
1999    } else {
2000        database
2001            .room_connection_ids(room_id, session.connection_id)
2002            .await?
2003    };
2004
2005    // For now, don't send view update messages back to that view's current leader.
2006    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2007        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2008        _ => None,
2009    });
2010
2011    for follower_peer_id in request.follower_ids.iter().copied() {
2012        let follower_connection_id = follower_peer_id.into();
2013        if Some(follower_peer_id) != connection_id_to_omit
2014            && connection_ids.contains(&follower_connection_id)
2015        {
2016            session.peer.forward_send(
2017                session.connection_id,
2018                follower_connection_id,
2019                request.clone(),
2020            )?;
2021        }
2022    }
2023    Ok(())
2024}
2025
2026async fn get_users(
2027    request: proto::GetUsers,
2028    response: Response<proto::GetUsers>,
2029    session: Session,
2030) -> Result<()> {
2031    let user_ids = request
2032        .user_ids
2033        .into_iter()
2034        .map(UserId::from_proto)
2035        .collect();
2036    let users = session
2037        .db()
2038        .await
2039        .get_users_by_ids(user_ids)
2040        .await?
2041        .into_iter()
2042        .map(|user| proto::User {
2043            id: user.id.to_proto(),
2044            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2045            github_login: user.github_login,
2046        })
2047        .collect();
2048    response.send(proto::UsersResponse { users })?;
2049    Ok(())
2050}
2051
2052async fn fuzzy_search_users(
2053    request: proto::FuzzySearchUsers,
2054    response: Response<proto::FuzzySearchUsers>,
2055    session: Session,
2056) -> Result<()> {
2057    let query = request.query;
2058    let users = match query.len() {
2059        0 => vec![],
2060        1 | 2 => session
2061            .db()
2062            .await
2063            .get_user_by_github_login(&query)
2064            .await?
2065            .into_iter()
2066            .collect(),
2067        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2068    };
2069    let users = users
2070        .into_iter()
2071        .filter(|user| user.id != session.user_id)
2072        .map(|user| proto::User {
2073            id: user.id.to_proto(),
2074            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2075            github_login: user.github_login,
2076        })
2077        .collect();
2078    response.send(proto::UsersResponse { users })?;
2079    Ok(())
2080}
2081
2082async fn request_contact(
2083    request: proto::RequestContact,
2084    response: Response<proto::RequestContact>,
2085    session: Session,
2086) -> Result<()> {
2087    let requester_id = session.user_id;
2088    let responder_id = UserId::from_proto(request.responder_id);
2089    if requester_id == responder_id {
2090        return Err(anyhow!("cannot add yourself as a contact"))?;
2091    }
2092
2093    let notifications = session
2094        .db()
2095        .await
2096        .send_contact_request(requester_id, responder_id)
2097        .await?;
2098
2099    // Update outgoing contact requests of requester
2100    let mut update = proto::UpdateContacts::default();
2101    update.outgoing_requests.push(responder_id.to_proto());
2102    for connection_id in session
2103        .connection_pool()
2104        .await
2105        .user_connection_ids(requester_id)
2106    {
2107        session.peer.send(connection_id, update.clone())?;
2108    }
2109
2110    // Update incoming contact requests of responder
2111    let mut update = proto::UpdateContacts::default();
2112    update
2113        .incoming_requests
2114        .push(proto::IncomingContactRequest {
2115            requester_id: requester_id.to_proto(),
2116        });
2117    let connection_pool = session.connection_pool().await;
2118    for connection_id in connection_pool.user_connection_ids(responder_id) {
2119        session.peer.send(connection_id, update.clone())?;
2120    }
2121
2122    send_notifications(&*connection_pool, &session.peer, notifications);
2123
2124    response.send(proto::Ack {})?;
2125    Ok(())
2126}
2127
2128async fn respond_to_contact_request(
2129    request: proto::RespondToContactRequest,
2130    response: Response<proto::RespondToContactRequest>,
2131    session: Session,
2132) -> Result<()> {
2133    let responder_id = session.user_id;
2134    let requester_id = UserId::from_proto(request.requester_id);
2135    let db = session.db().await;
2136    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2137        db.dismiss_contact_notification(responder_id, requester_id)
2138            .await?;
2139    } else {
2140        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2141
2142        let notifications = db
2143            .respond_to_contact_request(responder_id, requester_id, accept)
2144            .await?;
2145        let requester_busy = db.is_user_busy(requester_id).await?;
2146        let responder_busy = db.is_user_busy(responder_id).await?;
2147
2148        let pool = session.connection_pool().await;
2149        // Update responder with new contact
2150        let mut update = proto::UpdateContacts::default();
2151        if accept {
2152            update
2153                .contacts
2154                .push(contact_for_user(requester_id, requester_busy, &pool));
2155        }
2156        update
2157            .remove_incoming_requests
2158            .push(requester_id.to_proto());
2159        for connection_id in pool.user_connection_ids(responder_id) {
2160            session.peer.send(connection_id, update.clone())?;
2161        }
2162
2163        // Update requester with new contact
2164        let mut update = proto::UpdateContacts::default();
2165        if accept {
2166            update
2167                .contacts
2168                .push(contact_for_user(responder_id, responder_busy, &pool));
2169        }
2170        update
2171            .remove_outgoing_requests
2172            .push(responder_id.to_proto());
2173
2174        for connection_id in pool.user_connection_ids(requester_id) {
2175            session.peer.send(connection_id, update.clone())?;
2176        }
2177
2178        send_notifications(&*pool, &session.peer, notifications);
2179    }
2180
2181    response.send(proto::Ack {})?;
2182    Ok(())
2183}
2184
2185async fn remove_contact(
2186    request: proto::RemoveContact,
2187    response: Response<proto::RemoveContact>,
2188    session: Session,
2189) -> Result<()> {
2190    let requester_id = session.user_id;
2191    let responder_id = UserId::from_proto(request.user_id);
2192    let db = session.db().await;
2193    let (contact_accepted, deleted_notification_id) =
2194        db.remove_contact(requester_id, responder_id).await?;
2195
2196    let pool = session.connection_pool().await;
2197    // Update outgoing contact requests of requester
2198    let mut update = proto::UpdateContacts::default();
2199    if contact_accepted {
2200        update.remove_contacts.push(responder_id.to_proto());
2201    } else {
2202        update
2203            .remove_outgoing_requests
2204            .push(responder_id.to_proto());
2205    }
2206    for connection_id in pool.user_connection_ids(requester_id) {
2207        session.peer.send(connection_id, update.clone())?;
2208    }
2209
2210    // Update incoming contact requests of responder
2211    let mut update = proto::UpdateContacts::default();
2212    if contact_accepted {
2213        update.remove_contacts.push(requester_id.to_proto());
2214    } else {
2215        update
2216            .remove_incoming_requests
2217            .push(requester_id.to_proto());
2218    }
2219    for connection_id in pool.user_connection_ids(responder_id) {
2220        session.peer.send(connection_id, update.clone())?;
2221        if let Some(notification_id) = deleted_notification_id {
2222            session.peer.send(
2223                connection_id,
2224                proto::DeleteNotification {
2225                    notification_id: notification_id.to_proto(),
2226                },
2227            )?;
2228        }
2229    }
2230
2231    response.send(proto::Ack {})?;
2232    Ok(())
2233}
2234
2235async fn create_channel(
2236    request: proto::CreateChannel,
2237    response: Response<proto::CreateChannel>,
2238    session: Session,
2239) -> Result<()> {
2240    let db = session.db().await;
2241
2242    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2243    let CreateChannelResult {
2244        channel,
2245        participants_to_update,
2246    } = db
2247        .create_channel(&request.name, parent_id, session.user_id)
2248        .await?;
2249
2250    response.send(proto::CreateChannelResponse {
2251        channel: Some(channel.to_proto()),
2252        parent_id: request.parent_id,
2253    })?;
2254
2255    let connection_pool = session.connection_pool().await;
2256    for (user_id, channels) in participants_to_update {
2257        let update = build_channels_update(channels, vec![]);
2258        for connection_id in connection_pool.user_connection_ids(user_id) {
2259            if user_id == session.user_id {
2260                continue;
2261            }
2262            session.peer.send(connection_id, update.clone())?;
2263        }
2264    }
2265
2266    Ok(())
2267}
2268
2269async fn delete_channel(
2270    request: proto::DeleteChannel,
2271    response: Response<proto::DeleteChannel>,
2272    session: Session,
2273) -> Result<()> {
2274    let db = session.db().await;
2275
2276    let channel_id = request.channel_id;
2277    let (removed_channels, member_ids) = db
2278        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2279        .await?;
2280    response.send(proto::Ack {})?;
2281
2282    // Notify members of removed channels
2283    let mut update = proto::UpdateChannels::default();
2284    update
2285        .delete_channels
2286        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2287
2288    let connection_pool = session.connection_pool().await;
2289    for member_id in member_ids {
2290        for connection_id in connection_pool.user_connection_ids(member_id) {
2291            session.peer.send(connection_id, update.clone())?;
2292        }
2293    }
2294
2295    Ok(())
2296}
2297
2298async fn invite_channel_member(
2299    request: proto::InviteChannelMember,
2300    response: Response<proto::InviteChannelMember>,
2301    session: Session,
2302) -> Result<()> {
2303    let db = session.db().await;
2304    let channel_id = ChannelId::from_proto(request.channel_id);
2305    let invitee_id = UserId::from_proto(request.user_id);
2306    let InviteMemberResult {
2307        channel,
2308        notifications,
2309    } = db
2310        .invite_channel_member(
2311            channel_id,
2312            invitee_id,
2313            session.user_id,
2314            request.role().into(),
2315        )
2316        .await?;
2317
2318    let update = proto::UpdateChannels {
2319        channel_invitations: vec![channel.to_proto()],
2320        ..Default::default()
2321    };
2322
2323    let connection_pool = session.connection_pool().await;
2324    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2325        session.peer.send(connection_id, update.clone())?;
2326    }
2327
2328    send_notifications(&*connection_pool, &session.peer, notifications);
2329
2330    response.send(proto::Ack {})?;
2331    Ok(())
2332}
2333
2334async fn remove_channel_member(
2335    request: proto::RemoveChannelMember,
2336    response: Response<proto::RemoveChannelMember>,
2337    session: Session,
2338) -> Result<()> {
2339    let db = session.db().await;
2340    let channel_id = ChannelId::from_proto(request.channel_id);
2341    let member_id = UserId::from_proto(request.user_id);
2342
2343    let RemoveChannelMemberResult {
2344        membership_update,
2345        notification_id,
2346    } = db
2347        .remove_channel_member(channel_id, member_id, session.user_id)
2348        .await?;
2349
2350    let connection_pool = &session.connection_pool().await;
2351    notify_membership_updated(
2352        &connection_pool,
2353        membership_update,
2354        member_id,
2355        &session.peer,
2356    );
2357    for connection_id in connection_pool.user_connection_ids(member_id) {
2358        if let Some(notification_id) = notification_id {
2359            session
2360                .peer
2361                .send(
2362                    connection_id,
2363                    proto::DeleteNotification {
2364                        notification_id: notification_id.to_proto(),
2365                    },
2366                )
2367                .trace_err();
2368        }
2369    }
2370
2371    response.send(proto::Ack {})?;
2372    Ok(())
2373}
2374
2375async fn set_channel_visibility(
2376    request: proto::SetChannelVisibility,
2377    response: Response<proto::SetChannelVisibility>,
2378    session: Session,
2379) -> Result<()> {
2380    let db = session.db().await;
2381    let channel_id = ChannelId::from_proto(request.channel_id);
2382    let visibility = request.visibility().into();
2383
2384    let SetChannelVisibilityResult {
2385        participants_to_update,
2386        participants_to_remove,
2387        channels_to_remove,
2388    } = db
2389        .set_channel_visibility(channel_id, visibility, session.user_id)
2390        .await?;
2391
2392    let connection_pool = session.connection_pool().await;
2393    for (user_id, channels) in participants_to_update {
2394        let update = build_channels_update(channels, vec![]);
2395        for connection_id in connection_pool.user_connection_ids(user_id) {
2396            session.peer.send(connection_id, update.clone())?;
2397        }
2398    }
2399    for user_id in participants_to_remove {
2400        let update = proto::UpdateChannels {
2401            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2402            ..Default::default()
2403        };
2404        for connection_id in connection_pool.user_connection_ids(user_id) {
2405            session.peer.send(connection_id, update.clone())?;
2406        }
2407    }
2408
2409    response.send(proto::Ack {})?;
2410    Ok(())
2411}
2412
2413async fn set_channel_member_role(
2414    request: proto::SetChannelMemberRole,
2415    response: Response<proto::SetChannelMemberRole>,
2416    session: Session,
2417) -> Result<()> {
2418    let db = session.db().await;
2419    let channel_id = ChannelId::from_proto(request.channel_id);
2420    let member_id = UserId::from_proto(request.user_id);
2421    let result = db
2422        .set_channel_member_role(
2423            channel_id,
2424            session.user_id,
2425            member_id,
2426            request.role().into(),
2427        )
2428        .await?;
2429
2430    match result {
2431        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2432            let connection_pool = session.connection_pool().await;
2433            notify_membership_updated(
2434                &connection_pool,
2435                membership_update,
2436                member_id,
2437                &session.peer,
2438            )
2439        }
2440        db::SetMemberRoleResult::InviteUpdated(channel) => {
2441            let update = proto::UpdateChannels {
2442                channel_invitations: vec![channel.to_proto()],
2443                ..Default::default()
2444            };
2445
2446            for connection_id in session
2447                .connection_pool()
2448                .await
2449                .user_connection_ids(member_id)
2450            {
2451                session.peer.send(connection_id, update.clone())?;
2452            }
2453        }
2454    }
2455
2456    response.send(proto::Ack {})?;
2457    Ok(())
2458}
2459
2460async fn rename_channel(
2461    request: proto::RenameChannel,
2462    response: Response<proto::RenameChannel>,
2463    session: Session,
2464) -> Result<()> {
2465    let db = session.db().await;
2466    let channel_id = ChannelId::from_proto(request.channel_id);
2467    let RenameChannelResult {
2468        channel,
2469        participants_to_update,
2470    } = db
2471        .rename_channel(channel_id, session.user_id, &request.name)
2472        .await?;
2473
2474    response.send(proto::RenameChannelResponse {
2475        channel: Some(channel.to_proto()),
2476    })?;
2477
2478    let connection_pool = session.connection_pool().await;
2479    for (user_id, channel) in participants_to_update {
2480        for connection_id in connection_pool.user_connection_ids(user_id) {
2481            let update = proto::UpdateChannels {
2482                channels: vec![channel.to_proto()],
2483                ..Default::default()
2484            };
2485
2486            session.peer.send(connection_id, update.clone())?;
2487        }
2488    }
2489
2490    Ok(())
2491}
2492
2493async fn move_channel(
2494    request: proto::MoveChannel,
2495    response: Response<proto::MoveChannel>,
2496    session: Session,
2497) -> Result<()> {
2498    let channel_id = ChannelId::from_proto(request.channel_id);
2499    let to = request.to.map(ChannelId::from_proto);
2500
2501    let result = session
2502        .db()
2503        .await
2504        .move_channel(channel_id, to, session.user_id)
2505        .await?;
2506
2507    notify_channel_moved(result, session).await?;
2508
2509    response.send(Ack {})?;
2510    Ok(())
2511}
2512
2513async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2514    let Some(MoveChannelResult {
2515        participants_to_remove,
2516        participants_to_update,
2517        moved_channels,
2518    }) = result
2519    else {
2520        return Ok(());
2521    };
2522    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2523
2524    let connection_pool = session.connection_pool().await;
2525    for (user_id, channels) in participants_to_update {
2526        let mut update = build_channels_update(channels, vec![]);
2527        update.delete_channels = moved_channels.clone();
2528        for connection_id in connection_pool.user_connection_ids(user_id) {
2529            session.peer.send(connection_id, update.clone())?;
2530        }
2531    }
2532
2533    for user_id in participants_to_remove {
2534        let update = proto::UpdateChannels {
2535            delete_channels: moved_channels.clone(),
2536            ..Default::default()
2537        };
2538        for connection_id in connection_pool.user_connection_ids(user_id) {
2539            session.peer.send(connection_id, update.clone())?;
2540        }
2541    }
2542    Ok(())
2543}
2544
2545async fn get_channel_members(
2546    request: proto::GetChannelMembers,
2547    response: Response<proto::GetChannelMembers>,
2548    session: Session,
2549) -> Result<()> {
2550    let db = session.db().await;
2551    let channel_id = ChannelId::from_proto(request.channel_id);
2552    let members = db
2553        .get_channel_participant_details(channel_id, session.user_id)
2554        .await?;
2555    response.send(proto::GetChannelMembersResponse { members })?;
2556    Ok(())
2557}
2558
2559async fn respond_to_channel_invite(
2560    request: proto::RespondToChannelInvite,
2561    response: Response<proto::RespondToChannelInvite>,
2562    session: Session,
2563) -> Result<()> {
2564    let db = session.db().await;
2565    let channel_id = ChannelId::from_proto(request.channel_id);
2566    let RespondToChannelInvite {
2567        membership_update,
2568        notifications,
2569    } = db
2570        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2571        .await?;
2572
2573    let connection_pool = session.connection_pool().await;
2574    if let Some(membership_update) = membership_update {
2575        notify_membership_updated(
2576            &connection_pool,
2577            membership_update,
2578            session.user_id,
2579            &session.peer,
2580        );
2581    } else {
2582        let update = proto::UpdateChannels {
2583            remove_channel_invitations: vec![channel_id.to_proto()],
2584            ..Default::default()
2585        };
2586
2587        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2588            session.peer.send(connection_id, update.clone())?;
2589        }
2590    };
2591
2592    send_notifications(&*connection_pool, &session.peer, notifications);
2593
2594    response.send(proto::Ack {})?;
2595
2596    Ok(())
2597}
2598
2599async fn join_channel(
2600    request: proto::JoinChannel,
2601    response: Response<proto::JoinChannel>,
2602    session: Session,
2603) -> Result<()> {
2604    let channel_id = ChannelId::from_proto(request.channel_id);
2605    join_channel_internal(channel_id, Box::new(response), session).await
2606}
2607
2608trait JoinChannelInternalResponse {
2609    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2610}
2611impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2612    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2613        Response::<proto::JoinChannel>::send(self, result)
2614    }
2615}
2616impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2617    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2618        Response::<proto::JoinRoom>::send(self, result)
2619    }
2620}
2621
2622async fn join_channel_internal(
2623    channel_id: ChannelId,
2624    response: Box<impl JoinChannelInternalResponse>,
2625    session: Session,
2626) -> Result<()> {
2627    let joined_room = {
2628        leave_room_for_session(&session).await?;
2629        let db = session.db().await;
2630
2631        let (joined_room, membership_updated, role) = db
2632            .join_channel(
2633                channel_id,
2634                session.user_id,
2635                session.connection_id,
2636                session.zed_environment.as_ref(),
2637            )
2638            .await?;
2639
2640        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2641            let (can_publish, token) = if role == ChannelRole::Guest {
2642                (
2643                    false,
2644                    live_kit
2645                        .guest_token(
2646                            &joined_room.room.live_kit_room,
2647                            &session.user_id.to_string(),
2648                        )
2649                        .trace_err()?,
2650                )
2651            } else {
2652                (
2653                    true,
2654                    live_kit
2655                        .room_token(
2656                            &joined_room.room.live_kit_room,
2657                            &session.user_id.to_string(),
2658                        )
2659                        .trace_err()?,
2660                )
2661            };
2662
2663            Some(LiveKitConnectionInfo {
2664                server_url: live_kit.url().into(),
2665                token,
2666                can_publish,
2667            })
2668        });
2669
2670        response.send(proto::JoinRoomResponse {
2671            room: Some(joined_room.room.clone()),
2672            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2673            live_kit_connection_info,
2674        })?;
2675
2676        let connection_pool = session.connection_pool().await;
2677        if let Some(membership_updated) = membership_updated {
2678            notify_membership_updated(
2679                &connection_pool,
2680                membership_updated,
2681                session.user_id,
2682                &session.peer,
2683            );
2684        }
2685
2686        room_updated(&joined_room.room, &session.peer);
2687
2688        joined_room
2689    };
2690
2691    channel_updated(
2692        channel_id,
2693        &joined_room.room,
2694        &joined_room.channel_members,
2695        &session.peer,
2696        &*session.connection_pool().await,
2697    );
2698
2699    update_user_contacts(session.user_id, &session).await?;
2700    Ok(())
2701}
2702
2703async fn join_channel_buffer(
2704    request: proto::JoinChannelBuffer,
2705    response: Response<proto::JoinChannelBuffer>,
2706    session: Session,
2707) -> Result<()> {
2708    let db = session.db().await;
2709    let channel_id = ChannelId::from_proto(request.channel_id);
2710
2711    let open_response = db
2712        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2713        .await?;
2714
2715    let collaborators = open_response.collaborators.clone();
2716    response.send(open_response)?;
2717
2718    let update = UpdateChannelBufferCollaborators {
2719        channel_id: channel_id.to_proto(),
2720        collaborators: collaborators.clone(),
2721    };
2722    channel_buffer_updated(
2723        session.connection_id,
2724        collaborators
2725            .iter()
2726            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2727        &update,
2728        &session.peer,
2729    );
2730
2731    Ok(())
2732}
2733
2734async fn update_channel_buffer(
2735    request: proto::UpdateChannelBuffer,
2736    session: Session,
2737) -> Result<()> {
2738    let db = session.db().await;
2739    let channel_id = ChannelId::from_proto(request.channel_id);
2740
2741    let (collaborators, non_collaborators, epoch, version) = db
2742        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2743        .await?;
2744
2745    channel_buffer_updated(
2746        session.connection_id,
2747        collaborators,
2748        &proto::UpdateChannelBuffer {
2749            channel_id: channel_id.to_proto(),
2750            operations: request.operations,
2751        },
2752        &session.peer,
2753    );
2754
2755    let pool = &*session.connection_pool().await;
2756
2757    broadcast(
2758        None,
2759        non_collaborators
2760            .iter()
2761            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2762        |peer_id| {
2763            session.peer.send(
2764                peer_id.into(),
2765                proto::UpdateChannels {
2766                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2767                        channel_id: channel_id.to_proto(),
2768                        epoch: epoch as u64,
2769                        version: version.clone(),
2770                    }],
2771                    ..Default::default()
2772                },
2773            )
2774        },
2775    );
2776
2777    Ok(())
2778}
2779
2780async fn rejoin_channel_buffers(
2781    request: proto::RejoinChannelBuffers,
2782    response: Response<proto::RejoinChannelBuffers>,
2783    session: Session,
2784) -> Result<()> {
2785    let db = session.db().await;
2786    let buffers = db
2787        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2788        .await?;
2789
2790    for rejoined_buffer in &buffers {
2791        let collaborators_to_notify = rejoined_buffer
2792            .buffer
2793            .collaborators
2794            .iter()
2795            .filter_map(|c| Some(c.peer_id?.into()));
2796        channel_buffer_updated(
2797            session.connection_id,
2798            collaborators_to_notify,
2799            &proto::UpdateChannelBufferCollaborators {
2800                channel_id: rejoined_buffer.buffer.channel_id,
2801                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2802            },
2803            &session.peer,
2804        );
2805    }
2806
2807    response.send(proto::RejoinChannelBuffersResponse {
2808        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2809    })?;
2810
2811    Ok(())
2812}
2813
2814async fn leave_channel_buffer(
2815    request: proto::LeaveChannelBuffer,
2816    response: Response<proto::LeaveChannelBuffer>,
2817    session: Session,
2818) -> Result<()> {
2819    let db = session.db().await;
2820    let channel_id = ChannelId::from_proto(request.channel_id);
2821
2822    let left_buffer = db
2823        .leave_channel_buffer(channel_id, session.connection_id)
2824        .await?;
2825
2826    response.send(Ack {})?;
2827
2828    channel_buffer_updated(
2829        session.connection_id,
2830        left_buffer.connections,
2831        &proto::UpdateChannelBufferCollaborators {
2832            channel_id: channel_id.to_proto(),
2833            collaborators: left_buffer.collaborators,
2834        },
2835        &session.peer,
2836    );
2837
2838    Ok(())
2839}
2840
2841fn channel_buffer_updated<T: EnvelopedMessage>(
2842    sender_id: ConnectionId,
2843    collaborators: impl IntoIterator<Item = ConnectionId>,
2844    message: &T,
2845    peer: &Peer,
2846) {
2847    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2848        peer.send(peer_id.into(), message.clone())
2849    });
2850}
2851
2852fn send_notifications(
2853    connection_pool: &ConnectionPool,
2854    peer: &Peer,
2855    notifications: db::NotificationBatch,
2856) {
2857    for (user_id, notification) in notifications {
2858        for connection_id in connection_pool.user_connection_ids(user_id) {
2859            if let Err(error) = peer.send(
2860                connection_id,
2861                proto::AddNotification {
2862                    notification: Some(notification.clone()),
2863                },
2864            ) {
2865                tracing::error!(
2866                    "failed to send notification to {:?} {}",
2867                    connection_id,
2868                    error
2869                );
2870            }
2871        }
2872    }
2873}
2874
2875async fn send_channel_message(
2876    request: proto::SendChannelMessage,
2877    response: Response<proto::SendChannelMessage>,
2878    session: Session,
2879) -> Result<()> {
2880    // Validate the message body.
2881    let body = request.body.trim().to_string();
2882    if body.len() > MAX_MESSAGE_LEN {
2883        return Err(anyhow!("message is too long"))?;
2884    }
2885    if body.is_empty() {
2886        return Err(anyhow!("message can't be blank"))?;
2887    }
2888
2889    // TODO: adjust mentions if body is trimmed
2890
2891    let timestamp = OffsetDateTime::now_utc();
2892    let nonce = request
2893        .nonce
2894        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2895
2896    let channel_id = ChannelId::from_proto(request.channel_id);
2897    let CreatedChannelMessage {
2898        message_id,
2899        participant_connection_ids,
2900        channel_members,
2901        notifications,
2902    } = session
2903        .db()
2904        .await
2905        .create_channel_message(
2906            channel_id,
2907            session.user_id,
2908            &body,
2909            &request.mentions,
2910            timestamp,
2911            nonce.clone().into(),
2912        )
2913        .await?;
2914    let message = proto::ChannelMessage {
2915        sender_id: session.user_id.to_proto(),
2916        id: message_id.to_proto(),
2917        body,
2918        mentions: request.mentions,
2919        timestamp: timestamp.unix_timestamp() as u64,
2920        nonce: Some(nonce),
2921    };
2922    broadcast(
2923        Some(session.connection_id),
2924        participant_connection_ids,
2925        |connection| {
2926            session.peer.send(
2927                connection,
2928                proto::ChannelMessageSent {
2929                    channel_id: channel_id.to_proto(),
2930                    message: Some(message.clone()),
2931                },
2932            )
2933        },
2934    );
2935    response.send(proto::SendChannelMessageResponse {
2936        message: Some(message),
2937    })?;
2938
2939    let pool = &*session.connection_pool().await;
2940    broadcast(
2941        None,
2942        channel_members
2943            .iter()
2944            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2945        |peer_id| {
2946            session.peer.send(
2947                peer_id.into(),
2948                proto::UpdateChannels {
2949                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2950                        channel_id: channel_id.to_proto(),
2951                        message_id: message_id.to_proto(),
2952                    }],
2953                    ..Default::default()
2954                },
2955            )
2956        },
2957    );
2958    send_notifications(pool, &session.peer, notifications);
2959
2960    Ok(())
2961}
2962
2963async fn remove_channel_message(
2964    request: proto::RemoveChannelMessage,
2965    response: Response<proto::RemoveChannelMessage>,
2966    session: Session,
2967) -> Result<()> {
2968    let channel_id = ChannelId::from_proto(request.channel_id);
2969    let message_id = MessageId::from_proto(request.message_id);
2970    let connection_ids = session
2971        .db()
2972        .await
2973        .remove_channel_message(channel_id, message_id, session.user_id)
2974        .await?;
2975    broadcast(Some(session.connection_id), connection_ids, |connection| {
2976        session.peer.send(connection, request.clone())
2977    });
2978    response.send(proto::Ack {})?;
2979    Ok(())
2980}
2981
2982async fn acknowledge_channel_message(
2983    request: proto::AckChannelMessage,
2984    session: Session,
2985) -> Result<()> {
2986    let channel_id = ChannelId::from_proto(request.channel_id);
2987    let message_id = MessageId::from_proto(request.message_id);
2988    let notifications = session
2989        .db()
2990        .await
2991        .observe_channel_message(channel_id, session.user_id, message_id)
2992        .await?;
2993    send_notifications(
2994        &*session.connection_pool().await,
2995        &session.peer,
2996        notifications,
2997    );
2998    Ok(())
2999}
3000
3001async fn acknowledge_buffer_version(
3002    request: proto::AckBufferOperation,
3003    session: Session,
3004) -> Result<()> {
3005    let buffer_id = BufferId::from_proto(request.buffer_id);
3006    session
3007        .db()
3008        .await
3009        .observe_buffer_version(
3010            buffer_id,
3011            session.user_id,
3012            request.epoch as i32,
3013            &request.version,
3014        )
3015        .await?;
3016    Ok(())
3017}
3018
3019async fn join_channel_chat(
3020    request: proto::JoinChannelChat,
3021    response: Response<proto::JoinChannelChat>,
3022    session: Session,
3023) -> Result<()> {
3024    let channel_id = ChannelId::from_proto(request.channel_id);
3025
3026    let db = session.db().await;
3027    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3028        .await?;
3029    let messages = db
3030        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3031        .await?;
3032    response.send(proto::JoinChannelChatResponse {
3033        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3034        messages,
3035    })?;
3036    Ok(())
3037}
3038
3039async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3040    let channel_id = ChannelId::from_proto(request.channel_id);
3041    session
3042        .db()
3043        .await
3044        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3045        .await?;
3046    Ok(())
3047}
3048
3049async fn get_channel_messages(
3050    request: proto::GetChannelMessages,
3051    response: Response<proto::GetChannelMessages>,
3052    session: Session,
3053) -> Result<()> {
3054    let channel_id = ChannelId::from_proto(request.channel_id);
3055    let messages = session
3056        .db()
3057        .await
3058        .get_channel_messages(
3059            channel_id,
3060            session.user_id,
3061            MESSAGE_COUNT_PER_PAGE,
3062            Some(MessageId::from_proto(request.before_message_id)),
3063        )
3064        .await?;
3065    response.send(proto::GetChannelMessagesResponse {
3066        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3067        messages,
3068    })?;
3069    Ok(())
3070}
3071
3072async fn get_channel_messages_by_id(
3073    request: proto::GetChannelMessagesById,
3074    response: Response<proto::GetChannelMessagesById>,
3075    session: Session,
3076) -> Result<()> {
3077    let message_ids = request
3078        .message_ids
3079        .iter()
3080        .map(|id| MessageId::from_proto(*id))
3081        .collect::<Vec<_>>();
3082    let messages = session
3083        .db()
3084        .await
3085        .get_channel_messages_by_id(session.user_id, &message_ids)
3086        .await?;
3087    response.send(proto::GetChannelMessagesResponse {
3088        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3089        messages,
3090    })?;
3091    Ok(())
3092}
3093
3094async fn get_notifications(
3095    request: proto::GetNotifications,
3096    response: Response<proto::GetNotifications>,
3097    session: Session,
3098) -> Result<()> {
3099    let notifications = session
3100        .db()
3101        .await
3102        .get_notifications(
3103            session.user_id,
3104            NOTIFICATION_COUNT_PER_PAGE,
3105            request
3106                .before_id
3107                .map(|id| db::NotificationId::from_proto(id)),
3108        )
3109        .await?;
3110    response.send(proto::GetNotificationsResponse {
3111        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3112        notifications,
3113    })?;
3114    Ok(())
3115}
3116
3117async fn mark_notification_as_read(
3118    request: proto::MarkNotificationRead,
3119    response: Response<proto::MarkNotificationRead>,
3120    session: Session,
3121) -> Result<()> {
3122    let database = &session.db().await;
3123    let notifications = database
3124        .mark_notification_as_read_by_id(
3125            session.user_id,
3126            NotificationId::from_proto(request.notification_id),
3127        )
3128        .await?;
3129    send_notifications(
3130        &*session.connection_pool().await,
3131        &session.peer,
3132        notifications,
3133    );
3134    response.send(proto::Ack {})?;
3135    Ok(())
3136}
3137
3138async fn get_private_user_info(
3139    _request: proto::GetPrivateUserInfo,
3140    response: Response<proto::GetPrivateUserInfo>,
3141    session: Session,
3142) -> Result<()> {
3143    let db = session.db().await;
3144
3145    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3146    let user = db
3147        .get_user_by_id(session.user_id)
3148        .await?
3149        .ok_or_else(|| anyhow!("user not found"))?;
3150    let flags = db.get_user_flags(session.user_id).await?;
3151
3152    response.send(proto::GetPrivateUserInfoResponse {
3153        metrics_id,
3154        staff: user.admin,
3155        flags,
3156    })?;
3157    Ok(())
3158}
3159
3160fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3161    match message {
3162        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3163        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3164        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3165        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3166        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3167            code: frame.code.into(),
3168            reason: frame.reason,
3169        })),
3170    }
3171}
3172
3173fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3174    match message {
3175        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3176        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3177        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3178        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3179        AxumMessage::Close(frame) => {
3180            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3181                code: frame.code.into(),
3182                reason: frame.reason,
3183            }))
3184        }
3185    }
3186}
3187
3188fn notify_membership_updated(
3189    connection_pool: &ConnectionPool,
3190    result: MembershipUpdated,
3191    user_id: UserId,
3192    peer: &Peer,
3193) {
3194    let mut update = build_channels_update(result.new_channels, vec![]);
3195    update.delete_channels = result
3196        .removed_channels
3197        .into_iter()
3198        .map(|id| id.to_proto())
3199        .collect();
3200    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3201
3202    for connection_id in connection_pool.user_connection_ids(user_id) {
3203        peer.send(connection_id, update.clone()).trace_err();
3204    }
3205}
3206
3207fn build_channels_update(
3208    channels: ChannelsForUser,
3209    channel_invites: Vec<db::Channel>,
3210) -> proto::UpdateChannels {
3211    let mut update = proto::UpdateChannels::default();
3212
3213    for channel in channels.channels {
3214        update.channels.push(channel.to_proto());
3215    }
3216
3217    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3218    update.unseen_channel_messages = channels.channel_messages;
3219
3220    for (channel_id, participants) in channels.channel_participants {
3221        update
3222            .channel_participants
3223            .push(proto::ChannelParticipants {
3224                channel_id: channel_id.to_proto(),
3225                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3226            });
3227    }
3228
3229    for channel in channel_invites {
3230        update.channel_invitations.push(channel.to_proto());
3231    }
3232
3233    update
3234}
3235
3236fn build_initial_contacts_update(
3237    contacts: Vec<db::Contact>,
3238    pool: &ConnectionPool,
3239) -> proto::UpdateContacts {
3240    let mut update = proto::UpdateContacts::default();
3241
3242    for contact in contacts {
3243        match contact {
3244            db::Contact::Accepted { user_id, busy } => {
3245                update.contacts.push(contact_for_user(user_id, busy, &pool));
3246            }
3247            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3248            db::Contact::Incoming { user_id } => {
3249                update
3250                    .incoming_requests
3251                    .push(proto::IncomingContactRequest {
3252                        requester_id: user_id.to_proto(),
3253                    })
3254            }
3255        }
3256    }
3257
3258    update
3259}
3260
3261fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3262    proto::Contact {
3263        user_id: user_id.to_proto(),
3264        online: pool.is_user_online(user_id),
3265        busy,
3266    }
3267}
3268
3269fn room_updated(room: &proto::Room, peer: &Peer) {
3270    broadcast(
3271        None,
3272        room.participants
3273            .iter()
3274            .filter_map(|participant| Some(participant.peer_id?.into())),
3275        |peer_id| {
3276            peer.send(
3277                peer_id.into(),
3278                proto::RoomUpdated {
3279                    room: Some(room.clone()),
3280                },
3281            )
3282        },
3283    );
3284}
3285
3286fn channel_updated(
3287    channel_id: ChannelId,
3288    room: &proto::Room,
3289    channel_members: &[UserId],
3290    peer: &Peer,
3291    pool: &ConnectionPool,
3292) {
3293    let participants = room
3294        .participants
3295        .iter()
3296        .map(|p| p.user_id)
3297        .collect::<Vec<_>>();
3298
3299    broadcast(
3300        None,
3301        channel_members
3302            .iter()
3303            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3304        |peer_id| {
3305            peer.send(
3306                peer_id.into(),
3307                proto::UpdateChannels {
3308                    channel_participants: vec![proto::ChannelParticipants {
3309                        channel_id: channel_id.to_proto(),
3310                        participant_user_ids: participants.clone(),
3311                    }],
3312                    ..Default::default()
3313                },
3314            )
3315        },
3316    );
3317}
3318
3319async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3320    let db = session.db().await;
3321
3322    let contacts = db.get_contacts(user_id).await?;
3323    let busy = db.is_user_busy(user_id).await?;
3324
3325    let pool = session.connection_pool().await;
3326    let updated_contact = contact_for_user(user_id, busy, &pool);
3327    for contact in contacts {
3328        if let db::Contact::Accepted {
3329            user_id: contact_user_id,
3330            ..
3331        } = contact
3332        {
3333            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3334                session
3335                    .peer
3336                    .send(
3337                        contact_conn_id,
3338                        proto::UpdateContacts {
3339                            contacts: vec![updated_contact.clone()],
3340                            remove_contacts: Default::default(),
3341                            incoming_requests: Default::default(),
3342                            remove_incoming_requests: Default::default(),
3343                            outgoing_requests: Default::default(),
3344                            remove_outgoing_requests: Default::default(),
3345                        },
3346                    )
3347                    .trace_err();
3348            }
3349        }
3350    }
3351    Ok(())
3352}
3353
3354async fn leave_room_for_session(session: &Session) -> Result<()> {
3355    let mut contacts_to_update = HashSet::default();
3356
3357    let room_id;
3358    let canceled_calls_to_user_ids;
3359    let live_kit_room;
3360    let delete_live_kit_room;
3361    let room;
3362    let channel_members;
3363    let channel_id;
3364
3365    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3366        contacts_to_update.insert(session.user_id);
3367
3368        for project in left_room.left_projects.values() {
3369            project_left(project, session);
3370        }
3371
3372        room_id = RoomId::from_proto(left_room.room.id);
3373        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3374        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3375        delete_live_kit_room = left_room.deleted;
3376        room = mem::take(&mut left_room.room);
3377        channel_members = mem::take(&mut left_room.channel_members);
3378        channel_id = left_room.channel_id;
3379
3380        room_updated(&room, &session.peer);
3381    } else {
3382        return Ok(());
3383    }
3384
3385    if let Some(channel_id) = channel_id {
3386        channel_updated(
3387            channel_id,
3388            &room,
3389            &channel_members,
3390            &session.peer,
3391            &*session.connection_pool().await,
3392        );
3393    }
3394
3395    {
3396        let pool = session.connection_pool().await;
3397        for canceled_user_id in canceled_calls_to_user_ids {
3398            for connection_id in pool.user_connection_ids(canceled_user_id) {
3399                session
3400                    .peer
3401                    .send(
3402                        connection_id,
3403                        proto::CallCanceled {
3404                            room_id: room_id.to_proto(),
3405                        },
3406                    )
3407                    .trace_err();
3408            }
3409            contacts_to_update.insert(canceled_user_id);
3410        }
3411    }
3412
3413    for contact_user_id in contacts_to_update {
3414        update_user_contacts(contact_user_id, &session).await?;
3415    }
3416
3417    if let Some(live_kit) = session.live_kit_client.as_ref() {
3418        live_kit
3419            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3420            .await
3421            .trace_err();
3422
3423        if delete_live_kit_room {
3424            live_kit.delete_room(live_kit_room).await.trace_err();
3425        }
3426    }
3427
3428    Ok(())
3429}
3430
3431async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3432    let left_channel_buffers = session
3433        .db()
3434        .await
3435        .leave_channel_buffers(session.connection_id)
3436        .await?;
3437
3438    for left_buffer in left_channel_buffers {
3439        channel_buffer_updated(
3440            session.connection_id,
3441            left_buffer.connections,
3442            &proto::UpdateChannelBufferCollaborators {
3443                channel_id: left_buffer.channel_id.to_proto(),
3444                collaborators: left_buffer.collaborators,
3445            },
3446            &session.peer,
3447        );
3448    }
3449
3450    Ok(())
3451}
3452
3453fn project_left(project: &db::LeftProject, session: &Session) {
3454    for connection_id in &project.connection_ids {
3455        if project.host_user_id == session.user_id {
3456            session
3457                .peer
3458                .send(
3459                    *connection_id,
3460                    proto::UnshareProject {
3461                        project_id: project.id.to_proto(),
3462                    },
3463                )
3464                .trace_err();
3465        } else {
3466            session
3467                .peer
3468                .send(
3469                    *connection_id,
3470                    proto::RemoveProjectCollaborator {
3471                        project_id: project.id.to_proto(),
3472                        peer_id: Some(session.connection_id.into()),
3473                    },
3474                )
3475                .trace_err();
3476        }
3477    }
3478}
3479
3480pub trait ResultExt {
3481    type Ok;
3482
3483    fn trace_err(self) -> Option<Self::Ok>;
3484}
3485
3486impl<T, E> ResultExt for Result<T, E>
3487where
3488    E: std::fmt::Debug,
3489{
3490    type Ok = T;
3491
3492    fn trace_err(self) -> Option<T> {
3493        match self {
3494            Ok(value) => Some(value),
3495            Err(error) => {
3496                tracing::error!("{:?}", error);
3497                None
3498            }
3499        }
3500    }
3501}