rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, ProjectId,
   8        RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, User,
   9        UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, Result,
  13};
  14use anyhow::anyhow;
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use futures::{
  34    channel::oneshot,
  35    future::{self, BoxFuture},
  36    stream::FuturesUnordered,
  37    FutureExt, SinkExt, StreamExt, TryStreamExt,
  38};
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc, OnceLock,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67use util::SemanticVersion;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70
  71// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  73
  74const MESSAGE_COUNT_PER_PAGE: usize = 100;
  75const MAX_MESSAGE_LEN: usize = 1024;
  76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  77
  78type MessageHandler =
  79    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  80
  81struct Response<R> {
  82    peer: Arc<Peer>,
  83    receipt: Receipt<R>,
  84    responded: Arc<AtomicBool>,
  85}
  86
  87impl<R: RequestMessage> Response<R> {
  88    fn send(self, payload: R::Response) -> Result<()> {
  89        self.responded.store(true, SeqCst);
  90        self.peer.respond(self.receipt, payload)?;
  91        Ok(())
  92    }
  93}
  94
  95#[derive(Clone)]
  96struct Session {
  97    user_id: UserId,
  98    connection_id: ConnectionId,
  99    db: Arc<tokio::sync::Mutex<DbHandle>>,
 100    peer: Arc<Peer>,
 101    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 102    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 103    _executor: Executor,
 104}
 105
 106impl Session {
 107    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 108        #[cfg(test)]
 109        tokio::task::yield_now().await;
 110        let guard = self.db.lock().await;
 111        #[cfg(test)]
 112        tokio::task::yield_now().await;
 113        guard
 114    }
 115
 116    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        let guard = self.connection_pool.lock();
 120        ConnectionPoolGuard {
 121            guard,
 122            _not_send: PhantomData,
 123        }
 124    }
 125}
 126
 127impl fmt::Debug for Session {
 128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 129        f.debug_struct("Session")
 130            .field("user_id", &self.user_id)
 131            .field("connection_id", &self.connection_id)
 132            .finish()
 133    }
 134}
 135
 136struct DbHandle(Arc<Database>);
 137
 138impl Deref for DbHandle {
 139    type Target = Database;
 140
 141    fn deref(&self) -> &Self::Target {
 142        self.0.as_ref()
 143    }
 144}
 145
 146pub struct Server {
 147    id: parking_lot::Mutex<ServerId>,
 148    peer: Arc<Peer>,
 149    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 150    app_state: Arc<AppState>,
 151    executor: Executor,
 152    handlers: HashMap<TypeId, MessageHandler>,
 153    teardown: watch::Sender<bool>,
 154}
 155
 156pub(crate) struct ConnectionPoolGuard<'a> {
 157    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 158    _not_send: PhantomData<Rc<()>>,
 159}
 160
 161#[derive(Serialize)]
 162pub struct ServerSnapshot<'a> {
 163    peer: &'a Peer,
 164    #[serde(serialize_with = "serialize_deref")]
 165    connection_pool: ConnectionPoolGuard<'a>,
 166}
 167
 168pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 169where
 170    S: Serializer,
 171    T: Deref<Target = U>,
 172    U: Serialize,
 173{
 174    Serialize::serialize(value.deref(), serializer)
 175}
 176
 177impl Server {
 178    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 179        let mut server = Self {
 180            id: parking_lot::Mutex::new(id),
 181            peer: Peer::new(id.0 as u32),
 182            app_state,
 183            executor,
 184            connection_pool: Default::default(),
 185            handlers: Default::default(),
 186            teardown: watch::channel(false).0,
 187        };
 188
 189        server
 190            .add_request_handler(ping)
 191            .add_request_handler(create_room)
 192            .add_request_handler(join_room)
 193            .add_request_handler(rejoin_room)
 194            .add_request_handler(leave_room)
 195            .add_request_handler(set_room_participant_role)
 196            .add_request_handler(call)
 197            .add_request_handler(cancel_call)
 198            .add_message_handler(decline_call)
 199            .add_request_handler(update_participant_location)
 200            .add_request_handler(share_project)
 201            .add_message_handler(unshare_project)
 202            .add_request_handler(join_project)
 203            .add_request_handler(join_hosted_project)
 204            .add_message_handler(leave_project)
 205            .add_request_handler(update_project)
 206            .add_request_handler(update_worktree)
 207            .add_message_handler(start_language_server)
 208            .add_message_handler(update_language_server)
 209            .add_message_handler(update_diagnostic_summary)
 210            .add_message_handler(update_worktree_settings)
 211            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 212            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 213            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 214            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 215            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 216            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 217            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 218            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 219            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 220            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 221            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 222            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 223            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 224            .add_request_handler(
 225                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 226            )
 227            .add_request_handler(
 228                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 229            )
 230            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 231            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 232            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 233            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 234            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 235            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 236            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 237            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 238            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 239            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 240            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 241            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 242            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 243            .add_message_handler(create_buffer_for_peer)
 244            .add_request_handler(update_buffer)
 245            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 246            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 247            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 248            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 249            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 250            .add_request_handler(get_users)
 251            .add_request_handler(fuzzy_search_users)
 252            .add_request_handler(request_contact)
 253            .add_request_handler(remove_contact)
 254            .add_request_handler(respond_to_contact_request)
 255            .add_request_handler(create_channel)
 256            .add_request_handler(delete_channel)
 257            .add_request_handler(invite_channel_member)
 258            .add_request_handler(remove_channel_member)
 259            .add_request_handler(set_channel_member_role)
 260            .add_request_handler(set_channel_visibility)
 261            .add_request_handler(rename_channel)
 262            .add_request_handler(join_channel_buffer)
 263            .add_request_handler(leave_channel_buffer)
 264            .add_message_handler(update_channel_buffer)
 265            .add_request_handler(rejoin_channel_buffers)
 266            .add_request_handler(get_channel_members)
 267            .add_request_handler(respond_to_channel_invite)
 268            .add_request_handler(join_channel)
 269            .add_request_handler(join_channel_chat)
 270            .add_message_handler(leave_channel_chat)
 271            .add_request_handler(send_channel_message)
 272            .add_request_handler(remove_channel_message)
 273            .add_request_handler(get_channel_messages)
 274            .add_request_handler(get_channel_messages_by_id)
 275            .add_request_handler(get_notifications)
 276            .add_request_handler(mark_notification_as_read)
 277            .add_request_handler(move_channel)
 278            .add_request_handler(follow)
 279            .add_message_handler(unfollow)
 280            .add_message_handler(update_followers)
 281            .add_request_handler(get_private_user_info)
 282            .add_message_handler(acknowledge_channel_message)
 283            .add_message_handler(acknowledge_buffer_version);
 284
 285        Arc::new(server)
 286    }
 287
 288    pub async fn start(&self) -> Result<()> {
 289        let server_id = *self.id.lock();
 290        let app_state = self.app_state.clone();
 291        let peer = self.peer.clone();
 292        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 293        let pool = self.connection_pool.clone();
 294        let live_kit_client = self.app_state.live_kit_client.clone();
 295
 296        let span = info_span!("start server");
 297        self.executor.spawn_detached(
 298            async move {
 299                tracing::info!("waiting for cleanup timeout");
 300                timeout.await;
 301                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 302                if let Some((room_ids, channel_ids)) = app_state
 303                    .db
 304                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 305                    .await
 306                    .trace_err()
 307                {
 308                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 309                    tracing::info!(
 310                        stale_channel_buffer_count = channel_ids.len(),
 311                        "retrieved stale channel buffers"
 312                    );
 313
 314                    for channel_id in channel_ids {
 315                        if let Some(refreshed_channel_buffer) = app_state
 316                            .db
 317                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 318                            .await
 319                            .trace_err()
 320                        {
 321                            for connection_id in refreshed_channel_buffer.connection_ids {
 322                                peer.send(
 323                                    connection_id,
 324                                    proto::UpdateChannelBufferCollaborators {
 325                                        channel_id: channel_id.to_proto(),
 326                                        collaborators: refreshed_channel_buffer
 327                                            .collaborators
 328                                            .clone(),
 329                                    },
 330                                )
 331                                .trace_err();
 332                            }
 333                        }
 334                    }
 335
 336                    for room_id in room_ids {
 337                        let mut contacts_to_update = HashSet::default();
 338                        let mut canceled_calls_to_user_ids = Vec::new();
 339                        let mut live_kit_room = String::new();
 340                        let mut delete_live_kit_room = false;
 341
 342                        if let Some(mut refreshed_room) = app_state
 343                            .db
 344                            .clear_stale_room_participants(room_id, server_id)
 345                            .await
 346                            .trace_err()
 347                        {
 348                            tracing::info!(
 349                                room_id = room_id.0,
 350                                new_participant_count = refreshed_room.room.participants.len(),
 351                                "refreshed room"
 352                            );
 353                            room_updated(&refreshed_room.room, &peer);
 354                            if let Some(channel_id) = refreshed_room.channel_id {
 355                                channel_updated(
 356                                    channel_id,
 357                                    &refreshed_room.room,
 358                                    &refreshed_room.channel_members,
 359                                    &peer,
 360                                    &pool.lock(),
 361                                );
 362                            }
 363                            contacts_to_update
 364                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 365                            contacts_to_update
 366                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 367                            canceled_calls_to_user_ids =
 368                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 369                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 370                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 371                        }
 372
 373                        {
 374                            let pool = pool.lock();
 375                            for canceled_user_id in canceled_calls_to_user_ids {
 376                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 377                                    peer.send(
 378                                        connection_id,
 379                                        proto::CallCanceled {
 380                                            room_id: room_id.to_proto(),
 381                                        },
 382                                    )
 383                                    .trace_err();
 384                                }
 385                            }
 386                        }
 387
 388                        for user_id in contacts_to_update {
 389                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 390                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 391                            if let Some((busy, contacts)) = busy.zip(contacts) {
 392                                let pool = pool.lock();
 393                                let updated_contact = contact_for_user(user_id, busy, &pool);
 394                                for contact in contacts {
 395                                    if let db::Contact::Accepted {
 396                                        user_id: contact_user_id,
 397                                        ..
 398                                    } = contact
 399                                    {
 400                                        for contact_conn_id in
 401                                            pool.user_connection_ids(contact_user_id)
 402                                        {
 403                                            peer.send(
 404                                                contact_conn_id,
 405                                                proto::UpdateContacts {
 406                                                    contacts: vec![updated_contact.clone()],
 407                                                    remove_contacts: Default::default(),
 408                                                    incoming_requests: Default::default(),
 409                                                    remove_incoming_requests: Default::default(),
 410                                                    outgoing_requests: Default::default(),
 411                                                    remove_outgoing_requests: Default::default(),
 412                                                },
 413                                            )
 414                                            .trace_err();
 415                                        }
 416                                    }
 417                                }
 418                            }
 419                        }
 420
 421                        if let Some(live_kit) = live_kit_client.as_ref() {
 422                            if delete_live_kit_room {
 423                                live_kit.delete_room(live_kit_room).await.trace_err();
 424                            }
 425                        }
 426                    }
 427                }
 428
 429                app_state
 430                    .db
 431                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 432                    .await
 433                    .trace_err();
 434            }
 435            .instrument(span),
 436        );
 437        Ok(())
 438    }
 439
 440    pub fn teardown(&self) {
 441        self.peer.teardown();
 442        self.connection_pool.lock().reset();
 443        let _ = self.teardown.send(true);
 444    }
 445
 446    #[cfg(test)]
 447    pub fn reset(&self, id: ServerId) {
 448        self.teardown();
 449        *self.id.lock() = id;
 450        self.peer.reset(id.0 as u32);
 451        let _ = self.teardown.send(false);
 452    }
 453
 454    #[cfg(test)]
 455    pub fn id(&self) -> ServerId {
 456        *self.id.lock()
 457    }
 458
 459    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 460    where
 461        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 462        Fut: 'static + Send + Future<Output = Result<()>>,
 463        M: EnvelopedMessage,
 464    {
 465        let prev_handler = self.handlers.insert(
 466            TypeId::of::<M>(),
 467            Box::new(move |envelope, session| {
 468                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 469                let received_at = envelope.received_at;
 470                let span = info_span!(
 471                    "handle message",
 472                    payload_type = envelope.payload_type_name()
 473                );
 474                span.in_scope(|| {
 475                    tracing::info!(
 476                        payload_type = envelope.payload_type_name(),
 477                        "message received"
 478                    );
 479                });
 480                let start_time = Instant::now();
 481                let future = (handler)(*envelope, session);
 482                async move {
 483                    let result = future.await;
 484                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 485                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 486                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 487                    match result {
 488                        Err(error) => {
 489                            tracing::error!(%error, ?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "error handling message")
 490                        }
 491                        Ok(()) => tracing::info!(?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "finished handling message"),
 492                    }
 493                }
 494                .instrument(span)
 495                .boxed()
 496            }),
 497        );
 498        if prev_handler.is_some() {
 499            panic!("registered a handler for the same message twice");
 500        }
 501        self
 502    }
 503
 504    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 505    where
 506        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 507        Fut: 'static + Send + Future<Output = Result<()>>,
 508        M: EnvelopedMessage,
 509    {
 510        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 511        self
 512    }
 513
 514    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 515    where
 516        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 517        Fut: Send + Future<Output = Result<()>>,
 518        M: RequestMessage,
 519    {
 520        let handler = Arc::new(handler);
 521        self.add_handler(move |envelope, session| {
 522            let receipt = envelope.receipt();
 523            let handler = handler.clone();
 524            async move {
 525                let peer = session.peer.clone();
 526                let responded = Arc::new(AtomicBool::default());
 527                let response = Response {
 528                    peer: peer.clone(),
 529                    responded: responded.clone(),
 530                    receipt,
 531                };
 532                match (handler)(envelope.payload, response, session).await {
 533                    Ok(()) => {
 534                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 535                            Ok(())
 536                        } else {
 537                            Err(anyhow!("handler did not send a response"))?
 538                        }
 539                    }
 540                    Err(error) => {
 541                        let proto_err = match &error {
 542                            Error::Internal(err) => err.to_proto(),
 543                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 544                        };
 545                        peer.respond_with_error(receipt, proto_err)?;
 546                        Err(error)
 547                    }
 548                }
 549            }
 550        })
 551    }
 552
 553    #[allow(clippy::too_many_arguments)]
 554    pub fn handle_connection(
 555        self: &Arc<Self>,
 556        connection: Connection,
 557        address: String,
 558        user: User,
 559        zed_version: ZedVersion,
 560        impersonator: Option<User>,
 561        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 562        executor: Executor,
 563    ) -> impl Future<Output = Result<()>> {
 564        let this = self.clone();
 565        let user_id = user.id;
 566        let login = user.github_login;
 567        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 568        if let Some(impersonator) = impersonator {
 569            span.record("impersonator", &impersonator.github_login);
 570        }
 571        let mut teardown = self.teardown.subscribe();
 572        async move {
 573            if *teardown.borrow() {
 574                return Err(anyhow!("server is tearing down"))?;
 575            }
 576            let (connection_id, handle_io, mut incoming_rx) = this
 577                .peer
 578                .add_connection(connection, {
 579                    let executor = executor.clone();
 580                    move |duration| executor.sleep(duration)
 581                });
 582
 583            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 584            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 585            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 586
 587            if let Some(send_connection_id) = send_connection_id.take() {
 588                let _ = send_connection_id.send(connection_id);
 589            }
 590
 591            if !user.connected_once {
 592                this.peer.send(connection_id, proto::ShowContacts {})?;
 593                this.app_state.db.set_user_connected_once(user_id, true).await?;
 594            }
 595
 596            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 597                this.app_state.db.get_contacts(user_id),
 598                this.app_state.db.get_channels_for_user(user_id),
 599                this.app_state.db.get_channel_invites_for_user(user_id),
 600            ).await?;
 601
 602            {
 603                let mut pool = this.connection_pool.lock();
 604                pool.add_connection(connection_id, user_id, user.admin, zed_version);
 605                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 606                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 607                this.peer.send(connection_id, build_channels_update(
 608                    channels_for_user,
 609                    channel_invites
 610                ))?;
 611            }
 612
 613            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 614                this.peer.send(connection_id, incoming_call)?;
 615            }
 616
 617            let session = Session {
 618                user_id,
 619                connection_id,
 620                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 621                peer: this.peer.clone(),
 622                connection_pool: this.connection_pool.clone(),
 623                live_kit_client: this.app_state.live_kit_client.clone(),
 624                _executor: executor.clone()
 625            };
 626            update_user_contacts(user_id, &session).await?;
 627
 628            let handle_io = handle_io.fuse();
 629            futures::pin_mut!(handle_io);
 630
 631            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 632            // This prevents deadlocks when e.g., client A performs a request to client B and
 633            // client B performs a request to client A. If both clients stop processing further
 634            // messages until their respective request completes, they won't have a chance to
 635            // respond to the other client's request and cause a deadlock.
 636            //
 637            // This arrangement ensures we will attempt to process earlier messages first, but fall
 638            // back to processing messages arrived later in the spirit of making progress.
 639            let mut foreground_message_handlers = FuturesUnordered::new();
 640            let concurrent_handlers = Arc::new(Semaphore::new(256));
 641            loop {
 642                let next_message = async {
 643                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 644                    let message = incoming_rx.next().await;
 645                    (permit, message)
 646                }.fuse();
 647                futures::pin_mut!(next_message);
 648                futures::select_biased! {
 649                    _ = teardown.changed().fuse() => return Ok(()),
 650                    result = handle_io => {
 651                        if let Err(error) = result {
 652                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 653                        }
 654                        break;
 655                    }
 656                    _ = foreground_message_handlers.next() => {}
 657                    next_message = next_message => {
 658                        let (permit, message) = next_message;
 659                        if let Some(message) = message {
 660                            let type_name = message.payload_type_name();
 661                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 662                            let span_enter = span.enter();
 663                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 664                                let is_background = message.is_background();
 665                                let handle_message = (handler)(message, session.clone());
 666                                drop(span_enter);
 667
 668                                let handle_message = async move {
 669                                    handle_message.await;
 670                                    drop(permit);
 671                                }.instrument(span);
 672                                if is_background {
 673                                    executor.spawn_detached(handle_message);
 674                                } else {
 675                                    foreground_message_handlers.push(handle_message);
 676                                }
 677                            } else {
 678                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 679                            }
 680                        } else {
 681                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 682                            break;
 683                        }
 684                    }
 685                }
 686            }
 687
 688            drop(foreground_message_handlers);
 689            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 690            if let Err(error) = connection_lost(session, teardown, executor).await {
 691                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 692            }
 693
 694            Ok(())
 695        }.instrument(span)
 696    }
 697
 698    pub async fn invite_code_redeemed(
 699        self: &Arc<Self>,
 700        inviter_id: UserId,
 701        invitee_id: UserId,
 702    ) -> Result<()> {
 703        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 704            if let Some(code) = &user.invite_code {
 705                let pool = self.connection_pool.lock();
 706                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 707                for connection_id in pool.user_connection_ids(inviter_id) {
 708                    self.peer.send(
 709                        connection_id,
 710                        proto::UpdateContacts {
 711                            contacts: vec![invitee_contact.clone()],
 712                            ..Default::default()
 713                        },
 714                    )?;
 715                    self.peer.send(
 716                        connection_id,
 717                        proto::UpdateInviteInfo {
 718                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 719                            count: user.invite_count as u32,
 720                        },
 721                    )?;
 722                }
 723            }
 724        }
 725        Ok(())
 726    }
 727
 728    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 729        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 730            if let Some(invite_code) = &user.invite_code {
 731                let pool = self.connection_pool.lock();
 732                for connection_id in pool.user_connection_ids(user_id) {
 733                    self.peer.send(
 734                        connection_id,
 735                        proto::UpdateInviteInfo {
 736                            url: format!(
 737                                "{}{}",
 738                                self.app_state.config.invite_link_prefix, invite_code
 739                            ),
 740                            count: user.invite_count as u32,
 741                        },
 742                    )?;
 743                }
 744            }
 745        }
 746        Ok(())
 747    }
 748
 749    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 750        ServerSnapshot {
 751            connection_pool: ConnectionPoolGuard {
 752                guard: self.connection_pool.lock(),
 753                _not_send: PhantomData,
 754            },
 755            peer: &self.peer,
 756        }
 757    }
 758}
 759
 760impl<'a> Deref for ConnectionPoolGuard<'a> {
 761    type Target = ConnectionPool;
 762
 763    fn deref(&self) -> &Self::Target {
 764        &self.guard
 765    }
 766}
 767
 768impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 769    fn deref_mut(&mut self) -> &mut Self::Target {
 770        &mut self.guard
 771    }
 772}
 773
 774impl<'a> Drop for ConnectionPoolGuard<'a> {
 775    fn drop(&mut self) {
 776        #[cfg(test)]
 777        self.check_invariants();
 778    }
 779}
 780
 781fn broadcast<F>(
 782    sender_id: Option<ConnectionId>,
 783    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 784    mut f: F,
 785) where
 786    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 787{
 788    for receiver_id in receiver_ids {
 789        if Some(receiver_id) != sender_id {
 790            if let Err(error) = f(receiver_id) {
 791                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 792            }
 793        }
 794    }
 795}
 796
 797pub struct ProtocolVersion(u32);
 798
 799impl Header for ProtocolVersion {
 800    fn name() -> &'static HeaderName {
 801        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 802        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 803    }
 804
 805    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 806    where
 807        Self: Sized,
 808        I: Iterator<Item = &'i axum::http::HeaderValue>,
 809    {
 810        let version = values
 811            .next()
 812            .ok_or_else(axum::headers::Error::invalid)?
 813            .to_str()
 814            .map_err(|_| axum::headers::Error::invalid())?
 815            .parse()
 816            .map_err(|_| axum::headers::Error::invalid())?;
 817        Ok(Self(version))
 818    }
 819
 820    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 821        values.extend([self.0.to_string().parse().unwrap()]);
 822    }
 823}
 824
 825pub struct AppVersionHeader(SemanticVersion);
 826impl Header for AppVersionHeader {
 827    fn name() -> &'static HeaderName {
 828        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 829        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 830    }
 831
 832    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 833    where
 834        Self: Sized,
 835        I: Iterator<Item = &'i axum::http::HeaderValue>,
 836    {
 837        let version = values
 838            .next()
 839            .ok_or_else(axum::headers::Error::invalid)?
 840            .to_str()
 841            .map_err(|_| axum::headers::Error::invalid())?
 842            .parse()
 843            .map_err(|_| axum::headers::Error::invalid())?;
 844        Ok(Self(version))
 845    }
 846
 847    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 848        values.extend([self.0.to_string().parse().unwrap()]);
 849    }
 850}
 851
 852pub fn routes(server: Arc<Server>) -> Router<(), Body> {
 853    Router::new()
 854        .route("/rpc", get(handle_websocket_request))
 855        .layer(
 856            ServiceBuilder::new()
 857                .layer(Extension(server.app_state.clone()))
 858                .layer(middleware::from_fn(auth::validate_header)),
 859        )
 860        .route("/metrics", get(handle_metrics))
 861        .layer(Extension(server))
 862}
 863
 864pub async fn handle_websocket_request(
 865    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 866    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 867    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 868    Extension(server): Extension<Arc<Server>>,
 869    Extension(user): Extension<User>,
 870    Extension(impersonator): Extension<Impersonator>,
 871    ws: WebSocketUpgrade,
 872) -> axum::response::Response {
 873    if protocol_version != rpc::PROTOCOL_VERSION {
 874        return (
 875            StatusCode::UPGRADE_REQUIRED,
 876            "client must be upgraded".to_string(),
 877        )
 878            .into_response();
 879    }
 880
 881    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 882        return (
 883            StatusCode::UPGRADE_REQUIRED,
 884            "no version header found".to_string(),
 885        )
 886            .into_response();
 887    };
 888
 889    if !version.is_supported() {
 890        return (
 891            StatusCode::UPGRADE_REQUIRED,
 892            "client must be upgraded".to_string(),
 893        )
 894            .into_response();
 895    }
 896
 897    let socket_address = socket_address.to_string();
 898    ws.on_upgrade(move |socket| {
 899        use util::ResultExt;
 900        let socket = socket
 901            .map_ok(to_tungstenite_message)
 902            .err_into()
 903            .with(|message| async move { Ok(to_axum_message(message)) });
 904        let connection = Connection::new(Box::pin(socket));
 905        async move {
 906            server
 907                .handle_connection(
 908                    connection,
 909                    socket_address,
 910                    user,
 911                    version,
 912                    impersonator.0,
 913                    None,
 914                    Executor::Production,
 915                )
 916                .await
 917                .log_err();
 918        }
 919    })
 920}
 921
 922pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 923    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 924    let connections_metric = CONNECTIONS_METRIC
 925        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
 926
 927    let connections = server
 928        .connection_pool
 929        .lock()
 930        .connections()
 931        .filter(|connection| !connection.admin)
 932        .count();
 933    connections_metric.set(connections as _);
 934
 935    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 936    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
 937        register_int_gauge!(
 938            "shared_projects",
 939            "number of open projects with one or more guests"
 940        )
 941        .unwrap()
 942    });
 943
 944    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 945    shared_projects_metric.set(shared_projects as _);
 946
 947    let encoder = prometheus::TextEncoder::new();
 948    let metric_families = prometheus::gather();
 949    let encoded_metrics = encoder
 950        .encode_to_string(&metric_families)
 951        .map_err(|err| anyhow!("{}", err))?;
 952    Ok(encoded_metrics)
 953}
 954
 955#[instrument(err, skip(executor))]
 956async fn connection_lost(
 957    session: Session,
 958    mut teardown: watch::Receiver<bool>,
 959    executor: Executor,
 960) -> Result<()> {
 961    session.peer.disconnect(session.connection_id);
 962    session
 963        .connection_pool()
 964        .await
 965        .remove_connection(session.connection_id)?;
 966
 967    session
 968        .db()
 969        .await
 970        .connection_lost(session.connection_id)
 971        .await
 972        .trace_err();
 973
 974    futures::select_biased! {
 975        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 976            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 977            leave_room_for_session(&session).await.trace_err();
 978            leave_channel_buffers_for_session(&session)
 979                .await
 980                .trace_err();
 981
 982            if !session
 983                .connection_pool()
 984                .await
 985                .is_user_online(session.user_id)
 986            {
 987                let db = session.db().await;
 988                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 989                    room_updated(&room, &session.peer);
 990                }
 991            }
 992
 993            update_user_contacts(session.user_id, &session).await?;
 994        }
 995        _ = teardown.changed().fuse() => {}
 996    }
 997
 998    Ok(())
 999}
1000
1001/// Acknowledges a ping from a client, used to keep the connection alive.
1002async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1003    response.send(proto::Ack {})?;
1004    Ok(())
1005}
1006
1007/// Creates a new room for calling (outside of channels)
1008async fn create_room(
1009    _request: proto::CreateRoom,
1010    response: Response<proto::CreateRoom>,
1011    session: Session,
1012) -> Result<()> {
1013    let live_kit_room = nanoid::nanoid!(30);
1014
1015    let live_kit_connection_info = {
1016        let live_kit_room = live_kit_room.clone();
1017        let live_kit = session.live_kit_client.as_ref();
1018
1019        util::async_maybe!({
1020            let live_kit = live_kit?;
1021
1022            let token = live_kit
1023                .room_token(&live_kit_room, &session.user_id.to_string())
1024                .trace_err()?;
1025
1026            Some(proto::LiveKitConnectionInfo {
1027                server_url: live_kit.url().into(),
1028                token,
1029                can_publish: true,
1030            })
1031        })
1032    }
1033    .await;
1034
1035    let room = session
1036        .db()
1037        .await
1038        .create_room(session.user_id, session.connection_id, &live_kit_room)
1039        .await?;
1040
1041    response.send(proto::CreateRoomResponse {
1042        room: Some(room.clone()),
1043        live_kit_connection_info,
1044    })?;
1045
1046    update_user_contacts(session.user_id, &session).await?;
1047    Ok(())
1048}
1049
1050/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1051async fn join_room(
1052    request: proto::JoinRoom,
1053    response: Response<proto::JoinRoom>,
1054    session: Session,
1055) -> Result<()> {
1056    let room_id = RoomId::from_proto(request.id);
1057
1058    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1059
1060    if let Some(channel_id) = channel_id {
1061        return join_channel_internal(channel_id, Box::new(response), session).await;
1062    }
1063
1064    let joined_room = {
1065        let room = session
1066            .db()
1067            .await
1068            .join_room(room_id, session.user_id, session.connection_id)
1069            .await?;
1070        room_updated(&room.room, &session.peer);
1071        room.into_inner()
1072    };
1073
1074    for connection_id in session
1075        .connection_pool()
1076        .await
1077        .user_connection_ids(session.user_id)
1078    {
1079        session
1080            .peer
1081            .send(
1082                connection_id,
1083                proto::CallCanceled {
1084                    room_id: room_id.to_proto(),
1085                },
1086            )
1087            .trace_err();
1088    }
1089
1090    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1091        if let Some(token) = live_kit
1092            .room_token(
1093                &joined_room.room.live_kit_room,
1094                &session.user_id.to_string(),
1095            )
1096            .trace_err()
1097        {
1098            Some(proto::LiveKitConnectionInfo {
1099                server_url: live_kit.url().into(),
1100                token,
1101                can_publish: true,
1102            })
1103        } else {
1104            None
1105        }
1106    } else {
1107        None
1108    };
1109
1110    response.send(proto::JoinRoomResponse {
1111        room: Some(joined_room.room),
1112        channel_id: None,
1113        live_kit_connection_info,
1114    })?;
1115
1116    update_user_contacts(session.user_id, &session).await?;
1117    Ok(())
1118}
1119
1120/// Rejoin room is used to reconnect to a room after connection errors.
1121async fn rejoin_room(
1122    request: proto::RejoinRoom,
1123    response: Response<proto::RejoinRoom>,
1124    session: Session,
1125) -> Result<()> {
1126    let room;
1127    let channel_id;
1128    let channel_members;
1129    {
1130        let mut rejoined_room = session
1131            .db()
1132            .await
1133            .rejoin_room(request, session.user_id, session.connection_id)
1134            .await?;
1135
1136        response.send(proto::RejoinRoomResponse {
1137            room: Some(rejoined_room.room.clone()),
1138            reshared_projects: rejoined_room
1139                .reshared_projects
1140                .iter()
1141                .map(|project| proto::ResharedProject {
1142                    id: project.id.to_proto(),
1143                    collaborators: project
1144                        .collaborators
1145                        .iter()
1146                        .map(|collaborator| collaborator.to_proto())
1147                        .collect(),
1148                })
1149                .collect(),
1150            rejoined_projects: rejoined_room
1151                .rejoined_projects
1152                .iter()
1153                .map(|rejoined_project| proto::RejoinedProject {
1154                    id: rejoined_project.id.to_proto(),
1155                    worktrees: rejoined_project
1156                        .worktrees
1157                        .iter()
1158                        .map(|worktree| proto::WorktreeMetadata {
1159                            id: worktree.id,
1160                            root_name: worktree.root_name.clone(),
1161                            visible: worktree.visible,
1162                            abs_path: worktree.abs_path.clone(),
1163                        })
1164                        .collect(),
1165                    collaborators: rejoined_project
1166                        .collaborators
1167                        .iter()
1168                        .map(|collaborator| collaborator.to_proto())
1169                        .collect(),
1170                    language_servers: rejoined_project.language_servers.clone(),
1171                })
1172                .collect(),
1173        })?;
1174        room_updated(&rejoined_room.room, &session.peer);
1175
1176        for project in &rejoined_room.reshared_projects {
1177            for collaborator in &project.collaborators {
1178                session
1179                    .peer
1180                    .send(
1181                        collaborator.connection_id,
1182                        proto::UpdateProjectCollaborator {
1183                            project_id: project.id.to_proto(),
1184                            old_peer_id: Some(project.old_connection_id.into()),
1185                            new_peer_id: Some(session.connection_id.into()),
1186                        },
1187                    )
1188                    .trace_err();
1189            }
1190
1191            broadcast(
1192                Some(session.connection_id),
1193                project
1194                    .collaborators
1195                    .iter()
1196                    .map(|collaborator| collaborator.connection_id),
1197                |connection_id| {
1198                    session.peer.forward_send(
1199                        session.connection_id,
1200                        connection_id,
1201                        proto::UpdateProject {
1202                            project_id: project.id.to_proto(),
1203                            worktrees: project.worktrees.clone(),
1204                        },
1205                    )
1206                },
1207            );
1208        }
1209
1210        for project in &rejoined_room.rejoined_projects {
1211            for collaborator in &project.collaborators {
1212                session
1213                    .peer
1214                    .send(
1215                        collaborator.connection_id,
1216                        proto::UpdateProjectCollaborator {
1217                            project_id: project.id.to_proto(),
1218                            old_peer_id: Some(project.old_connection_id.into()),
1219                            new_peer_id: Some(session.connection_id.into()),
1220                        },
1221                    )
1222                    .trace_err();
1223            }
1224        }
1225
1226        for project in &mut rejoined_room.rejoined_projects {
1227            for worktree in mem::take(&mut project.worktrees) {
1228                #[cfg(any(test, feature = "test-support"))]
1229                const MAX_CHUNK_SIZE: usize = 2;
1230                #[cfg(not(any(test, feature = "test-support")))]
1231                const MAX_CHUNK_SIZE: usize = 256;
1232
1233                // Stream this worktree's entries.
1234                let message = proto::UpdateWorktree {
1235                    project_id: project.id.to_proto(),
1236                    worktree_id: worktree.id,
1237                    abs_path: worktree.abs_path.clone(),
1238                    root_name: worktree.root_name,
1239                    updated_entries: worktree.updated_entries,
1240                    removed_entries: worktree.removed_entries,
1241                    scan_id: worktree.scan_id,
1242                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1243                    updated_repositories: worktree.updated_repositories,
1244                    removed_repositories: worktree.removed_repositories,
1245                };
1246                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1247                    session.peer.send(session.connection_id, update.clone())?;
1248                }
1249
1250                // Stream this worktree's diagnostics.
1251                for summary in worktree.diagnostic_summaries {
1252                    session.peer.send(
1253                        session.connection_id,
1254                        proto::UpdateDiagnosticSummary {
1255                            project_id: project.id.to_proto(),
1256                            worktree_id: worktree.id,
1257                            summary: Some(summary),
1258                        },
1259                    )?;
1260                }
1261
1262                for settings_file in worktree.settings_files {
1263                    session.peer.send(
1264                        session.connection_id,
1265                        proto::UpdateWorktreeSettings {
1266                            project_id: project.id.to_proto(),
1267                            worktree_id: worktree.id,
1268                            path: settings_file.path,
1269                            content: Some(settings_file.content),
1270                        },
1271                    )?;
1272                }
1273            }
1274
1275            for language_server in &project.language_servers {
1276                session.peer.send(
1277                    session.connection_id,
1278                    proto::UpdateLanguageServer {
1279                        project_id: project.id.to_proto(),
1280                        language_server_id: language_server.id,
1281                        variant: Some(
1282                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1283                                proto::LspDiskBasedDiagnosticsUpdated {},
1284                            ),
1285                        ),
1286                    },
1287                )?;
1288            }
1289        }
1290
1291        let rejoined_room = rejoined_room.into_inner();
1292
1293        room = rejoined_room.room;
1294        channel_id = rejoined_room.channel_id;
1295        channel_members = rejoined_room.channel_members;
1296    }
1297
1298    if let Some(channel_id) = channel_id {
1299        channel_updated(
1300            channel_id,
1301            &room,
1302            &channel_members,
1303            &session.peer,
1304            &*session.connection_pool().await,
1305        );
1306    }
1307
1308    update_user_contacts(session.user_id, &session).await?;
1309    Ok(())
1310}
1311
1312/// leave room disconnects from the room.
1313async fn leave_room(
1314    _: proto::LeaveRoom,
1315    response: Response<proto::LeaveRoom>,
1316    session: Session,
1317) -> Result<()> {
1318    leave_room_for_session(&session).await?;
1319    response.send(proto::Ack {})?;
1320    Ok(())
1321}
1322
1323/// Updates the permissions of someone else in the room.
1324async fn set_room_participant_role(
1325    request: proto::SetRoomParticipantRole,
1326    response: Response<proto::SetRoomParticipantRole>,
1327    session: Session,
1328) -> Result<()> {
1329    let user_id = UserId::from_proto(request.user_id);
1330    let role = ChannelRole::from(request.role());
1331
1332    if role == ChannelRole::Talker {
1333        let pool = session.connection_pool().await;
1334
1335        for connection in pool.user_connections(user_id) {
1336            if !connection.zed_version.supports_talker_role() {
1337                Err(anyhow!(
1338                    "This user is on zed {} which does not support unmute",
1339                    connection.zed_version
1340                ))?;
1341            }
1342        }
1343    }
1344
1345    let (live_kit_room, can_publish) = {
1346        let room = session
1347            .db()
1348            .await
1349            .set_room_participant_role(
1350                session.user_id,
1351                RoomId::from_proto(request.room_id),
1352                user_id,
1353                role,
1354            )
1355            .await?;
1356
1357        let live_kit_room = room.live_kit_room.clone();
1358        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1359        room_updated(&room, &session.peer);
1360        (live_kit_room, can_publish)
1361    };
1362
1363    if let Some(live_kit) = session.live_kit_client.as_ref() {
1364        live_kit
1365            .update_participant(
1366                live_kit_room.clone(),
1367                request.user_id.to_string(),
1368                live_kit_server::proto::ParticipantPermission {
1369                    can_subscribe: true,
1370                    can_publish,
1371                    can_publish_data: can_publish,
1372                    hidden: false,
1373                    recorder: false,
1374                },
1375            )
1376            .await
1377            .trace_err();
1378    }
1379
1380    response.send(proto::Ack {})?;
1381    Ok(())
1382}
1383
1384/// Call someone else into the current room
1385async fn call(
1386    request: proto::Call,
1387    response: Response<proto::Call>,
1388    session: Session,
1389) -> Result<()> {
1390    let room_id = RoomId::from_proto(request.room_id);
1391    let calling_user_id = session.user_id;
1392    let calling_connection_id = session.connection_id;
1393    let called_user_id = UserId::from_proto(request.called_user_id);
1394    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1395    if !session
1396        .db()
1397        .await
1398        .has_contact(calling_user_id, called_user_id)
1399        .await?
1400    {
1401        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1402    }
1403
1404    let incoming_call = {
1405        let (room, incoming_call) = &mut *session
1406            .db()
1407            .await
1408            .call(
1409                room_id,
1410                calling_user_id,
1411                calling_connection_id,
1412                called_user_id,
1413                initial_project_id,
1414            )
1415            .await?;
1416        room_updated(&room, &session.peer);
1417        mem::take(incoming_call)
1418    };
1419    update_user_contacts(called_user_id, &session).await?;
1420
1421    let mut calls = session
1422        .connection_pool()
1423        .await
1424        .user_connection_ids(called_user_id)
1425        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1426        .collect::<FuturesUnordered<_>>();
1427
1428    while let Some(call_response) = calls.next().await {
1429        match call_response.as_ref() {
1430            Ok(_) => {
1431                response.send(proto::Ack {})?;
1432                return Ok(());
1433            }
1434            Err(_) => {
1435                call_response.trace_err();
1436            }
1437        }
1438    }
1439
1440    {
1441        let room = session
1442            .db()
1443            .await
1444            .call_failed(room_id, called_user_id)
1445            .await?;
1446        room_updated(&room, &session.peer);
1447    }
1448    update_user_contacts(called_user_id, &session).await?;
1449
1450    Err(anyhow!("failed to ring user"))?
1451}
1452
1453/// Cancel an outgoing call.
1454async fn cancel_call(
1455    request: proto::CancelCall,
1456    response: Response<proto::CancelCall>,
1457    session: Session,
1458) -> Result<()> {
1459    let called_user_id = UserId::from_proto(request.called_user_id);
1460    let room_id = RoomId::from_proto(request.room_id);
1461    {
1462        let room = session
1463            .db()
1464            .await
1465            .cancel_call(room_id, session.connection_id, called_user_id)
1466            .await?;
1467        room_updated(&room, &session.peer);
1468    }
1469
1470    for connection_id in session
1471        .connection_pool()
1472        .await
1473        .user_connection_ids(called_user_id)
1474    {
1475        session
1476            .peer
1477            .send(
1478                connection_id,
1479                proto::CallCanceled {
1480                    room_id: room_id.to_proto(),
1481                },
1482            )
1483            .trace_err();
1484    }
1485    response.send(proto::Ack {})?;
1486
1487    update_user_contacts(called_user_id, &session).await?;
1488    Ok(())
1489}
1490
1491/// Decline an incoming call.
1492async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1493    let room_id = RoomId::from_proto(message.room_id);
1494    {
1495        let room = session
1496            .db()
1497            .await
1498            .decline_call(Some(room_id), session.user_id)
1499            .await?
1500            .ok_or_else(|| anyhow!("failed to decline call"))?;
1501        room_updated(&room, &session.peer);
1502    }
1503
1504    for connection_id in session
1505        .connection_pool()
1506        .await
1507        .user_connection_ids(session.user_id)
1508    {
1509        session
1510            .peer
1511            .send(
1512                connection_id,
1513                proto::CallCanceled {
1514                    room_id: room_id.to_proto(),
1515                },
1516            )
1517            .trace_err();
1518    }
1519    update_user_contacts(session.user_id, &session).await?;
1520    Ok(())
1521}
1522
1523/// Updates other participants in the room with your current location.
1524async fn update_participant_location(
1525    request: proto::UpdateParticipantLocation,
1526    response: Response<proto::UpdateParticipantLocation>,
1527    session: Session,
1528) -> Result<()> {
1529    let room_id = RoomId::from_proto(request.room_id);
1530    let location = request
1531        .location
1532        .ok_or_else(|| anyhow!("invalid location"))?;
1533
1534    let db = session.db().await;
1535    let room = db
1536        .update_room_participant_location(room_id, session.connection_id, location)
1537        .await?;
1538
1539    room_updated(&room, &session.peer);
1540    response.send(proto::Ack {})?;
1541    Ok(())
1542}
1543
1544/// Share a project into the room.
1545async fn share_project(
1546    request: proto::ShareProject,
1547    response: Response<proto::ShareProject>,
1548    session: Session,
1549) -> Result<()> {
1550    let (project_id, room) = &*session
1551        .db()
1552        .await
1553        .share_project(
1554            RoomId::from_proto(request.room_id),
1555            session.connection_id,
1556            &request.worktrees,
1557        )
1558        .await?;
1559    response.send(proto::ShareProjectResponse {
1560        project_id: project_id.to_proto(),
1561    })?;
1562    room_updated(&room, &session.peer);
1563
1564    Ok(())
1565}
1566
1567/// Unshare a project from the room.
1568async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1569    let project_id = ProjectId::from_proto(message.project_id);
1570
1571    let (room, guest_connection_ids) = &*session
1572        .db()
1573        .await
1574        .unshare_project(project_id, session.connection_id)
1575        .await?;
1576
1577    broadcast(
1578        Some(session.connection_id),
1579        guest_connection_ids.iter().copied(),
1580        |conn_id| session.peer.send(conn_id, message.clone()),
1581    );
1582    room_updated(&room, &session.peer);
1583
1584    Ok(())
1585}
1586
1587/// Join someone elses shared project.
1588async fn join_project(
1589    request: proto::JoinProject,
1590    response: Response<proto::JoinProject>,
1591    session: Session,
1592) -> Result<()> {
1593    let project_id = ProjectId::from_proto(request.project_id);
1594
1595    tracing::info!(%project_id, "join project");
1596
1597    let (project, replica_id) = &mut *session
1598        .db()
1599        .await
1600        .join_project_in_room(project_id, session.connection_id)
1601        .await?;
1602
1603    join_project_internal(response, session, project, replica_id)
1604}
1605
1606trait JoinProjectInternalResponse {
1607    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1608}
1609impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1610    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1611        Response::<proto::JoinProject>::send(self, result)
1612    }
1613}
1614impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1615    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1616        Response::<proto::JoinHostedProject>::send(self, result)
1617    }
1618}
1619
1620fn join_project_internal(
1621    response: impl JoinProjectInternalResponse,
1622    session: Session,
1623    project: &mut Project,
1624    replica_id: &ReplicaId,
1625) -> Result<()> {
1626    let collaborators = project
1627        .collaborators
1628        .iter()
1629        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1630        .map(|collaborator| collaborator.to_proto())
1631        .collect::<Vec<_>>();
1632    let project_id = project.id;
1633    let guest_user_id = session.user_id;
1634
1635    let worktrees = project
1636        .worktrees
1637        .iter()
1638        .map(|(id, worktree)| proto::WorktreeMetadata {
1639            id: *id,
1640            root_name: worktree.root_name.clone(),
1641            visible: worktree.visible,
1642            abs_path: worktree.abs_path.clone(),
1643        })
1644        .collect::<Vec<_>>();
1645
1646    for collaborator in &collaborators {
1647        session
1648            .peer
1649            .send(
1650                collaborator.peer_id.unwrap().into(),
1651                proto::AddProjectCollaborator {
1652                    project_id: project_id.to_proto(),
1653                    collaborator: Some(proto::Collaborator {
1654                        peer_id: Some(session.connection_id.into()),
1655                        replica_id: replica_id.0 as u32,
1656                        user_id: guest_user_id.to_proto(),
1657                    }),
1658                },
1659            )
1660            .trace_err();
1661    }
1662
1663    // First, we send the metadata associated with each worktree.
1664    response.send(proto::JoinProjectResponse {
1665        project_id: project.id.0 as u64,
1666        worktrees: worktrees.clone(),
1667        replica_id: replica_id.0 as u32,
1668        collaborators: collaborators.clone(),
1669        language_servers: project.language_servers.clone(),
1670        role: project.role.into(), // todo
1671    })?;
1672
1673    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1674        #[cfg(any(test, feature = "test-support"))]
1675        const MAX_CHUNK_SIZE: usize = 2;
1676        #[cfg(not(any(test, feature = "test-support")))]
1677        const MAX_CHUNK_SIZE: usize = 256;
1678
1679        // Stream this worktree's entries.
1680        let message = proto::UpdateWorktree {
1681            project_id: project_id.to_proto(),
1682            worktree_id,
1683            abs_path: worktree.abs_path.clone(),
1684            root_name: worktree.root_name,
1685            updated_entries: worktree.entries,
1686            removed_entries: Default::default(),
1687            scan_id: worktree.scan_id,
1688            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1689            updated_repositories: worktree.repository_entries.into_values().collect(),
1690            removed_repositories: Default::default(),
1691        };
1692        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1693            session.peer.send(session.connection_id, update.clone())?;
1694        }
1695
1696        // Stream this worktree's diagnostics.
1697        for summary in worktree.diagnostic_summaries {
1698            session.peer.send(
1699                session.connection_id,
1700                proto::UpdateDiagnosticSummary {
1701                    project_id: project_id.to_proto(),
1702                    worktree_id: worktree.id,
1703                    summary: Some(summary),
1704                },
1705            )?;
1706        }
1707
1708        for settings_file in worktree.settings_files {
1709            session.peer.send(
1710                session.connection_id,
1711                proto::UpdateWorktreeSettings {
1712                    project_id: project_id.to_proto(),
1713                    worktree_id: worktree.id,
1714                    path: settings_file.path,
1715                    content: Some(settings_file.content),
1716                },
1717            )?;
1718        }
1719    }
1720
1721    for language_server in &project.language_servers {
1722        session.peer.send(
1723            session.connection_id,
1724            proto::UpdateLanguageServer {
1725                project_id: project_id.to_proto(),
1726                language_server_id: language_server.id,
1727                variant: Some(
1728                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1729                        proto::LspDiskBasedDiagnosticsUpdated {},
1730                    ),
1731                ),
1732            },
1733        )?;
1734    }
1735
1736    Ok(())
1737}
1738
1739/// Leave someone elses shared project.
1740async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1741    let sender_id = session.connection_id;
1742    let project_id = ProjectId::from_proto(request.project_id);
1743    let db = session.db().await;
1744    if db.is_hosted_project(project_id).await? {
1745        let project = db.leave_hosted_project(project_id, sender_id).await?;
1746        project_left(&project, &session);
1747        return Ok(());
1748    }
1749
1750    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1751    tracing::info!(
1752        %project_id,
1753        host_user_id = ?project.host_user_id,
1754        host_connection_id = ?project.host_connection_id,
1755        "leave project"
1756    );
1757
1758    project_left(&project, &session);
1759    room_updated(&room, &session.peer);
1760
1761    Ok(())
1762}
1763
1764async fn join_hosted_project(
1765    request: proto::JoinHostedProject,
1766    response: Response<proto::JoinHostedProject>,
1767    session: Session,
1768) -> Result<()> {
1769    let (mut project, replica_id) = session
1770        .db()
1771        .await
1772        .join_hosted_project(
1773            ProjectId(request.project_id as i32),
1774            session.user_id,
1775            session.connection_id,
1776        )
1777        .await?;
1778
1779    join_project_internal(response, session, &mut project, &replica_id)
1780}
1781
1782/// Updates other participants with changes to the project
1783async fn update_project(
1784    request: proto::UpdateProject,
1785    response: Response<proto::UpdateProject>,
1786    session: Session,
1787) -> Result<()> {
1788    let project_id = ProjectId::from_proto(request.project_id);
1789    let (room, guest_connection_ids) = &*session
1790        .db()
1791        .await
1792        .update_project(project_id, session.connection_id, &request.worktrees)
1793        .await?;
1794    broadcast(
1795        Some(session.connection_id),
1796        guest_connection_ids.iter().copied(),
1797        |connection_id| {
1798            session
1799                .peer
1800                .forward_send(session.connection_id, connection_id, request.clone())
1801        },
1802    );
1803    room_updated(&room, &session.peer);
1804    response.send(proto::Ack {})?;
1805
1806    Ok(())
1807}
1808
1809/// Updates other participants with changes to the worktree
1810async fn update_worktree(
1811    request: proto::UpdateWorktree,
1812    response: Response<proto::UpdateWorktree>,
1813    session: Session,
1814) -> Result<()> {
1815    let guest_connection_ids = session
1816        .db()
1817        .await
1818        .update_worktree(&request, session.connection_id)
1819        .await?;
1820
1821    broadcast(
1822        Some(session.connection_id),
1823        guest_connection_ids.iter().copied(),
1824        |connection_id| {
1825            session
1826                .peer
1827                .forward_send(session.connection_id, connection_id, request.clone())
1828        },
1829    );
1830    response.send(proto::Ack {})?;
1831    Ok(())
1832}
1833
1834/// Updates other participants with changes to the diagnostics
1835async fn update_diagnostic_summary(
1836    message: proto::UpdateDiagnosticSummary,
1837    session: Session,
1838) -> Result<()> {
1839    let guest_connection_ids = session
1840        .db()
1841        .await
1842        .update_diagnostic_summary(&message, session.connection_id)
1843        .await?;
1844
1845    broadcast(
1846        Some(session.connection_id),
1847        guest_connection_ids.iter().copied(),
1848        |connection_id| {
1849            session
1850                .peer
1851                .forward_send(session.connection_id, connection_id, message.clone())
1852        },
1853    );
1854
1855    Ok(())
1856}
1857
1858/// Updates other participants with changes to the worktree settings
1859async fn update_worktree_settings(
1860    message: proto::UpdateWorktreeSettings,
1861    session: Session,
1862) -> Result<()> {
1863    let guest_connection_ids = session
1864        .db()
1865        .await
1866        .update_worktree_settings(&message, session.connection_id)
1867        .await?;
1868
1869    broadcast(
1870        Some(session.connection_id),
1871        guest_connection_ids.iter().copied(),
1872        |connection_id| {
1873            session
1874                .peer
1875                .forward_send(session.connection_id, connection_id, message.clone())
1876        },
1877    );
1878
1879    Ok(())
1880}
1881
1882/// Notify other participants that a  language server has started.
1883async fn start_language_server(
1884    request: proto::StartLanguageServer,
1885    session: Session,
1886) -> Result<()> {
1887    let guest_connection_ids = session
1888        .db()
1889        .await
1890        .start_language_server(&request, session.connection_id)
1891        .await?;
1892
1893    broadcast(
1894        Some(session.connection_id),
1895        guest_connection_ids.iter().copied(),
1896        |connection_id| {
1897            session
1898                .peer
1899                .forward_send(session.connection_id, connection_id, request.clone())
1900        },
1901    );
1902    Ok(())
1903}
1904
1905/// Notify other participants that a language server has changed.
1906async fn update_language_server(
1907    request: proto::UpdateLanguageServer,
1908    session: Session,
1909) -> Result<()> {
1910    let project_id = ProjectId::from_proto(request.project_id);
1911    let project_connection_ids = session
1912        .db()
1913        .await
1914        .project_connection_ids(project_id, session.connection_id)
1915        .await?;
1916    broadcast(
1917        Some(session.connection_id),
1918        project_connection_ids.iter().copied(),
1919        |connection_id| {
1920            session
1921                .peer
1922                .forward_send(session.connection_id, connection_id, request.clone())
1923        },
1924    );
1925    Ok(())
1926}
1927
1928/// forward a project request to the host. These requests should be read only
1929/// as guests are allowed to send them.
1930async fn forward_read_only_project_request<T>(
1931    request: T,
1932    response: Response<T>,
1933    session: Session,
1934) -> Result<()>
1935where
1936    T: EntityMessage + RequestMessage,
1937{
1938    let project_id = ProjectId::from_proto(request.remote_entity_id());
1939    let host_connection_id = session
1940        .db()
1941        .await
1942        .host_for_read_only_project_request(project_id, session.connection_id)
1943        .await?;
1944    let payload = session
1945        .peer
1946        .forward_request(session.connection_id, host_connection_id, request)
1947        .await?;
1948    response.send(payload)?;
1949    Ok(())
1950}
1951
1952/// forward a project request to the host. These requests are disallowed
1953/// for guests.
1954async fn forward_mutating_project_request<T>(
1955    request: T,
1956    response: Response<T>,
1957    session: Session,
1958) -> Result<()>
1959where
1960    T: EntityMessage + RequestMessage,
1961{
1962    let project_id = ProjectId::from_proto(request.remote_entity_id());
1963    let host_connection_id = session
1964        .db()
1965        .await
1966        .host_for_mutating_project_request(project_id, session.connection_id)
1967        .await?;
1968    let payload = session
1969        .peer
1970        .forward_request(session.connection_id, host_connection_id, request)
1971        .await?;
1972    response.send(payload)?;
1973    Ok(())
1974}
1975
1976/// Notify other participants that a new buffer has been created
1977async fn create_buffer_for_peer(
1978    request: proto::CreateBufferForPeer,
1979    session: Session,
1980) -> Result<()> {
1981    session
1982        .db()
1983        .await
1984        .check_user_is_project_host(
1985            ProjectId::from_proto(request.project_id),
1986            session.connection_id,
1987        )
1988        .await?;
1989    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1990    session
1991        .peer
1992        .forward_send(session.connection_id, peer_id.into(), request)?;
1993    Ok(())
1994}
1995
1996/// Notify other participants that a buffer has been updated. This is
1997/// allowed for guests as long as the update is limited to selections.
1998async fn update_buffer(
1999    request: proto::UpdateBuffer,
2000    response: Response<proto::UpdateBuffer>,
2001    session: Session,
2002) -> Result<()> {
2003    let project_id = ProjectId::from_proto(request.project_id);
2004    let mut guest_connection_ids;
2005    let mut host_connection_id = None;
2006
2007    let mut requires_write_permission = false;
2008
2009    for op in request.operations.iter() {
2010        match op.variant {
2011            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2012            Some(_) => requires_write_permission = true,
2013        }
2014    }
2015
2016    {
2017        let collaborators = session
2018            .db()
2019            .await
2020            .project_collaborators_for_buffer_update(
2021                project_id,
2022                session.connection_id,
2023                requires_write_permission,
2024            )
2025            .await?;
2026        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2027        for collaborator in collaborators.iter() {
2028            if collaborator.is_host {
2029                host_connection_id = Some(collaborator.connection_id);
2030            } else {
2031                guest_connection_ids.push(collaborator.connection_id);
2032            }
2033        }
2034    }
2035    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2036
2037    broadcast(
2038        Some(session.connection_id),
2039        guest_connection_ids,
2040        |connection_id| {
2041            session
2042                .peer
2043                .forward_send(session.connection_id, connection_id, request.clone())
2044        },
2045    );
2046    if host_connection_id != session.connection_id {
2047        session
2048            .peer
2049            .forward_request(session.connection_id, host_connection_id, request.clone())
2050            .await?;
2051    }
2052
2053    response.send(proto::Ack {})?;
2054    Ok(())
2055}
2056
2057/// Notify other participants that a project has been updated.
2058async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2059    request: T,
2060    session: Session,
2061) -> Result<()> {
2062    let project_id = ProjectId::from_proto(request.remote_entity_id());
2063    let project_connection_ids = session
2064        .db()
2065        .await
2066        .project_connection_ids(project_id, session.connection_id)
2067        .await?;
2068
2069    broadcast(
2070        Some(session.connection_id),
2071        project_connection_ids.iter().copied(),
2072        |connection_id| {
2073            session
2074                .peer
2075                .forward_send(session.connection_id, connection_id, request.clone())
2076        },
2077    );
2078    Ok(())
2079}
2080
2081/// Start following another user in a call.
2082async fn follow(
2083    request: proto::Follow,
2084    response: Response<proto::Follow>,
2085    session: Session,
2086) -> Result<()> {
2087    let room_id = RoomId::from_proto(request.room_id);
2088    let project_id = request.project_id.map(ProjectId::from_proto);
2089    let leader_id = request
2090        .leader_id
2091        .ok_or_else(|| anyhow!("invalid leader id"))?
2092        .into();
2093    let follower_id = session.connection_id;
2094
2095    session
2096        .db()
2097        .await
2098        .check_room_participants(room_id, leader_id, session.connection_id)
2099        .await?;
2100
2101    let response_payload = session
2102        .peer
2103        .forward_request(session.connection_id, leader_id, request)
2104        .await?;
2105    response.send(response_payload)?;
2106
2107    if let Some(project_id) = project_id {
2108        let room = session
2109            .db()
2110            .await
2111            .follow(room_id, project_id, leader_id, follower_id)
2112            .await?;
2113        room_updated(&room, &session.peer);
2114    }
2115
2116    Ok(())
2117}
2118
2119/// Stop following another user in a call.
2120async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2121    let room_id = RoomId::from_proto(request.room_id);
2122    let project_id = request.project_id.map(ProjectId::from_proto);
2123    let leader_id = request
2124        .leader_id
2125        .ok_or_else(|| anyhow!("invalid leader id"))?
2126        .into();
2127    let follower_id = session.connection_id;
2128
2129    session
2130        .db()
2131        .await
2132        .check_room_participants(room_id, leader_id, session.connection_id)
2133        .await?;
2134
2135    session
2136        .peer
2137        .forward_send(session.connection_id, leader_id, request)?;
2138
2139    if let Some(project_id) = project_id {
2140        let room = session
2141            .db()
2142            .await
2143            .unfollow(room_id, project_id, leader_id, follower_id)
2144            .await?;
2145        room_updated(&room, &session.peer);
2146    }
2147
2148    Ok(())
2149}
2150
2151/// Notify everyone following you of your current location.
2152async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2153    let room_id = RoomId::from_proto(request.room_id);
2154    let database = session.db.lock().await;
2155
2156    let connection_ids = if let Some(project_id) = request.project_id {
2157        let project_id = ProjectId::from_proto(project_id);
2158        database
2159            .project_connection_ids(project_id, session.connection_id)
2160            .await?
2161    } else {
2162        database
2163            .room_connection_ids(room_id, session.connection_id)
2164            .await?
2165    };
2166
2167    // For now, don't send view update messages back to that view's current leader.
2168    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2169        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2170        _ => None,
2171    });
2172
2173    for connection_id in connection_ids.iter().cloned() {
2174        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2175            session
2176                .peer
2177                .forward_send(session.connection_id, connection_id, request.clone())?;
2178        }
2179    }
2180    Ok(())
2181}
2182
2183/// Get public data about users.
2184async fn get_users(
2185    request: proto::GetUsers,
2186    response: Response<proto::GetUsers>,
2187    session: Session,
2188) -> Result<()> {
2189    let user_ids = request
2190        .user_ids
2191        .into_iter()
2192        .map(UserId::from_proto)
2193        .collect();
2194    let users = session
2195        .db()
2196        .await
2197        .get_users_by_ids(user_ids)
2198        .await?
2199        .into_iter()
2200        .map(|user| proto::User {
2201            id: user.id.to_proto(),
2202            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2203            github_login: user.github_login,
2204        })
2205        .collect();
2206    response.send(proto::UsersResponse { users })?;
2207    Ok(())
2208}
2209
2210/// Search for users (to invite) buy Github login
2211async fn fuzzy_search_users(
2212    request: proto::FuzzySearchUsers,
2213    response: Response<proto::FuzzySearchUsers>,
2214    session: Session,
2215) -> Result<()> {
2216    let query = request.query;
2217    let users = match query.len() {
2218        0 => vec![],
2219        1 | 2 => session
2220            .db()
2221            .await
2222            .get_user_by_github_login(&query)
2223            .await?
2224            .into_iter()
2225            .collect(),
2226        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2227    };
2228    let users = users
2229        .into_iter()
2230        .filter(|user| user.id != session.user_id)
2231        .map(|user| proto::User {
2232            id: user.id.to_proto(),
2233            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2234            github_login: user.github_login,
2235        })
2236        .collect();
2237    response.send(proto::UsersResponse { users })?;
2238    Ok(())
2239}
2240
2241/// Send a contact request to another user.
2242async fn request_contact(
2243    request: proto::RequestContact,
2244    response: Response<proto::RequestContact>,
2245    session: Session,
2246) -> Result<()> {
2247    let requester_id = session.user_id;
2248    let responder_id = UserId::from_proto(request.responder_id);
2249    if requester_id == responder_id {
2250        return Err(anyhow!("cannot add yourself as a contact"))?;
2251    }
2252
2253    let notifications = session
2254        .db()
2255        .await
2256        .send_contact_request(requester_id, responder_id)
2257        .await?;
2258
2259    // Update outgoing contact requests of requester
2260    let mut update = proto::UpdateContacts::default();
2261    update.outgoing_requests.push(responder_id.to_proto());
2262    for connection_id in session
2263        .connection_pool()
2264        .await
2265        .user_connection_ids(requester_id)
2266    {
2267        session.peer.send(connection_id, update.clone())?;
2268    }
2269
2270    // Update incoming contact requests of responder
2271    let mut update = proto::UpdateContacts::default();
2272    update
2273        .incoming_requests
2274        .push(proto::IncomingContactRequest {
2275            requester_id: requester_id.to_proto(),
2276        });
2277    let connection_pool = session.connection_pool().await;
2278    for connection_id in connection_pool.user_connection_ids(responder_id) {
2279        session.peer.send(connection_id, update.clone())?;
2280    }
2281
2282    send_notifications(&connection_pool, &session.peer, notifications);
2283
2284    response.send(proto::Ack {})?;
2285    Ok(())
2286}
2287
2288/// Accept or decline a contact request
2289async fn respond_to_contact_request(
2290    request: proto::RespondToContactRequest,
2291    response: Response<proto::RespondToContactRequest>,
2292    session: Session,
2293) -> Result<()> {
2294    let responder_id = session.user_id;
2295    let requester_id = UserId::from_proto(request.requester_id);
2296    let db = session.db().await;
2297    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2298        db.dismiss_contact_notification(responder_id, requester_id)
2299            .await?;
2300    } else {
2301        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2302
2303        let notifications = db
2304            .respond_to_contact_request(responder_id, requester_id, accept)
2305            .await?;
2306        let requester_busy = db.is_user_busy(requester_id).await?;
2307        let responder_busy = db.is_user_busy(responder_id).await?;
2308
2309        let pool = session.connection_pool().await;
2310        // Update responder with new contact
2311        let mut update = proto::UpdateContacts::default();
2312        if accept {
2313            update
2314                .contacts
2315                .push(contact_for_user(requester_id, requester_busy, &pool));
2316        }
2317        update
2318            .remove_incoming_requests
2319            .push(requester_id.to_proto());
2320        for connection_id in pool.user_connection_ids(responder_id) {
2321            session.peer.send(connection_id, update.clone())?;
2322        }
2323
2324        // Update requester with new contact
2325        let mut update = proto::UpdateContacts::default();
2326        if accept {
2327            update
2328                .contacts
2329                .push(contact_for_user(responder_id, responder_busy, &pool));
2330        }
2331        update
2332            .remove_outgoing_requests
2333            .push(responder_id.to_proto());
2334
2335        for connection_id in pool.user_connection_ids(requester_id) {
2336            session.peer.send(connection_id, update.clone())?;
2337        }
2338
2339        send_notifications(&pool, &session.peer, notifications);
2340    }
2341
2342    response.send(proto::Ack {})?;
2343    Ok(())
2344}
2345
2346/// Remove a contact.
2347async fn remove_contact(
2348    request: proto::RemoveContact,
2349    response: Response<proto::RemoveContact>,
2350    session: Session,
2351) -> Result<()> {
2352    let requester_id = session.user_id;
2353    let responder_id = UserId::from_proto(request.user_id);
2354    let db = session.db().await;
2355    let (contact_accepted, deleted_notification_id) =
2356        db.remove_contact(requester_id, responder_id).await?;
2357
2358    let pool = session.connection_pool().await;
2359    // Update outgoing contact requests of requester
2360    let mut update = proto::UpdateContacts::default();
2361    if contact_accepted {
2362        update.remove_contacts.push(responder_id.to_proto());
2363    } else {
2364        update
2365            .remove_outgoing_requests
2366            .push(responder_id.to_proto());
2367    }
2368    for connection_id in pool.user_connection_ids(requester_id) {
2369        session.peer.send(connection_id, update.clone())?;
2370    }
2371
2372    // Update incoming contact requests of responder
2373    let mut update = proto::UpdateContacts::default();
2374    if contact_accepted {
2375        update.remove_contacts.push(requester_id.to_proto());
2376    } else {
2377        update
2378            .remove_incoming_requests
2379            .push(requester_id.to_proto());
2380    }
2381    for connection_id in pool.user_connection_ids(responder_id) {
2382        session.peer.send(connection_id, update.clone())?;
2383        if let Some(notification_id) = deleted_notification_id {
2384            session.peer.send(
2385                connection_id,
2386                proto::DeleteNotification {
2387                    notification_id: notification_id.to_proto(),
2388                },
2389            )?;
2390        }
2391    }
2392
2393    response.send(proto::Ack {})?;
2394    Ok(())
2395}
2396
2397/// Creates a new channel.
2398async fn create_channel(
2399    request: proto::CreateChannel,
2400    response: Response<proto::CreateChannel>,
2401    session: Session,
2402) -> Result<()> {
2403    let db = session.db().await;
2404
2405    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2406    let (channel, owner, channel_members) = db
2407        .create_channel(&request.name, parent_id, session.user_id)
2408        .await?;
2409
2410    response.send(proto::CreateChannelResponse {
2411        channel: Some(channel.to_proto()),
2412        parent_id: request.parent_id,
2413    })?;
2414
2415    let connection_pool = session.connection_pool().await;
2416    if let Some(owner) = owner {
2417        let update = proto::UpdateUserChannels {
2418            channel_memberships: vec![proto::ChannelMembership {
2419                channel_id: owner.channel_id.to_proto(),
2420                role: owner.role.into(),
2421            }],
2422            ..Default::default()
2423        };
2424        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2425            session.peer.send(connection_id, update.clone())?;
2426        }
2427    }
2428
2429    for channel_member in channel_members {
2430        if !channel_member.role.can_see_channel(channel.visibility) {
2431            continue;
2432        }
2433
2434        let update = proto::UpdateChannels {
2435            channels: vec![channel.to_proto()],
2436            ..Default::default()
2437        };
2438        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2439            session.peer.send(connection_id, update.clone())?;
2440        }
2441    }
2442
2443    Ok(())
2444}
2445
2446/// Delete a channel
2447async fn delete_channel(
2448    request: proto::DeleteChannel,
2449    response: Response<proto::DeleteChannel>,
2450    session: Session,
2451) -> Result<()> {
2452    let db = session.db().await;
2453
2454    let channel_id = request.channel_id;
2455    let (removed_channels, member_ids) = db
2456        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2457        .await?;
2458    response.send(proto::Ack {})?;
2459
2460    // Notify members of removed channels
2461    let mut update = proto::UpdateChannels::default();
2462    update
2463        .delete_channels
2464        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2465
2466    let connection_pool = session.connection_pool().await;
2467    for member_id in member_ids {
2468        for connection_id in connection_pool.user_connection_ids(member_id) {
2469            session.peer.send(connection_id, update.clone())?;
2470        }
2471    }
2472
2473    Ok(())
2474}
2475
2476/// Invite someone to join a channel.
2477async fn invite_channel_member(
2478    request: proto::InviteChannelMember,
2479    response: Response<proto::InviteChannelMember>,
2480    session: Session,
2481) -> Result<()> {
2482    let db = session.db().await;
2483    let channel_id = ChannelId::from_proto(request.channel_id);
2484    let invitee_id = UserId::from_proto(request.user_id);
2485    let InviteMemberResult {
2486        channel,
2487        notifications,
2488    } = db
2489        .invite_channel_member(
2490            channel_id,
2491            invitee_id,
2492            session.user_id,
2493            request.role().into(),
2494        )
2495        .await?;
2496
2497    let update = proto::UpdateChannels {
2498        channel_invitations: vec![channel.to_proto()],
2499        ..Default::default()
2500    };
2501
2502    let connection_pool = session.connection_pool().await;
2503    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2504        session.peer.send(connection_id, update.clone())?;
2505    }
2506
2507    send_notifications(&connection_pool, &session.peer, notifications);
2508
2509    response.send(proto::Ack {})?;
2510    Ok(())
2511}
2512
2513/// remove someone from a channel
2514async fn remove_channel_member(
2515    request: proto::RemoveChannelMember,
2516    response: Response<proto::RemoveChannelMember>,
2517    session: Session,
2518) -> Result<()> {
2519    let db = session.db().await;
2520    let channel_id = ChannelId::from_proto(request.channel_id);
2521    let member_id = UserId::from_proto(request.user_id);
2522
2523    let RemoveChannelMemberResult {
2524        membership_update,
2525        notification_id,
2526    } = db
2527        .remove_channel_member(channel_id, member_id, session.user_id)
2528        .await?;
2529
2530    let connection_pool = &session.connection_pool().await;
2531    notify_membership_updated(
2532        &connection_pool,
2533        membership_update,
2534        member_id,
2535        &session.peer,
2536    );
2537    for connection_id in connection_pool.user_connection_ids(member_id) {
2538        if let Some(notification_id) = notification_id {
2539            session
2540                .peer
2541                .send(
2542                    connection_id,
2543                    proto::DeleteNotification {
2544                        notification_id: notification_id.to_proto(),
2545                    },
2546                )
2547                .trace_err();
2548        }
2549    }
2550
2551    response.send(proto::Ack {})?;
2552    Ok(())
2553}
2554
2555/// Toggle the channel between public and private.
2556/// Care is taken to maintain the invariant that public channels only descend from public channels,
2557/// (though members-only channels can appear at any point in the hierarchy).
2558async fn set_channel_visibility(
2559    request: proto::SetChannelVisibility,
2560    response: Response<proto::SetChannelVisibility>,
2561    session: Session,
2562) -> Result<()> {
2563    let db = session.db().await;
2564    let channel_id = ChannelId::from_proto(request.channel_id);
2565    let visibility = request.visibility().into();
2566
2567    let (channel, channel_members) = db
2568        .set_channel_visibility(channel_id, visibility, session.user_id)
2569        .await?;
2570
2571    let connection_pool = session.connection_pool().await;
2572    for member in channel_members {
2573        let update = if member.role.can_see_channel(channel.visibility) {
2574            proto::UpdateChannels {
2575                channels: vec![channel.to_proto()],
2576                ..Default::default()
2577            }
2578        } else {
2579            proto::UpdateChannels {
2580                delete_channels: vec![channel.id.to_proto()],
2581                ..Default::default()
2582            }
2583        };
2584
2585        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2586            session.peer.send(connection_id, update.clone())?;
2587        }
2588    }
2589
2590    response.send(proto::Ack {})?;
2591    Ok(())
2592}
2593
2594/// Alter the role for a user in the channel.
2595async fn set_channel_member_role(
2596    request: proto::SetChannelMemberRole,
2597    response: Response<proto::SetChannelMemberRole>,
2598    session: Session,
2599) -> Result<()> {
2600    let db = session.db().await;
2601    let channel_id = ChannelId::from_proto(request.channel_id);
2602    let member_id = UserId::from_proto(request.user_id);
2603    let result = db
2604        .set_channel_member_role(
2605            channel_id,
2606            session.user_id,
2607            member_id,
2608            request.role().into(),
2609        )
2610        .await?;
2611
2612    match result {
2613        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2614            let connection_pool = session.connection_pool().await;
2615            notify_membership_updated(
2616                &connection_pool,
2617                membership_update,
2618                member_id,
2619                &session.peer,
2620            )
2621        }
2622        db::SetMemberRoleResult::InviteUpdated(channel) => {
2623            let update = proto::UpdateChannels {
2624                channel_invitations: vec![channel.to_proto()],
2625                ..Default::default()
2626            };
2627
2628            for connection_id in session
2629                .connection_pool()
2630                .await
2631                .user_connection_ids(member_id)
2632            {
2633                session.peer.send(connection_id, update.clone())?;
2634            }
2635        }
2636    }
2637
2638    response.send(proto::Ack {})?;
2639    Ok(())
2640}
2641
2642/// Change the name of a channel
2643async fn rename_channel(
2644    request: proto::RenameChannel,
2645    response: Response<proto::RenameChannel>,
2646    session: Session,
2647) -> Result<()> {
2648    let db = session.db().await;
2649    let channel_id = ChannelId::from_proto(request.channel_id);
2650    let (channel, channel_members) = db
2651        .rename_channel(channel_id, session.user_id, &request.name)
2652        .await?;
2653
2654    response.send(proto::RenameChannelResponse {
2655        channel: Some(channel.to_proto()),
2656    })?;
2657
2658    let connection_pool = session.connection_pool().await;
2659    for channel_member in channel_members {
2660        if !channel_member.role.can_see_channel(channel.visibility) {
2661            continue;
2662        }
2663        let update = proto::UpdateChannels {
2664            channels: vec![channel.to_proto()],
2665            ..Default::default()
2666        };
2667        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2668            session.peer.send(connection_id, update.clone())?;
2669        }
2670    }
2671
2672    Ok(())
2673}
2674
2675/// Move a channel to a new parent.
2676async fn move_channel(
2677    request: proto::MoveChannel,
2678    response: Response<proto::MoveChannel>,
2679    session: Session,
2680) -> Result<()> {
2681    let channel_id = ChannelId::from_proto(request.channel_id);
2682    let to = ChannelId::from_proto(request.to);
2683
2684    let (channels, channel_members) = session
2685        .db()
2686        .await
2687        .move_channel(channel_id, to, session.user_id)
2688        .await?;
2689
2690    let connection_pool = session.connection_pool().await;
2691    for member in channel_members {
2692        let channels = channels
2693            .iter()
2694            .filter_map(|channel| {
2695                if member.role.can_see_channel(channel.visibility) {
2696                    Some(channel.to_proto())
2697                } else {
2698                    None
2699                }
2700            })
2701            .collect::<Vec<_>>();
2702        if channels.is_empty() {
2703            continue;
2704        }
2705
2706        let update = proto::UpdateChannels {
2707            channels,
2708            ..Default::default()
2709        };
2710
2711        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2712            session.peer.send(connection_id, update.clone())?;
2713        }
2714    }
2715
2716    response.send(Ack {})?;
2717    Ok(())
2718}
2719
2720/// Get the list of channel members
2721async fn get_channel_members(
2722    request: proto::GetChannelMembers,
2723    response: Response<proto::GetChannelMembers>,
2724    session: Session,
2725) -> Result<()> {
2726    let db = session.db().await;
2727    let channel_id = ChannelId::from_proto(request.channel_id);
2728    let members = db
2729        .get_channel_participant_details(channel_id, session.user_id)
2730        .await?;
2731    response.send(proto::GetChannelMembersResponse { members })?;
2732    Ok(())
2733}
2734
2735/// Accept or decline a channel invitation.
2736async fn respond_to_channel_invite(
2737    request: proto::RespondToChannelInvite,
2738    response: Response<proto::RespondToChannelInvite>,
2739    session: Session,
2740) -> Result<()> {
2741    let db = session.db().await;
2742    let channel_id = ChannelId::from_proto(request.channel_id);
2743    let RespondToChannelInvite {
2744        membership_update,
2745        notifications,
2746    } = db
2747        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2748        .await?;
2749
2750    let connection_pool = session.connection_pool().await;
2751    if let Some(membership_update) = membership_update {
2752        notify_membership_updated(
2753            &connection_pool,
2754            membership_update,
2755            session.user_id,
2756            &session.peer,
2757        );
2758    } else {
2759        let update = proto::UpdateChannels {
2760            remove_channel_invitations: vec![channel_id.to_proto()],
2761            ..Default::default()
2762        };
2763
2764        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2765            session.peer.send(connection_id, update.clone())?;
2766        }
2767    };
2768
2769    send_notifications(&connection_pool, &session.peer, notifications);
2770
2771    response.send(proto::Ack {})?;
2772
2773    Ok(())
2774}
2775
2776/// Join the channels' room
2777async fn join_channel(
2778    request: proto::JoinChannel,
2779    response: Response<proto::JoinChannel>,
2780    session: Session,
2781) -> Result<()> {
2782    let channel_id = ChannelId::from_proto(request.channel_id);
2783    join_channel_internal(channel_id, Box::new(response), session).await
2784}
2785
2786trait JoinChannelInternalResponse {
2787    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2788}
2789impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2790    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2791        Response::<proto::JoinChannel>::send(self, result)
2792    }
2793}
2794impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2795    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2796        Response::<proto::JoinRoom>::send(self, result)
2797    }
2798}
2799
2800async fn join_channel_internal(
2801    channel_id: ChannelId,
2802    response: Box<impl JoinChannelInternalResponse>,
2803    session: Session,
2804) -> Result<()> {
2805    let joined_room = {
2806        leave_room_for_session(&session).await?;
2807        let db = session.db().await;
2808
2809        let (joined_room, membership_updated, role) = db
2810            .join_channel(channel_id, session.user_id, session.connection_id)
2811            .await?;
2812
2813        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2814            let (can_publish, token) = if role == ChannelRole::Guest {
2815                (
2816                    false,
2817                    live_kit
2818                        .guest_token(
2819                            &joined_room.room.live_kit_room,
2820                            &session.user_id.to_string(),
2821                        )
2822                        .trace_err()?,
2823                )
2824            } else {
2825                (
2826                    true,
2827                    live_kit
2828                        .room_token(
2829                            &joined_room.room.live_kit_room,
2830                            &session.user_id.to_string(),
2831                        )
2832                        .trace_err()?,
2833                )
2834            };
2835
2836            Some(LiveKitConnectionInfo {
2837                server_url: live_kit.url().into(),
2838                token,
2839                can_publish,
2840            })
2841        });
2842
2843        response.send(proto::JoinRoomResponse {
2844            room: Some(joined_room.room.clone()),
2845            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2846            live_kit_connection_info,
2847        })?;
2848
2849        let connection_pool = session.connection_pool().await;
2850        if let Some(membership_updated) = membership_updated {
2851            notify_membership_updated(
2852                &connection_pool,
2853                membership_updated,
2854                session.user_id,
2855                &session.peer,
2856            );
2857        }
2858
2859        room_updated(&joined_room.room, &session.peer);
2860
2861        joined_room
2862    };
2863
2864    channel_updated(
2865        channel_id,
2866        &joined_room.room,
2867        &joined_room.channel_members,
2868        &session.peer,
2869        &*session.connection_pool().await,
2870    );
2871
2872    update_user_contacts(session.user_id, &session).await?;
2873    Ok(())
2874}
2875
2876/// Start editing the channel notes
2877async fn join_channel_buffer(
2878    request: proto::JoinChannelBuffer,
2879    response: Response<proto::JoinChannelBuffer>,
2880    session: Session,
2881) -> Result<()> {
2882    let db = session.db().await;
2883    let channel_id = ChannelId::from_proto(request.channel_id);
2884
2885    let open_response = db
2886        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2887        .await?;
2888
2889    let collaborators = open_response.collaborators.clone();
2890    response.send(open_response)?;
2891
2892    let update = UpdateChannelBufferCollaborators {
2893        channel_id: channel_id.to_proto(),
2894        collaborators: collaborators.clone(),
2895    };
2896    channel_buffer_updated(
2897        session.connection_id,
2898        collaborators
2899            .iter()
2900            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2901        &update,
2902        &session.peer,
2903    );
2904
2905    Ok(())
2906}
2907
2908/// Edit the channel notes
2909async fn update_channel_buffer(
2910    request: proto::UpdateChannelBuffer,
2911    session: Session,
2912) -> Result<()> {
2913    let db = session.db().await;
2914    let channel_id = ChannelId::from_proto(request.channel_id);
2915
2916    let (collaborators, non_collaborators, epoch, version) = db
2917        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2918        .await?;
2919
2920    channel_buffer_updated(
2921        session.connection_id,
2922        collaborators,
2923        &proto::UpdateChannelBuffer {
2924            channel_id: channel_id.to_proto(),
2925            operations: request.operations,
2926        },
2927        &session.peer,
2928    );
2929
2930    let pool = &*session.connection_pool().await;
2931
2932    broadcast(
2933        None,
2934        non_collaborators
2935            .iter()
2936            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2937        |peer_id| {
2938            session.peer.send(
2939                peer_id,
2940                proto::UpdateChannels {
2941                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2942                        channel_id: channel_id.to_proto(),
2943                        epoch: epoch as u64,
2944                        version: version.clone(),
2945                    }],
2946                    ..Default::default()
2947                },
2948            )
2949        },
2950    );
2951
2952    Ok(())
2953}
2954
2955/// Rejoin the channel notes after a connection blip
2956async fn rejoin_channel_buffers(
2957    request: proto::RejoinChannelBuffers,
2958    response: Response<proto::RejoinChannelBuffers>,
2959    session: Session,
2960) -> Result<()> {
2961    let db = session.db().await;
2962    let buffers = db
2963        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2964        .await?;
2965
2966    for rejoined_buffer in &buffers {
2967        let collaborators_to_notify = rejoined_buffer
2968            .buffer
2969            .collaborators
2970            .iter()
2971            .filter_map(|c| Some(c.peer_id?.into()));
2972        channel_buffer_updated(
2973            session.connection_id,
2974            collaborators_to_notify,
2975            &proto::UpdateChannelBufferCollaborators {
2976                channel_id: rejoined_buffer.buffer.channel_id,
2977                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2978            },
2979            &session.peer,
2980        );
2981    }
2982
2983    response.send(proto::RejoinChannelBuffersResponse {
2984        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2985    })?;
2986
2987    Ok(())
2988}
2989
2990/// Stop editing the channel notes
2991async fn leave_channel_buffer(
2992    request: proto::LeaveChannelBuffer,
2993    response: Response<proto::LeaveChannelBuffer>,
2994    session: Session,
2995) -> Result<()> {
2996    let db = session.db().await;
2997    let channel_id = ChannelId::from_proto(request.channel_id);
2998
2999    let left_buffer = db
3000        .leave_channel_buffer(channel_id, session.connection_id)
3001        .await?;
3002
3003    response.send(Ack {})?;
3004
3005    channel_buffer_updated(
3006        session.connection_id,
3007        left_buffer.connections,
3008        &proto::UpdateChannelBufferCollaborators {
3009            channel_id: channel_id.to_proto(),
3010            collaborators: left_buffer.collaborators,
3011        },
3012        &session.peer,
3013    );
3014
3015    Ok(())
3016}
3017
3018fn channel_buffer_updated<T: EnvelopedMessage>(
3019    sender_id: ConnectionId,
3020    collaborators: impl IntoIterator<Item = ConnectionId>,
3021    message: &T,
3022    peer: &Peer,
3023) {
3024    broadcast(Some(sender_id), collaborators, |peer_id| {
3025        peer.send(peer_id, message.clone())
3026    });
3027}
3028
3029fn send_notifications(
3030    connection_pool: &ConnectionPool,
3031    peer: &Peer,
3032    notifications: db::NotificationBatch,
3033) {
3034    for (user_id, notification) in notifications {
3035        for connection_id in connection_pool.user_connection_ids(user_id) {
3036            if let Err(error) = peer.send(
3037                connection_id,
3038                proto::AddNotification {
3039                    notification: Some(notification.clone()),
3040                },
3041            ) {
3042                tracing::error!(
3043                    "failed to send notification to {:?} {}",
3044                    connection_id,
3045                    error
3046                );
3047            }
3048        }
3049    }
3050}
3051
3052/// Send a message to the channel
3053async fn send_channel_message(
3054    request: proto::SendChannelMessage,
3055    response: Response<proto::SendChannelMessage>,
3056    session: Session,
3057) -> Result<()> {
3058    // Validate the message body.
3059    let body = request.body.trim().to_string();
3060    if body.len() > MAX_MESSAGE_LEN {
3061        return Err(anyhow!("message is too long"))?;
3062    }
3063    if body.is_empty() {
3064        return Err(anyhow!("message can't be blank"))?;
3065    }
3066
3067    // TODO: adjust mentions if body is trimmed
3068
3069    let timestamp = OffsetDateTime::now_utc();
3070    let nonce = request
3071        .nonce
3072        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3073
3074    let channel_id = ChannelId::from_proto(request.channel_id);
3075    let CreatedChannelMessage {
3076        message_id,
3077        participant_connection_ids,
3078        channel_members,
3079        notifications,
3080    } = session
3081        .db()
3082        .await
3083        .create_channel_message(
3084            channel_id,
3085            session.user_id,
3086            &body,
3087            &request.mentions,
3088            timestamp,
3089            nonce.clone().into(),
3090            match request.reply_to_message_id {
3091                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3092                None => None,
3093            },
3094        )
3095        .await?;
3096    let message = proto::ChannelMessage {
3097        sender_id: session.user_id.to_proto(),
3098        id: message_id.to_proto(),
3099        body,
3100        mentions: request.mentions,
3101        timestamp: timestamp.unix_timestamp() as u64,
3102        nonce: Some(nonce),
3103        reply_to_message_id: request.reply_to_message_id,
3104    };
3105    broadcast(
3106        Some(session.connection_id),
3107        participant_connection_ids,
3108        |connection| {
3109            session.peer.send(
3110                connection,
3111                proto::ChannelMessageSent {
3112                    channel_id: channel_id.to_proto(),
3113                    message: Some(message.clone()),
3114                },
3115            )
3116        },
3117    );
3118    response.send(proto::SendChannelMessageResponse {
3119        message: Some(message),
3120    })?;
3121
3122    let pool = &*session.connection_pool().await;
3123    broadcast(
3124        None,
3125        channel_members
3126            .iter()
3127            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3128        |peer_id| {
3129            session.peer.send(
3130                peer_id,
3131                proto::UpdateChannels {
3132                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3133                        channel_id: channel_id.to_proto(),
3134                        message_id: message_id.to_proto(),
3135                    }],
3136                    ..Default::default()
3137                },
3138            )
3139        },
3140    );
3141    send_notifications(pool, &session.peer, notifications);
3142
3143    Ok(())
3144}
3145
3146/// Delete a channel message
3147async fn remove_channel_message(
3148    request: proto::RemoveChannelMessage,
3149    response: Response<proto::RemoveChannelMessage>,
3150    session: Session,
3151) -> Result<()> {
3152    let channel_id = ChannelId::from_proto(request.channel_id);
3153    let message_id = MessageId::from_proto(request.message_id);
3154    let connection_ids = session
3155        .db()
3156        .await
3157        .remove_channel_message(channel_id, message_id, session.user_id)
3158        .await?;
3159    broadcast(Some(session.connection_id), connection_ids, |connection| {
3160        session.peer.send(connection, request.clone())
3161    });
3162    response.send(proto::Ack {})?;
3163    Ok(())
3164}
3165
3166/// Mark a channel message as read
3167async fn acknowledge_channel_message(
3168    request: proto::AckChannelMessage,
3169    session: Session,
3170) -> Result<()> {
3171    let channel_id = ChannelId::from_proto(request.channel_id);
3172    let message_id = MessageId::from_proto(request.message_id);
3173    let notifications = session
3174        .db()
3175        .await
3176        .observe_channel_message(channel_id, session.user_id, message_id)
3177        .await?;
3178    send_notifications(
3179        &*session.connection_pool().await,
3180        &session.peer,
3181        notifications,
3182    );
3183    Ok(())
3184}
3185
3186/// Mark a buffer version as synced
3187async fn acknowledge_buffer_version(
3188    request: proto::AckBufferOperation,
3189    session: Session,
3190) -> Result<()> {
3191    let buffer_id = BufferId::from_proto(request.buffer_id);
3192    session
3193        .db()
3194        .await
3195        .observe_buffer_version(
3196            buffer_id,
3197            session.user_id,
3198            request.epoch as i32,
3199            &request.version,
3200        )
3201        .await?;
3202    Ok(())
3203}
3204
3205/// Start receiving chat updates for a channel
3206async fn join_channel_chat(
3207    request: proto::JoinChannelChat,
3208    response: Response<proto::JoinChannelChat>,
3209    session: Session,
3210) -> Result<()> {
3211    let channel_id = ChannelId::from_proto(request.channel_id);
3212
3213    let db = session.db().await;
3214    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3215        .await?;
3216    let messages = db
3217        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3218        .await?;
3219    response.send(proto::JoinChannelChatResponse {
3220        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3221        messages,
3222    })?;
3223    Ok(())
3224}
3225
3226/// Stop receiving chat updates for a channel
3227async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3228    let channel_id = ChannelId::from_proto(request.channel_id);
3229    session
3230        .db()
3231        .await
3232        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3233        .await?;
3234    Ok(())
3235}
3236
3237/// Retrieve the chat history for a channel
3238async fn get_channel_messages(
3239    request: proto::GetChannelMessages,
3240    response: Response<proto::GetChannelMessages>,
3241    session: Session,
3242) -> Result<()> {
3243    let channel_id = ChannelId::from_proto(request.channel_id);
3244    let messages = session
3245        .db()
3246        .await
3247        .get_channel_messages(
3248            channel_id,
3249            session.user_id,
3250            MESSAGE_COUNT_PER_PAGE,
3251            Some(MessageId::from_proto(request.before_message_id)),
3252        )
3253        .await?;
3254    response.send(proto::GetChannelMessagesResponse {
3255        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3256        messages,
3257    })?;
3258    Ok(())
3259}
3260
3261/// Retrieve specific chat messages
3262async fn get_channel_messages_by_id(
3263    request: proto::GetChannelMessagesById,
3264    response: Response<proto::GetChannelMessagesById>,
3265    session: Session,
3266) -> Result<()> {
3267    let message_ids = request
3268        .message_ids
3269        .iter()
3270        .map(|id| MessageId::from_proto(*id))
3271        .collect::<Vec<_>>();
3272    let messages = session
3273        .db()
3274        .await
3275        .get_channel_messages_by_id(session.user_id, &message_ids)
3276        .await?;
3277    response.send(proto::GetChannelMessagesResponse {
3278        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3279        messages,
3280    })?;
3281    Ok(())
3282}
3283
3284/// Retrieve the current users notifications
3285async fn get_notifications(
3286    request: proto::GetNotifications,
3287    response: Response<proto::GetNotifications>,
3288    session: Session,
3289) -> Result<()> {
3290    let notifications = session
3291        .db()
3292        .await
3293        .get_notifications(
3294            session.user_id,
3295            NOTIFICATION_COUNT_PER_PAGE,
3296            request
3297                .before_id
3298                .map(|id| db::NotificationId::from_proto(id)),
3299        )
3300        .await?;
3301    response.send(proto::GetNotificationsResponse {
3302        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3303        notifications,
3304    })?;
3305    Ok(())
3306}
3307
3308/// Mark notifications as read
3309async fn mark_notification_as_read(
3310    request: proto::MarkNotificationRead,
3311    response: Response<proto::MarkNotificationRead>,
3312    session: Session,
3313) -> Result<()> {
3314    let database = &session.db().await;
3315    let notifications = database
3316        .mark_notification_as_read_by_id(
3317            session.user_id,
3318            NotificationId::from_proto(request.notification_id),
3319        )
3320        .await?;
3321    send_notifications(
3322        &*session.connection_pool().await,
3323        &session.peer,
3324        notifications,
3325    );
3326    response.send(proto::Ack {})?;
3327    Ok(())
3328}
3329
3330/// Get the current users information
3331async fn get_private_user_info(
3332    _request: proto::GetPrivateUserInfo,
3333    response: Response<proto::GetPrivateUserInfo>,
3334    session: Session,
3335) -> Result<()> {
3336    let db = session.db().await;
3337
3338    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3339    let user = db
3340        .get_user_by_id(session.user_id)
3341        .await?
3342        .ok_or_else(|| anyhow!("user not found"))?;
3343    let flags = db.get_user_flags(session.user_id).await?;
3344
3345    response.send(proto::GetPrivateUserInfoResponse {
3346        metrics_id,
3347        staff: user.admin,
3348        flags,
3349    })?;
3350    Ok(())
3351}
3352
3353fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3354    match message {
3355        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3356        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3357        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3358        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3359        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3360            code: frame.code.into(),
3361            reason: frame.reason,
3362        })),
3363    }
3364}
3365
3366fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3367    match message {
3368        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3369        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3370        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3371        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3372        AxumMessage::Close(frame) => {
3373            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3374                code: frame.code.into(),
3375                reason: frame.reason,
3376            }))
3377        }
3378    }
3379}
3380
3381fn notify_membership_updated(
3382    connection_pool: &ConnectionPool,
3383    result: MembershipUpdated,
3384    user_id: UserId,
3385    peer: &Peer,
3386) {
3387    let user_channels_update = proto::UpdateUserChannels {
3388        channel_memberships: result
3389            .new_channels
3390            .channel_memberships
3391            .iter()
3392            .map(|cm| proto::ChannelMembership {
3393                channel_id: cm.channel_id.to_proto(),
3394                role: cm.role.into(),
3395            })
3396            .collect(),
3397        ..Default::default()
3398    };
3399    let mut update = build_channels_update(result.new_channels, vec![]);
3400    update.delete_channels = result
3401        .removed_channels
3402        .into_iter()
3403        .map(|id| id.to_proto())
3404        .collect();
3405    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3406
3407    for connection_id in connection_pool.user_connection_ids(user_id) {
3408        peer.send(connection_id, user_channels_update.clone())
3409            .trace_err();
3410        peer.send(connection_id, update.clone()).trace_err();
3411    }
3412}
3413
3414fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3415    proto::UpdateUserChannels {
3416        channel_memberships: channels
3417            .channel_memberships
3418            .iter()
3419            .map(|m| proto::ChannelMembership {
3420                channel_id: m.channel_id.to_proto(),
3421                role: m.role.into(),
3422            })
3423            .collect(),
3424        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3425        observed_channel_message_id: channels.observed_channel_messages.clone(),
3426    }
3427}
3428
3429fn build_channels_update(
3430    channels: ChannelsForUser,
3431    channel_invites: Vec<db::Channel>,
3432) -> proto::UpdateChannels {
3433    let mut update = proto::UpdateChannels::default();
3434
3435    for channel in channels.channels {
3436        update.channels.push(channel.to_proto());
3437    }
3438
3439    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3440    update.latest_channel_message_ids = channels.latest_channel_messages;
3441
3442    for (channel_id, participants) in channels.channel_participants {
3443        update
3444            .channel_participants
3445            .push(proto::ChannelParticipants {
3446                channel_id: channel_id.to_proto(),
3447                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3448            });
3449    }
3450
3451    for channel in channel_invites {
3452        update.channel_invitations.push(channel.to_proto());
3453    }
3454    for project in channels.hosted_projects {
3455        update.hosted_projects.push(project);
3456    }
3457
3458    update
3459}
3460
3461fn build_initial_contacts_update(
3462    contacts: Vec<db::Contact>,
3463    pool: &ConnectionPool,
3464) -> proto::UpdateContacts {
3465    let mut update = proto::UpdateContacts::default();
3466
3467    for contact in contacts {
3468        match contact {
3469            db::Contact::Accepted { user_id, busy } => {
3470                update.contacts.push(contact_for_user(user_id, busy, &pool));
3471            }
3472            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3473            db::Contact::Incoming { user_id } => {
3474                update
3475                    .incoming_requests
3476                    .push(proto::IncomingContactRequest {
3477                        requester_id: user_id.to_proto(),
3478                    })
3479            }
3480        }
3481    }
3482
3483    update
3484}
3485
3486fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3487    proto::Contact {
3488        user_id: user_id.to_proto(),
3489        online: pool.is_user_online(user_id),
3490        busy,
3491    }
3492}
3493
3494fn room_updated(room: &proto::Room, peer: &Peer) {
3495    broadcast(
3496        None,
3497        room.participants
3498            .iter()
3499            .filter_map(|participant| Some(participant.peer_id?.into())),
3500        |peer_id| {
3501            peer.send(
3502                peer_id,
3503                proto::RoomUpdated {
3504                    room: Some(room.clone()),
3505                },
3506            )
3507        },
3508    );
3509}
3510
3511fn channel_updated(
3512    channel_id: ChannelId,
3513    room: &proto::Room,
3514    channel_members: &[UserId],
3515    peer: &Peer,
3516    pool: &ConnectionPool,
3517) {
3518    let participants = room
3519        .participants
3520        .iter()
3521        .map(|p| p.user_id)
3522        .collect::<Vec<_>>();
3523
3524    broadcast(
3525        None,
3526        channel_members
3527            .iter()
3528            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3529        |peer_id| {
3530            peer.send(
3531                peer_id,
3532                proto::UpdateChannels {
3533                    channel_participants: vec![proto::ChannelParticipants {
3534                        channel_id: channel_id.to_proto(),
3535                        participant_user_ids: participants.clone(),
3536                    }],
3537                    ..Default::default()
3538                },
3539            )
3540        },
3541    );
3542}
3543
3544async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3545    let db = session.db().await;
3546
3547    let contacts = db.get_contacts(user_id).await?;
3548    let busy = db.is_user_busy(user_id).await?;
3549
3550    let pool = session.connection_pool().await;
3551    let updated_contact = contact_for_user(user_id, busy, &pool);
3552    for contact in contacts {
3553        if let db::Contact::Accepted {
3554            user_id: contact_user_id,
3555            ..
3556        } = contact
3557        {
3558            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3559                session
3560                    .peer
3561                    .send(
3562                        contact_conn_id,
3563                        proto::UpdateContacts {
3564                            contacts: vec![updated_contact.clone()],
3565                            remove_contacts: Default::default(),
3566                            incoming_requests: Default::default(),
3567                            remove_incoming_requests: Default::default(),
3568                            outgoing_requests: Default::default(),
3569                            remove_outgoing_requests: Default::default(),
3570                        },
3571                    )
3572                    .trace_err();
3573            }
3574        }
3575    }
3576    Ok(())
3577}
3578
3579async fn leave_room_for_session(session: &Session) -> Result<()> {
3580    let mut contacts_to_update = HashSet::default();
3581
3582    let room_id;
3583    let canceled_calls_to_user_ids;
3584    let live_kit_room;
3585    let delete_live_kit_room;
3586    let room;
3587    let channel_members;
3588    let channel_id;
3589
3590    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3591        contacts_to_update.insert(session.user_id);
3592
3593        for project in left_room.left_projects.values() {
3594            project_left(project, session);
3595        }
3596
3597        room_id = RoomId::from_proto(left_room.room.id);
3598        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3599        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3600        delete_live_kit_room = left_room.deleted;
3601        room = mem::take(&mut left_room.room);
3602        channel_members = mem::take(&mut left_room.channel_members);
3603        channel_id = left_room.channel_id;
3604
3605        room_updated(&room, &session.peer);
3606    } else {
3607        return Ok(());
3608    }
3609
3610    if let Some(channel_id) = channel_id {
3611        channel_updated(
3612            channel_id,
3613            &room,
3614            &channel_members,
3615            &session.peer,
3616            &*session.connection_pool().await,
3617        );
3618    }
3619
3620    {
3621        let pool = session.connection_pool().await;
3622        for canceled_user_id in canceled_calls_to_user_ids {
3623            for connection_id in pool.user_connection_ids(canceled_user_id) {
3624                session
3625                    .peer
3626                    .send(
3627                        connection_id,
3628                        proto::CallCanceled {
3629                            room_id: room_id.to_proto(),
3630                        },
3631                    )
3632                    .trace_err();
3633            }
3634            contacts_to_update.insert(canceled_user_id);
3635        }
3636    }
3637
3638    for contact_user_id in contacts_to_update {
3639        update_user_contacts(contact_user_id, &session).await?;
3640    }
3641
3642    if let Some(live_kit) = session.live_kit_client.as_ref() {
3643        live_kit
3644            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3645            .await
3646            .trace_err();
3647
3648        if delete_live_kit_room {
3649            live_kit.delete_room(live_kit_room).await.trace_err();
3650        }
3651    }
3652
3653    Ok(())
3654}
3655
3656async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3657    let left_channel_buffers = session
3658        .db()
3659        .await
3660        .leave_channel_buffers(session.connection_id)
3661        .await?;
3662
3663    for left_buffer in left_channel_buffers {
3664        channel_buffer_updated(
3665            session.connection_id,
3666            left_buffer.connections,
3667            &proto::UpdateChannelBufferCollaborators {
3668                channel_id: left_buffer.channel_id.to_proto(),
3669                collaborators: left_buffer.collaborators,
3670            },
3671            &session.peer,
3672        );
3673    }
3674
3675    Ok(())
3676}
3677
3678fn project_left(project: &db::LeftProject, session: &Session) {
3679    for connection_id in &project.connection_ids {
3680        if project.host_user_id == Some(session.user_id) {
3681            session
3682                .peer
3683                .send(
3684                    *connection_id,
3685                    proto::UnshareProject {
3686                        project_id: project.id.to_proto(),
3687                    },
3688                )
3689                .trace_err();
3690        } else {
3691            session
3692                .peer
3693                .send(
3694                    *connection_id,
3695                    proto::RemoveProjectCollaborator {
3696                        project_id: project.id.to_proto(),
3697                        peer_id: Some(session.connection_id.into()),
3698                    },
3699                )
3700                .trace_err();
3701        }
3702    }
3703}
3704
3705pub trait ResultExt {
3706    type Ok;
3707
3708    fn trace_err(self) -> Option<Self::Ok>;
3709}
3710
3711impl<T, E> ResultExt for Result<T, E>
3712where
3713    E: std::fmt::Debug,
3714{
3715    type Ok = T;
3716
3717    fn trace_err(self) -> Option<T> {
3718        match self {
3719            Ok(value) => Some(value),
3720            Err(error) => {
3721                tracing::error!("{:?}", error);
3722                None
3723            }
3724        }
3725    }
3726}