rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  72
  73const MESSAGE_COUNT_PER_PAGE: usize = 100;
  74const MAX_MESSAGE_LEN: usize = 1024;
  75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  76
  77lazy_static! {
  78    static ref METRIC_CONNECTIONS: IntGauge =
  79        register_int_gauge!("connections", "number of connections").unwrap();
  80    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  81        "shared_projects",
  82        "number of open projects with one or more guests"
  83    )
  84    .unwrap();
  85}
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct Session {
 106    zed_environment: Arc<str>,
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(call)
 206            .add_request_handler(cancel_call)
 207            .add_message_handler(decline_call)
 208            .add_request_handler(update_participant_location)
 209            .add_request_handler(share_project)
 210            .add_message_handler(unshare_project)
 211            .add_request_handler(join_project)
 212            .add_message_handler(leave_project)
 213            .add_request_handler(update_project)
 214            .add_request_handler(update_worktree)
 215            .add_message_handler(start_language_server)
 216            .add_message_handler(update_language_server)
 217            .add_message_handler(update_diagnostic_summary)
 218            .add_message_handler(update_worktree_settings)
 219            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 223            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 226            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 228            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 229            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 230            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 231            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 232            .add_request_handler(
 233                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 234            )
 235            .add_request_handler(
 236                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 237            )
 238            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 239            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 240            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 243            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 245            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 250            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 251            .add_message_handler(create_buffer_for_peer)
 252            .add_request_handler(update_buffer)
 253            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 258            .add_request_handler(get_users)
 259            .add_request_handler(fuzzy_search_users)
 260            .add_request_handler(request_contact)
 261            .add_request_handler(remove_contact)
 262            .add_request_handler(respond_to_contact_request)
 263            .add_request_handler(create_channel)
 264            .add_request_handler(delete_channel)
 265            .add_request_handler(invite_channel_member)
 266            .add_request_handler(remove_channel_member)
 267            .add_request_handler(set_channel_member_role)
 268            .add_request_handler(set_channel_visibility)
 269            .add_request_handler(rename_channel)
 270            .add_request_handler(join_channel_buffer)
 271            .add_request_handler(leave_channel_buffer)
 272            .add_message_handler(update_channel_buffer)
 273            .add_request_handler(rejoin_channel_buffers)
 274            .add_request_handler(get_channel_members)
 275            .add_request_handler(respond_to_channel_invite)
 276            .add_request_handler(join_channel)
 277            .add_request_handler(join_channel_chat)
 278            .add_message_handler(leave_channel_chat)
 279            .add_request_handler(send_channel_message)
 280            .add_request_handler(remove_channel_message)
 281            .add_request_handler(get_channel_messages)
 282            .add_request_handler(get_channel_messages_by_id)
 283            .add_request_handler(get_notifications)
 284            .add_request_handler(mark_notification_as_read)
 285            .add_request_handler(move_channel)
 286            .add_request_handler(follow)
 287            .add_message_handler(unfollow)
 288            .add_message_handler(update_followers)
 289            .add_request_handler(get_private_user_info)
 290            .add_message_handler(acknowledge_channel_message)
 291            .add_message_handler(acknowledge_buffer_version);
 292
 293        Arc::new(server)
 294    }
 295
 296    pub async fn start(&self) -> Result<()> {
 297        let server_id = *self.id.lock();
 298        let app_state = self.app_state.clone();
 299        let peer = self.peer.clone();
 300        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 301        let pool = self.connection_pool.clone();
 302        let live_kit_client = self.app_state.live_kit_client.clone();
 303
 304        let span = info_span!("start server");
 305        self.executor.spawn_detached(
 306            async move {
 307                tracing::info!("waiting for cleanup timeout");
 308                timeout.await;
 309                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 310                if let Some((room_ids, channel_ids)) = app_state
 311                    .db
 312                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 313                    .await
 314                    .trace_err()
 315                {
 316                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 317                    tracing::info!(
 318                        stale_channel_buffer_count = channel_ids.len(),
 319                        "retrieved stale channel buffers"
 320                    );
 321
 322                    for channel_id in channel_ids {
 323                        if let Some(refreshed_channel_buffer) = app_state
 324                            .db
 325                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 326                            .await
 327                            .trace_err()
 328                        {
 329                            for connection_id in refreshed_channel_buffer.connection_ids {
 330                                peer.send(
 331                                    connection_id,
 332                                    proto::UpdateChannelBufferCollaborators {
 333                                        channel_id: channel_id.to_proto(),
 334                                        collaborators: refreshed_channel_buffer
 335                                            .collaborators
 336                                            .clone(),
 337                                    },
 338                                )
 339                                .trace_err();
 340                            }
 341                        }
 342                    }
 343
 344                    for room_id in room_ids {
 345                        let mut contacts_to_update = HashSet::default();
 346                        let mut canceled_calls_to_user_ids = Vec::new();
 347                        let mut live_kit_room = String::new();
 348                        let mut delete_live_kit_room = false;
 349
 350                        if let Some(mut refreshed_room) = app_state
 351                            .db
 352                            .clear_stale_room_participants(room_id, server_id)
 353                            .await
 354                            .trace_err()
 355                        {
 356                            tracing::info!(
 357                                room_id = room_id.0,
 358                                new_participant_count = refreshed_room.room.participants.len(),
 359                                "refreshed room"
 360                            );
 361                            room_updated(&refreshed_room.room, &peer);
 362                            if let Some(channel_id) = refreshed_room.channel_id {
 363                                channel_updated(
 364                                    channel_id,
 365                                    &refreshed_room.room,
 366                                    &refreshed_room.channel_members,
 367                                    &peer,
 368                                    &*pool.lock(),
 369                                );
 370                            }
 371                            contacts_to_update
 372                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 373                            contacts_to_update
 374                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 375                            canceled_calls_to_user_ids =
 376                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 377                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 378                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 379                        }
 380
 381                        {
 382                            let pool = pool.lock();
 383                            for canceled_user_id in canceled_calls_to_user_ids {
 384                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 385                                    peer.send(
 386                                        connection_id,
 387                                        proto::CallCanceled {
 388                                            room_id: room_id.to_proto(),
 389                                        },
 390                                    )
 391                                    .trace_err();
 392                                }
 393                            }
 394                        }
 395
 396                        for user_id in contacts_to_update {
 397                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 398                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 399                            if let Some((busy, contacts)) = busy.zip(contacts) {
 400                                let pool = pool.lock();
 401                                let updated_contact = contact_for_user(user_id, busy, &pool);
 402                                for contact in contacts {
 403                                    if let db::Contact::Accepted {
 404                                        user_id: contact_user_id,
 405                                        ..
 406                                    } = contact
 407                                    {
 408                                        for contact_conn_id in
 409                                            pool.user_connection_ids(contact_user_id)
 410                                        {
 411                                            peer.send(
 412                                                contact_conn_id,
 413                                                proto::UpdateContacts {
 414                                                    contacts: vec![updated_contact.clone()],
 415                                                    remove_contacts: Default::default(),
 416                                                    incoming_requests: Default::default(),
 417                                                    remove_incoming_requests: Default::default(),
 418                                                    outgoing_requests: Default::default(),
 419                                                    remove_outgoing_requests: Default::default(),
 420                                                },
 421                                            )
 422                                            .trace_err();
 423                                        }
 424                                    }
 425                                }
 426                            }
 427                        }
 428
 429                        if let Some(live_kit) = live_kit_client.as_ref() {
 430                            if delete_live_kit_room {
 431                                live_kit.delete_room(live_kit_room).await.trace_err();
 432                            }
 433                        }
 434                    }
 435                }
 436
 437                app_state
 438                    .db
 439                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 440                    .await
 441                    .trace_err();
 442            }
 443            .instrument(span),
 444        );
 445        Ok(())
 446    }
 447
 448    pub fn teardown(&self) {
 449        self.peer.teardown();
 450        self.connection_pool.lock().reset();
 451        let _ = self.teardown.send(());
 452    }
 453
 454    #[cfg(test)]
 455    pub fn reset(&self, id: ServerId) {
 456        self.teardown();
 457        *self.id.lock() = id;
 458        self.peer.reset(id.0 as u32);
 459    }
 460
 461    #[cfg(test)]
 462    pub fn id(&self) -> ServerId {
 463        *self.id.lock()
 464    }
 465
 466    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 467    where
 468        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 469        Fut: 'static + Send + Future<Output = Result<()>>,
 470        M: EnvelopedMessage,
 471    {
 472        let prev_handler = self.handlers.insert(
 473            TypeId::of::<M>(),
 474            Box::new(move |envelope, session| {
 475                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 476                let span = info_span!(
 477                    "handle message",
 478                    payload_type = envelope.payload_type_name()
 479                );
 480                span.in_scope(|| {
 481                    tracing::info!(
 482                        payload_type = envelope.payload_type_name(),
 483                        "message received"
 484                    );
 485                });
 486                let start_time = Instant::now();
 487                let future = (handler)(*envelope, session);
 488                async move {
 489                    let result = future.await;
 490                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 491                    match result {
 492                        Err(error) => {
 493                            tracing::error!(%error, ?duration_ms, "error handling message")
 494                        }
 495                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 496                    }
 497                }
 498                .instrument(span)
 499                .boxed()
 500            }),
 501        );
 502        if prev_handler.is_some() {
 503            panic!("registered a handler for the same message twice");
 504        }
 505        self
 506    }
 507
 508    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 509    where
 510        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 511        Fut: 'static + Send + Future<Output = Result<()>>,
 512        M: EnvelopedMessage,
 513    {
 514        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 515        self
 516    }
 517
 518    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 519    where
 520        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 521        Fut: Send + Future<Output = Result<()>>,
 522        M: RequestMessage,
 523    {
 524        let handler = Arc::new(handler);
 525        self.add_handler(move |envelope, session| {
 526            let receipt = envelope.receipt();
 527            let handler = handler.clone();
 528            async move {
 529                let peer = session.peer.clone();
 530                let responded = Arc::new(AtomicBool::default());
 531                let response = Response {
 532                    peer: peer.clone(),
 533                    responded: responded.clone(),
 534                    receipt,
 535                };
 536                match (handler)(envelope.payload, response, session).await {
 537                    Ok(()) => {
 538                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 539                            Ok(())
 540                        } else {
 541                            Err(anyhow!("handler did not send a response"))?
 542                        }
 543                    }
 544                    Err(error) => {
 545                        peer.respond_with_error(
 546                            receipt,
 547                            proto::Error {
 548                                message: error.to_string(),
 549                            },
 550                        )?;
 551                        Err(error)
 552                    }
 553                }
 554            }
 555        })
 556    }
 557
 558    pub fn handle_connection(
 559        self: &Arc<Self>,
 560        connection: Connection,
 561        address: String,
 562        user: User,
 563        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 564        executor: Executor,
 565    ) -> impl Future<Output = Result<()>> {
 566        let this = self.clone();
 567        let user_id = user.id;
 568        let login = user.github_login;
 569        let span = info_span!("handle connection", %user_id, %login, %address);
 570        let mut teardown = self.teardown.subscribe();
 571        async move {
 572            let (connection_id, handle_io, mut incoming_rx) = this
 573                .peer
 574                .add_connection(connection, {
 575                    let executor = executor.clone();
 576                    move |duration| executor.sleep(duration)
 577                });
 578
 579            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 580            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 581            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 582
 583            if let Some(send_connection_id) = send_connection_id.take() {
 584                let _ = send_connection_id.send(connection_id);
 585            }
 586
 587            if !user.connected_once {
 588                this.peer.send(connection_id, proto::ShowContacts {})?;
 589                this.app_state.db.set_user_connected_once(user_id, true).await?;
 590            }
 591
 592            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 593                this.app_state.db.get_contacts(user_id),
 594                this.app_state.db.get_channels_for_user(user_id),
 595                this.app_state.db.get_channel_invites_for_user(user_id),
 596            ).await?;
 597
 598            {
 599                let mut pool = this.connection_pool.lock();
 600                pool.add_connection(connection_id, user_id, user.admin);
 601                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 602                this.peer.send(connection_id, build_channels_update(
 603                    channels_for_user,
 604                    channel_invites
 605                ))?;
 606            }
 607
 608            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 609                this.peer.send(connection_id, incoming_call)?;
 610            }
 611
 612            let session = Session {
 613                user_id,
 614                connection_id,
 615                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 616                zed_environment: this.app_state.config.zed_environment.clone(),
 617                peer: this.peer.clone(),
 618                connection_pool: this.connection_pool.clone(),
 619                live_kit_client: this.app_state.live_kit_client.clone(),
 620                _executor: executor.clone()
 621            };
 622            update_user_contacts(user_id, &session).await?;
 623
 624            let handle_io = handle_io.fuse();
 625            futures::pin_mut!(handle_io);
 626
 627            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 628            // This prevents deadlocks when e.g., client A performs a request to client B and
 629            // client B performs a request to client A. If both clients stop processing further
 630            // messages until their respective request completes, they won't have a chance to
 631            // respond to the other client's request and cause a deadlock.
 632            //
 633            // This arrangement ensures we will attempt to process earlier messages first, but fall
 634            // back to processing messages arrived later in the spirit of making progress.
 635            let mut foreground_message_handlers = FuturesUnordered::new();
 636            let concurrent_handlers = Arc::new(Semaphore::new(256));
 637            loop {
 638                let next_message = async {
 639                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 640                    let message = incoming_rx.next().await;
 641                    (permit, message)
 642                }.fuse();
 643                futures::pin_mut!(next_message);
 644                futures::select_biased! {
 645                    _ = teardown.changed().fuse() => return Ok(()),
 646                    result = handle_io => {
 647                        if let Err(error) = result {
 648                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 649                        }
 650                        break;
 651                    }
 652                    _ = foreground_message_handlers.next() => {}
 653                    next_message = next_message => {
 654                        let (permit, message) = next_message;
 655                        if let Some(message) = message {
 656                            let type_name = message.payload_type_name();
 657                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 658                            let span_enter = span.enter();
 659                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 660                                let is_background = message.is_background();
 661                                let handle_message = (handler)(message, session.clone());
 662                                drop(span_enter);
 663
 664                                let handle_message = async move {
 665                                    handle_message.await;
 666                                    drop(permit);
 667                                }.instrument(span);
 668                                if is_background {
 669                                    executor.spawn_detached(handle_message);
 670                                } else {
 671                                    foreground_message_handlers.push(handle_message);
 672                                }
 673                            } else {
 674                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 675                            }
 676                        } else {
 677                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 678                            break;
 679                        }
 680                    }
 681                }
 682            }
 683
 684            drop(foreground_message_handlers);
 685            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 686            if let Err(error) = connection_lost(session, teardown, executor).await {
 687                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 688            }
 689
 690            Ok(())
 691        }.instrument(span)
 692    }
 693
 694    pub async fn invite_code_redeemed(
 695        self: &Arc<Self>,
 696        inviter_id: UserId,
 697        invitee_id: UserId,
 698    ) -> Result<()> {
 699        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 700            if let Some(code) = &user.invite_code {
 701                let pool = self.connection_pool.lock();
 702                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 703                for connection_id in pool.user_connection_ids(inviter_id) {
 704                    self.peer.send(
 705                        connection_id,
 706                        proto::UpdateContacts {
 707                            contacts: vec![invitee_contact.clone()],
 708                            ..Default::default()
 709                        },
 710                    )?;
 711                    self.peer.send(
 712                        connection_id,
 713                        proto::UpdateInviteInfo {
 714                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 715                            count: user.invite_count as u32,
 716                        },
 717                    )?;
 718                }
 719            }
 720        }
 721        Ok(())
 722    }
 723
 724    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 725        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 726            if let Some(invite_code) = &user.invite_code {
 727                let pool = self.connection_pool.lock();
 728                for connection_id in pool.user_connection_ids(user_id) {
 729                    self.peer.send(
 730                        connection_id,
 731                        proto::UpdateInviteInfo {
 732                            url: format!(
 733                                "{}{}",
 734                                self.app_state.config.invite_link_prefix, invite_code
 735                            ),
 736                            count: user.invite_count as u32,
 737                        },
 738                    )?;
 739                }
 740            }
 741        }
 742        Ok(())
 743    }
 744
 745    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 746        ServerSnapshot {
 747            connection_pool: ConnectionPoolGuard {
 748                guard: self.connection_pool.lock(),
 749                _not_send: PhantomData,
 750            },
 751            peer: &self.peer,
 752        }
 753    }
 754}
 755
 756impl<'a> Deref for ConnectionPoolGuard<'a> {
 757    type Target = ConnectionPool;
 758
 759    fn deref(&self) -> &Self::Target {
 760        &*self.guard
 761    }
 762}
 763
 764impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 765    fn deref_mut(&mut self) -> &mut Self::Target {
 766        &mut *self.guard
 767    }
 768}
 769
 770impl<'a> Drop for ConnectionPoolGuard<'a> {
 771    fn drop(&mut self) {
 772        #[cfg(test)]
 773        self.check_invariants();
 774    }
 775}
 776
 777fn broadcast<F>(
 778    sender_id: Option<ConnectionId>,
 779    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 780    mut f: F,
 781) where
 782    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 783{
 784    for receiver_id in receiver_ids {
 785        if Some(receiver_id) != sender_id {
 786            if let Err(error) = f(receiver_id) {
 787                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 788            }
 789        }
 790    }
 791}
 792
 793lazy_static! {
 794    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 795}
 796
 797pub struct ProtocolVersion(u32);
 798
 799impl Header for ProtocolVersion {
 800    fn name() -> &'static HeaderName {
 801        &ZED_PROTOCOL_VERSION
 802    }
 803
 804    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 805    where
 806        Self: Sized,
 807        I: Iterator<Item = &'i axum::http::HeaderValue>,
 808    {
 809        let version = values
 810            .next()
 811            .ok_or_else(axum::headers::Error::invalid)?
 812            .to_str()
 813            .map_err(|_| axum::headers::Error::invalid())?
 814            .parse()
 815            .map_err(|_| axum::headers::Error::invalid())?;
 816        Ok(Self(version))
 817    }
 818
 819    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 820        values.extend([self.0.to_string().parse().unwrap()]);
 821    }
 822}
 823
 824pub fn routes(server: Arc<Server>) -> Router<Body> {
 825    Router::new()
 826        .route("/rpc", get(handle_websocket_request))
 827        .layer(
 828            ServiceBuilder::new()
 829                .layer(Extension(server.app_state.clone()))
 830                .layer(middleware::from_fn(auth::validate_header)),
 831        )
 832        .route("/metrics", get(handle_metrics))
 833        .layer(Extension(server))
 834}
 835
 836pub async fn handle_websocket_request(
 837    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 838    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 839    Extension(server): Extension<Arc<Server>>,
 840    Extension(user): Extension<User>,
 841    ws: WebSocketUpgrade,
 842) -> axum::response::Response {
 843    if protocol_version != rpc::PROTOCOL_VERSION {
 844        return (
 845            StatusCode::UPGRADE_REQUIRED,
 846            "client must be upgraded".to_string(),
 847        )
 848            .into_response();
 849    }
 850    let socket_address = socket_address.to_string();
 851    ws.on_upgrade(move |socket| {
 852        use util::ResultExt;
 853        let socket = socket
 854            .map_ok(to_tungstenite_message)
 855            .err_into()
 856            .with(|message| async move { Ok(to_axum_message(message)) });
 857        let connection = Connection::new(Box::pin(socket));
 858        async move {
 859            server
 860                .handle_connection(connection, socket_address, user, None, Executor::Production)
 861                .await
 862                .log_err();
 863        }
 864    })
 865}
 866
 867pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 868    let connections = server
 869        .connection_pool
 870        .lock()
 871        .connections()
 872        .filter(|connection| !connection.admin)
 873        .count();
 874
 875    METRIC_CONNECTIONS.set(connections as _);
 876
 877    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 878    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 879
 880    let encoder = prometheus::TextEncoder::new();
 881    let metric_families = prometheus::gather();
 882    let encoded_metrics = encoder
 883        .encode_to_string(&metric_families)
 884        .map_err(|err| anyhow!("{}", err))?;
 885    Ok(encoded_metrics)
 886}
 887
 888#[instrument(err, skip(executor))]
 889async fn connection_lost(
 890    session: Session,
 891    mut teardown: watch::Receiver<()>,
 892    executor: Executor,
 893) -> Result<()> {
 894    session.peer.disconnect(session.connection_id);
 895    session
 896        .connection_pool()
 897        .await
 898        .remove_connection(session.connection_id)?;
 899
 900    session
 901        .db()
 902        .await
 903        .connection_lost(session.connection_id)
 904        .await
 905        .trace_err();
 906
 907    futures::select_biased! {
 908        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 909            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 910            leave_room_for_session(&session).await.trace_err();
 911            leave_channel_buffers_for_session(&session)
 912                .await
 913                .trace_err();
 914
 915            if !session
 916                .connection_pool()
 917                .await
 918                .is_user_online(session.user_id)
 919            {
 920                let db = session.db().await;
 921                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 922                    room_updated(&room, &session.peer);
 923                }
 924            }
 925
 926            update_user_contacts(session.user_id, &session).await?;
 927        }
 928        _ = teardown.changed().fuse() => {}
 929    }
 930
 931    Ok(())
 932}
 933
 934async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 935    response.send(proto::Ack {})?;
 936    Ok(())
 937}
 938
 939async fn create_room(
 940    _request: proto::CreateRoom,
 941    response: Response<proto::CreateRoom>,
 942    session: Session,
 943) -> Result<()> {
 944    let live_kit_room = nanoid::nanoid!(30);
 945
 946    let live_kit_connection_info = {
 947        let live_kit_room = live_kit_room.clone();
 948        let live_kit = session.live_kit_client.as_ref();
 949
 950        util::async_maybe!({
 951            let live_kit = live_kit?;
 952
 953            let token = live_kit
 954                .room_token(&live_kit_room, &session.user_id.to_string())
 955                .trace_err()?;
 956
 957            Some(proto::LiveKitConnectionInfo {
 958                server_url: live_kit.url().into(),
 959                token,
 960                can_publish: true,
 961            })
 962        })
 963    }
 964    .await;
 965
 966    let room = session
 967        .db()
 968        .await
 969        .create_room(
 970            session.user_id,
 971            session.connection_id,
 972            &live_kit_room,
 973            &session.zed_environment,
 974        )
 975        .await?;
 976
 977    response.send(proto::CreateRoomResponse {
 978        room: Some(room.clone()),
 979        live_kit_connection_info,
 980    })?;
 981
 982    update_user_contacts(session.user_id, &session).await?;
 983    Ok(())
 984}
 985
 986async fn join_room(
 987    request: proto::JoinRoom,
 988    response: Response<proto::JoinRoom>,
 989    session: Session,
 990) -> Result<()> {
 991    let room_id = RoomId::from_proto(request.id);
 992
 993    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 994
 995    if let Some(channel_id) = channel_id {
 996        return join_channel_internal(channel_id, Box::new(response), session).await;
 997    }
 998
 999    let joined_room = {
1000        let room = session
1001            .db()
1002            .await
1003            .join_room(
1004                room_id,
1005                session.user_id,
1006                session.connection_id,
1007                session.zed_environment.as_ref(),
1008            )
1009            .await?;
1010        room_updated(&room.room, &session.peer);
1011        room.into_inner()
1012    };
1013
1014    for connection_id in session
1015        .connection_pool()
1016        .await
1017        .user_connection_ids(session.user_id)
1018    {
1019        session
1020            .peer
1021            .send(
1022                connection_id,
1023                proto::CallCanceled {
1024                    room_id: room_id.to_proto(),
1025                },
1026            )
1027            .trace_err();
1028    }
1029
1030    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1031        if let Some(token) = live_kit
1032            .room_token(
1033                &joined_room.room.live_kit_room,
1034                &session.user_id.to_string(),
1035            )
1036            .trace_err()
1037        {
1038            Some(proto::LiveKitConnectionInfo {
1039                server_url: live_kit.url().into(),
1040                token,
1041                can_publish: true,
1042            })
1043        } else {
1044            None
1045        }
1046    } else {
1047        None
1048    };
1049
1050    response.send(proto::JoinRoomResponse {
1051        room: Some(joined_room.room),
1052        channel_id: None,
1053        live_kit_connection_info,
1054    })?;
1055
1056    update_user_contacts(session.user_id, &session).await?;
1057    Ok(())
1058}
1059
1060async fn rejoin_room(
1061    request: proto::RejoinRoom,
1062    response: Response<proto::RejoinRoom>,
1063    session: Session,
1064) -> Result<()> {
1065    let room;
1066    let channel_id;
1067    let channel_members;
1068    {
1069        let mut rejoined_room = session
1070            .db()
1071            .await
1072            .rejoin_room(request, session.user_id, session.connection_id)
1073            .await?;
1074
1075        response.send(proto::RejoinRoomResponse {
1076            room: Some(rejoined_room.room.clone()),
1077            reshared_projects: rejoined_room
1078                .reshared_projects
1079                .iter()
1080                .map(|project| proto::ResharedProject {
1081                    id: project.id.to_proto(),
1082                    collaborators: project
1083                        .collaborators
1084                        .iter()
1085                        .map(|collaborator| collaborator.to_proto())
1086                        .collect(),
1087                })
1088                .collect(),
1089            rejoined_projects: rejoined_room
1090                .rejoined_projects
1091                .iter()
1092                .map(|rejoined_project| proto::RejoinedProject {
1093                    id: rejoined_project.id.to_proto(),
1094                    worktrees: rejoined_project
1095                        .worktrees
1096                        .iter()
1097                        .map(|worktree| proto::WorktreeMetadata {
1098                            id: worktree.id,
1099                            root_name: worktree.root_name.clone(),
1100                            visible: worktree.visible,
1101                            abs_path: worktree.abs_path.clone(),
1102                        })
1103                        .collect(),
1104                    collaborators: rejoined_project
1105                        .collaborators
1106                        .iter()
1107                        .map(|collaborator| collaborator.to_proto())
1108                        .collect(),
1109                    language_servers: rejoined_project.language_servers.clone(),
1110                })
1111                .collect(),
1112        })?;
1113        room_updated(&rejoined_room.room, &session.peer);
1114
1115        for project in &rejoined_room.reshared_projects {
1116            for collaborator in &project.collaborators {
1117                session
1118                    .peer
1119                    .send(
1120                        collaborator.connection_id,
1121                        proto::UpdateProjectCollaborator {
1122                            project_id: project.id.to_proto(),
1123                            old_peer_id: Some(project.old_connection_id.into()),
1124                            new_peer_id: Some(session.connection_id.into()),
1125                        },
1126                    )
1127                    .trace_err();
1128            }
1129
1130            broadcast(
1131                Some(session.connection_id),
1132                project
1133                    .collaborators
1134                    .iter()
1135                    .map(|collaborator| collaborator.connection_id),
1136                |connection_id| {
1137                    session.peer.forward_send(
1138                        session.connection_id,
1139                        connection_id,
1140                        proto::UpdateProject {
1141                            project_id: project.id.to_proto(),
1142                            worktrees: project.worktrees.clone(),
1143                        },
1144                    )
1145                },
1146            );
1147        }
1148
1149        for project in &rejoined_room.rejoined_projects {
1150            for collaborator in &project.collaborators {
1151                session
1152                    .peer
1153                    .send(
1154                        collaborator.connection_id,
1155                        proto::UpdateProjectCollaborator {
1156                            project_id: project.id.to_proto(),
1157                            old_peer_id: Some(project.old_connection_id.into()),
1158                            new_peer_id: Some(session.connection_id.into()),
1159                        },
1160                    )
1161                    .trace_err();
1162            }
1163        }
1164
1165        for project in &mut rejoined_room.rejoined_projects {
1166            for worktree in mem::take(&mut project.worktrees) {
1167                #[cfg(any(test, feature = "test-support"))]
1168                const MAX_CHUNK_SIZE: usize = 2;
1169                #[cfg(not(any(test, feature = "test-support")))]
1170                const MAX_CHUNK_SIZE: usize = 256;
1171
1172                // Stream this worktree's entries.
1173                let message = proto::UpdateWorktree {
1174                    project_id: project.id.to_proto(),
1175                    worktree_id: worktree.id,
1176                    abs_path: worktree.abs_path.clone(),
1177                    root_name: worktree.root_name,
1178                    updated_entries: worktree.updated_entries,
1179                    removed_entries: worktree.removed_entries,
1180                    scan_id: worktree.scan_id,
1181                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1182                    updated_repositories: worktree.updated_repositories,
1183                    removed_repositories: worktree.removed_repositories,
1184                };
1185                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1186                    session.peer.send(session.connection_id, update.clone())?;
1187                }
1188
1189                // Stream this worktree's diagnostics.
1190                for summary in worktree.diagnostic_summaries {
1191                    session.peer.send(
1192                        session.connection_id,
1193                        proto::UpdateDiagnosticSummary {
1194                            project_id: project.id.to_proto(),
1195                            worktree_id: worktree.id,
1196                            summary: Some(summary),
1197                        },
1198                    )?;
1199                }
1200
1201                for settings_file in worktree.settings_files {
1202                    session.peer.send(
1203                        session.connection_id,
1204                        proto::UpdateWorktreeSettings {
1205                            project_id: project.id.to_proto(),
1206                            worktree_id: worktree.id,
1207                            path: settings_file.path,
1208                            content: Some(settings_file.content),
1209                        },
1210                    )?;
1211                }
1212            }
1213
1214            for language_server in &project.language_servers {
1215                session.peer.send(
1216                    session.connection_id,
1217                    proto::UpdateLanguageServer {
1218                        project_id: project.id.to_proto(),
1219                        language_server_id: language_server.id,
1220                        variant: Some(
1221                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1222                                proto::LspDiskBasedDiagnosticsUpdated {},
1223                            ),
1224                        ),
1225                    },
1226                )?;
1227            }
1228        }
1229
1230        let rejoined_room = rejoined_room.into_inner();
1231
1232        room = rejoined_room.room;
1233        channel_id = rejoined_room.channel_id;
1234        channel_members = rejoined_room.channel_members;
1235    }
1236
1237    if let Some(channel_id) = channel_id {
1238        channel_updated(
1239            channel_id,
1240            &room,
1241            &channel_members,
1242            &session.peer,
1243            &*session.connection_pool().await,
1244        );
1245    }
1246
1247    update_user_contacts(session.user_id, &session).await?;
1248    Ok(())
1249}
1250
1251async fn leave_room(
1252    _: proto::LeaveRoom,
1253    response: Response<proto::LeaveRoom>,
1254    session: Session,
1255) -> Result<()> {
1256    leave_room_for_session(&session).await?;
1257    response.send(proto::Ack {})?;
1258    Ok(())
1259}
1260
1261async fn call(
1262    request: proto::Call,
1263    response: Response<proto::Call>,
1264    session: Session,
1265) -> Result<()> {
1266    let room_id = RoomId::from_proto(request.room_id);
1267    let calling_user_id = session.user_id;
1268    let calling_connection_id = session.connection_id;
1269    let called_user_id = UserId::from_proto(request.called_user_id);
1270    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1271    if !session
1272        .db()
1273        .await
1274        .has_contact(calling_user_id, called_user_id)
1275        .await?
1276    {
1277        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1278    }
1279
1280    let incoming_call = {
1281        let (room, incoming_call) = &mut *session
1282            .db()
1283            .await
1284            .call(
1285                room_id,
1286                calling_user_id,
1287                calling_connection_id,
1288                called_user_id,
1289                initial_project_id,
1290            )
1291            .await?;
1292        room_updated(&room, &session.peer);
1293        mem::take(incoming_call)
1294    };
1295    update_user_contacts(called_user_id, &session).await?;
1296
1297    let mut calls = session
1298        .connection_pool()
1299        .await
1300        .user_connection_ids(called_user_id)
1301        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1302        .collect::<FuturesUnordered<_>>();
1303
1304    while let Some(call_response) = calls.next().await {
1305        match call_response.as_ref() {
1306            Ok(_) => {
1307                response.send(proto::Ack {})?;
1308                return Ok(());
1309            }
1310            Err(_) => {
1311                call_response.trace_err();
1312            }
1313        }
1314    }
1315
1316    {
1317        let room = session
1318            .db()
1319            .await
1320            .call_failed(room_id, called_user_id)
1321            .await?;
1322        room_updated(&room, &session.peer);
1323    }
1324    update_user_contacts(called_user_id, &session).await?;
1325
1326    Err(anyhow!("failed to ring user"))?
1327}
1328
1329async fn cancel_call(
1330    request: proto::CancelCall,
1331    response: Response<proto::CancelCall>,
1332    session: Session,
1333) -> Result<()> {
1334    let called_user_id = UserId::from_proto(request.called_user_id);
1335    let room_id = RoomId::from_proto(request.room_id);
1336    {
1337        let room = session
1338            .db()
1339            .await
1340            .cancel_call(room_id, session.connection_id, called_user_id)
1341            .await?;
1342        room_updated(&room, &session.peer);
1343    }
1344
1345    for connection_id in session
1346        .connection_pool()
1347        .await
1348        .user_connection_ids(called_user_id)
1349    {
1350        session
1351            .peer
1352            .send(
1353                connection_id,
1354                proto::CallCanceled {
1355                    room_id: room_id.to_proto(),
1356                },
1357            )
1358            .trace_err();
1359    }
1360    response.send(proto::Ack {})?;
1361
1362    update_user_contacts(called_user_id, &session).await?;
1363    Ok(())
1364}
1365
1366async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1367    let room_id = RoomId::from_proto(message.room_id);
1368    {
1369        let room = session
1370            .db()
1371            .await
1372            .decline_call(Some(room_id), session.user_id)
1373            .await?
1374            .ok_or_else(|| anyhow!("failed to decline call"))?;
1375        room_updated(&room, &session.peer);
1376    }
1377
1378    for connection_id in session
1379        .connection_pool()
1380        .await
1381        .user_connection_ids(session.user_id)
1382    {
1383        session
1384            .peer
1385            .send(
1386                connection_id,
1387                proto::CallCanceled {
1388                    room_id: room_id.to_proto(),
1389                },
1390            )
1391            .trace_err();
1392    }
1393    update_user_contacts(session.user_id, &session).await?;
1394    Ok(())
1395}
1396
1397async fn update_participant_location(
1398    request: proto::UpdateParticipantLocation,
1399    response: Response<proto::UpdateParticipantLocation>,
1400    session: Session,
1401) -> Result<()> {
1402    let room_id = RoomId::from_proto(request.room_id);
1403    let location = request
1404        .location
1405        .ok_or_else(|| anyhow!("invalid location"))?;
1406
1407    let db = session.db().await;
1408    let room = db
1409        .update_room_participant_location(room_id, session.connection_id, location)
1410        .await?;
1411
1412    room_updated(&room, &session.peer);
1413    response.send(proto::Ack {})?;
1414    Ok(())
1415}
1416
1417async fn share_project(
1418    request: proto::ShareProject,
1419    response: Response<proto::ShareProject>,
1420    session: Session,
1421) -> Result<()> {
1422    let (project_id, room) = &*session
1423        .db()
1424        .await
1425        .share_project(
1426            RoomId::from_proto(request.room_id),
1427            session.connection_id,
1428            &request.worktrees,
1429        )
1430        .await?;
1431    response.send(proto::ShareProjectResponse {
1432        project_id: project_id.to_proto(),
1433    })?;
1434    room_updated(&room, &session.peer);
1435
1436    Ok(())
1437}
1438
1439async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1440    let project_id = ProjectId::from_proto(message.project_id);
1441
1442    let (room, guest_connection_ids) = &*session
1443        .db()
1444        .await
1445        .unshare_project(project_id, session.connection_id)
1446        .await?;
1447
1448    broadcast(
1449        Some(session.connection_id),
1450        guest_connection_ids.iter().copied(),
1451        |conn_id| session.peer.send(conn_id, message.clone()),
1452    );
1453    room_updated(&room, &session.peer);
1454
1455    Ok(())
1456}
1457
1458async fn join_project(
1459    request: proto::JoinProject,
1460    response: Response<proto::JoinProject>,
1461    session: Session,
1462) -> Result<()> {
1463    let project_id = ProjectId::from_proto(request.project_id);
1464    let guest_user_id = session.user_id;
1465
1466    tracing::info!(%project_id, "join project");
1467
1468    let (project, replica_id) = &mut *session
1469        .db()
1470        .await
1471        .join_project(project_id, session.connection_id)
1472        .await?;
1473
1474    let collaborators = project
1475        .collaborators
1476        .iter()
1477        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1478        .map(|collaborator| collaborator.to_proto())
1479        .collect::<Vec<_>>();
1480
1481    let worktrees = project
1482        .worktrees
1483        .iter()
1484        .map(|(id, worktree)| proto::WorktreeMetadata {
1485            id: *id,
1486            root_name: worktree.root_name.clone(),
1487            visible: worktree.visible,
1488            abs_path: worktree.abs_path.clone(),
1489        })
1490        .collect::<Vec<_>>();
1491
1492    for collaborator in &collaborators {
1493        session
1494            .peer
1495            .send(
1496                collaborator.peer_id.unwrap().into(),
1497                proto::AddProjectCollaborator {
1498                    project_id: project_id.to_proto(),
1499                    collaborator: Some(proto::Collaborator {
1500                        peer_id: Some(session.connection_id.into()),
1501                        replica_id: replica_id.0 as u32,
1502                        user_id: guest_user_id.to_proto(),
1503                    }),
1504                },
1505            )
1506            .trace_err();
1507    }
1508
1509    // First, we send the metadata associated with each worktree.
1510    response.send(proto::JoinProjectResponse {
1511        worktrees: worktrees.clone(),
1512        replica_id: replica_id.0 as u32,
1513        collaborators: collaborators.clone(),
1514        language_servers: project.language_servers.clone(),
1515    })?;
1516
1517    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1518        #[cfg(any(test, feature = "test-support"))]
1519        const MAX_CHUNK_SIZE: usize = 2;
1520        #[cfg(not(any(test, feature = "test-support")))]
1521        const MAX_CHUNK_SIZE: usize = 256;
1522
1523        // Stream this worktree's entries.
1524        let message = proto::UpdateWorktree {
1525            project_id: project_id.to_proto(),
1526            worktree_id,
1527            abs_path: worktree.abs_path.clone(),
1528            root_name: worktree.root_name,
1529            updated_entries: worktree.entries,
1530            removed_entries: Default::default(),
1531            scan_id: worktree.scan_id,
1532            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1533            updated_repositories: worktree.repository_entries.into_values().collect(),
1534            removed_repositories: Default::default(),
1535        };
1536        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1537            session.peer.send(session.connection_id, update.clone())?;
1538        }
1539
1540        // Stream this worktree's diagnostics.
1541        for summary in worktree.diagnostic_summaries {
1542            session.peer.send(
1543                session.connection_id,
1544                proto::UpdateDiagnosticSummary {
1545                    project_id: project_id.to_proto(),
1546                    worktree_id: worktree.id,
1547                    summary: Some(summary),
1548                },
1549            )?;
1550        }
1551
1552        for settings_file in worktree.settings_files {
1553            session.peer.send(
1554                session.connection_id,
1555                proto::UpdateWorktreeSettings {
1556                    project_id: project_id.to_proto(),
1557                    worktree_id: worktree.id,
1558                    path: settings_file.path,
1559                    content: Some(settings_file.content),
1560                },
1561            )?;
1562        }
1563    }
1564
1565    for language_server in &project.language_servers {
1566        session.peer.send(
1567            session.connection_id,
1568            proto::UpdateLanguageServer {
1569                project_id: project_id.to_proto(),
1570                language_server_id: language_server.id,
1571                variant: Some(
1572                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1573                        proto::LspDiskBasedDiagnosticsUpdated {},
1574                    ),
1575                ),
1576            },
1577        )?;
1578    }
1579
1580    Ok(())
1581}
1582
1583async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1584    let sender_id = session.connection_id;
1585    let project_id = ProjectId::from_proto(request.project_id);
1586
1587    let (room, project) = &*session
1588        .db()
1589        .await
1590        .leave_project(project_id, sender_id)
1591        .await?;
1592    tracing::info!(
1593        %project_id,
1594        host_user_id = %project.host_user_id,
1595        host_connection_id = %project.host_connection_id,
1596        "leave project"
1597    );
1598
1599    project_left(&project, &session);
1600    room_updated(&room, &session.peer);
1601
1602    Ok(())
1603}
1604
1605async fn update_project(
1606    request: proto::UpdateProject,
1607    response: Response<proto::UpdateProject>,
1608    session: Session,
1609) -> Result<()> {
1610    let project_id = ProjectId::from_proto(request.project_id);
1611    let (room, guest_connection_ids) = &*session
1612        .db()
1613        .await
1614        .update_project(project_id, session.connection_id, &request.worktrees)
1615        .await?;
1616    broadcast(
1617        Some(session.connection_id),
1618        guest_connection_ids.iter().copied(),
1619        |connection_id| {
1620            session
1621                .peer
1622                .forward_send(session.connection_id, connection_id, request.clone())
1623        },
1624    );
1625    room_updated(&room, &session.peer);
1626    response.send(proto::Ack {})?;
1627
1628    Ok(())
1629}
1630
1631async fn update_worktree(
1632    request: proto::UpdateWorktree,
1633    response: Response<proto::UpdateWorktree>,
1634    session: Session,
1635) -> Result<()> {
1636    let guest_connection_ids = session
1637        .db()
1638        .await
1639        .update_worktree(&request, session.connection_id)
1640        .await?;
1641
1642    broadcast(
1643        Some(session.connection_id),
1644        guest_connection_ids.iter().copied(),
1645        |connection_id| {
1646            session
1647                .peer
1648                .forward_send(session.connection_id, connection_id, request.clone())
1649        },
1650    );
1651    response.send(proto::Ack {})?;
1652    Ok(())
1653}
1654
1655async fn update_diagnostic_summary(
1656    message: proto::UpdateDiagnosticSummary,
1657    session: Session,
1658) -> Result<()> {
1659    let guest_connection_ids = session
1660        .db()
1661        .await
1662        .update_diagnostic_summary(&message, session.connection_id)
1663        .await?;
1664
1665    broadcast(
1666        Some(session.connection_id),
1667        guest_connection_ids.iter().copied(),
1668        |connection_id| {
1669            session
1670                .peer
1671                .forward_send(session.connection_id, connection_id, message.clone())
1672        },
1673    );
1674
1675    Ok(())
1676}
1677
1678async fn update_worktree_settings(
1679    message: proto::UpdateWorktreeSettings,
1680    session: Session,
1681) -> Result<()> {
1682    let guest_connection_ids = session
1683        .db()
1684        .await
1685        .update_worktree_settings(&message, session.connection_id)
1686        .await?;
1687
1688    broadcast(
1689        Some(session.connection_id),
1690        guest_connection_ids.iter().copied(),
1691        |connection_id| {
1692            session
1693                .peer
1694                .forward_send(session.connection_id, connection_id, message.clone())
1695        },
1696    );
1697
1698    Ok(())
1699}
1700
1701async fn start_language_server(
1702    request: proto::StartLanguageServer,
1703    session: Session,
1704) -> Result<()> {
1705    let guest_connection_ids = session
1706        .db()
1707        .await
1708        .start_language_server(&request, session.connection_id)
1709        .await?;
1710
1711    broadcast(
1712        Some(session.connection_id),
1713        guest_connection_ids.iter().copied(),
1714        |connection_id| {
1715            session
1716                .peer
1717                .forward_send(session.connection_id, connection_id, request.clone())
1718        },
1719    );
1720    Ok(())
1721}
1722
1723async fn update_language_server(
1724    request: proto::UpdateLanguageServer,
1725    session: Session,
1726) -> Result<()> {
1727    let project_id = ProjectId::from_proto(request.project_id);
1728    let project_connection_ids = session
1729        .db()
1730        .await
1731        .project_connection_ids(project_id, session.connection_id)
1732        .await?;
1733    broadcast(
1734        Some(session.connection_id),
1735        project_connection_ids.iter().copied(),
1736        |connection_id| {
1737            session
1738                .peer
1739                .forward_send(session.connection_id, connection_id, request.clone())
1740        },
1741    );
1742    Ok(())
1743}
1744
1745async fn forward_read_only_project_request<T>(
1746    request: T,
1747    response: Response<T>,
1748    session: Session,
1749) -> Result<()>
1750where
1751    T: EntityMessage + RequestMessage,
1752{
1753    let project_id = ProjectId::from_proto(request.remote_entity_id());
1754    let host_connection_id = session
1755        .db()
1756        .await
1757        .host_for_read_only_project_request(project_id, session.connection_id)
1758        .await?;
1759    let payload = session
1760        .peer
1761        .forward_request(session.connection_id, host_connection_id, request)
1762        .await?;
1763    response.send(payload)?;
1764    Ok(())
1765}
1766
1767async fn forward_mutating_project_request<T>(
1768    request: T,
1769    response: Response<T>,
1770    session: Session,
1771) -> Result<()>
1772where
1773    T: EntityMessage + RequestMessage,
1774{
1775    let project_id = ProjectId::from_proto(request.remote_entity_id());
1776    let host_connection_id = session
1777        .db()
1778        .await
1779        .host_for_mutating_project_request(project_id, session.connection_id)
1780        .await?;
1781    let payload = session
1782        .peer
1783        .forward_request(session.connection_id, host_connection_id, request)
1784        .await?;
1785    response.send(payload)?;
1786    Ok(())
1787}
1788
1789async fn create_buffer_for_peer(
1790    request: proto::CreateBufferForPeer,
1791    session: Session,
1792) -> Result<()> {
1793    session
1794        .db()
1795        .await
1796        .check_user_is_project_host(
1797            ProjectId::from_proto(request.project_id),
1798            session.connection_id,
1799        )
1800        .await?;
1801    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1802    session
1803        .peer
1804        .forward_send(session.connection_id, peer_id.into(), request)?;
1805    Ok(())
1806}
1807
1808async fn update_buffer(
1809    request: proto::UpdateBuffer,
1810    response: Response<proto::UpdateBuffer>,
1811    session: Session,
1812) -> Result<()> {
1813    let project_id = ProjectId::from_proto(request.project_id);
1814    let mut guest_connection_ids;
1815    let mut host_connection_id = None;
1816
1817    let mut requires_write_permission = false;
1818
1819    for op in request.operations.iter() {
1820        match op.variant {
1821            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1822            Some(_) => requires_write_permission = true,
1823        }
1824    }
1825
1826    {
1827        let collaborators = session
1828            .db()
1829            .await
1830            .project_collaborators_for_buffer_update(
1831                project_id,
1832                session.connection_id,
1833                requires_write_permission,
1834            )
1835            .await?;
1836        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1837        for collaborator in collaborators.iter() {
1838            if collaborator.is_host {
1839                host_connection_id = Some(collaborator.connection_id);
1840            } else {
1841                guest_connection_ids.push(collaborator.connection_id);
1842            }
1843        }
1844    }
1845    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1846
1847    broadcast(
1848        Some(session.connection_id),
1849        guest_connection_ids,
1850        |connection_id| {
1851            session
1852                .peer
1853                .forward_send(session.connection_id, connection_id, request.clone())
1854        },
1855    );
1856    if host_connection_id != session.connection_id {
1857        session
1858            .peer
1859            .forward_request(session.connection_id, host_connection_id, request.clone())
1860            .await?;
1861    }
1862
1863    response.send(proto::Ack {})?;
1864    Ok(())
1865}
1866
1867async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1868    request: T,
1869    session: Session,
1870) -> Result<()> {
1871    let project_id = ProjectId::from_proto(request.remote_entity_id());
1872    let project_connection_ids = session
1873        .db()
1874        .await
1875        .project_connection_ids(project_id, session.connection_id)
1876        .await?;
1877
1878    broadcast(
1879        Some(session.connection_id),
1880        project_connection_ids.iter().copied(),
1881        |connection_id| {
1882            session
1883                .peer
1884                .forward_send(session.connection_id, connection_id, request.clone())
1885        },
1886    );
1887    Ok(())
1888}
1889
1890async fn follow(
1891    request: proto::Follow,
1892    response: Response<proto::Follow>,
1893    session: Session,
1894) -> Result<()> {
1895    let room_id = RoomId::from_proto(request.room_id);
1896    let project_id = request.project_id.map(ProjectId::from_proto);
1897    let leader_id = request
1898        .leader_id
1899        .ok_or_else(|| anyhow!("invalid leader id"))?
1900        .into();
1901    let follower_id = session.connection_id;
1902
1903    session
1904        .db()
1905        .await
1906        .check_room_participants(room_id, leader_id, session.connection_id)
1907        .await?;
1908
1909    let response_payload = session
1910        .peer
1911        .forward_request(session.connection_id, leader_id, request)
1912        .await?;
1913    response.send(response_payload)?;
1914
1915    if let Some(project_id) = project_id {
1916        let room = session
1917            .db()
1918            .await
1919            .follow(room_id, project_id, leader_id, follower_id)
1920            .await?;
1921        room_updated(&room, &session.peer);
1922    }
1923
1924    Ok(())
1925}
1926
1927async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1928    let room_id = RoomId::from_proto(request.room_id);
1929    let project_id = request.project_id.map(ProjectId::from_proto);
1930    let leader_id = request
1931        .leader_id
1932        .ok_or_else(|| anyhow!("invalid leader id"))?
1933        .into();
1934    let follower_id = session.connection_id;
1935
1936    session
1937        .db()
1938        .await
1939        .check_room_participants(room_id, leader_id, session.connection_id)
1940        .await?;
1941
1942    session
1943        .peer
1944        .forward_send(session.connection_id, leader_id, request)?;
1945
1946    if let Some(project_id) = project_id {
1947        let room = session
1948            .db()
1949            .await
1950            .unfollow(room_id, project_id, leader_id, follower_id)
1951            .await?;
1952        room_updated(&room, &session.peer);
1953    }
1954
1955    Ok(())
1956}
1957
1958async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1959    let room_id = RoomId::from_proto(request.room_id);
1960    let database = session.db.lock().await;
1961
1962    let connection_ids = if let Some(project_id) = request.project_id {
1963        let project_id = ProjectId::from_proto(project_id);
1964        database
1965            .project_connection_ids(project_id, session.connection_id)
1966            .await?
1967    } else {
1968        database
1969            .room_connection_ids(room_id, session.connection_id)
1970            .await?
1971    };
1972
1973    // For now, don't send view update messages back to that view's current leader.
1974    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1975        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1976        _ => None,
1977    });
1978
1979    for follower_peer_id in request.follower_ids.iter().copied() {
1980        let follower_connection_id = follower_peer_id.into();
1981        if Some(follower_peer_id) != connection_id_to_omit
1982            && connection_ids.contains(&follower_connection_id)
1983        {
1984            session.peer.forward_send(
1985                session.connection_id,
1986                follower_connection_id,
1987                request.clone(),
1988            )?;
1989        }
1990    }
1991    Ok(())
1992}
1993
1994async fn get_users(
1995    request: proto::GetUsers,
1996    response: Response<proto::GetUsers>,
1997    session: Session,
1998) -> Result<()> {
1999    let user_ids = request
2000        .user_ids
2001        .into_iter()
2002        .map(UserId::from_proto)
2003        .collect();
2004    let users = session
2005        .db()
2006        .await
2007        .get_users_by_ids(user_ids)
2008        .await?
2009        .into_iter()
2010        .map(|user| proto::User {
2011            id: user.id.to_proto(),
2012            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2013            github_login: user.github_login,
2014        })
2015        .collect();
2016    response.send(proto::UsersResponse { users })?;
2017    Ok(())
2018}
2019
2020async fn fuzzy_search_users(
2021    request: proto::FuzzySearchUsers,
2022    response: Response<proto::FuzzySearchUsers>,
2023    session: Session,
2024) -> Result<()> {
2025    let query = request.query;
2026    let users = match query.len() {
2027        0 => vec![],
2028        1 | 2 => session
2029            .db()
2030            .await
2031            .get_user_by_github_login(&query)
2032            .await?
2033            .into_iter()
2034            .collect(),
2035        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2036    };
2037    let users = users
2038        .into_iter()
2039        .filter(|user| user.id != session.user_id)
2040        .map(|user| proto::User {
2041            id: user.id.to_proto(),
2042            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2043            github_login: user.github_login,
2044        })
2045        .collect();
2046    response.send(proto::UsersResponse { users })?;
2047    Ok(())
2048}
2049
2050async fn request_contact(
2051    request: proto::RequestContact,
2052    response: Response<proto::RequestContact>,
2053    session: Session,
2054) -> Result<()> {
2055    let requester_id = session.user_id;
2056    let responder_id = UserId::from_proto(request.responder_id);
2057    if requester_id == responder_id {
2058        return Err(anyhow!("cannot add yourself as a contact"))?;
2059    }
2060
2061    let notifications = session
2062        .db()
2063        .await
2064        .send_contact_request(requester_id, responder_id)
2065        .await?;
2066
2067    // Update outgoing contact requests of requester
2068    let mut update = proto::UpdateContacts::default();
2069    update.outgoing_requests.push(responder_id.to_proto());
2070    for connection_id in session
2071        .connection_pool()
2072        .await
2073        .user_connection_ids(requester_id)
2074    {
2075        session.peer.send(connection_id, update.clone())?;
2076    }
2077
2078    // Update incoming contact requests of responder
2079    let mut update = proto::UpdateContacts::default();
2080    update
2081        .incoming_requests
2082        .push(proto::IncomingContactRequest {
2083            requester_id: requester_id.to_proto(),
2084        });
2085    let connection_pool = session.connection_pool().await;
2086    for connection_id in connection_pool.user_connection_ids(responder_id) {
2087        session.peer.send(connection_id, update.clone())?;
2088    }
2089
2090    send_notifications(&*connection_pool, &session.peer, notifications);
2091
2092    response.send(proto::Ack {})?;
2093    Ok(())
2094}
2095
2096async fn respond_to_contact_request(
2097    request: proto::RespondToContactRequest,
2098    response: Response<proto::RespondToContactRequest>,
2099    session: Session,
2100) -> Result<()> {
2101    let responder_id = session.user_id;
2102    let requester_id = UserId::from_proto(request.requester_id);
2103    let db = session.db().await;
2104    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2105        db.dismiss_contact_notification(responder_id, requester_id)
2106            .await?;
2107    } else {
2108        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2109
2110        let notifications = db
2111            .respond_to_contact_request(responder_id, requester_id, accept)
2112            .await?;
2113        let requester_busy = db.is_user_busy(requester_id).await?;
2114        let responder_busy = db.is_user_busy(responder_id).await?;
2115
2116        let pool = session.connection_pool().await;
2117        // Update responder with new contact
2118        let mut update = proto::UpdateContacts::default();
2119        if accept {
2120            update
2121                .contacts
2122                .push(contact_for_user(requester_id, requester_busy, &pool));
2123        }
2124        update
2125            .remove_incoming_requests
2126            .push(requester_id.to_proto());
2127        for connection_id in pool.user_connection_ids(responder_id) {
2128            session.peer.send(connection_id, update.clone())?;
2129        }
2130
2131        // Update requester with new contact
2132        let mut update = proto::UpdateContacts::default();
2133        if accept {
2134            update
2135                .contacts
2136                .push(contact_for_user(responder_id, responder_busy, &pool));
2137        }
2138        update
2139            .remove_outgoing_requests
2140            .push(responder_id.to_proto());
2141
2142        for connection_id in pool.user_connection_ids(requester_id) {
2143            session.peer.send(connection_id, update.clone())?;
2144        }
2145
2146        send_notifications(&*pool, &session.peer, notifications);
2147    }
2148
2149    response.send(proto::Ack {})?;
2150    Ok(())
2151}
2152
2153async fn remove_contact(
2154    request: proto::RemoveContact,
2155    response: Response<proto::RemoveContact>,
2156    session: Session,
2157) -> Result<()> {
2158    let requester_id = session.user_id;
2159    let responder_id = UserId::from_proto(request.user_id);
2160    let db = session.db().await;
2161    let (contact_accepted, deleted_notification_id) =
2162        db.remove_contact(requester_id, responder_id).await?;
2163
2164    let pool = session.connection_pool().await;
2165    // Update outgoing contact requests of requester
2166    let mut update = proto::UpdateContacts::default();
2167    if contact_accepted {
2168        update.remove_contacts.push(responder_id.to_proto());
2169    } else {
2170        update
2171            .remove_outgoing_requests
2172            .push(responder_id.to_proto());
2173    }
2174    for connection_id in pool.user_connection_ids(requester_id) {
2175        session.peer.send(connection_id, update.clone())?;
2176    }
2177
2178    // Update incoming contact requests of responder
2179    let mut update = proto::UpdateContacts::default();
2180    if contact_accepted {
2181        update.remove_contacts.push(requester_id.to_proto());
2182    } else {
2183        update
2184            .remove_incoming_requests
2185            .push(requester_id.to_proto());
2186    }
2187    for connection_id in pool.user_connection_ids(responder_id) {
2188        session.peer.send(connection_id, update.clone())?;
2189        if let Some(notification_id) = deleted_notification_id {
2190            session.peer.send(
2191                connection_id,
2192                proto::DeleteNotification {
2193                    notification_id: notification_id.to_proto(),
2194                },
2195            )?;
2196        }
2197    }
2198
2199    response.send(proto::Ack {})?;
2200    Ok(())
2201}
2202
2203async fn create_channel(
2204    request: proto::CreateChannel,
2205    response: Response<proto::CreateChannel>,
2206    session: Session,
2207) -> Result<()> {
2208    let db = session.db().await;
2209
2210    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2211    let CreateChannelResult {
2212        channel,
2213        participants_to_update,
2214    } = db
2215        .create_channel(&request.name, parent_id, session.user_id)
2216        .await?;
2217
2218    response.send(proto::CreateChannelResponse {
2219        channel: Some(channel.to_proto()),
2220        parent_id: request.parent_id,
2221    })?;
2222
2223    let connection_pool = session.connection_pool().await;
2224    for (user_id, channels) in participants_to_update {
2225        let update = build_channels_update(channels, vec![]);
2226        for connection_id in connection_pool.user_connection_ids(user_id) {
2227            if user_id == session.user_id {
2228                continue;
2229            }
2230            session.peer.send(connection_id, update.clone())?;
2231        }
2232    }
2233
2234    Ok(())
2235}
2236
2237async fn delete_channel(
2238    request: proto::DeleteChannel,
2239    response: Response<proto::DeleteChannel>,
2240    session: Session,
2241) -> Result<()> {
2242    let db = session.db().await;
2243
2244    let channel_id = request.channel_id;
2245    let (removed_channels, member_ids) = db
2246        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2247        .await?;
2248    response.send(proto::Ack {})?;
2249
2250    // Notify members of removed channels
2251    let mut update = proto::UpdateChannels::default();
2252    update
2253        .delete_channels
2254        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2255
2256    let connection_pool = session.connection_pool().await;
2257    for member_id in member_ids {
2258        for connection_id in connection_pool.user_connection_ids(member_id) {
2259            session.peer.send(connection_id, update.clone())?;
2260        }
2261    }
2262
2263    Ok(())
2264}
2265
2266async fn invite_channel_member(
2267    request: proto::InviteChannelMember,
2268    response: Response<proto::InviteChannelMember>,
2269    session: Session,
2270) -> Result<()> {
2271    let db = session.db().await;
2272    let channel_id = ChannelId::from_proto(request.channel_id);
2273    let invitee_id = UserId::from_proto(request.user_id);
2274    let InviteMemberResult {
2275        channel,
2276        notifications,
2277    } = db
2278        .invite_channel_member(
2279            channel_id,
2280            invitee_id,
2281            session.user_id,
2282            request.role().into(),
2283        )
2284        .await?;
2285
2286    let update = proto::UpdateChannels {
2287        channel_invitations: vec![channel.to_proto()],
2288        ..Default::default()
2289    };
2290
2291    let connection_pool = session.connection_pool().await;
2292    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2293        session.peer.send(connection_id, update.clone())?;
2294    }
2295
2296    send_notifications(&*connection_pool, &session.peer, notifications);
2297
2298    response.send(proto::Ack {})?;
2299    Ok(())
2300}
2301
2302async fn remove_channel_member(
2303    request: proto::RemoveChannelMember,
2304    response: Response<proto::RemoveChannelMember>,
2305    session: Session,
2306) -> Result<()> {
2307    let db = session.db().await;
2308    let channel_id = ChannelId::from_proto(request.channel_id);
2309    let member_id = UserId::from_proto(request.user_id);
2310
2311    let RemoveChannelMemberResult {
2312        membership_update,
2313        notification_id,
2314    } = db
2315        .remove_channel_member(channel_id, member_id, session.user_id)
2316        .await?;
2317
2318    let connection_pool = &session.connection_pool().await;
2319    notify_membership_updated(
2320        &connection_pool,
2321        membership_update,
2322        member_id,
2323        &session.peer,
2324    );
2325    for connection_id in connection_pool.user_connection_ids(member_id) {
2326        if let Some(notification_id) = notification_id {
2327            session
2328                .peer
2329                .send(
2330                    connection_id,
2331                    proto::DeleteNotification {
2332                        notification_id: notification_id.to_proto(),
2333                    },
2334                )
2335                .trace_err();
2336        }
2337    }
2338
2339    response.send(proto::Ack {})?;
2340    Ok(())
2341}
2342
2343async fn set_channel_visibility(
2344    request: proto::SetChannelVisibility,
2345    response: Response<proto::SetChannelVisibility>,
2346    session: Session,
2347) -> Result<()> {
2348    let db = session.db().await;
2349    let channel_id = ChannelId::from_proto(request.channel_id);
2350    let visibility = request.visibility().into();
2351
2352    let SetChannelVisibilityResult {
2353        participants_to_update,
2354        participants_to_remove,
2355        channels_to_remove,
2356    } = db
2357        .set_channel_visibility(channel_id, visibility, session.user_id)
2358        .await?;
2359
2360    let connection_pool = session.connection_pool().await;
2361    for (user_id, channels) in participants_to_update {
2362        let update = build_channels_update(channels, vec![]);
2363        for connection_id in connection_pool.user_connection_ids(user_id) {
2364            session.peer.send(connection_id, update.clone())?;
2365        }
2366    }
2367    for user_id in participants_to_remove {
2368        let update = proto::UpdateChannels {
2369            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2370            ..Default::default()
2371        };
2372        for connection_id in connection_pool.user_connection_ids(user_id) {
2373            session.peer.send(connection_id, update.clone())?;
2374        }
2375    }
2376
2377    response.send(proto::Ack {})?;
2378    Ok(())
2379}
2380
2381async fn set_channel_member_role(
2382    request: proto::SetChannelMemberRole,
2383    response: Response<proto::SetChannelMemberRole>,
2384    session: Session,
2385) -> Result<()> {
2386    let db = session.db().await;
2387    let channel_id = ChannelId::from_proto(request.channel_id);
2388    let member_id = UserId::from_proto(request.user_id);
2389    let result = db
2390        .set_channel_member_role(
2391            channel_id,
2392            session.user_id,
2393            member_id,
2394            request.role().into(),
2395        )
2396        .await?;
2397
2398    match result {
2399        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2400            let connection_pool = session.connection_pool().await;
2401            notify_membership_updated(
2402                &connection_pool,
2403                membership_update,
2404                member_id,
2405                &session.peer,
2406            )
2407        }
2408        db::SetMemberRoleResult::InviteUpdated(channel) => {
2409            let update = proto::UpdateChannels {
2410                channel_invitations: vec![channel.to_proto()],
2411                ..Default::default()
2412            };
2413
2414            for connection_id in session
2415                .connection_pool()
2416                .await
2417                .user_connection_ids(member_id)
2418            {
2419                session.peer.send(connection_id, update.clone())?;
2420            }
2421        }
2422    }
2423
2424    response.send(proto::Ack {})?;
2425    Ok(())
2426}
2427
2428async fn rename_channel(
2429    request: proto::RenameChannel,
2430    response: Response<proto::RenameChannel>,
2431    session: Session,
2432) -> Result<()> {
2433    let db = session.db().await;
2434    let channel_id = ChannelId::from_proto(request.channel_id);
2435    let RenameChannelResult {
2436        channel,
2437        participants_to_update,
2438    } = db
2439        .rename_channel(channel_id, session.user_id, &request.name)
2440        .await?;
2441
2442    response.send(proto::RenameChannelResponse {
2443        channel: Some(channel.to_proto()),
2444    })?;
2445
2446    let connection_pool = session.connection_pool().await;
2447    for (user_id, channel) in participants_to_update {
2448        for connection_id in connection_pool.user_connection_ids(user_id) {
2449            let update = proto::UpdateChannels {
2450                channels: vec![channel.to_proto()],
2451                ..Default::default()
2452            };
2453
2454            session.peer.send(connection_id, update.clone())?;
2455        }
2456    }
2457
2458    Ok(())
2459}
2460
2461async fn move_channel(
2462    request: proto::MoveChannel,
2463    response: Response<proto::MoveChannel>,
2464    session: Session,
2465) -> Result<()> {
2466    let channel_id = ChannelId::from_proto(request.channel_id);
2467    let to = request.to.map(ChannelId::from_proto);
2468
2469    let result = session
2470        .db()
2471        .await
2472        .move_channel(channel_id, to, session.user_id)
2473        .await?;
2474
2475    notify_channel_moved(result, session).await?;
2476
2477    response.send(Ack {})?;
2478    Ok(())
2479}
2480
2481async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2482    let Some(MoveChannelResult {
2483        participants_to_remove,
2484        participants_to_update,
2485        moved_channels,
2486    }) = result
2487    else {
2488        return Ok(());
2489    };
2490    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2491
2492    let connection_pool = session.connection_pool().await;
2493    for (user_id, channels) in participants_to_update {
2494        let mut update = build_channels_update(channels, vec![]);
2495        update.delete_channels = moved_channels.clone();
2496        for connection_id in connection_pool.user_connection_ids(user_id) {
2497            session.peer.send(connection_id, update.clone())?;
2498        }
2499    }
2500
2501    for user_id in participants_to_remove {
2502        let update = proto::UpdateChannels {
2503            delete_channels: moved_channels.clone(),
2504            ..Default::default()
2505        };
2506        for connection_id in connection_pool.user_connection_ids(user_id) {
2507            session.peer.send(connection_id, update.clone())?;
2508        }
2509    }
2510    Ok(())
2511}
2512
2513async fn get_channel_members(
2514    request: proto::GetChannelMembers,
2515    response: Response<proto::GetChannelMembers>,
2516    session: Session,
2517) -> Result<()> {
2518    let db = session.db().await;
2519    let channel_id = ChannelId::from_proto(request.channel_id);
2520    let members = db
2521        .get_channel_participant_details(channel_id, session.user_id)
2522        .await?;
2523    response.send(proto::GetChannelMembersResponse { members })?;
2524    Ok(())
2525}
2526
2527async fn respond_to_channel_invite(
2528    request: proto::RespondToChannelInvite,
2529    response: Response<proto::RespondToChannelInvite>,
2530    session: Session,
2531) -> Result<()> {
2532    let db = session.db().await;
2533    let channel_id = ChannelId::from_proto(request.channel_id);
2534    let RespondToChannelInvite {
2535        membership_update,
2536        notifications,
2537    } = db
2538        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2539        .await?;
2540
2541    let connection_pool = session.connection_pool().await;
2542    if let Some(membership_update) = membership_update {
2543        notify_membership_updated(
2544            &connection_pool,
2545            membership_update,
2546            session.user_id,
2547            &session.peer,
2548        );
2549    } else {
2550        let update = proto::UpdateChannels {
2551            remove_channel_invitations: vec![channel_id.to_proto()],
2552            ..Default::default()
2553        };
2554
2555        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2556            session.peer.send(connection_id, update.clone())?;
2557        }
2558    };
2559
2560    send_notifications(&*connection_pool, &session.peer, notifications);
2561
2562    response.send(proto::Ack {})?;
2563
2564    Ok(())
2565}
2566
2567async fn join_channel(
2568    request: proto::JoinChannel,
2569    response: Response<proto::JoinChannel>,
2570    session: Session,
2571) -> Result<()> {
2572    let channel_id = ChannelId::from_proto(request.channel_id);
2573    join_channel_internal(channel_id, Box::new(response), session).await
2574}
2575
2576trait JoinChannelInternalResponse {
2577    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2578}
2579impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2580    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2581        Response::<proto::JoinChannel>::send(self, result)
2582    }
2583}
2584impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2585    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2586        Response::<proto::JoinRoom>::send(self, result)
2587    }
2588}
2589
2590async fn join_channel_internal(
2591    channel_id: ChannelId,
2592    response: Box<impl JoinChannelInternalResponse>,
2593    session: Session,
2594) -> Result<()> {
2595    let joined_room = {
2596        leave_room_for_session(&session).await?;
2597        let db = session.db().await;
2598
2599        let (joined_room, membership_updated, role) = db
2600            .join_channel(
2601                channel_id,
2602                session.user_id,
2603                session.connection_id,
2604                session.zed_environment.as_ref(),
2605            )
2606            .await?;
2607
2608        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2609            let (can_publish, token) = if role == ChannelRole::Guest {
2610                (
2611                    false,
2612                    live_kit
2613                        .guest_token(
2614                            &joined_room.room.live_kit_room,
2615                            &session.user_id.to_string(),
2616                        )
2617                        .trace_err()?,
2618                )
2619            } else {
2620                (
2621                    true,
2622                    live_kit
2623                        .room_token(
2624                            &joined_room.room.live_kit_room,
2625                            &session.user_id.to_string(),
2626                        )
2627                        .trace_err()?,
2628                )
2629            };
2630
2631            Some(LiveKitConnectionInfo {
2632                server_url: live_kit.url().into(),
2633                token,
2634                can_publish,
2635            })
2636        });
2637
2638        response.send(proto::JoinRoomResponse {
2639            room: Some(joined_room.room.clone()),
2640            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2641            live_kit_connection_info,
2642        })?;
2643
2644        let connection_pool = session.connection_pool().await;
2645        if let Some(membership_updated) = membership_updated {
2646            notify_membership_updated(
2647                &connection_pool,
2648                membership_updated,
2649                session.user_id,
2650                &session.peer,
2651            );
2652        }
2653
2654        room_updated(&joined_room.room, &session.peer);
2655
2656        joined_room
2657    };
2658
2659    channel_updated(
2660        channel_id,
2661        &joined_room.room,
2662        &joined_room.channel_members,
2663        &session.peer,
2664        &*session.connection_pool().await,
2665    );
2666
2667    update_user_contacts(session.user_id, &session).await?;
2668    Ok(())
2669}
2670
2671async fn join_channel_buffer(
2672    request: proto::JoinChannelBuffer,
2673    response: Response<proto::JoinChannelBuffer>,
2674    session: Session,
2675) -> Result<()> {
2676    let db = session.db().await;
2677    let channel_id = ChannelId::from_proto(request.channel_id);
2678
2679    let open_response = db
2680        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2681        .await?;
2682
2683    let collaborators = open_response.collaborators.clone();
2684    response.send(open_response)?;
2685
2686    let update = UpdateChannelBufferCollaborators {
2687        channel_id: channel_id.to_proto(),
2688        collaborators: collaborators.clone(),
2689    };
2690    channel_buffer_updated(
2691        session.connection_id,
2692        collaborators
2693            .iter()
2694            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2695        &update,
2696        &session.peer,
2697    );
2698
2699    Ok(())
2700}
2701
2702async fn update_channel_buffer(
2703    request: proto::UpdateChannelBuffer,
2704    session: Session,
2705) -> Result<()> {
2706    let db = session.db().await;
2707    let channel_id = ChannelId::from_proto(request.channel_id);
2708
2709    let (collaborators, non_collaborators, epoch, version) = db
2710        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2711        .await?;
2712
2713    channel_buffer_updated(
2714        session.connection_id,
2715        collaborators,
2716        &proto::UpdateChannelBuffer {
2717            channel_id: channel_id.to_proto(),
2718            operations: request.operations,
2719        },
2720        &session.peer,
2721    );
2722
2723    let pool = &*session.connection_pool().await;
2724
2725    broadcast(
2726        None,
2727        non_collaborators
2728            .iter()
2729            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2730        |peer_id| {
2731            session.peer.send(
2732                peer_id.into(),
2733                proto::UpdateChannels {
2734                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2735                        channel_id: channel_id.to_proto(),
2736                        epoch: epoch as u64,
2737                        version: version.clone(),
2738                    }],
2739                    ..Default::default()
2740                },
2741            )
2742        },
2743    );
2744
2745    Ok(())
2746}
2747
2748async fn rejoin_channel_buffers(
2749    request: proto::RejoinChannelBuffers,
2750    response: Response<proto::RejoinChannelBuffers>,
2751    session: Session,
2752) -> Result<()> {
2753    let db = session.db().await;
2754    let buffers = db
2755        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2756        .await?;
2757
2758    for rejoined_buffer in &buffers {
2759        let collaborators_to_notify = rejoined_buffer
2760            .buffer
2761            .collaborators
2762            .iter()
2763            .filter_map(|c| Some(c.peer_id?.into()));
2764        channel_buffer_updated(
2765            session.connection_id,
2766            collaborators_to_notify,
2767            &proto::UpdateChannelBufferCollaborators {
2768                channel_id: rejoined_buffer.buffer.channel_id,
2769                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2770            },
2771            &session.peer,
2772        );
2773    }
2774
2775    response.send(proto::RejoinChannelBuffersResponse {
2776        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2777    })?;
2778
2779    Ok(())
2780}
2781
2782async fn leave_channel_buffer(
2783    request: proto::LeaveChannelBuffer,
2784    response: Response<proto::LeaveChannelBuffer>,
2785    session: Session,
2786) -> Result<()> {
2787    let db = session.db().await;
2788    let channel_id = ChannelId::from_proto(request.channel_id);
2789
2790    let left_buffer = db
2791        .leave_channel_buffer(channel_id, session.connection_id)
2792        .await?;
2793
2794    response.send(Ack {})?;
2795
2796    channel_buffer_updated(
2797        session.connection_id,
2798        left_buffer.connections,
2799        &proto::UpdateChannelBufferCollaborators {
2800            channel_id: channel_id.to_proto(),
2801            collaborators: left_buffer.collaborators,
2802        },
2803        &session.peer,
2804    );
2805
2806    Ok(())
2807}
2808
2809fn channel_buffer_updated<T: EnvelopedMessage>(
2810    sender_id: ConnectionId,
2811    collaborators: impl IntoIterator<Item = ConnectionId>,
2812    message: &T,
2813    peer: &Peer,
2814) {
2815    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2816        peer.send(peer_id.into(), message.clone())
2817    });
2818}
2819
2820fn send_notifications(
2821    connection_pool: &ConnectionPool,
2822    peer: &Peer,
2823    notifications: db::NotificationBatch,
2824) {
2825    for (user_id, notification) in notifications {
2826        for connection_id in connection_pool.user_connection_ids(user_id) {
2827            if let Err(error) = peer.send(
2828                connection_id,
2829                proto::AddNotification {
2830                    notification: Some(notification.clone()),
2831                },
2832            ) {
2833                tracing::error!(
2834                    "failed to send notification to {:?} {}",
2835                    connection_id,
2836                    error
2837                );
2838            }
2839        }
2840    }
2841}
2842
2843async fn send_channel_message(
2844    request: proto::SendChannelMessage,
2845    response: Response<proto::SendChannelMessage>,
2846    session: Session,
2847) -> Result<()> {
2848    // Validate the message body.
2849    let body = request.body.trim().to_string();
2850    if body.len() > MAX_MESSAGE_LEN {
2851        return Err(anyhow!("message is too long"))?;
2852    }
2853    if body.is_empty() {
2854        return Err(anyhow!("message can't be blank"))?;
2855    }
2856
2857    // TODO: adjust mentions if body is trimmed
2858
2859    let timestamp = OffsetDateTime::now_utc();
2860    let nonce = request
2861        .nonce
2862        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2863
2864    let channel_id = ChannelId::from_proto(request.channel_id);
2865    let CreatedChannelMessage {
2866        message_id,
2867        participant_connection_ids,
2868        channel_members,
2869        notifications,
2870    } = session
2871        .db()
2872        .await
2873        .create_channel_message(
2874            channel_id,
2875            session.user_id,
2876            &body,
2877            &request.mentions,
2878            timestamp,
2879            nonce.clone().into(),
2880        )
2881        .await?;
2882    let message = proto::ChannelMessage {
2883        sender_id: session.user_id.to_proto(),
2884        id: message_id.to_proto(),
2885        body,
2886        mentions: request.mentions,
2887        timestamp: timestamp.unix_timestamp() as u64,
2888        nonce: Some(nonce),
2889    };
2890    broadcast(
2891        Some(session.connection_id),
2892        participant_connection_ids,
2893        |connection| {
2894            session.peer.send(
2895                connection,
2896                proto::ChannelMessageSent {
2897                    channel_id: channel_id.to_proto(),
2898                    message: Some(message.clone()),
2899                },
2900            )
2901        },
2902    );
2903    response.send(proto::SendChannelMessageResponse {
2904        message: Some(message),
2905    })?;
2906
2907    let pool = &*session.connection_pool().await;
2908    broadcast(
2909        None,
2910        channel_members
2911            .iter()
2912            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2913        |peer_id| {
2914            session.peer.send(
2915                peer_id.into(),
2916                proto::UpdateChannels {
2917                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2918                        channel_id: channel_id.to_proto(),
2919                        message_id: message_id.to_proto(),
2920                    }],
2921                    ..Default::default()
2922                },
2923            )
2924        },
2925    );
2926    send_notifications(pool, &session.peer, notifications);
2927
2928    Ok(())
2929}
2930
2931async fn remove_channel_message(
2932    request: proto::RemoveChannelMessage,
2933    response: Response<proto::RemoveChannelMessage>,
2934    session: Session,
2935) -> Result<()> {
2936    let channel_id = ChannelId::from_proto(request.channel_id);
2937    let message_id = MessageId::from_proto(request.message_id);
2938    let connection_ids = session
2939        .db()
2940        .await
2941        .remove_channel_message(channel_id, message_id, session.user_id)
2942        .await?;
2943    broadcast(Some(session.connection_id), connection_ids, |connection| {
2944        session.peer.send(connection, request.clone())
2945    });
2946    response.send(proto::Ack {})?;
2947    Ok(())
2948}
2949
2950async fn acknowledge_channel_message(
2951    request: proto::AckChannelMessage,
2952    session: Session,
2953) -> Result<()> {
2954    let channel_id = ChannelId::from_proto(request.channel_id);
2955    let message_id = MessageId::from_proto(request.message_id);
2956    let notifications = session
2957        .db()
2958        .await
2959        .observe_channel_message(channel_id, session.user_id, message_id)
2960        .await?;
2961    send_notifications(
2962        &*session.connection_pool().await,
2963        &session.peer,
2964        notifications,
2965    );
2966    Ok(())
2967}
2968
2969async fn acknowledge_buffer_version(
2970    request: proto::AckBufferOperation,
2971    session: Session,
2972) -> Result<()> {
2973    let buffer_id = BufferId::from_proto(request.buffer_id);
2974    session
2975        .db()
2976        .await
2977        .observe_buffer_version(
2978            buffer_id,
2979            session.user_id,
2980            request.epoch as i32,
2981            &request.version,
2982        )
2983        .await?;
2984    Ok(())
2985}
2986
2987async fn join_channel_chat(
2988    request: proto::JoinChannelChat,
2989    response: Response<proto::JoinChannelChat>,
2990    session: Session,
2991) -> Result<()> {
2992    let channel_id = ChannelId::from_proto(request.channel_id);
2993
2994    let db = session.db().await;
2995    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2996        .await?;
2997    let messages = db
2998        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2999        .await?;
3000    response.send(proto::JoinChannelChatResponse {
3001        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3002        messages,
3003    })?;
3004    Ok(())
3005}
3006
3007async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3008    let channel_id = ChannelId::from_proto(request.channel_id);
3009    session
3010        .db()
3011        .await
3012        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3013        .await?;
3014    Ok(())
3015}
3016
3017async fn get_channel_messages(
3018    request: proto::GetChannelMessages,
3019    response: Response<proto::GetChannelMessages>,
3020    session: Session,
3021) -> Result<()> {
3022    let channel_id = ChannelId::from_proto(request.channel_id);
3023    let messages = session
3024        .db()
3025        .await
3026        .get_channel_messages(
3027            channel_id,
3028            session.user_id,
3029            MESSAGE_COUNT_PER_PAGE,
3030            Some(MessageId::from_proto(request.before_message_id)),
3031        )
3032        .await?;
3033    response.send(proto::GetChannelMessagesResponse {
3034        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3035        messages,
3036    })?;
3037    Ok(())
3038}
3039
3040async fn get_channel_messages_by_id(
3041    request: proto::GetChannelMessagesById,
3042    response: Response<proto::GetChannelMessagesById>,
3043    session: Session,
3044) -> Result<()> {
3045    let message_ids = request
3046        .message_ids
3047        .iter()
3048        .map(|id| MessageId::from_proto(*id))
3049        .collect::<Vec<_>>();
3050    let messages = session
3051        .db()
3052        .await
3053        .get_channel_messages_by_id(session.user_id, &message_ids)
3054        .await?;
3055    response.send(proto::GetChannelMessagesResponse {
3056        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3057        messages,
3058    })?;
3059    Ok(())
3060}
3061
3062async fn get_notifications(
3063    request: proto::GetNotifications,
3064    response: Response<proto::GetNotifications>,
3065    session: Session,
3066) -> Result<()> {
3067    let notifications = session
3068        .db()
3069        .await
3070        .get_notifications(
3071            session.user_id,
3072            NOTIFICATION_COUNT_PER_PAGE,
3073            request
3074                .before_id
3075                .map(|id| db::NotificationId::from_proto(id)),
3076        )
3077        .await?;
3078    response.send(proto::GetNotificationsResponse {
3079        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3080        notifications,
3081    })?;
3082    Ok(())
3083}
3084
3085async fn mark_notification_as_read(
3086    request: proto::MarkNotificationRead,
3087    response: Response<proto::MarkNotificationRead>,
3088    session: Session,
3089) -> Result<()> {
3090    let database = &session.db().await;
3091    let notifications = database
3092        .mark_notification_as_read_by_id(
3093            session.user_id,
3094            NotificationId::from_proto(request.notification_id),
3095        )
3096        .await?;
3097    send_notifications(
3098        &*session.connection_pool().await,
3099        &session.peer,
3100        notifications,
3101    );
3102    response.send(proto::Ack {})?;
3103    Ok(())
3104}
3105
3106async fn get_private_user_info(
3107    _request: proto::GetPrivateUserInfo,
3108    response: Response<proto::GetPrivateUserInfo>,
3109    session: Session,
3110) -> Result<()> {
3111    let db = session.db().await;
3112
3113    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3114    let user = db
3115        .get_user_by_id(session.user_id)
3116        .await?
3117        .ok_or_else(|| anyhow!("user not found"))?;
3118    let flags = db.get_user_flags(session.user_id).await?;
3119
3120    response.send(proto::GetPrivateUserInfoResponse {
3121        metrics_id,
3122        staff: user.admin,
3123        flags,
3124    })?;
3125    Ok(())
3126}
3127
3128fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3129    match message {
3130        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3131        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3132        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3133        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3134        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3135            code: frame.code.into(),
3136            reason: frame.reason,
3137        })),
3138    }
3139}
3140
3141fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3142    match message {
3143        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3144        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3145        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3146        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3147        AxumMessage::Close(frame) => {
3148            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3149                code: frame.code.into(),
3150                reason: frame.reason,
3151            }))
3152        }
3153    }
3154}
3155
3156fn notify_membership_updated(
3157    connection_pool: &ConnectionPool,
3158    result: MembershipUpdated,
3159    user_id: UserId,
3160    peer: &Peer,
3161) {
3162    let mut update = build_channels_update(result.new_channels, vec![]);
3163    update.delete_channels = result
3164        .removed_channels
3165        .into_iter()
3166        .map(|id| id.to_proto())
3167        .collect();
3168    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3169
3170    for connection_id in connection_pool.user_connection_ids(user_id) {
3171        peer.send(connection_id, update.clone()).trace_err();
3172    }
3173}
3174
3175fn build_channels_update(
3176    channels: ChannelsForUser,
3177    channel_invites: Vec<db::Channel>,
3178) -> proto::UpdateChannels {
3179    let mut update = proto::UpdateChannels::default();
3180
3181    for channel in channels.channels {
3182        update.channels.push(channel.to_proto());
3183    }
3184
3185    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3186    update.unseen_channel_messages = channels.channel_messages;
3187
3188    for (channel_id, participants) in channels.channel_participants {
3189        update
3190            .channel_participants
3191            .push(proto::ChannelParticipants {
3192                channel_id: channel_id.to_proto(),
3193                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3194            });
3195    }
3196
3197    for channel in channel_invites {
3198        update.channel_invitations.push(channel.to_proto());
3199    }
3200
3201    update
3202}
3203
3204fn build_initial_contacts_update(
3205    contacts: Vec<db::Contact>,
3206    pool: &ConnectionPool,
3207) -> proto::UpdateContacts {
3208    let mut update = proto::UpdateContacts::default();
3209
3210    for contact in contacts {
3211        match contact {
3212            db::Contact::Accepted { user_id, busy } => {
3213                update.contacts.push(contact_for_user(user_id, busy, &pool));
3214            }
3215            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3216            db::Contact::Incoming { user_id } => {
3217                update
3218                    .incoming_requests
3219                    .push(proto::IncomingContactRequest {
3220                        requester_id: user_id.to_proto(),
3221                    })
3222            }
3223        }
3224    }
3225
3226    update
3227}
3228
3229fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3230    proto::Contact {
3231        user_id: user_id.to_proto(),
3232        online: pool.is_user_online(user_id),
3233        busy,
3234    }
3235}
3236
3237fn room_updated(room: &proto::Room, peer: &Peer) {
3238    broadcast(
3239        None,
3240        room.participants
3241            .iter()
3242            .filter_map(|participant| Some(participant.peer_id?.into())),
3243        |peer_id| {
3244            peer.send(
3245                peer_id.into(),
3246                proto::RoomUpdated {
3247                    room: Some(room.clone()),
3248                },
3249            )
3250        },
3251    );
3252}
3253
3254fn channel_updated(
3255    channel_id: ChannelId,
3256    room: &proto::Room,
3257    channel_members: &[UserId],
3258    peer: &Peer,
3259    pool: &ConnectionPool,
3260) {
3261    let participants = room
3262        .participants
3263        .iter()
3264        .map(|p| p.user_id)
3265        .collect::<Vec<_>>();
3266
3267    broadcast(
3268        None,
3269        channel_members
3270            .iter()
3271            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3272        |peer_id| {
3273            peer.send(
3274                peer_id.into(),
3275                proto::UpdateChannels {
3276                    channel_participants: vec![proto::ChannelParticipants {
3277                        channel_id: channel_id.to_proto(),
3278                        participant_user_ids: participants.clone(),
3279                    }],
3280                    ..Default::default()
3281                },
3282            )
3283        },
3284    );
3285}
3286
3287async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3288    let db = session.db().await;
3289
3290    let contacts = db.get_contacts(user_id).await?;
3291    let busy = db.is_user_busy(user_id).await?;
3292
3293    let pool = session.connection_pool().await;
3294    let updated_contact = contact_for_user(user_id, busy, &pool);
3295    for contact in contacts {
3296        if let db::Contact::Accepted {
3297            user_id: contact_user_id,
3298            ..
3299        } = contact
3300        {
3301            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3302                session
3303                    .peer
3304                    .send(
3305                        contact_conn_id,
3306                        proto::UpdateContacts {
3307                            contacts: vec![updated_contact.clone()],
3308                            remove_contacts: Default::default(),
3309                            incoming_requests: Default::default(),
3310                            remove_incoming_requests: Default::default(),
3311                            outgoing_requests: Default::default(),
3312                            remove_outgoing_requests: Default::default(),
3313                        },
3314                    )
3315                    .trace_err();
3316            }
3317        }
3318    }
3319    Ok(())
3320}
3321
3322async fn leave_room_for_session(session: &Session) -> Result<()> {
3323    let mut contacts_to_update = HashSet::default();
3324
3325    let room_id;
3326    let canceled_calls_to_user_ids;
3327    let live_kit_room;
3328    let delete_live_kit_room;
3329    let room;
3330    let channel_members;
3331    let channel_id;
3332
3333    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3334        contacts_to_update.insert(session.user_id);
3335
3336        for project in left_room.left_projects.values() {
3337            project_left(project, session);
3338        }
3339
3340        room_id = RoomId::from_proto(left_room.room.id);
3341        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3342        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3343        delete_live_kit_room = left_room.deleted;
3344        room = mem::take(&mut left_room.room);
3345        channel_members = mem::take(&mut left_room.channel_members);
3346        channel_id = left_room.channel_id;
3347
3348        room_updated(&room, &session.peer);
3349    } else {
3350        return Ok(());
3351    }
3352
3353    if let Some(channel_id) = channel_id {
3354        channel_updated(
3355            channel_id,
3356            &room,
3357            &channel_members,
3358            &session.peer,
3359            &*session.connection_pool().await,
3360        );
3361    }
3362
3363    {
3364        let pool = session.connection_pool().await;
3365        for canceled_user_id in canceled_calls_to_user_ids {
3366            for connection_id in pool.user_connection_ids(canceled_user_id) {
3367                session
3368                    .peer
3369                    .send(
3370                        connection_id,
3371                        proto::CallCanceled {
3372                            room_id: room_id.to_proto(),
3373                        },
3374                    )
3375                    .trace_err();
3376            }
3377            contacts_to_update.insert(canceled_user_id);
3378        }
3379    }
3380
3381    for contact_user_id in contacts_to_update {
3382        update_user_contacts(contact_user_id, &session).await?;
3383    }
3384
3385    if let Some(live_kit) = session.live_kit_client.as_ref() {
3386        live_kit
3387            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3388            .await
3389            .trace_err();
3390
3391        if delete_live_kit_room {
3392            live_kit.delete_room(live_kit_room).await.trace_err();
3393        }
3394    }
3395
3396    Ok(())
3397}
3398
3399async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3400    let left_channel_buffers = session
3401        .db()
3402        .await
3403        .leave_channel_buffers(session.connection_id)
3404        .await?;
3405
3406    for left_buffer in left_channel_buffers {
3407        channel_buffer_updated(
3408            session.connection_id,
3409            left_buffer.connections,
3410            &proto::UpdateChannelBufferCollaborators {
3411                channel_id: left_buffer.channel_id.to_proto(),
3412                collaborators: left_buffer.collaborators,
3413            },
3414            &session.peer,
3415        );
3416    }
3417
3418    Ok(())
3419}
3420
3421fn project_left(project: &db::LeftProject, session: &Session) {
3422    for connection_id in &project.connection_ids {
3423        if project.host_user_id == session.user_id {
3424            session
3425                .peer
3426                .send(
3427                    *connection_id,
3428                    proto::UnshareProject {
3429                        project_id: project.id.to_proto(),
3430                    },
3431                )
3432                .trace_err();
3433        } else {
3434            session
3435                .peer
3436                .send(
3437                    *connection_id,
3438                    proto::RemoveProjectCollaborator {
3439                        project_id: project.id.to_proto(),
3440                        peer_id: Some(session.connection_id.into()),
3441                    },
3442                )
3443                .trace_err();
3444        }
3445    }
3446}
3447
3448pub trait ResultExt {
3449    type Ok;
3450
3451    fn trace_err(self) -> Option<Self::Ok>;
3452}
3453
3454impl<T, E> ResultExt for Result<T, E>
3455where
3456    E: std::fmt::Debug,
3457{
3458    type Ok = T;
3459
3460    fn trace_err(self) -> Option<T> {
3461        match self {
3462            Ok(value) => Some(value),
3463            Err(error) => {
3464                tracing::error!("{:?}", error);
3465                None
3466            }
3467        }
3468    }
3469}