rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  72
  73const MESSAGE_COUNT_PER_PAGE: usize = 100;
  74const MAX_MESSAGE_LEN: usize = 1024;
  75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  76
  77lazy_static! {
  78    static ref METRIC_CONNECTIONS: IntGauge =
  79        register_int_gauge!("connections", "number of connections").unwrap();
  80    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  81        "shared_projects",
  82        "number of open projects with one or more guests"
  83    )
  84    .unwrap();
  85}
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct Session {
 106    zed_environment: Arc<str>,
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(set_room_participant_role)
 206            .add_request_handler(call)
 207            .add_request_handler(cancel_call)
 208            .add_message_handler(decline_call)
 209            .add_request_handler(update_participant_location)
 210            .add_request_handler(share_project)
 211            .add_message_handler(unshare_project)
 212            .add_request_handler(join_project)
 213            .add_message_handler(leave_project)
 214            .add_request_handler(update_project)
 215            .add_request_handler(update_worktree)
 216            .add_message_handler(start_language_server)
 217            .add_message_handler(update_language_server)
 218            .add_message_handler(update_diagnostic_summary)
 219            .add_message_handler(update_worktree_settings)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 230            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 231            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 232            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 233            .add_request_handler(
 234                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 235            )
 236            .add_request_handler(
 237                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 238            )
 239            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 240            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 243            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 251            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 252            .add_message_handler(create_buffer_for_peer)
 253            .add_request_handler(update_buffer)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 258            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 259            .add_request_handler(get_users)
 260            .add_request_handler(fuzzy_search_users)
 261            .add_request_handler(request_contact)
 262            .add_request_handler(remove_contact)
 263            .add_request_handler(respond_to_contact_request)
 264            .add_request_handler(create_channel)
 265            .add_request_handler(delete_channel)
 266            .add_request_handler(invite_channel_member)
 267            .add_request_handler(remove_channel_member)
 268            .add_request_handler(set_channel_member_role)
 269            .add_request_handler(set_channel_visibility)
 270            .add_request_handler(rename_channel)
 271            .add_request_handler(join_channel_buffer)
 272            .add_request_handler(leave_channel_buffer)
 273            .add_message_handler(update_channel_buffer)
 274            .add_request_handler(rejoin_channel_buffers)
 275            .add_request_handler(get_channel_members)
 276            .add_request_handler(respond_to_channel_invite)
 277            .add_request_handler(join_channel)
 278            .add_request_handler(join_channel_chat)
 279            .add_message_handler(leave_channel_chat)
 280            .add_request_handler(send_channel_message)
 281            .add_request_handler(remove_channel_message)
 282            .add_request_handler(get_channel_messages)
 283            .add_request_handler(get_channel_messages_by_id)
 284            .add_request_handler(get_notifications)
 285            .add_request_handler(mark_notification_as_read)
 286            .add_request_handler(move_channel)
 287            .add_request_handler(follow)
 288            .add_message_handler(unfollow)
 289            .add_message_handler(update_followers)
 290            .add_request_handler(get_private_user_info)
 291            .add_message_handler(acknowledge_channel_message)
 292            .add_message_handler(acknowledge_buffer_version);
 293
 294        Arc::new(server)
 295    }
 296
 297    pub async fn start(&self) -> Result<()> {
 298        let server_id = *self.id.lock();
 299        let app_state = self.app_state.clone();
 300        let peer = self.peer.clone();
 301        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 302        let pool = self.connection_pool.clone();
 303        let live_kit_client = self.app_state.live_kit_client.clone();
 304
 305        let span = info_span!("start server");
 306        self.executor.spawn_detached(
 307            async move {
 308                tracing::info!("waiting for cleanup timeout");
 309                timeout.await;
 310                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 311                if let Some((room_ids, channel_ids)) = app_state
 312                    .db
 313                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 314                    .await
 315                    .trace_err()
 316                {
 317                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 318                    tracing::info!(
 319                        stale_channel_buffer_count = channel_ids.len(),
 320                        "retrieved stale channel buffers"
 321                    );
 322
 323                    for channel_id in channel_ids {
 324                        if let Some(refreshed_channel_buffer) = app_state
 325                            .db
 326                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 327                            .await
 328                            .trace_err()
 329                        {
 330                            for connection_id in refreshed_channel_buffer.connection_ids {
 331                                peer.send(
 332                                    connection_id,
 333                                    proto::UpdateChannelBufferCollaborators {
 334                                        channel_id: channel_id.to_proto(),
 335                                        collaborators: refreshed_channel_buffer
 336                                            .collaborators
 337                                            .clone(),
 338                                    },
 339                                )
 340                                .trace_err();
 341                            }
 342                        }
 343                    }
 344
 345                    for room_id in room_ids {
 346                        let mut contacts_to_update = HashSet::default();
 347                        let mut canceled_calls_to_user_ids = Vec::new();
 348                        let mut live_kit_room = String::new();
 349                        let mut delete_live_kit_room = false;
 350
 351                        if let Some(mut refreshed_room) = app_state
 352                            .db
 353                            .clear_stale_room_participants(room_id, server_id)
 354                            .await
 355                            .trace_err()
 356                        {
 357                            tracing::info!(
 358                                room_id = room_id.0,
 359                                new_participant_count = refreshed_room.room.participants.len(),
 360                                "refreshed room"
 361                            );
 362                            room_updated(&refreshed_room.room, &peer);
 363                            if let Some(channel_id) = refreshed_room.channel_id {
 364                                channel_updated(
 365                                    channel_id,
 366                                    &refreshed_room.room,
 367                                    &refreshed_room.channel_members,
 368                                    &peer,
 369                                    &*pool.lock(),
 370                                );
 371                            }
 372                            contacts_to_update
 373                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 374                            contacts_to_update
 375                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 376                            canceled_calls_to_user_ids =
 377                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 378                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 379                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 380                        }
 381
 382                        {
 383                            let pool = pool.lock();
 384                            for canceled_user_id in canceled_calls_to_user_ids {
 385                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 386                                    peer.send(
 387                                        connection_id,
 388                                        proto::CallCanceled {
 389                                            room_id: room_id.to_proto(),
 390                                        },
 391                                    )
 392                                    .trace_err();
 393                                }
 394                            }
 395                        }
 396
 397                        for user_id in contacts_to_update {
 398                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 399                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 400                            if let Some((busy, contacts)) = busy.zip(contacts) {
 401                                let pool = pool.lock();
 402                                let updated_contact = contact_for_user(user_id, busy, &pool);
 403                                for contact in contacts {
 404                                    if let db::Contact::Accepted {
 405                                        user_id: contact_user_id,
 406                                        ..
 407                                    } = contact
 408                                    {
 409                                        for contact_conn_id in
 410                                            pool.user_connection_ids(contact_user_id)
 411                                        {
 412                                            peer.send(
 413                                                contact_conn_id,
 414                                                proto::UpdateContacts {
 415                                                    contacts: vec![updated_contact.clone()],
 416                                                    remove_contacts: Default::default(),
 417                                                    incoming_requests: Default::default(),
 418                                                    remove_incoming_requests: Default::default(),
 419                                                    outgoing_requests: Default::default(),
 420                                                    remove_outgoing_requests: Default::default(),
 421                                                },
 422                                            )
 423                                            .trace_err();
 424                                        }
 425                                    }
 426                                }
 427                            }
 428                        }
 429
 430                        if let Some(live_kit) = live_kit_client.as_ref() {
 431                            if delete_live_kit_room {
 432                                live_kit.delete_room(live_kit_room).await.trace_err();
 433                            }
 434                        }
 435                    }
 436                }
 437
 438                app_state
 439                    .db
 440                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 441                    .await
 442                    .trace_err();
 443            }
 444            .instrument(span),
 445        );
 446        Ok(())
 447    }
 448
 449    pub fn teardown(&self) {
 450        self.peer.teardown();
 451        self.connection_pool.lock().reset();
 452        let _ = self.teardown.send(());
 453    }
 454
 455    #[cfg(test)]
 456    pub fn reset(&self, id: ServerId) {
 457        self.teardown();
 458        *self.id.lock() = id;
 459        self.peer.reset(id.0 as u32);
 460    }
 461
 462    #[cfg(test)]
 463    pub fn id(&self) -> ServerId {
 464        *self.id.lock()
 465    }
 466
 467    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 468    where
 469        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 470        Fut: 'static + Send + Future<Output = Result<()>>,
 471        M: EnvelopedMessage,
 472    {
 473        let prev_handler = self.handlers.insert(
 474            TypeId::of::<M>(),
 475            Box::new(move |envelope, session| {
 476                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 477                let span = info_span!(
 478                    "handle message",
 479                    payload_type = envelope.payload_type_name()
 480                );
 481                span.in_scope(|| {
 482                    tracing::info!(
 483                        payload_type = envelope.payload_type_name(),
 484                        "message received"
 485                    );
 486                });
 487                let start_time = Instant::now();
 488                let future = (handler)(*envelope, session);
 489                async move {
 490                    let result = future.await;
 491                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 492                    match result {
 493                        Err(error) => {
 494                            tracing::error!(%error, ?duration_ms, "error handling message")
 495                        }
 496                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 497                    }
 498                }
 499                .instrument(span)
 500                .boxed()
 501            }),
 502        );
 503        if prev_handler.is_some() {
 504            panic!("registered a handler for the same message twice");
 505        }
 506        self
 507    }
 508
 509    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 512        Fut: 'static + Send + Future<Output = Result<()>>,
 513        M: EnvelopedMessage,
 514    {
 515        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 516        self
 517    }
 518
 519    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 520    where
 521        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 522        Fut: Send + Future<Output = Result<()>>,
 523        M: RequestMessage,
 524    {
 525        let handler = Arc::new(handler);
 526        self.add_handler(move |envelope, session| {
 527            let receipt = envelope.receipt();
 528            let handler = handler.clone();
 529            async move {
 530                let peer = session.peer.clone();
 531                let responded = Arc::new(AtomicBool::default());
 532                let response = Response {
 533                    peer: peer.clone(),
 534                    responded: responded.clone(),
 535                    receipt,
 536                };
 537                match (handler)(envelope.payload, response, session).await {
 538                    Ok(()) => {
 539                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 540                            Ok(())
 541                        } else {
 542                            Err(anyhow!("handler did not send a response"))?
 543                        }
 544                    }
 545                    Err(error) => {
 546                        peer.respond_with_error(
 547                            receipt,
 548                            proto::Error {
 549                                message: error.to_string(),
 550                            },
 551                        )?;
 552                        Err(error)
 553                    }
 554                }
 555            }
 556        })
 557    }
 558
 559    pub fn handle_connection(
 560        self: &Arc<Self>,
 561        connection: Connection,
 562        address: String,
 563        user: User,
 564        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 565        executor: Executor,
 566    ) -> impl Future<Output = Result<()>> {
 567        let this = self.clone();
 568        let user_id = user.id;
 569        let login = user.github_login;
 570        let span = info_span!("handle connection", %user_id, %login, %address);
 571        let mut teardown = self.teardown.subscribe();
 572        async move {
 573            let (connection_id, handle_io, mut incoming_rx) = this
 574                .peer
 575                .add_connection(connection, {
 576                    let executor = executor.clone();
 577                    move |duration| executor.sleep(duration)
 578                });
 579
 580            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 581            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 582            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 583
 584            if let Some(send_connection_id) = send_connection_id.take() {
 585                let _ = send_connection_id.send(connection_id);
 586            }
 587
 588            if !user.connected_once {
 589                this.peer.send(connection_id, proto::ShowContacts {})?;
 590                this.app_state.db.set_user_connected_once(user_id, true).await?;
 591            }
 592
 593            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 594                this.app_state.db.get_contacts(user_id),
 595                this.app_state.db.get_channels_for_user(user_id),
 596                this.app_state.db.get_channel_invites_for_user(user_id),
 597            ).await?;
 598
 599            {
 600                let mut pool = this.connection_pool.lock();
 601                pool.add_connection(connection_id, user_id, user.admin);
 602                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 603                this.peer.send(connection_id, build_channels_update(
 604                    channels_for_user,
 605                    channel_invites
 606                ))?;
 607            }
 608
 609            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 610                this.peer.send(connection_id, incoming_call)?;
 611            }
 612
 613            let session = Session {
 614                user_id,
 615                connection_id,
 616                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 617                zed_environment: this.app_state.config.zed_environment.clone(),
 618                peer: this.peer.clone(),
 619                connection_pool: this.connection_pool.clone(),
 620                live_kit_client: this.app_state.live_kit_client.clone(),
 621                _executor: executor.clone()
 622            };
 623            update_user_contacts(user_id, &session).await?;
 624
 625            let handle_io = handle_io.fuse();
 626            futures::pin_mut!(handle_io);
 627
 628            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 629            // This prevents deadlocks when e.g., client A performs a request to client B and
 630            // client B performs a request to client A. If both clients stop processing further
 631            // messages until their respective request completes, they won't have a chance to
 632            // respond to the other client's request and cause a deadlock.
 633            //
 634            // This arrangement ensures we will attempt to process earlier messages first, but fall
 635            // back to processing messages arrived later in the spirit of making progress.
 636            let mut foreground_message_handlers = FuturesUnordered::new();
 637            let concurrent_handlers = Arc::new(Semaphore::new(256));
 638            loop {
 639                let next_message = async {
 640                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 641                    let message = incoming_rx.next().await;
 642                    (permit, message)
 643                }.fuse();
 644                futures::pin_mut!(next_message);
 645                futures::select_biased! {
 646                    _ = teardown.changed().fuse() => return Ok(()),
 647                    result = handle_io => {
 648                        if let Err(error) = result {
 649                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 650                        }
 651                        break;
 652                    }
 653                    _ = foreground_message_handlers.next() => {}
 654                    next_message = next_message => {
 655                        let (permit, message) = next_message;
 656                        if let Some(message) = message {
 657                            let type_name = message.payload_type_name();
 658                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 659                            let span_enter = span.enter();
 660                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 661                                let is_background = message.is_background();
 662                                let handle_message = (handler)(message, session.clone());
 663                                drop(span_enter);
 664
 665                                let handle_message = async move {
 666                                    handle_message.await;
 667                                    drop(permit);
 668                                }.instrument(span);
 669                                if is_background {
 670                                    executor.spawn_detached(handle_message);
 671                                } else {
 672                                    foreground_message_handlers.push(handle_message);
 673                                }
 674                            } else {
 675                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 676                            }
 677                        } else {
 678                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 679                            break;
 680                        }
 681                    }
 682                }
 683            }
 684
 685            drop(foreground_message_handlers);
 686            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 687            if let Err(error) = connection_lost(session, teardown, executor).await {
 688                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 689            }
 690
 691            Ok(())
 692        }.instrument(span)
 693    }
 694
 695    pub async fn invite_code_redeemed(
 696        self: &Arc<Self>,
 697        inviter_id: UserId,
 698        invitee_id: UserId,
 699    ) -> Result<()> {
 700        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 701            if let Some(code) = &user.invite_code {
 702                let pool = self.connection_pool.lock();
 703                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 704                for connection_id in pool.user_connection_ids(inviter_id) {
 705                    self.peer.send(
 706                        connection_id,
 707                        proto::UpdateContacts {
 708                            contacts: vec![invitee_contact.clone()],
 709                            ..Default::default()
 710                        },
 711                    )?;
 712                    self.peer.send(
 713                        connection_id,
 714                        proto::UpdateInviteInfo {
 715                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 716                            count: user.invite_count as u32,
 717                        },
 718                    )?;
 719                }
 720            }
 721        }
 722        Ok(())
 723    }
 724
 725    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 726        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 727            if let Some(invite_code) = &user.invite_code {
 728                let pool = self.connection_pool.lock();
 729                for connection_id in pool.user_connection_ids(user_id) {
 730                    self.peer.send(
 731                        connection_id,
 732                        proto::UpdateInviteInfo {
 733                            url: format!(
 734                                "{}{}",
 735                                self.app_state.config.invite_link_prefix, invite_code
 736                            ),
 737                            count: user.invite_count as u32,
 738                        },
 739                    )?;
 740                }
 741            }
 742        }
 743        Ok(())
 744    }
 745
 746    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 747        ServerSnapshot {
 748            connection_pool: ConnectionPoolGuard {
 749                guard: self.connection_pool.lock(),
 750                _not_send: PhantomData,
 751            },
 752            peer: &self.peer,
 753        }
 754    }
 755}
 756
 757impl<'a> Deref for ConnectionPoolGuard<'a> {
 758    type Target = ConnectionPool;
 759
 760    fn deref(&self) -> &Self::Target {
 761        &*self.guard
 762    }
 763}
 764
 765impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 766    fn deref_mut(&mut self) -> &mut Self::Target {
 767        &mut *self.guard
 768    }
 769}
 770
 771impl<'a> Drop for ConnectionPoolGuard<'a> {
 772    fn drop(&mut self) {
 773        #[cfg(test)]
 774        self.check_invariants();
 775    }
 776}
 777
 778fn broadcast<F>(
 779    sender_id: Option<ConnectionId>,
 780    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 781    mut f: F,
 782) where
 783    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 784{
 785    for receiver_id in receiver_ids {
 786        if Some(receiver_id) != sender_id {
 787            if let Err(error) = f(receiver_id) {
 788                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 789            }
 790        }
 791    }
 792}
 793
 794lazy_static! {
 795    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 796}
 797
 798pub struct ProtocolVersion(u32);
 799
 800impl Header for ProtocolVersion {
 801    fn name() -> &'static HeaderName {
 802        &ZED_PROTOCOL_VERSION
 803    }
 804
 805    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 806    where
 807        Self: Sized,
 808        I: Iterator<Item = &'i axum::http::HeaderValue>,
 809    {
 810        let version = values
 811            .next()
 812            .ok_or_else(axum::headers::Error::invalid)?
 813            .to_str()
 814            .map_err(|_| axum::headers::Error::invalid())?
 815            .parse()
 816            .map_err(|_| axum::headers::Error::invalid())?;
 817        Ok(Self(version))
 818    }
 819
 820    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 821        values.extend([self.0.to_string().parse().unwrap()]);
 822    }
 823}
 824
 825pub fn routes(server: Arc<Server>) -> Router<Body> {
 826    Router::new()
 827        .route("/rpc", get(handle_websocket_request))
 828        .layer(
 829            ServiceBuilder::new()
 830                .layer(Extension(server.app_state.clone()))
 831                .layer(middleware::from_fn(auth::validate_header)),
 832        )
 833        .route("/metrics", get(handle_metrics))
 834        .layer(Extension(server))
 835}
 836
 837pub async fn handle_websocket_request(
 838    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 839    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 840    Extension(server): Extension<Arc<Server>>,
 841    Extension(user): Extension<User>,
 842    ws: WebSocketUpgrade,
 843) -> axum::response::Response {
 844    if protocol_version != rpc::PROTOCOL_VERSION {
 845        return (
 846            StatusCode::UPGRADE_REQUIRED,
 847            "client must be upgraded".to_string(),
 848        )
 849            .into_response();
 850    }
 851    let socket_address = socket_address.to_string();
 852    ws.on_upgrade(move |socket| {
 853        use util::ResultExt;
 854        let socket = socket
 855            .map_ok(to_tungstenite_message)
 856            .err_into()
 857            .with(|message| async move { Ok(to_axum_message(message)) });
 858        let connection = Connection::new(Box::pin(socket));
 859        async move {
 860            server
 861                .handle_connection(connection, socket_address, user, None, Executor::Production)
 862                .await
 863                .log_err();
 864        }
 865    })
 866}
 867
 868pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 869    let connections = server
 870        .connection_pool
 871        .lock()
 872        .connections()
 873        .filter(|connection| !connection.admin)
 874        .count();
 875
 876    METRIC_CONNECTIONS.set(connections as _);
 877
 878    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 879    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 880
 881    let encoder = prometheus::TextEncoder::new();
 882    let metric_families = prometheus::gather();
 883    let encoded_metrics = encoder
 884        .encode_to_string(&metric_families)
 885        .map_err(|err| anyhow!("{}", err))?;
 886    Ok(encoded_metrics)
 887}
 888
 889#[instrument(err, skip(executor))]
 890async fn connection_lost(
 891    session: Session,
 892    mut teardown: watch::Receiver<()>,
 893    executor: Executor,
 894) -> Result<()> {
 895    session.peer.disconnect(session.connection_id);
 896    session
 897        .connection_pool()
 898        .await
 899        .remove_connection(session.connection_id)?;
 900
 901    session
 902        .db()
 903        .await
 904        .connection_lost(session.connection_id)
 905        .await
 906        .trace_err();
 907
 908    futures::select_biased! {
 909        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 910            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 911            leave_room_for_session(&session).await.trace_err();
 912            leave_channel_buffers_for_session(&session)
 913                .await
 914                .trace_err();
 915
 916            if !session
 917                .connection_pool()
 918                .await
 919                .is_user_online(session.user_id)
 920            {
 921                let db = session.db().await;
 922                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 923                    room_updated(&room, &session.peer);
 924                }
 925            }
 926
 927            update_user_contacts(session.user_id, &session).await?;
 928        }
 929        _ = teardown.changed().fuse() => {}
 930    }
 931
 932    Ok(())
 933}
 934
 935/// Acknowledges a ping from a client, used to keep the connection alive.
 936async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 937    response.send(proto::Ack {})?;
 938    Ok(())
 939}
 940
 941/// Create a new room for calling (outside of channels)
 942async fn create_room(
 943    _request: proto::CreateRoom,
 944    response: Response<proto::CreateRoom>,
 945    session: Session,
 946) -> Result<()> {
 947    let live_kit_room = nanoid::nanoid!(30);
 948
 949    let live_kit_connection_info = {
 950        let live_kit_room = live_kit_room.clone();
 951        let live_kit = session.live_kit_client.as_ref();
 952
 953        util::async_maybe!({
 954            let live_kit = live_kit?;
 955
 956            let token = live_kit
 957                .room_token(&live_kit_room, &session.user_id.to_string())
 958                .trace_err()?;
 959
 960            Some(proto::LiveKitConnectionInfo {
 961                server_url: live_kit.url().into(),
 962                token,
 963                can_publish: true,
 964            })
 965        })
 966    }
 967    .await;
 968
 969    let room = session
 970        .db()
 971        .await
 972        .create_room(
 973            session.user_id,
 974            session.connection_id,
 975            &live_kit_room,
 976            &session.zed_environment,
 977        )
 978        .await?;
 979
 980    response.send(proto::CreateRoomResponse {
 981        room: Some(room.clone()),
 982        live_kit_connection_info,
 983    })?;
 984
 985    update_user_contacts(session.user_id, &session).await?;
 986    Ok(())
 987}
 988
 989/// Join a room from an invitation. Equivalent to joining a channel if there is one.
 990async fn join_room(
 991    request: proto::JoinRoom,
 992    response: Response<proto::JoinRoom>,
 993    session: Session,
 994) -> Result<()> {
 995    let room_id = RoomId::from_proto(request.id);
 996
 997    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 998
 999    if let Some(channel_id) = channel_id {
1000        return join_channel_internal(channel_id, Box::new(response), session).await;
1001    }
1002
1003    let joined_room = {
1004        let room = session
1005            .db()
1006            .await
1007            .join_room(
1008                room_id,
1009                session.user_id,
1010                session.connection_id,
1011                session.zed_environment.as_ref(),
1012            )
1013            .await?;
1014        room_updated(&room.room, &session.peer);
1015        room.into_inner()
1016    };
1017
1018    for connection_id in session
1019        .connection_pool()
1020        .await
1021        .user_connection_ids(session.user_id)
1022    {
1023        session
1024            .peer
1025            .send(
1026                connection_id,
1027                proto::CallCanceled {
1028                    room_id: room_id.to_proto(),
1029                },
1030            )
1031            .trace_err();
1032    }
1033
1034    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1035        if let Some(token) = live_kit
1036            .room_token(
1037                &joined_room.room.live_kit_room,
1038                &session.user_id.to_string(),
1039            )
1040            .trace_err()
1041        {
1042            Some(proto::LiveKitConnectionInfo {
1043                server_url: live_kit.url().into(),
1044                token,
1045                can_publish: true,
1046            })
1047        } else {
1048            None
1049        }
1050    } else {
1051        None
1052    };
1053
1054    response.send(proto::JoinRoomResponse {
1055        room: Some(joined_room.room),
1056        channel_id: None,
1057        live_kit_connection_info,
1058    })?;
1059
1060    update_user_contacts(session.user_id, &session).await?;
1061    Ok(())
1062}
1063
1064/// Rejoin room is used to reconnect to a room after connection errors.
1065async fn rejoin_room(
1066    request: proto::RejoinRoom,
1067    response: Response<proto::RejoinRoom>,
1068    session: Session,
1069) -> Result<()> {
1070    let room;
1071    let channel_id;
1072    let channel_members;
1073    {
1074        let mut rejoined_room = session
1075            .db()
1076            .await
1077            .rejoin_room(request, session.user_id, session.connection_id)
1078            .await?;
1079
1080        response.send(proto::RejoinRoomResponse {
1081            room: Some(rejoined_room.room.clone()),
1082            reshared_projects: rejoined_room
1083                .reshared_projects
1084                .iter()
1085                .map(|project| proto::ResharedProject {
1086                    id: project.id.to_proto(),
1087                    collaborators: project
1088                        .collaborators
1089                        .iter()
1090                        .map(|collaborator| collaborator.to_proto())
1091                        .collect(),
1092                })
1093                .collect(),
1094            rejoined_projects: rejoined_room
1095                .rejoined_projects
1096                .iter()
1097                .map(|rejoined_project| proto::RejoinedProject {
1098                    id: rejoined_project.id.to_proto(),
1099                    worktrees: rejoined_project
1100                        .worktrees
1101                        .iter()
1102                        .map(|worktree| proto::WorktreeMetadata {
1103                            id: worktree.id,
1104                            root_name: worktree.root_name.clone(),
1105                            visible: worktree.visible,
1106                            abs_path: worktree.abs_path.clone(),
1107                        })
1108                        .collect(),
1109                    collaborators: rejoined_project
1110                        .collaborators
1111                        .iter()
1112                        .map(|collaborator| collaborator.to_proto())
1113                        .collect(),
1114                    language_servers: rejoined_project.language_servers.clone(),
1115                })
1116                .collect(),
1117        })?;
1118        room_updated(&rejoined_room.room, &session.peer);
1119
1120        for project in &rejoined_room.reshared_projects {
1121            for collaborator in &project.collaborators {
1122                session
1123                    .peer
1124                    .send(
1125                        collaborator.connection_id,
1126                        proto::UpdateProjectCollaborator {
1127                            project_id: project.id.to_proto(),
1128                            old_peer_id: Some(project.old_connection_id.into()),
1129                            new_peer_id: Some(session.connection_id.into()),
1130                        },
1131                    )
1132                    .trace_err();
1133            }
1134
1135            broadcast(
1136                Some(session.connection_id),
1137                project
1138                    .collaborators
1139                    .iter()
1140                    .map(|collaborator| collaborator.connection_id),
1141                |connection_id| {
1142                    session.peer.forward_send(
1143                        session.connection_id,
1144                        connection_id,
1145                        proto::UpdateProject {
1146                            project_id: project.id.to_proto(),
1147                            worktrees: project.worktrees.clone(),
1148                        },
1149                    )
1150                },
1151            );
1152        }
1153
1154        for project in &rejoined_room.rejoined_projects {
1155            for collaborator in &project.collaborators {
1156                session
1157                    .peer
1158                    .send(
1159                        collaborator.connection_id,
1160                        proto::UpdateProjectCollaborator {
1161                            project_id: project.id.to_proto(),
1162                            old_peer_id: Some(project.old_connection_id.into()),
1163                            new_peer_id: Some(session.connection_id.into()),
1164                        },
1165                    )
1166                    .trace_err();
1167            }
1168        }
1169
1170        for project in &mut rejoined_room.rejoined_projects {
1171            for worktree in mem::take(&mut project.worktrees) {
1172                #[cfg(any(test, feature = "test-support"))]
1173                const MAX_CHUNK_SIZE: usize = 2;
1174                #[cfg(not(any(test, feature = "test-support")))]
1175                const MAX_CHUNK_SIZE: usize = 256;
1176
1177                // Stream this worktree's entries.
1178                let message = proto::UpdateWorktree {
1179                    project_id: project.id.to_proto(),
1180                    worktree_id: worktree.id,
1181                    abs_path: worktree.abs_path.clone(),
1182                    root_name: worktree.root_name,
1183                    updated_entries: worktree.updated_entries,
1184                    removed_entries: worktree.removed_entries,
1185                    scan_id: worktree.scan_id,
1186                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1187                    updated_repositories: worktree.updated_repositories,
1188                    removed_repositories: worktree.removed_repositories,
1189                };
1190                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1191                    session.peer.send(session.connection_id, update.clone())?;
1192                }
1193
1194                // Stream this worktree's diagnostics.
1195                for summary in worktree.diagnostic_summaries {
1196                    session.peer.send(
1197                        session.connection_id,
1198                        proto::UpdateDiagnosticSummary {
1199                            project_id: project.id.to_proto(),
1200                            worktree_id: worktree.id,
1201                            summary: Some(summary),
1202                        },
1203                    )?;
1204                }
1205
1206                for settings_file in worktree.settings_files {
1207                    session.peer.send(
1208                        session.connection_id,
1209                        proto::UpdateWorktreeSettings {
1210                            project_id: project.id.to_proto(),
1211                            worktree_id: worktree.id,
1212                            path: settings_file.path,
1213                            content: Some(settings_file.content),
1214                        },
1215                    )?;
1216                }
1217            }
1218
1219            for language_server in &project.language_servers {
1220                session.peer.send(
1221                    session.connection_id,
1222                    proto::UpdateLanguageServer {
1223                        project_id: project.id.to_proto(),
1224                        language_server_id: language_server.id,
1225                        variant: Some(
1226                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1227                                proto::LspDiskBasedDiagnosticsUpdated {},
1228                            ),
1229                        ),
1230                    },
1231                )?;
1232            }
1233        }
1234
1235        let rejoined_room = rejoined_room.into_inner();
1236
1237        room = rejoined_room.room;
1238        channel_id = rejoined_room.channel_id;
1239        channel_members = rejoined_room.channel_members;
1240    }
1241
1242    if let Some(channel_id) = channel_id {
1243        channel_updated(
1244            channel_id,
1245            &room,
1246            &channel_members,
1247            &session.peer,
1248            &*session.connection_pool().await,
1249        );
1250    }
1251
1252    update_user_contacts(session.user_id, &session).await?;
1253    Ok(())
1254}
1255
1256/// leave room disonnects from the room.
1257async fn leave_room(
1258    _: proto::LeaveRoom,
1259    response: Response<proto::LeaveRoom>,
1260    session: Session,
1261) -> Result<()> {
1262    leave_room_for_session(&session).await?;
1263    response.send(proto::Ack {})?;
1264    Ok(())
1265}
1266
1267/// Update the permissions of someone else in the room.
1268async fn set_room_participant_role(
1269    request: proto::SetRoomParticipantRole,
1270    response: Response<proto::SetRoomParticipantRole>,
1271    session: Session,
1272) -> Result<()> {
1273    let (live_kit_room, can_publish) = {
1274        let room = session
1275            .db()
1276            .await
1277            .set_room_participant_role(
1278                session.user_id,
1279                RoomId::from_proto(request.room_id),
1280                UserId::from_proto(request.user_id),
1281                ChannelRole::from(request.role()),
1282            )
1283            .await?;
1284
1285        let live_kit_room = room.live_kit_room.clone();
1286        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1287        room_updated(&room, &session.peer);
1288        (live_kit_room, can_publish)
1289    };
1290
1291    if let Some(live_kit) = session.live_kit_client.as_ref() {
1292        live_kit
1293            .update_participant(
1294                live_kit_room.clone(),
1295                request.user_id.to_string(),
1296                live_kit_server::proto::ParticipantPermission {
1297                    can_subscribe: true,
1298                    can_publish,
1299                    can_publish_data: can_publish,
1300                    hidden: false,
1301                    recorder: false,
1302                },
1303            )
1304            .await
1305            .trace_err();
1306    }
1307
1308    response.send(proto::Ack {})?;
1309    Ok(())
1310}
1311
1312/// Call someone else into the current room
1313async fn call(
1314    request: proto::Call,
1315    response: Response<proto::Call>,
1316    session: Session,
1317) -> Result<()> {
1318    let room_id = RoomId::from_proto(request.room_id);
1319    let calling_user_id = session.user_id;
1320    let calling_connection_id = session.connection_id;
1321    let called_user_id = UserId::from_proto(request.called_user_id);
1322    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1323    if !session
1324        .db()
1325        .await
1326        .has_contact(calling_user_id, called_user_id)
1327        .await?
1328    {
1329        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1330    }
1331
1332    let incoming_call = {
1333        let (room, incoming_call) = &mut *session
1334            .db()
1335            .await
1336            .call(
1337                room_id,
1338                calling_user_id,
1339                calling_connection_id,
1340                called_user_id,
1341                initial_project_id,
1342            )
1343            .await?;
1344        room_updated(&room, &session.peer);
1345        mem::take(incoming_call)
1346    };
1347    update_user_contacts(called_user_id, &session).await?;
1348
1349    let mut calls = session
1350        .connection_pool()
1351        .await
1352        .user_connection_ids(called_user_id)
1353        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1354        .collect::<FuturesUnordered<_>>();
1355
1356    while let Some(call_response) = calls.next().await {
1357        match call_response.as_ref() {
1358            Ok(_) => {
1359                response.send(proto::Ack {})?;
1360                return Ok(());
1361            }
1362            Err(_) => {
1363                call_response.trace_err();
1364            }
1365        }
1366    }
1367
1368    {
1369        let room = session
1370            .db()
1371            .await
1372            .call_failed(room_id, called_user_id)
1373            .await?;
1374        room_updated(&room, &session.peer);
1375    }
1376    update_user_contacts(called_user_id, &session).await?;
1377
1378    Err(anyhow!("failed to ring user"))?
1379}
1380
1381/// Cancel an outgoing call.
1382async fn cancel_call(
1383    request: proto::CancelCall,
1384    response: Response<proto::CancelCall>,
1385    session: Session,
1386) -> Result<()> {
1387    let called_user_id = UserId::from_proto(request.called_user_id);
1388    let room_id = RoomId::from_proto(request.room_id);
1389    {
1390        let room = session
1391            .db()
1392            .await
1393            .cancel_call(room_id, session.connection_id, called_user_id)
1394            .await?;
1395        room_updated(&room, &session.peer);
1396    }
1397
1398    for connection_id in session
1399        .connection_pool()
1400        .await
1401        .user_connection_ids(called_user_id)
1402    {
1403        session
1404            .peer
1405            .send(
1406                connection_id,
1407                proto::CallCanceled {
1408                    room_id: room_id.to_proto(),
1409                },
1410            )
1411            .trace_err();
1412    }
1413    response.send(proto::Ack {})?;
1414
1415    update_user_contacts(called_user_id, &session).await?;
1416    Ok(())
1417}
1418
1419/// Decline an incoming call.
1420async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1421    let room_id = RoomId::from_proto(message.room_id);
1422    {
1423        let room = session
1424            .db()
1425            .await
1426            .decline_call(Some(room_id), session.user_id)
1427            .await?
1428            .ok_or_else(|| anyhow!("failed to decline call"))?;
1429        room_updated(&room, &session.peer);
1430    }
1431
1432    for connection_id in session
1433        .connection_pool()
1434        .await
1435        .user_connection_ids(session.user_id)
1436    {
1437        session
1438            .peer
1439            .send(
1440                connection_id,
1441                proto::CallCanceled {
1442                    room_id: room_id.to_proto(),
1443                },
1444            )
1445            .trace_err();
1446    }
1447    update_user_contacts(session.user_id, &session).await?;
1448    Ok(())
1449}
1450
1451/// Update other participants in the room with your current location.
1452async fn update_participant_location(
1453    request: proto::UpdateParticipantLocation,
1454    response: Response<proto::UpdateParticipantLocation>,
1455    session: Session,
1456) -> Result<()> {
1457    let room_id = RoomId::from_proto(request.room_id);
1458    let location = request
1459        .location
1460        .ok_or_else(|| anyhow!("invalid location"))?;
1461
1462    let db = session.db().await;
1463    let room = db
1464        .update_room_participant_location(room_id, session.connection_id, location)
1465        .await?;
1466
1467    room_updated(&room, &session.peer);
1468    response.send(proto::Ack {})?;
1469    Ok(())
1470}
1471
1472/// Share a project into the room.
1473async fn share_project(
1474    request: proto::ShareProject,
1475    response: Response<proto::ShareProject>,
1476    session: Session,
1477) -> Result<()> {
1478    let (project_id, room) = &*session
1479        .db()
1480        .await
1481        .share_project(
1482            RoomId::from_proto(request.room_id),
1483            session.connection_id,
1484            &request.worktrees,
1485        )
1486        .await?;
1487    response.send(proto::ShareProjectResponse {
1488        project_id: project_id.to_proto(),
1489    })?;
1490    room_updated(&room, &session.peer);
1491
1492    Ok(())
1493}
1494
1495/// Unshare a project from the room.
1496async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1497    let project_id = ProjectId::from_proto(message.project_id);
1498
1499    let (room, guest_connection_ids) = &*session
1500        .db()
1501        .await
1502        .unshare_project(project_id, session.connection_id)
1503        .await?;
1504
1505    broadcast(
1506        Some(session.connection_id),
1507        guest_connection_ids.iter().copied(),
1508        |conn_id| session.peer.send(conn_id, message.clone()),
1509    );
1510    room_updated(&room, &session.peer);
1511
1512    Ok(())
1513}
1514
1515/// Join someone elses shared project.
1516async fn join_project(
1517    request: proto::JoinProject,
1518    response: Response<proto::JoinProject>,
1519    session: Session,
1520) -> Result<()> {
1521    let project_id = ProjectId::from_proto(request.project_id);
1522    let guest_user_id = session.user_id;
1523
1524    tracing::info!(%project_id, "join project");
1525
1526    let (project, replica_id) = &mut *session
1527        .db()
1528        .await
1529        .join_project(project_id, session.connection_id)
1530        .await?;
1531
1532    let collaborators = project
1533        .collaborators
1534        .iter()
1535        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1536        .map(|collaborator| collaborator.to_proto())
1537        .collect::<Vec<_>>();
1538
1539    let worktrees = project
1540        .worktrees
1541        .iter()
1542        .map(|(id, worktree)| proto::WorktreeMetadata {
1543            id: *id,
1544            root_name: worktree.root_name.clone(),
1545            visible: worktree.visible,
1546            abs_path: worktree.abs_path.clone(),
1547        })
1548        .collect::<Vec<_>>();
1549
1550    for collaborator in &collaborators {
1551        session
1552            .peer
1553            .send(
1554                collaborator.peer_id.unwrap().into(),
1555                proto::AddProjectCollaborator {
1556                    project_id: project_id.to_proto(),
1557                    collaborator: Some(proto::Collaborator {
1558                        peer_id: Some(session.connection_id.into()),
1559                        replica_id: replica_id.0 as u32,
1560                        user_id: guest_user_id.to_proto(),
1561                    }),
1562                },
1563            )
1564            .trace_err();
1565    }
1566
1567    // First, we send the metadata associated with each worktree.
1568    response.send(proto::JoinProjectResponse {
1569        worktrees: worktrees.clone(),
1570        replica_id: replica_id.0 as u32,
1571        collaborators: collaborators.clone(),
1572        language_servers: project.language_servers.clone(),
1573    })?;
1574
1575    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1576        #[cfg(any(test, feature = "test-support"))]
1577        const MAX_CHUNK_SIZE: usize = 2;
1578        #[cfg(not(any(test, feature = "test-support")))]
1579        const MAX_CHUNK_SIZE: usize = 256;
1580
1581        // Stream this worktree's entries.
1582        let message = proto::UpdateWorktree {
1583            project_id: project_id.to_proto(),
1584            worktree_id,
1585            abs_path: worktree.abs_path.clone(),
1586            root_name: worktree.root_name,
1587            updated_entries: worktree.entries,
1588            removed_entries: Default::default(),
1589            scan_id: worktree.scan_id,
1590            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1591            updated_repositories: worktree.repository_entries.into_values().collect(),
1592            removed_repositories: Default::default(),
1593        };
1594        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1595            session.peer.send(session.connection_id, update.clone())?;
1596        }
1597
1598        // Stream this worktree's diagnostics.
1599        for summary in worktree.diagnostic_summaries {
1600            session.peer.send(
1601                session.connection_id,
1602                proto::UpdateDiagnosticSummary {
1603                    project_id: project_id.to_proto(),
1604                    worktree_id: worktree.id,
1605                    summary: Some(summary),
1606                },
1607            )?;
1608        }
1609
1610        for settings_file in worktree.settings_files {
1611            session.peer.send(
1612                session.connection_id,
1613                proto::UpdateWorktreeSettings {
1614                    project_id: project_id.to_proto(),
1615                    worktree_id: worktree.id,
1616                    path: settings_file.path,
1617                    content: Some(settings_file.content),
1618                },
1619            )?;
1620        }
1621    }
1622
1623    for language_server in &project.language_servers {
1624        session.peer.send(
1625            session.connection_id,
1626            proto::UpdateLanguageServer {
1627                project_id: project_id.to_proto(),
1628                language_server_id: language_server.id,
1629                variant: Some(
1630                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1631                        proto::LspDiskBasedDiagnosticsUpdated {},
1632                    ),
1633                ),
1634            },
1635        )?;
1636    }
1637
1638    Ok(())
1639}
1640
1641/// Leave someone elses shared project.
1642async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1643    let sender_id = session.connection_id;
1644    let project_id = ProjectId::from_proto(request.project_id);
1645
1646    let (room, project) = &*session
1647        .db()
1648        .await
1649        .leave_project(project_id, sender_id)
1650        .await?;
1651    tracing::info!(
1652        %project_id,
1653        host_user_id = %project.host_user_id,
1654        host_connection_id = %project.host_connection_id,
1655        "leave project"
1656    );
1657
1658    project_left(&project, &session);
1659    room_updated(&room, &session.peer);
1660
1661    Ok(())
1662}
1663
1664/// Update other participants with changes to the project
1665async fn update_project(
1666    request: proto::UpdateProject,
1667    response: Response<proto::UpdateProject>,
1668    session: Session,
1669) -> Result<()> {
1670    let project_id = ProjectId::from_proto(request.project_id);
1671    let (room, guest_connection_ids) = &*session
1672        .db()
1673        .await
1674        .update_project(project_id, session.connection_id, &request.worktrees)
1675        .await?;
1676    broadcast(
1677        Some(session.connection_id),
1678        guest_connection_ids.iter().copied(),
1679        |connection_id| {
1680            session
1681                .peer
1682                .forward_send(session.connection_id, connection_id, request.clone())
1683        },
1684    );
1685    room_updated(&room, &session.peer);
1686    response.send(proto::Ack {})?;
1687
1688    Ok(())
1689}
1690
1691/// Update other participants with changes to the worktree
1692async fn update_worktree(
1693    request: proto::UpdateWorktree,
1694    response: Response<proto::UpdateWorktree>,
1695    session: Session,
1696) -> Result<()> {
1697    let guest_connection_ids = session
1698        .db()
1699        .await
1700        .update_worktree(&request, session.connection_id)
1701        .await?;
1702
1703    broadcast(
1704        Some(session.connection_id),
1705        guest_connection_ids.iter().copied(),
1706        |connection_id| {
1707            session
1708                .peer
1709                .forward_send(session.connection_id, connection_id, request.clone())
1710        },
1711    );
1712    response.send(proto::Ack {})?;
1713    Ok(())
1714}
1715
1716/// Update other participants with changes to the diagnostics
1717async fn update_diagnostic_summary(
1718    message: proto::UpdateDiagnosticSummary,
1719    session: Session,
1720) -> Result<()> {
1721    let guest_connection_ids = session
1722        .db()
1723        .await
1724        .update_diagnostic_summary(&message, session.connection_id)
1725        .await?;
1726
1727    broadcast(
1728        Some(session.connection_id),
1729        guest_connection_ids.iter().copied(),
1730        |connection_id| {
1731            session
1732                .peer
1733                .forward_send(session.connection_id, connection_id, message.clone())
1734        },
1735    );
1736
1737    Ok(())
1738}
1739
1740/// Update other participants with changes to the worktree settings
1741async fn update_worktree_settings(
1742    message: proto::UpdateWorktreeSettings,
1743    session: Session,
1744) -> Result<()> {
1745    let guest_connection_ids = session
1746        .db()
1747        .await
1748        .update_worktree_settings(&message, session.connection_id)
1749        .await?;
1750
1751    broadcast(
1752        Some(session.connection_id),
1753        guest_connection_ids.iter().copied(),
1754        |connection_id| {
1755            session
1756                .peer
1757                .forward_send(session.connection_id, connection_id, message.clone())
1758        },
1759    );
1760
1761    Ok(())
1762}
1763
1764/// Notify other participants that a  language server has started.
1765async fn start_language_server(
1766    request: proto::StartLanguageServer,
1767    session: Session,
1768) -> Result<()> {
1769    let guest_connection_ids = session
1770        .db()
1771        .await
1772        .start_language_server(&request, session.connection_id)
1773        .await?;
1774
1775    broadcast(
1776        Some(session.connection_id),
1777        guest_connection_ids.iter().copied(),
1778        |connection_id| {
1779            session
1780                .peer
1781                .forward_send(session.connection_id, connection_id, request.clone())
1782        },
1783    );
1784    Ok(())
1785}
1786
1787/// Notify other participants that a language server has changed.
1788async fn update_language_server(
1789    request: proto::UpdateLanguageServer,
1790    session: Session,
1791) -> Result<()> {
1792    let project_id = ProjectId::from_proto(request.project_id);
1793    let project_connection_ids = session
1794        .db()
1795        .await
1796        .project_connection_ids(project_id, session.connection_id)
1797        .await?;
1798    broadcast(
1799        Some(session.connection_id),
1800        project_connection_ids.iter().copied(),
1801        |connection_id| {
1802            session
1803                .peer
1804                .forward_send(session.connection_id, connection_id, request.clone())
1805        },
1806    );
1807    Ok(())
1808}
1809
1810/// forward a project request to the host. These requests should be read only
1811/// as guests are allowed to send them.
1812async fn forward_read_only_project_request<T>(
1813    request: T,
1814    response: Response<T>,
1815    session: Session,
1816) -> Result<()>
1817where
1818    T: EntityMessage + RequestMessage,
1819{
1820    let project_id = ProjectId::from_proto(request.remote_entity_id());
1821    let host_connection_id = session
1822        .db()
1823        .await
1824        .host_for_read_only_project_request(project_id, session.connection_id)
1825        .await?;
1826    let payload = session
1827        .peer
1828        .forward_request(session.connection_id, host_connection_id, request)
1829        .await?;
1830    response.send(payload)?;
1831    Ok(())
1832}
1833
1834/// forward a project request to the host. These requests are disallowed
1835/// for guests.
1836async fn forward_mutating_project_request<T>(
1837    request: T,
1838    response: Response<T>,
1839    session: Session,
1840) -> Result<()>
1841where
1842    T: EntityMessage + RequestMessage,
1843{
1844    let project_id = ProjectId::from_proto(request.remote_entity_id());
1845    let host_connection_id = session
1846        .db()
1847        .await
1848        .host_for_mutating_project_request(project_id, session.connection_id)
1849        .await?;
1850    let payload = session
1851        .peer
1852        .forward_request(session.connection_id, host_connection_id, request)
1853        .await?;
1854    response.send(payload)?;
1855    Ok(())
1856}
1857
1858/// Notify other participants that a new buffer has been created
1859async fn create_buffer_for_peer(
1860    request: proto::CreateBufferForPeer,
1861    session: Session,
1862) -> Result<()> {
1863    session
1864        .db()
1865        .await
1866        .check_user_is_project_host(
1867            ProjectId::from_proto(request.project_id),
1868            session.connection_id,
1869        )
1870        .await?;
1871    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1872    session
1873        .peer
1874        .forward_send(session.connection_id, peer_id.into(), request)?;
1875    Ok(())
1876}
1877
1878/// Notify other participants that a buffer has been updated. This is
1879/// allowed for guests as long as the update is limited to selections.
1880async fn update_buffer(
1881    request: proto::UpdateBuffer,
1882    response: Response<proto::UpdateBuffer>,
1883    session: Session,
1884) -> Result<()> {
1885    let project_id = ProjectId::from_proto(request.project_id);
1886    let mut guest_connection_ids;
1887    let mut host_connection_id = None;
1888
1889    let mut requires_write_permission = false;
1890
1891    for op in request.operations.iter() {
1892        match op.variant {
1893            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1894            Some(_) => requires_write_permission = true,
1895        }
1896    }
1897
1898    {
1899        let collaborators = session
1900            .db()
1901            .await
1902            .project_collaborators_for_buffer_update(
1903                project_id,
1904                session.connection_id,
1905                requires_write_permission,
1906            )
1907            .await?;
1908        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1909        for collaborator in collaborators.iter() {
1910            if collaborator.is_host {
1911                host_connection_id = Some(collaborator.connection_id);
1912            } else {
1913                guest_connection_ids.push(collaborator.connection_id);
1914            }
1915        }
1916    }
1917    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1918
1919    broadcast(
1920        Some(session.connection_id),
1921        guest_connection_ids,
1922        |connection_id| {
1923            session
1924                .peer
1925                .forward_send(session.connection_id, connection_id, request.clone())
1926        },
1927    );
1928    if host_connection_id != session.connection_id {
1929        session
1930            .peer
1931            .forward_request(session.connection_id, host_connection_id, request.clone())
1932            .await?;
1933    }
1934
1935    response.send(proto::Ack {})?;
1936    Ok(())
1937}
1938
1939/// Notify other participants that a project has been updated.
1940async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1941    request: T,
1942    session: Session,
1943) -> Result<()> {
1944    let project_id = ProjectId::from_proto(request.remote_entity_id());
1945    let project_connection_ids = session
1946        .db()
1947        .await
1948        .project_connection_ids(project_id, session.connection_id)
1949        .await?;
1950
1951    broadcast(
1952        Some(session.connection_id),
1953        project_connection_ids.iter().copied(),
1954        |connection_id| {
1955            session
1956                .peer
1957                .forward_send(session.connection_id, connection_id, request.clone())
1958        },
1959    );
1960    Ok(())
1961}
1962
1963/// Start following another user in a call.
1964async fn follow(
1965    request: proto::Follow,
1966    response: Response<proto::Follow>,
1967    session: Session,
1968) -> Result<()> {
1969    let room_id = RoomId::from_proto(request.room_id);
1970    let project_id = request.project_id.map(ProjectId::from_proto);
1971    let leader_id = request
1972        .leader_id
1973        .ok_or_else(|| anyhow!("invalid leader id"))?
1974        .into();
1975    let follower_id = session.connection_id;
1976
1977    session
1978        .db()
1979        .await
1980        .check_room_participants(room_id, leader_id, session.connection_id)
1981        .await?;
1982
1983    let response_payload = session
1984        .peer
1985        .forward_request(session.connection_id, leader_id, request)
1986        .await?;
1987    response.send(response_payload)?;
1988
1989    if let Some(project_id) = project_id {
1990        let room = session
1991            .db()
1992            .await
1993            .follow(room_id, project_id, leader_id, follower_id)
1994            .await?;
1995        room_updated(&room, &session.peer);
1996    }
1997
1998    Ok(())
1999}
2000
2001/// Stop following another user in a call.
2002async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2003    let room_id = RoomId::from_proto(request.room_id);
2004    let project_id = request.project_id.map(ProjectId::from_proto);
2005    let leader_id = request
2006        .leader_id
2007        .ok_or_else(|| anyhow!("invalid leader id"))?
2008        .into();
2009    let follower_id = session.connection_id;
2010
2011    session
2012        .db()
2013        .await
2014        .check_room_participants(room_id, leader_id, session.connection_id)
2015        .await?;
2016
2017    session
2018        .peer
2019        .forward_send(session.connection_id, leader_id, request)?;
2020
2021    if let Some(project_id) = project_id {
2022        let room = session
2023            .db()
2024            .await
2025            .unfollow(room_id, project_id, leader_id, follower_id)
2026            .await?;
2027        room_updated(&room, &session.peer);
2028    }
2029
2030    Ok(())
2031}
2032
2033/// Notify everyone following you of your current location.
2034async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2035    let room_id = RoomId::from_proto(request.room_id);
2036    let database = session.db.lock().await;
2037
2038    let connection_ids = if let Some(project_id) = request.project_id {
2039        let project_id = ProjectId::from_proto(project_id);
2040        database
2041            .project_connection_ids(project_id, session.connection_id)
2042            .await?
2043    } else {
2044        database
2045            .room_connection_ids(room_id, session.connection_id)
2046            .await?
2047    };
2048
2049    // For now, don't send view update messages back to that view's current leader.
2050    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2051        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2052        _ => None,
2053    });
2054
2055    for follower_peer_id in request.follower_ids.iter().copied() {
2056        let follower_connection_id = follower_peer_id.into();
2057        if Some(follower_peer_id) != connection_id_to_omit
2058            && connection_ids.contains(&follower_connection_id)
2059        {
2060            session.peer.forward_send(
2061                session.connection_id,
2062                follower_connection_id,
2063                request.clone(),
2064            )?;
2065        }
2066    }
2067    Ok(())
2068}
2069
2070/// Get public data about users.
2071async fn get_users(
2072    request: proto::GetUsers,
2073    response: Response<proto::GetUsers>,
2074    session: Session,
2075) -> Result<()> {
2076    let user_ids = request
2077        .user_ids
2078        .into_iter()
2079        .map(UserId::from_proto)
2080        .collect();
2081    let users = session
2082        .db()
2083        .await
2084        .get_users_by_ids(user_ids)
2085        .await?
2086        .into_iter()
2087        .map(|user| proto::User {
2088            id: user.id.to_proto(),
2089            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2090            github_login: user.github_login,
2091        })
2092        .collect();
2093    response.send(proto::UsersResponse { users })?;
2094    Ok(())
2095}
2096
2097/// Search for users (to invite) buy Github login
2098async fn fuzzy_search_users(
2099    request: proto::FuzzySearchUsers,
2100    response: Response<proto::FuzzySearchUsers>,
2101    session: Session,
2102) -> Result<()> {
2103    let query = request.query;
2104    let users = match query.len() {
2105        0 => vec![],
2106        1 | 2 => session
2107            .db()
2108            .await
2109            .get_user_by_github_login(&query)
2110            .await?
2111            .into_iter()
2112            .collect(),
2113        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2114    };
2115    let users = users
2116        .into_iter()
2117        .filter(|user| user.id != session.user_id)
2118        .map(|user| proto::User {
2119            id: user.id.to_proto(),
2120            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2121            github_login: user.github_login,
2122        })
2123        .collect();
2124    response.send(proto::UsersResponse { users })?;
2125    Ok(())
2126}
2127
2128/// Send a contact request to another user.
2129async fn request_contact(
2130    request: proto::RequestContact,
2131    response: Response<proto::RequestContact>,
2132    session: Session,
2133) -> Result<()> {
2134    let requester_id = session.user_id;
2135    let responder_id = UserId::from_proto(request.responder_id);
2136    if requester_id == responder_id {
2137        return Err(anyhow!("cannot add yourself as a contact"))?;
2138    }
2139
2140    let notifications = session
2141        .db()
2142        .await
2143        .send_contact_request(requester_id, responder_id)
2144        .await?;
2145
2146    // Update outgoing contact requests of requester
2147    let mut update = proto::UpdateContacts::default();
2148    update.outgoing_requests.push(responder_id.to_proto());
2149    for connection_id in session
2150        .connection_pool()
2151        .await
2152        .user_connection_ids(requester_id)
2153    {
2154        session.peer.send(connection_id, update.clone())?;
2155    }
2156
2157    // Update incoming contact requests of responder
2158    let mut update = proto::UpdateContacts::default();
2159    update
2160        .incoming_requests
2161        .push(proto::IncomingContactRequest {
2162            requester_id: requester_id.to_proto(),
2163        });
2164    let connection_pool = session.connection_pool().await;
2165    for connection_id in connection_pool.user_connection_ids(responder_id) {
2166        session.peer.send(connection_id, update.clone())?;
2167    }
2168
2169    send_notifications(&*connection_pool, &session.peer, notifications);
2170
2171    response.send(proto::Ack {})?;
2172    Ok(())
2173}
2174
2175/// Accept or decline a contact request
2176async fn respond_to_contact_request(
2177    request: proto::RespondToContactRequest,
2178    response: Response<proto::RespondToContactRequest>,
2179    session: Session,
2180) -> Result<()> {
2181    let responder_id = session.user_id;
2182    let requester_id = UserId::from_proto(request.requester_id);
2183    let db = session.db().await;
2184    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2185        db.dismiss_contact_notification(responder_id, requester_id)
2186            .await?;
2187    } else {
2188        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2189
2190        let notifications = db
2191            .respond_to_contact_request(responder_id, requester_id, accept)
2192            .await?;
2193        let requester_busy = db.is_user_busy(requester_id).await?;
2194        let responder_busy = db.is_user_busy(responder_id).await?;
2195
2196        let pool = session.connection_pool().await;
2197        // Update responder with new contact
2198        let mut update = proto::UpdateContacts::default();
2199        if accept {
2200            update
2201                .contacts
2202                .push(contact_for_user(requester_id, requester_busy, &pool));
2203        }
2204        update
2205            .remove_incoming_requests
2206            .push(requester_id.to_proto());
2207        for connection_id in pool.user_connection_ids(responder_id) {
2208            session.peer.send(connection_id, update.clone())?;
2209        }
2210
2211        // Update requester with new contact
2212        let mut update = proto::UpdateContacts::default();
2213        if accept {
2214            update
2215                .contacts
2216                .push(contact_for_user(responder_id, responder_busy, &pool));
2217        }
2218        update
2219            .remove_outgoing_requests
2220            .push(responder_id.to_proto());
2221
2222        for connection_id in pool.user_connection_ids(requester_id) {
2223            session.peer.send(connection_id, update.clone())?;
2224        }
2225
2226        send_notifications(&*pool, &session.peer, notifications);
2227    }
2228
2229    response.send(proto::Ack {})?;
2230    Ok(())
2231}
2232
2233/// Remove a contact.
2234async fn remove_contact(
2235    request: proto::RemoveContact,
2236    response: Response<proto::RemoveContact>,
2237    session: Session,
2238) -> Result<()> {
2239    let requester_id = session.user_id;
2240    let responder_id = UserId::from_proto(request.user_id);
2241    let db = session.db().await;
2242    let (contact_accepted, deleted_notification_id) =
2243        db.remove_contact(requester_id, responder_id).await?;
2244
2245    let pool = session.connection_pool().await;
2246    // Update outgoing contact requests of requester
2247    let mut update = proto::UpdateContacts::default();
2248    if contact_accepted {
2249        update.remove_contacts.push(responder_id.to_proto());
2250    } else {
2251        update
2252            .remove_outgoing_requests
2253            .push(responder_id.to_proto());
2254    }
2255    for connection_id in pool.user_connection_ids(requester_id) {
2256        session.peer.send(connection_id, update.clone())?;
2257    }
2258
2259    // Update incoming contact requests of responder
2260    let mut update = proto::UpdateContacts::default();
2261    if contact_accepted {
2262        update.remove_contacts.push(requester_id.to_proto());
2263    } else {
2264        update
2265            .remove_incoming_requests
2266            .push(requester_id.to_proto());
2267    }
2268    for connection_id in pool.user_connection_ids(responder_id) {
2269        session.peer.send(connection_id, update.clone())?;
2270        if let Some(notification_id) = deleted_notification_id {
2271            session.peer.send(
2272                connection_id,
2273                proto::DeleteNotification {
2274                    notification_id: notification_id.to_proto(),
2275                },
2276            )?;
2277        }
2278    }
2279
2280    response.send(proto::Ack {})?;
2281    Ok(())
2282}
2283
2284/// Create a new channel.
2285async fn create_channel(
2286    request: proto::CreateChannel,
2287    response: Response<proto::CreateChannel>,
2288    session: Session,
2289) -> Result<()> {
2290    let db = session.db().await;
2291
2292    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2293    let CreateChannelResult {
2294        channel,
2295        participants_to_update,
2296    } = db
2297        .create_channel(&request.name, parent_id, session.user_id)
2298        .await?;
2299
2300    response.send(proto::CreateChannelResponse {
2301        channel: Some(channel.to_proto()),
2302        parent_id: request.parent_id,
2303    })?;
2304
2305    let connection_pool = session.connection_pool().await;
2306    for (user_id, channels) in participants_to_update {
2307        let update = build_channels_update(channels, vec![]);
2308        for connection_id in connection_pool.user_connection_ids(user_id) {
2309            if user_id == session.user_id {
2310                continue;
2311            }
2312            session.peer.send(connection_id, update.clone())?;
2313        }
2314    }
2315
2316    Ok(())
2317}
2318
2319/// Delete a channel
2320async fn delete_channel(
2321    request: proto::DeleteChannel,
2322    response: Response<proto::DeleteChannel>,
2323    session: Session,
2324) -> Result<()> {
2325    let db = session.db().await;
2326
2327    let channel_id = request.channel_id;
2328    let (removed_channels, member_ids) = db
2329        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2330        .await?;
2331    response.send(proto::Ack {})?;
2332
2333    // Notify members of removed channels
2334    let mut update = proto::UpdateChannels::default();
2335    update
2336        .delete_channels
2337        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2338
2339    let connection_pool = session.connection_pool().await;
2340    for member_id in member_ids {
2341        for connection_id in connection_pool.user_connection_ids(member_id) {
2342            session.peer.send(connection_id, update.clone())?;
2343        }
2344    }
2345
2346    Ok(())
2347}
2348
2349/// Invite someone to join a channel.
2350async fn invite_channel_member(
2351    request: proto::InviteChannelMember,
2352    response: Response<proto::InviteChannelMember>,
2353    session: Session,
2354) -> Result<()> {
2355    let db = session.db().await;
2356    let channel_id = ChannelId::from_proto(request.channel_id);
2357    let invitee_id = UserId::from_proto(request.user_id);
2358    let InviteMemberResult {
2359        channel,
2360        notifications,
2361    } = db
2362        .invite_channel_member(
2363            channel_id,
2364            invitee_id,
2365            session.user_id,
2366            request.role().into(),
2367        )
2368        .await?;
2369
2370    let update = proto::UpdateChannels {
2371        channel_invitations: vec![channel.to_proto()],
2372        ..Default::default()
2373    };
2374
2375    let connection_pool = session.connection_pool().await;
2376    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2377        session.peer.send(connection_id, update.clone())?;
2378    }
2379
2380    send_notifications(&*connection_pool, &session.peer, notifications);
2381
2382    response.send(proto::Ack {})?;
2383    Ok(())
2384}
2385
2386/// remove someone from a channel
2387async fn remove_channel_member(
2388    request: proto::RemoveChannelMember,
2389    response: Response<proto::RemoveChannelMember>,
2390    session: Session,
2391) -> Result<()> {
2392    let db = session.db().await;
2393    let channel_id = ChannelId::from_proto(request.channel_id);
2394    let member_id = UserId::from_proto(request.user_id);
2395
2396    let RemoveChannelMemberResult {
2397        membership_update,
2398        notification_id,
2399    } = db
2400        .remove_channel_member(channel_id, member_id, session.user_id)
2401        .await?;
2402
2403    let connection_pool = &session.connection_pool().await;
2404    notify_membership_updated(
2405        &connection_pool,
2406        membership_update,
2407        member_id,
2408        &session.peer,
2409    );
2410    for connection_id in connection_pool.user_connection_ids(member_id) {
2411        if let Some(notification_id) = notification_id {
2412            session
2413                .peer
2414                .send(
2415                    connection_id,
2416                    proto::DeleteNotification {
2417                        notification_id: notification_id.to_proto(),
2418                    },
2419                )
2420                .trace_err();
2421        }
2422    }
2423
2424    response.send(proto::Ack {})?;
2425    Ok(())
2426}
2427
2428/// Toggle the channel between public and private
2429async fn set_channel_visibility(
2430    request: proto::SetChannelVisibility,
2431    response: Response<proto::SetChannelVisibility>,
2432    session: Session,
2433) -> Result<()> {
2434    let db = session.db().await;
2435    let channel_id = ChannelId::from_proto(request.channel_id);
2436    let visibility = request.visibility().into();
2437
2438    let SetChannelVisibilityResult {
2439        participants_to_update,
2440        participants_to_remove,
2441        channels_to_remove,
2442    } = db
2443        .set_channel_visibility(channel_id, visibility, session.user_id)
2444        .await?;
2445
2446    let connection_pool = session.connection_pool().await;
2447    for (user_id, channels) in participants_to_update {
2448        let update = build_channels_update(channels, vec![]);
2449        for connection_id in connection_pool.user_connection_ids(user_id) {
2450            session.peer.send(connection_id, update.clone())?;
2451        }
2452    }
2453    for user_id in participants_to_remove {
2454        let update = proto::UpdateChannels {
2455            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2456            ..Default::default()
2457        };
2458        for connection_id in connection_pool.user_connection_ids(user_id) {
2459            session.peer.send(connection_id, update.clone())?;
2460        }
2461    }
2462
2463    response.send(proto::Ack {})?;
2464    Ok(())
2465}
2466
2467/// Alter the role for a user in the channel
2468async fn set_channel_member_role(
2469    request: proto::SetChannelMemberRole,
2470    response: Response<proto::SetChannelMemberRole>,
2471    session: Session,
2472) -> Result<()> {
2473    let db = session.db().await;
2474    let channel_id = ChannelId::from_proto(request.channel_id);
2475    let member_id = UserId::from_proto(request.user_id);
2476    let result = db
2477        .set_channel_member_role(
2478            channel_id,
2479            session.user_id,
2480            member_id,
2481            request.role().into(),
2482        )
2483        .await?;
2484
2485    match result {
2486        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2487            let connection_pool = session.connection_pool().await;
2488            notify_membership_updated(
2489                &connection_pool,
2490                membership_update,
2491                member_id,
2492                &session.peer,
2493            )
2494        }
2495        db::SetMemberRoleResult::InviteUpdated(channel) => {
2496            let update = proto::UpdateChannels {
2497                channel_invitations: vec![channel.to_proto()],
2498                ..Default::default()
2499            };
2500
2501            for connection_id in session
2502                .connection_pool()
2503                .await
2504                .user_connection_ids(member_id)
2505            {
2506                session.peer.send(connection_id, update.clone())?;
2507            }
2508        }
2509    }
2510
2511    response.send(proto::Ack {})?;
2512    Ok(())
2513}
2514
2515/// Change the name of a channel
2516async fn rename_channel(
2517    request: proto::RenameChannel,
2518    response: Response<proto::RenameChannel>,
2519    session: Session,
2520) -> Result<()> {
2521    let db = session.db().await;
2522    let channel_id = ChannelId::from_proto(request.channel_id);
2523    let RenameChannelResult {
2524        channel,
2525        participants_to_update,
2526    } = db
2527        .rename_channel(channel_id, session.user_id, &request.name)
2528        .await?;
2529
2530    response.send(proto::RenameChannelResponse {
2531        channel: Some(channel.to_proto()),
2532    })?;
2533
2534    let connection_pool = session.connection_pool().await;
2535    for (user_id, channel) in participants_to_update {
2536        for connection_id in connection_pool.user_connection_ids(user_id) {
2537            let update = proto::UpdateChannels {
2538                channels: vec![channel.to_proto()],
2539                ..Default::default()
2540            };
2541
2542            session.peer.send(connection_id, update.clone())?;
2543        }
2544    }
2545
2546    Ok(())
2547}
2548
2549/// Move a channel to a new parent.
2550async fn move_channel(
2551    request: proto::MoveChannel,
2552    response: Response<proto::MoveChannel>,
2553    session: Session,
2554) -> Result<()> {
2555    let channel_id = ChannelId::from_proto(request.channel_id);
2556    let to = request.to.map(ChannelId::from_proto);
2557
2558    let result = session
2559        .db()
2560        .await
2561        .move_channel(channel_id, to, session.user_id)
2562        .await?;
2563
2564    notify_channel_moved(result, session).await?;
2565
2566    response.send(Ack {})?;
2567    Ok(())
2568}
2569
2570async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2571    let Some(MoveChannelResult {
2572        participants_to_remove,
2573        participants_to_update,
2574        moved_channels,
2575    }) = result
2576    else {
2577        return Ok(());
2578    };
2579    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2580
2581    let connection_pool = session.connection_pool().await;
2582    for (user_id, channels) in participants_to_update {
2583        let mut update = build_channels_update(channels, vec![]);
2584        update.delete_channels = moved_channels.clone();
2585        for connection_id in connection_pool.user_connection_ids(user_id) {
2586            session.peer.send(connection_id, update.clone())?;
2587        }
2588    }
2589
2590    for user_id in participants_to_remove {
2591        let update = proto::UpdateChannels {
2592            delete_channels: moved_channels.clone(),
2593            ..Default::default()
2594        };
2595        for connection_id in connection_pool.user_connection_ids(user_id) {
2596            session.peer.send(connection_id, update.clone())?;
2597        }
2598    }
2599    Ok(())
2600}
2601
2602/// Get the list of channel members
2603async fn get_channel_members(
2604    request: proto::GetChannelMembers,
2605    response: Response<proto::GetChannelMembers>,
2606    session: Session,
2607) -> Result<()> {
2608    let db = session.db().await;
2609    let channel_id = ChannelId::from_proto(request.channel_id);
2610    let members = db
2611        .get_channel_participant_details(channel_id, session.user_id)
2612        .await?;
2613    response.send(proto::GetChannelMembersResponse { members })?;
2614    Ok(())
2615}
2616
2617/// Accept or decline a channel invitation.
2618async fn respond_to_channel_invite(
2619    request: proto::RespondToChannelInvite,
2620    response: Response<proto::RespondToChannelInvite>,
2621    session: Session,
2622) -> Result<()> {
2623    let db = session.db().await;
2624    let channel_id = ChannelId::from_proto(request.channel_id);
2625    let RespondToChannelInvite {
2626        membership_update,
2627        notifications,
2628    } = db
2629        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2630        .await?;
2631
2632    let connection_pool = session.connection_pool().await;
2633    if let Some(membership_update) = membership_update {
2634        notify_membership_updated(
2635            &connection_pool,
2636            membership_update,
2637            session.user_id,
2638            &session.peer,
2639        );
2640    } else {
2641        let update = proto::UpdateChannels {
2642            remove_channel_invitations: vec![channel_id.to_proto()],
2643            ..Default::default()
2644        };
2645
2646        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2647            session.peer.send(connection_id, update.clone())?;
2648        }
2649    };
2650
2651    send_notifications(&*connection_pool, &session.peer, notifications);
2652
2653    response.send(proto::Ack {})?;
2654
2655    Ok(())
2656}
2657
2658/// Join the channels' room
2659async fn join_channel(
2660    request: proto::JoinChannel,
2661    response: Response<proto::JoinChannel>,
2662    session: Session,
2663) -> Result<()> {
2664    let channel_id = ChannelId::from_proto(request.channel_id);
2665    join_channel_internal(channel_id, Box::new(response), session).await
2666}
2667
2668trait JoinChannelInternalResponse {
2669    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2670}
2671impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2672    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2673        Response::<proto::JoinChannel>::send(self, result)
2674    }
2675}
2676impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2677    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2678        Response::<proto::JoinRoom>::send(self, result)
2679    }
2680}
2681
2682async fn join_channel_internal(
2683    channel_id: ChannelId,
2684    response: Box<impl JoinChannelInternalResponse>,
2685    session: Session,
2686) -> Result<()> {
2687    let joined_room = {
2688        leave_room_for_session(&session).await?;
2689        let db = session.db().await;
2690
2691        let (joined_room, membership_updated, role) = db
2692            .join_channel(
2693                channel_id,
2694                session.user_id,
2695                session.connection_id,
2696                session.zed_environment.as_ref(),
2697            )
2698            .await?;
2699
2700        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2701            let (can_publish, token) = if role == ChannelRole::Guest {
2702                (
2703                    false,
2704                    live_kit
2705                        .guest_token(
2706                            &joined_room.room.live_kit_room,
2707                            &session.user_id.to_string(),
2708                        )
2709                        .trace_err()?,
2710                )
2711            } else {
2712                (
2713                    true,
2714                    live_kit
2715                        .room_token(
2716                            &joined_room.room.live_kit_room,
2717                            &session.user_id.to_string(),
2718                        )
2719                        .trace_err()?,
2720                )
2721            };
2722
2723            Some(LiveKitConnectionInfo {
2724                server_url: live_kit.url().into(),
2725                token,
2726                can_publish,
2727            })
2728        });
2729
2730        response.send(proto::JoinRoomResponse {
2731            room: Some(joined_room.room.clone()),
2732            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2733            live_kit_connection_info,
2734        })?;
2735
2736        let connection_pool = session.connection_pool().await;
2737        if let Some(membership_updated) = membership_updated {
2738            notify_membership_updated(
2739                &connection_pool,
2740                membership_updated,
2741                session.user_id,
2742                &session.peer,
2743            );
2744        }
2745
2746        room_updated(&joined_room.room, &session.peer);
2747
2748        joined_room
2749    };
2750
2751    channel_updated(
2752        channel_id,
2753        &joined_room.room,
2754        &joined_room.channel_members,
2755        &session.peer,
2756        &*session.connection_pool().await,
2757    );
2758
2759    update_user_contacts(session.user_id, &session).await?;
2760    Ok(())
2761}
2762
2763/// Start editing the channel notes
2764async fn join_channel_buffer(
2765    request: proto::JoinChannelBuffer,
2766    response: Response<proto::JoinChannelBuffer>,
2767    session: Session,
2768) -> Result<()> {
2769    let db = session.db().await;
2770    let channel_id = ChannelId::from_proto(request.channel_id);
2771
2772    let open_response = db
2773        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2774        .await?;
2775
2776    let collaborators = open_response.collaborators.clone();
2777    response.send(open_response)?;
2778
2779    let update = UpdateChannelBufferCollaborators {
2780        channel_id: channel_id.to_proto(),
2781        collaborators: collaborators.clone(),
2782    };
2783    channel_buffer_updated(
2784        session.connection_id,
2785        collaborators
2786            .iter()
2787            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2788        &update,
2789        &session.peer,
2790    );
2791
2792    Ok(())
2793}
2794
2795/// Edit the channel notes
2796async fn update_channel_buffer(
2797    request: proto::UpdateChannelBuffer,
2798    session: Session,
2799) -> Result<()> {
2800    let db = session.db().await;
2801    let channel_id = ChannelId::from_proto(request.channel_id);
2802
2803    let (collaborators, non_collaborators, epoch, version) = db
2804        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2805        .await?;
2806
2807    channel_buffer_updated(
2808        session.connection_id,
2809        collaborators,
2810        &proto::UpdateChannelBuffer {
2811            channel_id: channel_id.to_proto(),
2812            operations: request.operations,
2813        },
2814        &session.peer,
2815    );
2816
2817    let pool = &*session.connection_pool().await;
2818
2819    broadcast(
2820        None,
2821        non_collaborators
2822            .iter()
2823            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2824        |peer_id| {
2825            session.peer.send(
2826                peer_id.into(),
2827                proto::UpdateChannels {
2828                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2829                        channel_id: channel_id.to_proto(),
2830                        epoch: epoch as u64,
2831                        version: version.clone(),
2832                    }],
2833                    ..Default::default()
2834                },
2835            )
2836        },
2837    );
2838
2839    Ok(())
2840}
2841
2842/// Rejoin the channel notes after a connection blip
2843async fn rejoin_channel_buffers(
2844    request: proto::RejoinChannelBuffers,
2845    response: Response<proto::RejoinChannelBuffers>,
2846    session: Session,
2847) -> Result<()> {
2848    let db = session.db().await;
2849    let buffers = db
2850        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2851        .await?;
2852
2853    for rejoined_buffer in &buffers {
2854        let collaborators_to_notify = rejoined_buffer
2855            .buffer
2856            .collaborators
2857            .iter()
2858            .filter_map(|c| Some(c.peer_id?.into()));
2859        channel_buffer_updated(
2860            session.connection_id,
2861            collaborators_to_notify,
2862            &proto::UpdateChannelBufferCollaborators {
2863                channel_id: rejoined_buffer.buffer.channel_id,
2864                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2865            },
2866            &session.peer,
2867        );
2868    }
2869
2870    response.send(proto::RejoinChannelBuffersResponse {
2871        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2872    })?;
2873
2874    Ok(())
2875}
2876
2877/// Stop editing the channel notes
2878async fn leave_channel_buffer(
2879    request: proto::LeaveChannelBuffer,
2880    response: Response<proto::LeaveChannelBuffer>,
2881    session: Session,
2882) -> Result<()> {
2883    let db = session.db().await;
2884    let channel_id = ChannelId::from_proto(request.channel_id);
2885
2886    let left_buffer = db
2887        .leave_channel_buffer(channel_id, session.connection_id)
2888        .await?;
2889
2890    response.send(Ack {})?;
2891
2892    channel_buffer_updated(
2893        session.connection_id,
2894        left_buffer.connections,
2895        &proto::UpdateChannelBufferCollaborators {
2896            channel_id: channel_id.to_proto(),
2897            collaborators: left_buffer.collaborators,
2898        },
2899        &session.peer,
2900    );
2901
2902    Ok(())
2903}
2904
2905fn channel_buffer_updated<T: EnvelopedMessage>(
2906    sender_id: ConnectionId,
2907    collaborators: impl IntoIterator<Item = ConnectionId>,
2908    message: &T,
2909    peer: &Peer,
2910) {
2911    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2912        peer.send(peer_id.into(), message.clone())
2913    });
2914}
2915
2916fn send_notifications(
2917    connection_pool: &ConnectionPool,
2918    peer: &Peer,
2919    notifications: db::NotificationBatch,
2920) {
2921    for (user_id, notification) in notifications {
2922        for connection_id in connection_pool.user_connection_ids(user_id) {
2923            if let Err(error) = peer.send(
2924                connection_id,
2925                proto::AddNotification {
2926                    notification: Some(notification.clone()),
2927                },
2928            ) {
2929                tracing::error!(
2930                    "failed to send notification to {:?} {}",
2931                    connection_id,
2932                    error
2933                );
2934            }
2935        }
2936    }
2937}
2938
2939/// Send a message to the channel
2940async fn send_channel_message(
2941    request: proto::SendChannelMessage,
2942    response: Response<proto::SendChannelMessage>,
2943    session: Session,
2944) -> Result<()> {
2945    // Validate the message body.
2946    let body = request.body.trim().to_string();
2947    if body.len() > MAX_MESSAGE_LEN {
2948        return Err(anyhow!("message is too long"))?;
2949    }
2950    if body.is_empty() {
2951        return Err(anyhow!("message can't be blank"))?;
2952    }
2953
2954    // TODO: adjust mentions if body is trimmed
2955
2956    let timestamp = OffsetDateTime::now_utc();
2957    let nonce = request
2958        .nonce
2959        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2960
2961    let channel_id = ChannelId::from_proto(request.channel_id);
2962    let CreatedChannelMessage {
2963        message_id,
2964        participant_connection_ids,
2965        channel_members,
2966        notifications,
2967    } = session
2968        .db()
2969        .await
2970        .create_channel_message(
2971            channel_id,
2972            session.user_id,
2973            &body,
2974            &request.mentions,
2975            timestamp,
2976            nonce.clone().into(),
2977        )
2978        .await?;
2979    let message = proto::ChannelMessage {
2980        sender_id: session.user_id.to_proto(),
2981        id: message_id.to_proto(),
2982        body,
2983        mentions: request.mentions,
2984        timestamp: timestamp.unix_timestamp() as u64,
2985        nonce: Some(nonce),
2986    };
2987    broadcast(
2988        Some(session.connection_id),
2989        participant_connection_ids,
2990        |connection| {
2991            session.peer.send(
2992                connection,
2993                proto::ChannelMessageSent {
2994                    channel_id: channel_id.to_proto(),
2995                    message: Some(message.clone()),
2996                },
2997            )
2998        },
2999    );
3000    response.send(proto::SendChannelMessageResponse {
3001        message: Some(message),
3002    })?;
3003
3004    let pool = &*session.connection_pool().await;
3005    broadcast(
3006        None,
3007        channel_members
3008            .iter()
3009            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3010        |peer_id| {
3011            session.peer.send(
3012                peer_id.into(),
3013                proto::UpdateChannels {
3014                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
3015                        channel_id: channel_id.to_proto(),
3016                        message_id: message_id.to_proto(),
3017                    }],
3018                    ..Default::default()
3019                },
3020            )
3021        },
3022    );
3023    send_notifications(pool, &session.peer, notifications);
3024
3025    Ok(())
3026}
3027
3028/// Delete a channel message
3029async fn remove_channel_message(
3030    request: proto::RemoveChannelMessage,
3031    response: Response<proto::RemoveChannelMessage>,
3032    session: Session,
3033) -> Result<()> {
3034    let channel_id = ChannelId::from_proto(request.channel_id);
3035    let message_id = MessageId::from_proto(request.message_id);
3036    let connection_ids = session
3037        .db()
3038        .await
3039        .remove_channel_message(channel_id, message_id, session.user_id)
3040        .await?;
3041    broadcast(Some(session.connection_id), connection_ids, |connection| {
3042        session.peer.send(connection, request.clone())
3043    });
3044    response.send(proto::Ack {})?;
3045    Ok(())
3046}
3047
3048/// Mark a channel message as read
3049async fn acknowledge_channel_message(
3050    request: proto::AckChannelMessage,
3051    session: Session,
3052) -> Result<()> {
3053    let channel_id = ChannelId::from_proto(request.channel_id);
3054    let message_id = MessageId::from_proto(request.message_id);
3055    let notifications = session
3056        .db()
3057        .await
3058        .observe_channel_message(channel_id, session.user_id, message_id)
3059        .await?;
3060    send_notifications(
3061        &*session.connection_pool().await,
3062        &session.peer,
3063        notifications,
3064    );
3065    Ok(())
3066}
3067
3068/// Mark a buffer version as synced
3069async fn acknowledge_buffer_version(
3070    request: proto::AckBufferOperation,
3071    session: Session,
3072) -> Result<()> {
3073    let buffer_id = BufferId::from_proto(request.buffer_id);
3074    session
3075        .db()
3076        .await
3077        .observe_buffer_version(
3078            buffer_id,
3079            session.user_id,
3080            request.epoch as i32,
3081            &request.version,
3082        )
3083        .await?;
3084    Ok(())
3085}
3086
3087/// Start receiving chat updates for a channel
3088async fn join_channel_chat(
3089    request: proto::JoinChannelChat,
3090    response: Response<proto::JoinChannelChat>,
3091    session: Session,
3092) -> Result<()> {
3093    let channel_id = ChannelId::from_proto(request.channel_id);
3094
3095    let db = session.db().await;
3096    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3097        .await?;
3098    let messages = db
3099        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3100        .await?;
3101    response.send(proto::JoinChannelChatResponse {
3102        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3103        messages,
3104    })?;
3105    Ok(())
3106}
3107
3108/// Stop receiving chat updates for a channel
3109async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3110    let channel_id = ChannelId::from_proto(request.channel_id);
3111    session
3112        .db()
3113        .await
3114        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3115        .await?;
3116    Ok(())
3117}
3118
3119/// Retrieve the chat history for a channel
3120async fn get_channel_messages(
3121    request: proto::GetChannelMessages,
3122    response: Response<proto::GetChannelMessages>,
3123    session: Session,
3124) -> Result<()> {
3125    let channel_id = ChannelId::from_proto(request.channel_id);
3126    let messages = session
3127        .db()
3128        .await
3129        .get_channel_messages(
3130            channel_id,
3131            session.user_id,
3132            MESSAGE_COUNT_PER_PAGE,
3133            Some(MessageId::from_proto(request.before_message_id)),
3134        )
3135        .await?;
3136    response.send(proto::GetChannelMessagesResponse {
3137        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3138        messages,
3139    })?;
3140    Ok(())
3141}
3142
3143/// Retrieve specific chat messages
3144async fn get_channel_messages_by_id(
3145    request: proto::GetChannelMessagesById,
3146    response: Response<proto::GetChannelMessagesById>,
3147    session: Session,
3148) -> Result<()> {
3149    let message_ids = request
3150        .message_ids
3151        .iter()
3152        .map(|id| MessageId::from_proto(*id))
3153        .collect::<Vec<_>>();
3154    let messages = session
3155        .db()
3156        .await
3157        .get_channel_messages_by_id(session.user_id, &message_ids)
3158        .await?;
3159    response.send(proto::GetChannelMessagesResponse {
3160        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3161        messages,
3162    })?;
3163    Ok(())
3164}
3165
3166/// Retrieve the current users notifications
3167async fn get_notifications(
3168    request: proto::GetNotifications,
3169    response: Response<proto::GetNotifications>,
3170    session: Session,
3171) -> Result<()> {
3172    let notifications = session
3173        .db()
3174        .await
3175        .get_notifications(
3176            session.user_id,
3177            NOTIFICATION_COUNT_PER_PAGE,
3178            request
3179                .before_id
3180                .map(|id| db::NotificationId::from_proto(id)),
3181        )
3182        .await?;
3183    response.send(proto::GetNotificationsResponse {
3184        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3185        notifications,
3186    })?;
3187    Ok(())
3188}
3189
3190/// Mark notifications as read
3191async fn mark_notification_as_read(
3192    request: proto::MarkNotificationRead,
3193    response: Response<proto::MarkNotificationRead>,
3194    session: Session,
3195) -> Result<()> {
3196    let database = &session.db().await;
3197    let notifications = database
3198        .mark_notification_as_read_by_id(
3199            session.user_id,
3200            NotificationId::from_proto(request.notification_id),
3201        )
3202        .await?;
3203    send_notifications(
3204        &*session.connection_pool().await,
3205        &session.peer,
3206        notifications,
3207    );
3208    response.send(proto::Ack {})?;
3209    Ok(())
3210}
3211
3212/// Get the current users information
3213async fn get_private_user_info(
3214    _request: proto::GetPrivateUserInfo,
3215    response: Response<proto::GetPrivateUserInfo>,
3216    session: Session,
3217) -> Result<()> {
3218    let db = session.db().await;
3219
3220    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3221    let user = db
3222        .get_user_by_id(session.user_id)
3223        .await?
3224        .ok_or_else(|| anyhow!("user not found"))?;
3225    let flags = db.get_user_flags(session.user_id).await?;
3226
3227    response.send(proto::GetPrivateUserInfoResponse {
3228        metrics_id,
3229        staff: user.admin,
3230        flags,
3231    })?;
3232    Ok(())
3233}
3234
3235fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3236    match message {
3237        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3238        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3239        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3240        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3241        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3242            code: frame.code.into(),
3243            reason: frame.reason,
3244        })),
3245    }
3246}
3247
3248fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3249    match message {
3250        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3251        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3252        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3253        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3254        AxumMessage::Close(frame) => {
3255            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3256                code: frame.code.into(),
3257                reason: frame.reason,
3258            }))
3259        }
3260    }
3261}
3262
3263fn notify_membership_updated(
3264    connection_pool: &ConnectionPool,
3265    result: MembershipUpdated,
3266    user_id: UserId,
3267    peer: &Peer,
3268) {
3269    let mut update = build_channels_update(result.new_channels, vec![]);
3270    update.delete_channels = result
3271        .removed_channels
3272        .into_iter()
3273        .map(|id| id.to_proto())
3274        .collect();
3275    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3276
3277    for connection_id in connection_pool.user_connection_ids(user_id) {
3278        peer.send(connection_id, update.clone()).trace_err();
3279    }
3280}
3281
3282fn build_channels_update(
3283    channels: ChannelsForUser,
3284    channel_invites: Vec<db::Channel>,
3285) -> proto::UpdateChannels {
3286    let mut update = proto::UpdateChannels::default();
3287
3288    for channel in channels.channels {
3289        update.channels.push(channel.to_proto());
3290    }
3291
3292    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3293    update.unseen_channel_messages = channels.channel_messages;
3294
3295    for (channel_id, participants) in channels.channel_participants {
3296        update
3297            .channel_participants
3298            .push(proto::ChannelParticipants {
3299                channel_id: channel_id.to_proto(),
3300                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3301            });
3302    }
3303
3304    for channel in channel_invites {
3305        update.channel_invitations.push(channel.to_proto());
3306    }
3307
3308    update
3309}
3310
3311fn build_initial_contacts_update(
3312    contacts: Vec<db::Contact>,
3313    pool: &ConnectionPool,
3314) -> proto::UpdateContacts {
3315    let mut update = proto::UpdateContacts::default();
3316
3317    for contact in contacts {
3318        match contact {
3319            db::Contact::Accepted { user_id, busy } => {
3320                update.contacts.push(contact_for_user(user_id, busy, &pool));
3321            }
3322            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3323            db::Contact::Incoming { user_id } => {
3324                update
3325                    .incoming_requests
3326                    .push(proto::IncomingContactRequest {
3327                        requester_id: user_id.to_proto(),
3328                    })
3329            }
3330        }
3331    }
3332
3333    update
3334}
3335
3336fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3337    proto::Contact {
3338        user_id: user_id.to_proto(),
3339        online: pool.is_user_online(user_id),
3340        busy,
3341    }
3342}
3343
3344fn room_updated(room: &proto::Room, peer: &Peer) {
3345    broadcast(
3346        None,
3347        room.participants
3348            .iter()
3349            .filter_map(|participant| Some(participant.peer_id?.into())),
3350        |peer_id| {
3351            peer.send(
3352                peer_id.into(),
3353                proto::RoomUpdated {
3354                    room: Some(room.clone()),
3355                },
3356            )
3357        },
3358    );
3359}
3360
3361fn channel_updated(
3362    channel_id: ChannelId,
3363    room: &proto::Room,
3364    channel_members: &[UserId],
3365    peer: &Peer,
3366    pool: &ConnectionPool,
3367) {
3368    let participants = room
3369        .participants
3370        .iter()
3371        .map(|p| p.user_id)
3372        .collect::<Vec<_>>();
3373
3374    broadcast(
3375        None,
3376        channel_members
3377            .iter()
3378            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3379        |peer_id| {
3380            peer.send(
3381                peer_id.into(),
3382                proto::UpdateChannels {
3383                    channel_participants: vec![proto::ChannelParticipants {
3384                        channel_id: channel_id.to_proto(),
3385                        participant_user_ids: participants.clone(),
3386                    }],
3387                    ..Default::default()
3388                },
3389            )
3390        },
3391    );
3392}
3393
3394async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3395    let db = session.db().await;
3396
3397    let contacts = db.get_contacts(user_id).await?;
3398    let busy = db.is_user_busy(user_id).await?;
3399
3400    let pool = session.connection_pool().await;
3401    let updated_contact = contact_for_user(user_id, busy, &pool);
3402    for contact in contacts {
3403        if let db::Contact::Accepted {
3404            user_id: contact_user_id,
3405            ..
3406        } = contact
3407        {
3408            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3409                session
3410                    .peer
3411                    .send(
3412                        contact_conn_id,
3413                        proto::UpdateContacts {
3414                            contacts: vec![updated_contact.clone()],
3415                            remove_contacts: Default::default(),
3416                            incoming_requests: Default::default(),
3417                            remove_incoming_requests: Default::default(),
3418                            outgoing_requests: Default::default(),
3419                            remove_outgoing_requests: Default::default(),
3420                        },
3421                    )
3422                    .trace_err();
3423            }
3424        }
3425    }
3426    Ok(())
3427}
3428
3429async fn leave_room_for_session(session: &Session) -> Result<()> {
3430    let mut contacts_to_update = HashSet::default();
3431
3432    let room_id;
3433    let canceled_calls_to_user_ids;
3434    let live_kit_room;
3435    let delete_live_kit_room;
3436    let room;
3437    let channel_members;
3438    let channel_id;
3439
3440    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3441        contacts_to_update.insert(session.user_id);
3442
3443        for project in left_room.left_projects.values() {
3444            project_left(project, session);
3445        }
3446
3447        room_id = RoomId::from_proto(left_room.room.id);
3448        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3449        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3450        delete_live_kit_room = left_room.deleted;
3451        room = mem::take(&mut left_room.room);
3452        channel_members = mem::take(&mut left_room.channel_members);
3453        channel_id = left_room.channel_id;
3454
3455        room_updated(&room, &session.peer);
3456    } else {
3457        return Ok(());
3458    }
3459
3460    if let Some(channel_id) = channel_id {
3461        channel_updated(
3462            channel_id,
3463            &room,
3464            &channel_members,
3465            &session.peer,
3466            &*session.connection_pool().await,
3467        );
3468    }
3469
3470    {
3471        let pool = session.connection_pool().await;
3472        for canceled_user_id in canceled_calls_to_user_ids {
3473            for connection_id in pool.user_connection_ids(canceled_user_id) {
3474                session
3475                    .peer
3476                    .send(
3477                        connection_id,
3478                        proto::CallCanceled {
3479                            room_id: room_id.to_proto(),
3480                        },
3481                    )
3482                    .trace_err();
3483            }
3484            contacts_to_update.insert(canceled_user_id);
3485        }
3486    }
3487
3488    for contact_user_id in contacts_to_update {
3489        update_user_contacts(contact_user_id, &session).await?;
3490    }
3491
3492    if let Some(live_kit) = session.live_kit_client.as_ref() {
3493        live_kit
3494            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3495            .await
3496            .trace_err();
3497
3498        if delete_live_kit_room {
3499            live_kit.delete_room(live_kit_room).await.trace_err();
3500        }
3501    }
3502
3503    Ok(())
3504}
3505
3506async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3507    let left_channel_buffers = session
3508        .db()
3509        .await
3510        .leave_channel_buffers(session.connection_id)
3511        .await?;
3512
3513    for left_buffer in left_channel_buffers {
3514        channel_buffer_updated(
3515            session.connection_id,
3516            left_buffer.connections,
3517            &proto::UpdateChannelBufferCollaborators {
3518                channel_id: left_buffer.channel_id.to_proto(),
3519                collaborators: left_buffer.collaborators,
3520            },
3521            &session.peer,
3522        );
3523    }
3524
3525    Ok(())
3526}
3527
3528fn project_left(project: &db::LeftProject, session: &Session) {
3529    for connection_id in &project.connection_ids {
3530        if project.host_user_id == session.user_id {
3531            session
3532                .peer
3533                .send(
3534                    *connection_id,
3535                    proto::UnshareProject {
3536                        project_id: project.id.to_proto(),
3537                    },
3538                )
3539                .trace_err();
3540        } else {
3541            session
3542                .peer
3543                .send(
3544                    *connection_id,
3545                    proto::RemoveProjectCollaborator {
3546                        project_id: project.id.to_proto(),
3547                        peer_id: Some(session.connection_id.into()),
3548                    },
3549                )
3550                .trace_err();
3551        }
3552    }
3553}
3554
3555pub trait ResultExt {
3556    type Ok;
3557
3558    fn trace_err(self) -> Option<Self::Ok>;
3559}
3560
3561impl<T, E> ResultExt for Result<T, E>
3562where
3563    E: std::fmt::Debug,
3564{
3565    type Ok = T;
3566
3567    fn trace_err(self) -> Option<T> {
3568        match self {
3569            Ok(value) => Some(value),
3570            Err(error) => {
3571                tracing::error!("{:?}", error);
3572                None
3573            }
3574        }
3575    }
3576}