rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        HostedProjectId, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
   8        ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
   9        User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, Result,
  13};
  14use anyhow::anyhow;
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use futures::{
  34    channel::oneshot,
  35    future::{self, BoxFuture},
  36    stream::FuturesUnordered,
  37    FutureExt, SinkExt, StreamExt, TryStreamExt,
  38};
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc, OnceLock,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67use util::SemanticVersion;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    _executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<bool>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(false).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(set_room_participant_role)
 194            .add_request_handler(call)
 195            .add_request_handler(cancel_call)
 196            .add_message_handler(decline_call)
 197            .add_request_handler(update_participant_location)
 198            .add_request_handler(share_project)
 199            .add_message_handler(unshare_project)
 200            .add_request_handler(join_project)
 201            .add_request_handler(join_hosted_project)
 202            .add_message_handler(leave_project)
 203            .add_request_handler(update_project)
 204            .add_request_handler(update_worktree)
 205            .add_message_handler(start_language_server)
 206            .add_message_handler(update_language_server)
 207            .add_message_handler(update_diagnostic_summary)
 208            .add_message_handler(update_worktree_settings)
 209            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 210            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 211            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 212            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 213            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 214            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 215            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 216            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 217            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 218            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 219            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 220            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 221            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 222            .add_request_handler(
 223                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 224            )
 225            .add_request_handler(
 226                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 227            )
 228            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 229            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 230            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 231            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 232            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 233            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 234            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 235            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 236            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 237            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 238            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 239            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 240            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 241            .add_message_handler(create_buffer_for_peer)
 242            .add_request_handler(update_buffer)
 243            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 244            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 245            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 246            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 247            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 248            .add_request_handler(get_users)
 249            .add_request_handler(fuzzy_search_users)
 250            .add_request_handler(request_contact)
 251            .add_request_handler(remove_contact)
 252            .add_request_handler(respond_to_contact_request)
 253            .add_request_handler(create_channel)
 254            .add_request_handler(delete_channel)
 255            .add_request_handler(invite_channel_member)
 256            .add_request_handler(remove_channel_member)
 257            .add_request_handler(set_channel_member_role)
 258            .add_request_handler(set_channel_visibility)
 259            .add_request_handler(rename_channel)
 260            .add_request_handler(join_channel_buffer)
 261            .add_request_handler(leave_channel_buffer)
 262            .add_message_handler(update_channel_buffer)
 263            .add_request_handler(rejoin_channel_buffers)
 264            .add_request_handler(get_channel_members)
 265            .add_request_handler(respond_to_channel_invite)
 266            .add_request_handler(join_channel)
 267            .add_request_handler(join_channel_chat)
 268            .add_message_handler(leave_channel_chat)
 269            .add_request_handler(send_channel_message)
 270            .add_request_handler(remove_channel_message)
 271            .add_request_handler(get_channel_messages)
 272            .add_request_handler(get_channel_messages_by_id)
 273            .add_request_handler(get_notifications)
 274            .add_request_handler(mark_notification_as_read)
 275            .add_request_handler(move_channel)
 276            .add_request_handler(follow)
 277            .add_message_handler(unfollow)
 278            .add_message_handler(update_followers)
 279            .add_request_handler(get_private_user_info)
 280            .add_message_handler(acknowledge_channel_message)
 281            .add_message_handler(acknowledge_buffer_version);
 282
 283        Arc::new(server)
 284    }
 285
 286    pub async fn start(&self) -> Result<()> {
 287        let server_id = *self.id.lock();
 288        let app_state = self.app_state.clone();
 289        let peer = self.peer.clone();
 290        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 291        let pool = self.connection_pool.clone();
 292        let live_kit_client = self.app_state.live_kit_client.clone();
 293
 294        let span = info_span!("start server");
 295        self.executor.spawn_detached(
 296            async move {
 297                tracing::info!("waiting for cleanup timeout");
 298                timeout.await;
 299                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 300                if let Some((room_ids, channel_ids)) = app_state
 301                    .db
 302                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 303                    .await
 304                    .trace_err()
 305                {
 306                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 307                    tracing::info!(
 308                        stale_channel_buffer_count = channel_ids.len(),
 309                        "retrieved stale channel buffers"
 310                    );
 311
 312                    for channel_id in channel_ids {
 313                        if let Some(refreshed_channel_buffer) = app_state
 314                            .db
 315                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 316                            .await
 317                            .trace_err()
 318                        {
 319                            for connection_id in refreshed_channel_buffer.connection_ids {
 320                                peer.send(
 321                                    connection_id,
 322                                    proto::UpdateChannelBufferCollaborators {
 323                                        channel_id: channel_id.to_proto(),
 324                                        collaborators: refreshed_channel_buffer
 325                                            .collaborators
 326                                            .clone(),
 327                                    },
 328                                )
 329                                .trace_err();
 330                            }
 331                        }
 332                    }
 333
 334                    for room_id in room_ids {
 335                        let mut contacts_to_update = HashSet::default();
 336                        let mut canceled_calls_to_user_ids = Vec::new();
 337                        let mut live_kit_room = String::new();
 338                        let mut delete_live_kit_room = false;
 339
 340                        if let Some(mut refreshed_room) = app_state
 341                            .db
 342                            .clear_stale_room_participants(room_id, server_id)
 343                            .await
 344                            .trace_err()
 345                        {
 346                            tracing::info!(
 347                                room_id = room_id.0,
 348                                new_participant_count = refreshed_room.room.participants.len(),
 349                                "refreshed room"
 350                            );
 351                            room_updated(&refreshed_room.room, &peer);
 352                            if let Some(channel_id) = refreshed_room.channel_id {
 353                                channel_updated(
 354                                    channel_id,
 355                                    &refreshed_room.room,
 356                                    &refreshed_room.channel_members,
 357                                    &peer,
 358                                    &pool.lock(),
 359                                );
 360                            }
 361                            contacts_to_update
 362                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 363                            contacts_to_update
 364                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 365                            canceled_calls_to_user_ids =
 366                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 367                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 368                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 369                        }
 370
 371                        {
 372                            let pool = pool.lock();
 373                            for canceled_user_id in canceled_calls_to_user_ids {
 374                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 375                                    peer.send(
 376                                        connection_id,
 377                                        proto::CallCanceled {
 378                                            room_id: room_id.to_proto(),
 379                                        },
 380                                    )
 381                                    .trace_err();
 382                                }
 383                            }
 384                        }
 385
 386                        for user_id in contacts_to_update {
 387                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 388                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 389                            if let Some((busy, contacts)) = busy.zip(contacts) {
 390                                let pool = pool.lock();
 391                                let updated_contact = contact_for_user(user_id, busy, &pool);
 392                                for contact in contacts {
 393                                    if let db::Contact::Accepted {
 394                                        user_id: contact_user_id,
 395                                        ..
 396                                    } = contact
 397                                    {
 398                                        for contact_conn_id in
 399                                            pool.user_connection_ids(contact_user_id)
 400                                        {
 401                                            peer.send(
 402                                                contact_conn_id,
 403                                                proto::UpdateContacts {
 404                                                    contacts: vec![updated_contact.clone()],
 405                                                    remove_contacts: Default::default(),
 406                                                    incoming_requests: Default::default(),
 407                                                    remove_incoming_requests: Default::default(),
 408                                                    outgoing_requests: Default::default(),
 409                                                    remove_outgoing_requests: Default::default(),
 410                                                },
 411                                            )
 412                                            .trace_err();
 413                                        }
 414                                    }
 415                                }
 416                            }
 417                        }
 418
 419                        if let Some(live_kit) = live_kit_client.as_ref() {
 420                            if delete_live_kit_room {
 421                                live_kit.delete_room(live_kit_room).await.trace_err();
 422                            }
 423                        }
 424                    }
 425                }
 426
 427                app_state
 428                    .db
 429                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 430                    .await
 431                    .trace_err();
 432            }
 433            .instrument(span),
 434        );
 435        Ok(())
 436    }
 437
 438    pub fn teardown(&self) {
 439        self.peer.teardown();
 440        self.connection_pool.lock().reset();
 441        let _ = self.teardown.send(true);
 442    }
 443
 444    #[cfg(test)]
 445    pub fn reset(&self, id: ServerId) {
 446        self.teardown();
 447        *self.id.lock() = id;
 448        self.peer.reset(id.0 as u32);
 449        let _ = self.teardown.send(false);
 450    }
 451
 452    #[cfg(test)]
 453    pub fn id(&self) -> ServerId {
 454        *self.id.lock()
 455    }
 456
 457    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 458    where
 459        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 460        Fut: 'static + Send + Future<Output = Result<()>>,
 461        M: EnvelopedMessage,
 462    {
 463        let prev_handler = self.handlers.insert(
 464            TypeId::of::<M>(),
 465            Box::new(move |envelope, session| {
 466                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 467                let span = info_span!(
 468                    "handle message",
 469                    payload_type = envelope.payload_type_name()
 470                );
 471                span.in_scope(|| {
 472                    tracing::info!(
 473                        payload_type = envelope.payload_type_name(),
 474                        "message received"
 475                    );
 476                });
 477                let start_time = Instant::now();
 478                let future = (handler)(*envelope, session);
 479                async move {
 480                    let result = future.await;
 481                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 482                    match result {
 483                        Err(error) => {
 484                            tracing::error!(%error, ?duration_ms, "error handling message")
 485                        }
 486                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 487                    }
 488                }
 489                .instrument(span)
 490                .boxed()
 491            }),
 492        );
 493        if prev_handler.is_some() {
 494            panic!("registered a handler for the same message twice");
 495        }
 496        self
 497    }
 498
 499    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 500    where
 501        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 502        Fut: 'static + Send + Future<Output = Result<()>>,
 503        M: EnvelopedMessage,
 504    {
 505        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 506        self
 507    }
 508
 509    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 512        Fut: Send + Future<Output = Result<()>>,
 513        M: RequestMessage,
 514    {
 515        let handler = Arc::new(handler);
 516        self.add_handler(move |envelope, session| {
 517            let receipt = envelope.receipt();
 518            let handler = handler.clone();
 519            async move {
 520                let peer = session.peer.clone();
 521                let responded = Arc::new(AtomicBool::default());
 522                let response = Response {
 523                    peer: peer.clone(),
 524                    responded: responded.clone(),
 525                    receipt,
 526                };
 527                match (handler)(envelope.payload, response, session).await {
 528                    Ok(()) => {
 529                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 530                            Ok(())
 531                        } else {
 532                            Err(anyhow!("handler did not send a response"))?
 533                        }
 534                    }
 535                    Err(error) => {
 536                        let proto_err = match &error {
 537                            Error::Internal(err) => err.to_proto(),
 538                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 539                        };
 540                        peer.respond_with_error(receipt, proto_err)?;
 541                        Err(error)
 542                    }
 543                }
 544            }
 545        })
 546    }
 547
 548    #[allow(clippy::too_many_arguments)]
 549    pub fn handle_connection(
 550        self: &Arc<Self>,
 551        connection: Connection,
 552        address: String,
 553        user: User,
 554        zed_version: ZedVersion,
 555        impersonator: Option<User>,
 556        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 557        executor: Executor,
 558    ) -> impl Future<Output = Result<()>> {
 559        let this = self.clone();
 560        let user_id = user.id;
 561        let login = user.github_login;
 562        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 563        if let Some(impersonator) = impersonator {
 564            span.record("impersonator", &impersonator.github_login);
 565        }
 566        let mut teardown = self.teardown.subscribe();
 567        async move {
 568            if *teardown.borrow() {
 569                return Err(anyhow!("server is tearing down"))?;
 570            }
 571            let (connection_id, handle_io, mut incoming_rx) = this
 572                .peer
 573                .add_connection(connection, {
 574                    let executor = executor.clone();
 575                    move |duration| executor.sleep(duration)
 576                });
 577
 578            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 579            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 580            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 581
 582            if let Some(send_connection_id) = send_connection_id.take() {
 583                let _ = send_connection_id.send(connection_id);
 584            }
 585
 586            if !user.connected_once {
 587                this.peer.send(connection_id, proto::ShowContacts {})?;
 588                this.app_state.db.set_user_connected_once(user_id, true).await?;
 589            }
 590
 591            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 592                this.app_state.db.get_contacts(user_id),
 593                this.app_state.db.get_channels_for_user(user_id),
 594                this.app_state.db.get_channel_invites_for_user(user_id),
 595            ).await?;
 596
 597            {
 598                let mut pool = this.connection_pool.lock();
 599                pool.add_connection(connection_id, user_id, user.admin, zed_version);
 600                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 601                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 602                this.peer.send(connection_id, build_channels_update(
 603                    channels_for_user,
 604                    channel_invites
 605                ))?;
 606            }
 607
 608            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 609                this.peer.send(connection_id, incoming_call)?;
 610            }
 611
 612            let session = Session {
 613                user_id,
 614                connection_id,
 615                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 616                peer: this.peer.clone(),
 617                connection_pool: this.connection_pool.clone(),
 618                live_kit_client: this.app_state.live_kit_client.clone(),
 619                _executor: executor.clone()
 620            };
 621            update_user_contacts(user_id, &session).await?;
 622
 623            let handle_io = handle_io.fuse();
 624            futures::pin_mut!(handle_io);
 625
 626            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 627            // This prevents deadlocks when e.g., client A performs a request to client B and
 628            // client B performs a request to client A. If both clients stop processing further
 629            // messages until their respective request completes, they won't have a chance to
 630            // respond to the other client's request and cause a deadlock.
 631            //
 632            // This arrangement ensures we will attempt to process earlier messages first, but fall
 633            // back to processing messages arrived later in the spirit of making progress.
 634            let mut foreground_message_handlers = FuturesUnordered::new();
 635            let concurrent_handlers = Arc::new(Semaphore::new(256));
 636            loop {
 637                let next_message = async {
 638                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 639                    let message = incoming_rx.next().await;
 640                    (permit, message)
 641                }.fuse();
 642                futures::pin_mut!(next_message);
 643                futures::select_biased! {
 644                    _ = teardown.changed().fuse() => return Ok(()),
 645                    result = handle_io => {
 646                        if let Err(error) = result {
 647                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 648                        }
 649                        break;
 650                    }
 651                    _ = foreground_message_handlers.next() => {}
 652                    next_message = next_message => {
 653                        let (permit, message) = next_message;
 654                        if let Some(message) = message {
 655                            let type_name = message.payload_type_name();
 656                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 657                            let span_enter = span.enter();
 658                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 659                                let is_background = message.is_background();
 660                                let handle_message = (handler)(message, session.clone());
 661                                drop(span_enter);
 662
 663                                let handle_message = async move {
 664                                    handle_message.await;
 665                                    drop(permit);
 666                                }.instrument(span);
 667                                if is_background {
 668                                    executor.spawn_detached(handle_message);
 669                                } else {
 670                                    foreground_message_handlers.push(handle_message);
 671                                }
 672                            } else {
 673                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 674                            }
 675                        } else {
 676                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 677                            break;
 678                        }
 679                    }
 680                }
 681            }
 682
 683            drop(foreground_message_handlers);
 684            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 685            if let Err(error) = connection_lost(session, teardown, executor).await {
 686                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 687            }
 688
 689            Ok(())
 690        }.instrument(span)
 691    }
 692
 693    pub async fn invite_code_redeemed(
 694        self: &Arc<Self>,
 695        inviter_id: UserId,
 696        invitee_id: UserId,
 697    ) -> Result<()> {
 698        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 699            if let Some(code) = &user.invite_code {
 700                let pool = self.connection_pool.lock();
 701                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 702                for connection_id in pool.user_connection_ids(inviter_id) {
 703                    self.peer.send(
 704                        connection_id,
 705                        proto::UpdateContacts {
 706                            contacts: vec![invitee_contact.clone()],
 707                            ..Default::default()
 708                        },
 709                    )?;
 710                    self.peer.send(
 711                        connection_id,
 712                        proto::UpdateInviteInfo {
 713                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 714                            count: user.invite_count as u32,
 715                        },
 716                    )?;
 717                }
 718            }
 719        }
 720        Ok(())
 721    }
 722
 723    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 724        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 725            if let Some(invite_code) = &user.invite_code {
 726                let pool = self.connection_pool.lock();
 727                for connection_id in pool.user_connection_ids(user_id) {
 728                    self.peer.send(
 729                        connection_id,
 730                        proto::UpdateInviteInfo {
 731                            url: format!(
 732                                "{}{}",
 733                                self.app_state.config.invite_link_prefix, invite_code
 734                            ),
 735                            count: user.invite_count as u32,
 736                        },
 737                    )?;
 738                }
 739            }
 740        }
 741        Ok(())
 742    }
 743
 744    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 745        ServerSnapshot {
 746            connection_pool: ConnectionPoolGuard {
 747                guard: self.connection_pool.lock(),
 748                _not_send: PhantomData,
 749            },
 750            peer: &self.peer,
 751        }
 752    }
 753}
 754
 755impl<'a> Deref for ConnectionPoolGuard<'a> {
 756    type Target = ConnectionPool;
 757
 758    fn deref(&self) -> &Self::Target {
 759        &self.guard
 760    }
 761}
 762
 763impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 764    fn deref_mut(&mut self) -> &mut Self::Target {
 765        &mut self.guard
 766    }
 767}
 768
 769impl<'a> Drop for ConnectionPoolGuard<'a> {
 770    fn drop(&mut self) {
 771        #[cfg(test)]
 772        self.check_invariants();
 773    }
 774}
 775
 776fn broadcast<F>(
 777    sender_id: Option<ConnectionId>,
 778    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 779    mut f: F,
 780) where
 781    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 782{
 783    for receiver_id in receiver_ids {
 784        if Some(receiver_id) != sender_id {
 785            if let Err(error) = f(receiver_id) {
 786                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 787            }
 788        }
 789    }
 790}
 791
 792pub struct ProtocolVersion(u32);
 793
 794impl Header for ProtocolVersion {
 795    fn name() -> &'static HeaderName {
 796        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 797        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 798    }
 799
 800    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 801    where
 802        Self: Sized,
 803        I: Iterator<Item = &'i axum::http::HeaderValue>,
 804    {
 805        let version = values
 806            .next()
 807            .ok_or_else(axum::headers::Error::invalid)?
 808            .to_str()
 809            .map_err(|_| axum::headers::Error::invalid())?
 810            .parse()
 811            .map_err(|_| axum::headers::Error::invalid())?;
 812        Ok(Self(version))
 813    }
 814
 815    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 816        values.extend([self.0.to_string().parse().unwrap()]);
 817    }
 818}
 819
 820pub struct AppVersionHeader(SemanticVersion);
 821impl Header for AppVersionHeader {
 822    fn name() -> &'static HeaderName {
 823        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 824        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 825    }
 826
 827    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 828    where
 829        Self: Sized,
 830        I: Iterator<Item = &'i axum::http::HeaderValue>,
 831    {
 832        let version = values
 833            .next()
 834            .ok_or_else(axum::headers::Error::invalid)?
 835            .to_str()
 836            .map_err(|_| axum::headers::Error::invalid())?
 837            .parse()
 838            .map_err(|_| axum::headers::Error::invalid())?;
 839        Ok(Self(version))
 840    }
 841
 842    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 843        values.extend([self.0.to_string().parse().unwrap()]);
 844    }
 845}
 846
 847pub fn routes(server: Arc<Server>) -> Router<Body> {
 848    Router::new()
 849        .route("/rpc", get(handle_websocket_request))
 850        .layer(
 851            ServiceBuilder::new()
 852                .layer(Extension(server.app_state.clone()))
 853                .layer(middleware::from_fn(auth::validate_header)),
 854        )
 855        .route("/metrics", get(handle_metrics))
 856        .layer(Extension(server))
 857}
 858
 859pub async fn handle_websocket_request(
 860    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 861    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 862    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 863    Extension(server): Extension<Arc<Server>>,
 864    Extension(user): Extension<User>,
 865    Extension(impersonator): Extension<Impersonator>,
 866    ws: WebSocketUpgrade,
 867) -> axum::response::Response {
 868    if protocol_version != rpc::PROTOCOL_VERSION {
 869        return (
 870            StatusCode::UPGRADE_REQUIRED,
 871            "client must be upgraded".to_string(),
 872        )
 873            .into_response();
 874    }
 875
 876    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 877        return (
 878            StatusCode::UPGRADE_REQUIRED,
 879            "no version header found".to_string(),
 880        )
 881            .into_response();
 882    };
 883
 884    if !version.is_supported() {
 885        return (
 886            StatusCode::UPGRADE_REQUIRED,
 887            "client must be upgraded".to_string(),
 888        )
 889            .into_response();
 890    }
 891
 892    let socket_address = socket_address.to_string();
 893    ws.on_upgrade(move |socket| {
 894        use util::ResultExt;
 895        let socket = socket
 896            .map_ok(to_tungstenite_message)
 897            .err_into()
 898            .with(|message| async move { Ok(to_axum_message(message)) });
 899        let connection = Connection::new(Box::pin(socket));
 900        async move {
 901            server
 902                .handle_connection(
 903                    connection,
 904                    socket_address,
 905                    user,
 906                    version,
 907                    impersonator.0,
 908                    None,
 909                    Executor::Production,
 910                )
 911                .await
 912                .log_err();
 913        }
 914    })
 915}
 916
 917pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 918    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 919    let connections_metric = CONNECTIONS_METRIC
 920        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
 921
 922    let connections = server
 923        .connection_pool
 924        .lock()
 925        .connections()
 926        .filter(|connection| !connection.admin)
 927        .count();
 928    connections_metric.set(connections as _);
 929
 930    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 931    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
 932        register_int_gauge!(
 933            "shared_projects",
 934            "number of open projects with one or more guests"
 935        )
 936        .unwrap()
 937    });
 938
 939    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 940    shared_projects_metric.set(shared_projects as _);
 941
 942    let encoder = prometheus::TextEncoder::new();
 943    let metric_families = prometheus::gather();
 944    let encoded_metrics = encoder
 945        .encode_to_string(&metric_families)
 946        .map_err(|err| anyhow!("{}", err))?;
 947    Ok(encoded_metrics)
 948}
 949
 950#[instrument(err, skip(executor))]
 951async fn connection_lost(
 952    session: Session,
 953    mut teardown: watch::Receiver<bool>,
 954    executor: Executor,
 955) -> Result<()> {
 956    session.peer.disconnect(session.connection_id);
 957    session
 958        .connection_pool()
 959        .await
 960        .remove_connection(session.connection_id)?;
 961
 962    session
 963        .db()
 964        .await
 965        .connection_lost(session.connection_id)
 966        .await
 967        .trace_err();
 968
 969    futures::select_biased! {
 970        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 971            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 972            leave_room_for_session(&session).await.trace_err();
 973            leave_channel_buffers_for_session(&session)
 974                .await
 975                .trace_err();
 976
 977            if !session
 978                .connection_pool()
 979                .await
 980                .is_user_online(session.user_id)
 981            {
 982                let db = session.db().await;
 983                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 984                    room_updated(&room, &session.peer);
 985                }
 986            }
 987
 988            update_user_contacts(session.user_id, &session).await?;
 989        }
 990        _ = teardown.changed().fuse() => {}
 991    }
 992
 993    Ok(())
 994}
 995
 996/// Acknowledges a ping from a client, used to keep the connection alive.
 997async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 998    response.send(proto::Ack {})?;
 999    Ok(())
1000}
1001
1002/// Creates a new room for calling (outside of channels)
1003async fn create_room(
1004    _request: proto::CreateRoom,
1005    response: Response<proto::CreateRoom>,
1006    session: Session,
1007) -> Result<()> {
1008    let live_kit_room = nanoid::nanoid!(30);
1009
1010    let live_kit_connection_info = {
1011        let live_kit_room = live_kit_room.clone();
1012        let live_kit = session.live_kit_client.as_ref();
1013
1014        util::async_maybe!({
1015            let live_kit = live_kit?;
1016
1017            let token = live_kit
1018                .room_token(&live_kit_room, &session.user_id.to_string())
1019                .trace_err()?;
1020
1021            Some(proto::LiveKitConnectionInfo {
1022                server_url: live_kit.url().into(),
1023                token,
1024                can_publish: true,
1025            })
1026        })
1027    }
1028    .await;
1029
1030    let room = session
1031        .db()
1032        .await
1033        .create_room(session.user_id, session.connection_id, &live_kit_room)
1034        .await?;
1035
1036    response.send(proto::CreateRoomResponse {
1037        room: Some(room.clone()),
1038        live_kit_connection_info,
1039    })?;
1040
1041    update_user_contacts(session.user_id, &session).await?;
1042    Ok(())
1043}
1044
1045/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1046async fn join_room(
1047    request: proto::JoinRoom,
1048    response: Response<proto::JoinRoom>,
1049    session: Session,
1050) -> Result<()> {
1051    let room_id = RoomId::from_proto(request.id);
1052
1053    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1054
1055    if let Some(channel_id) = channel_id {
1056        return join_channel_internal(channel_id, Box::new(response), session).await;
1057    }
1058
1059    let joined_room = {
1060        let room = session
1061            .db()
1062            .await
1063            .join_room(room_id, session.user_id, session.connection_id)
1064            .await?;
1065        room_updated(&room.room, &session.peer);
1066        room.into_inner()
1067    };
1068
1069    for connection_id in session
1070        .connection_pool()
1071        .await
1072        .user_connection_ids(session.user_id)
1073    {
1074        session
1075            .peer
1076            .send(
1077                connection_id,
1078                proto::CallCanceled {
1079                    room_id: room_id.to_proto(),
1080                },
1081            )
1082            .trace_err();
1083    }
1084
1085    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1086        if let Some(token) = live_kit
1087            .room_token(
1088                &joined_room.room.live_kit_room,
1089                &session.user_id.to_string(),
1090            )
1091            .trace_err()
1092        {
1093            Some(proto::LiveKitConnectionInfo {
1094                server_url: live_kit.url().into(),
1095                token,
1096                can_publish: true,
1097            })
1098        } else {
1099            None
1100        }
1101    } else {
1102        None
1103    };
1104
1105    response.send(proto::JoinRoomResponse {
1106        room: Some(joined_room.room),
1107        channel_id: None,
1108        live_kit_connection_info,
1109    })?;
1110
1111    update_user_contacts(session.user_id, &session).await?;
1112    Ok(())
1113}
1114
1115/// Rejoin room is used to reconnect to a room after connection errors.
1116async fn rejoin_room(
1117    request: proto::RejoinRoom,
1118    response: Response<proto::RejoinRoom>,
1119    session: Session,
1120) -> Result<()> {
1121    let room;
1122    let channel_id;
1123    let channel_members;
1124    {
1125        let mut rejoined_room = session
1126            .db()
1127            .await
1128            .rejoin_room(request, session.user_id, session.connection_id)
1129            .await?;
1130
1131        response.send(proto::RejoinRoomResponse {
1132            room: Some(rejoined_room.room.clone()),
1133            reshared_projects: rejoined_room
1134                .reshared_projects
1135                .iter()
1136                .map(|project| proto::ResharedProject {
1137                    id: project.id.to_proto(),
1138                    collaborators: project
1139                        .collaborators
1140                        .iter()
1141                        .map(|collaborator| collaborator.to_proto())
1142                        .collect(),
1143                })
1144                .collect(),
1145            rejoined_projects: rejoined_room
1146                .rejoined_projects
1147                .iter()
1148                .map(|rejoined_project| proto::RejoinedProject {
1149                    id: rejoined_project.id.to_proto(),
1150                    worktrees: rejoined_project
1151                        .worktrees
1152                        .iter()
1153                        .map(|worktree| proto::WorktreeMetadata {
1154                            id: worktree.id,
1155                            root_name: worktree.root_name.clone(),
1156                            visible: worktree.visible,
1157                            abs_path: worktree.abs_path.clone(),
1158                        })
1159                        .collect(),
1160                    collaborators: rejoined_project
1161                        .collaborators
1162                        .iter()
1163                        .map(|collaborator| collaborator.to_proto())
1164                        .collect(),
1165                    language_servers: rejoined_project.language_servers.clone(),
1166                })
1167                .collect(),
1168        })?;
1169        room_updated(&rejoined_room.room, &session.peer);
1170
1171        for project in &rejoined_room.reshared_projects {
1172            for collaborator in &project.collaborators {
1173                session
1174                    .peer
1175                    .send(
1176                        collaborator.connection_id,
1177                        proto::UpdateProjectCollaborator {
1178                            project_id: project.id.to_proto(),
1179                            old_peer_id: Some(project.old_connection_id.into()),
1180                            new_peer_id: Some(session.connection_id.into()),
1181                        },
1182                    )
1183                    .trace_err();
1184            }
1185
1186            broadcast(
1187                Some(session.connection_id),
1188                project
1189                    .collaborators
1190                    .iter()
1191                    .map(|collaborator| collaborator.connection_id),
1192                |connection_id| {
1193                    session.peer.forward_send(
1194                        session.connection_id,
1195                        connection_id,
1196                        proto::UpdateProject {
1197                            project_id: project.id.to_proto(),
1198                            worktrees: project.worktrees.clone(),
1199                        },
1200                    )
1201                },
1202            );
1203        }
1204
1205        for project in &rejoined_room.rejoined_projects {
1206            for collaborator in &project.collaborators {
1207                session
1208                    .peer
1209                    .send(
1210                        collaborator.connection_id,
1211                        proto::UpdateProjectCollaborator {
1212                            project_id: project.id.to_proto(),
1213                            old_peer_id: Some(project.old_connection_id.into()),
1214                            new_peer_id: Some(session.connection_id.into()),
1215                        },
1216                    )
1217                    .trace_err();
1218            }
1219        }
1220
1221        for project in &mut rejoined_room.rejoined_projects {
1222            for worktree in mem::take(&mut project.worktrees) {
1223                #[cfg(any(test, feature = "test-support"))]
1224                const MAX_CHUNK_SIZE: usize = 2;
1225                #[cfg(not(any(test, feature = "test-support")))]
1226                const MAX_CHUNK_SIZE: usize = 256;
1227
1228                // Stream this worktree's entries.
1229                let message = proto::UpdateWorktree {
1230                    project_id: project.id.to_proto(),
1231                    worktree_id: worktree.id,
1232                    abs_path: worktree.abs_path.clone(),
1233                    root_name: worktree.root_name,
1234                    updated_entries: worktree.updated_entries,
1235                    removed_entries: worktree.removed_entries,
1236                    scan_id: worktree.scan_id,
1237                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1238                    updated_repositories: worktree.updated_repositories,
1239                    removed_repositories: worktree.removed_repositories,
1240                };
1241                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1242                    session.peer.send(session.connection_id, update.clone())?;
1243                }
1244
1245                // Stream this worktree's diagnostics.
1246                for summary in worktree.diagnostic_summaries {
1247                    session.peer.send(
1248                        session.connection_id,
1249                        proto::UpdateDiagnosticSummary {
1250                            project_id: project.id.to_proto(),
1251                            worktree_id: worktree.id,
1252                            summary: Some(summary),
1253                        },
1254                    )?;
1255                }
1256
1257                for settings_file in worktree.settings_files {
1258                    session.peer.send(
1259                        session.connection_id,
1260                        proto::UpdateWorktreeSettings {
1261                            project_id: project.id.to_proto(),
1262                            worktree_id: worktree.id,
1263                            path: settings_file.path,
1264                            content: Some(settings_file.content),
1265                        },
1266                    )?;
1267                }
1268            }
1269
1270            for language_server in &project.language_servers {
1271                session.peer.send(
1272                    session.connection_id,
1273                    proto::UpdateLanguageServer {
1274                        project_id: project.id.to_proto(),
1275                        language_server_id: language_server.id,
1276                        variant: Some(
1277                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1278                                proto::LspDiskBasedDiagnosticsUpdated {},
1279                            ),
1280                        ),
1281                    },
1282                )?;
1283            }
1284        }
1285
1286        let rejoined_room = rejoined_room.into_inner();
1287
1288        room = rejoined_room.room;
1289        channel_id = rejoined_room.channel_id;
1290        channel_members = rejoined_room.channel_members;
1291    }
1292
1293    if let Some(channel_id) = channel_id {
1294        channel_updated(
1295            channel_id,
1296            &room,
1297            &channel_members,
1298            &session.peer,
1299            &*session.connection_pool().await,
1300        );
1301    }
1302
1303    update_user_contacts(session.user_id, &session).await?;
1304    Ok(())
1305}
1306
1307/// leave room disconnects from the room.
1308async fn leave_room(
1309    _: proto::LeaveRoom,
1310    response: Response<proto::LeaveRoom>,
1311    session: Session,
1312) -> Result<()> {
1313    leave_room_for_session(&session).await?;
1314    response.send(proto::Ack {})?;
1315    Ok(())
1316}
1317
1318/// Updates the permissions of someone else in the room.
1319async fn set_room_participant_role(
1320    request: proto::SetRoomParticipantRole,
1321    response: Response<proto::SetRoomParticipantRole>,
1322    session: Session,
1323) -> Result<()> {
1324    let user_id = UserId::from_proto(request.user_id);
1325    let role = ChannelRole::from(request.role());
1326
1327    if role == ChannelRole::Talker {
1328        let pool = session.connection_pool().await;
1329
1330        for connection in pool.user_connections(user_id) {
1331            if !connection.zed_version.supports_talker_role() {
1332                Err(anyhow!(
1333                    "This user is on zed {} which does not support unmute",
1334                    connection.zed_version
1335                ))?;
1336            }
1337        }
1338    }
1339
1340    let (live_kit_room, can_publish) = {
1341        let room = session
1342            .db()
1343            .await
1344            .set_room_participant_role(
1345                session.user_id,
1346                RoomId::from_proto(request.room_id),
1347                user_id,
1348                role,
1349            )
1350            .await?;
1351
1352        let live_kit_room = room.live_kit_room.clone();
1353        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1354        room_updated(&room, &session.peer);
1355        (live_kit_room, can_publish)
1356    };
1357
1358    if let Some(live_kit) = session.live_kit_client.as_ref() {
1359        live_kit
1360            .update_participant(
1361                live_kit_room.clone(),
1362                request.user_id.to_string(),
1363                live_kit_server::proto::ParticipantPermission {
1364                    can_subscribe: true,
1365                    can_publish,
1366                    can_publish_data: can_publish,
1367                    hidden: false,
1368                    recorder: false,
1369                },
1370            )
1371            .await
1372            .trace_err();
1373    }
1374
1375    response.send(proto::Ack {})?;
1376    Ok(())
1377}
1378
1379/// Call someone else into the current room
1380async fn call(
1381    request: proto::Call,
1382    response: Response<proto::Call>,
1383    session: Session,
1384) -> Result<()> {
1385    let room_id = RoomId::from_proto(request.room_id);
1386    let calling_user_id = session.user_id;
1387    let calling_connection_id = session.connection_id;
1388    let called_user_id = UserId::from_proto(request.called_user_id);
1389    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1390    if !session
1391        .db()
1392        .await
1393        .has_contact(calling_user_id, called_user_id)
1394        .await?
1395    {
1396        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1397    }
1398
1399    let incoming_call = {
1400        let (room, incoming_call) = &mut *session
1401            .db()
1402            .await
1403            .call(
1404                room_id,
1405                calling_user_id,
1406                calling_connection_id,
1407                called_user_id,
1408                initial_project_id,
1409            )
1410            .await?;
1411        room_updated(&room, &session.peer);
1412        mem::take(incoming_call)
1413    };
1414    update_user_contacts(called_user_id, &session).await?;
1415
1416    let mut calls = session
1417        .connection_pool()
1418        .await
1419        .user_connection_ids(called_user_id)
1420        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1421        .collect::<FuturesUnordered<_>>();
1422
1423    while let Some(call_response) = calls.next().await {
1424        match call_response.as_ref() {
1425            Ok(_) => {
1426                response.send(proto::Ack {})?;
1427                return Ok(());
1428            }
1429            Err(_) => {
1430                call_response.trace_err();
1431            }
1432        }
1433    }
1434
1435    {
1436        let room = session
1437            .db()
1438            .await
1439            .call_failed(room_id, called_user_id)
1440            .await?;
1441        room_updated(&room, &session.peer);
1442    }
1443    update_user_contacts(called_user_id, &session).await?;
1444
1445    Err(anyhow!("failed to ring user"))?
1446}
1447
1448/// Cancel an outgoing call.
1449async fn cancel_call(
1450    request: proto::CancelCall,
1451    response: Response<proto::CancelCall>,
1452    session: Session,
1453) -> Result<()> {
1454    let called_user_id = UserId::from_proto(request.called_user_id);
1455    let room_id = RoomId::from_proto(request.room_id);
1456    {
1457        let room = session
1458            .db()
1459            .await
1460            .cancel_call(room_id, session.connection_id, called_user_id)
1461            .await?;
1462        room_updated(&room, &session.peer);
1463    }
1464
1465    for connection_id in session
1466        .connection_pool()
1467        .await
1468        .user_connection_ids(called_user_id)
1469    {
1470        session
1471            .peer
1472            .send(
1473                connection_id,
1474                proto::CallCanceled {
1475                    room_id: room_id.to_proto(),
1476                },
1477            )
1478            .trace_err();
1479    }
1480    response.send(proto::Ack {})?;
1481
1482    update_user_contacts(called_user_id, &session).await?;
1483    Ok(())
1484}
1485
1486/// Decline an incoming call.
1487async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1488    let room_id = RoomId::from_proto(message.room_id);
1489    {
1490        let room = session
1491            .db()
1492            .await
1493            .decline_call(Some(room_id), session.user_id)
1494            .await?
1495            .ok_or_else(|| anyhow!("failed to decline call"))?;
1496        room_updated(&room, &session.peer);
1497    }
1498
1499    for connection_id in session
1500        .connection_pool()
1501        .await
1502        .user_connection_ids(session.user_id)
1503    {
1504        session
1505            .peer
1506            .send(
1507                connection_id,
1508                proto::CallCanceled {
1509                    room_id: room_id.to_proto(),
1510                },
1511            )
1512            .trace_err();
1513    }
1514    update_user_contacts(session.user_id, &session).await?;
1515    Ok(())
1516}
1517
1518/// Updates other participants in the room with your current location.
1519async fn update_participant_location(
1520    request: proto::UpdateParticipantLocation,
1521    response: Response<proto::UpdateParticipantLocation>,
1522    session: Session,
1523) -> Result<()> {
1524    let room_id = RoomId::from_proto(request.room_id);
1525    let location = request
1526        .location
1527        .ok_or_else(|| anyhow!("invalid location"))?;
1528
1529    let db = session.db().await;
1530    let room = db
1531        .update_room_participant_location(room_id, session.connection_id, location)
1532        .await?;
1533
1534    room_updated(&room, &session.peer);
1535    response.send(proto::Ack {})?;
1536    Ok(())
1537}
1538
1539/// Share a project into the room.
1540async fn share_project(
1541    request: proto::ShareProject,
1542    response: Response<proto::ShareProject>,
1543    session: Session,
1544) -> Result<()> {
1545    let (project_id, room) = &*session
1546        .db()
1547        .await
1548        .share_project(
1549            RoomId::from_proto(request.room_id),
1550            session.connection_id,
1551            &request.worktrees,
1552        )
1553        .await?;
1554    response.send(proto::ShareProjectResponse {
1555        project_id: project_id.to_proto(),
1556    })?;
1557    room_updated(&room, &session.peer);
1558
1559    Ok(())
1560}
1561
1562/// Unshare a project from the room.
1563async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1564    let project_id = ProjectId::from_proto(message.project_id);
1565
1566    let (room, guest_connection_ids) = &*session
1567        .db()
1568        .await
1569        .unshare_project(project_id, session.connection_id)
1570        .await?;
1571
1572    broadcast(
1573        Some(session.connection_id),
1574        guest_connection_ids.iter().copied(),
1575        |conn_id| session.peer.send(conn_id, message.clone()),
1576    );
1577    room_updated(&room, &session.peer);
1578
1579    Ok(())
1580}
1581
1582/// Join someone elses shared project.
1583async fn join_project(
1584    request: proto::JoinProject,
1585    response: Response<proto::JoinProject>,
1586    session: Session,
1587) -> Result<()> {
1588    let project_id = ProjectId::from_proto(request.project_id);
1589
1590    tracing::info!(%project_id, "join project");
1591
1592    let (project, replica_id) = &mut *session
1593        .db()
1594        .await
1595        .join_project_in_room(project_id, session.connection_id)
1596        .await?;
1597
1598    join_project_internal(response, session, project, replica_id)
1599}
1600
1601trait JoinProjectInternalResponse {
1602    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1603}
1604impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1605    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1606        Response::<proto::JoinProject>::send(self, result)
1607    }
1608}
1609impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1610    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1611        Response::<proto::JoinHostedProject>::send(self, result)
1612    }
1613}
1614
1615fn join_project_internal(
1616    response: impl JoinProjectInternalResponse,
1617    session: Session,
1618    project: &mut Project,
1619    replica_id: &ReplicaId,
1620) -> Result<()> {
1621    let collaborators = project
1622        .collaborators
1623        .iter()
1624        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1625        .map(|collaborator| collaborator.to_proto())
1626        .collect::<Vec<_>>();
1627    let project_id = project.id;
1628    let guest_user_id = session.user_id;
1629
1630    let worktrees = project
1631        .worktrees
1632        .iter()
1633        .map(|(id, worktree)| proto::WorktreeMetadata {
1634            id: *id,
1635            root_name: worktree.root_name.clone(),
1636            visible: worktree.visible,
1637            abs_path: worktree.abs_path.clone(),
1638        })
1639        .collect::<Vec<_>>();
1640
1641    for collaborator in &collaborators {
1642        session
1643            .peer
1644            .send(
1645                collaborator.peer_id.unwrap().into(),
1646                proto::AddProjectCollaborator {
1647                    project_id: project_id.to_proto(),
1648                    collaborator: Some(proto::Collaborator {
1649                        peer_id: Some(session.connection_id.into()),
1650                        replica_id: replica_id.0 as u32,
1651                        user_id: guest_user_id.to_proto(),
1652                    }),
1653                },
1654            )
1655            .trace_err();
1656    }
1657
1658    // First, we send the metadata associated with each worktree.
1659    response.send(proto::JoinProjectResponse {
1660        project_id: project.id.0 as u64,
1661        worktrees: worktrees.clone(),
1662        replica_id: replica_id.0 as u32,
1663        collaborators: collaborators.clone(),
1664        language_servers: project.language_servers.clone(),
1665        role: project.role.into(), // todo
1666    })?;
1667
1668    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1669        #[cfg(any(test, feature = "test-support"))]
1670        const MAX_CHUNK_SIZE: usize = 2;
1671        #[cfg(not(any(test, feature = "test-support")))]
1672        const MAX_CHUNK_SIZE: usize = 256;
1673
1674        // Stream this worktree's entries.
1675        let message = proto::UpdateWorktree {
1676            project_id: project_id.to_proto(),
1677            worktree_id,
1678            abs_path: worktree.abs_path.clone(),
1679            root_name: worktree.root_name,
1680            updated_entries: worktree.entries,
1681            removed_entries: Default::default(),
1682            scan_id: worktree.scan_id,
1683            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1684            updated_repositories: worktree.repository_entries.into_values().collect(),
1685            removed_repositories: Default::default(),
1686        };
1687        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1688            session.peer.send(session.connection_id, update.clone())?;
1689        }
1690
1691        // Stream this worktree's diagnostics.
1692        for summary in worktree.diagnostic_summaries {
1693            session.peer.send(
1694                session.connection_id,
1695                proto::UpdateDiagnosticSummary {
1696                    project_id: project_id.to_proto(),
1697                    worktree_id: worktree.id,
1698                    summary: Some(summary),
1699                },
1700            )?;
1701        }
1702
1703        for settings_file in worktree.settings_files {
1704            session.peer.send(
1705                session.connection_id,
1706                proto::UpdateWorktreeSettings {
1707                    project_id: project_id.to_proto(),
1708                    worktree_id: worktree.id,
1709                    path: settings_file.path,
1710                    content: Some(settings_file.content),
1711                },
1712            )?;
1713        }
1714    }
1715
1716    for language_server in &project.language_servers {
1717        session.peer.send(
1718            session.connection_id,
1719            proto::UpdateLanguageServer {
1720                project_id: project_id.to_proto(),
1721                language_server_id: language_server.id,
1722                variant: Some(
1723                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1724                        proto::LspDiskBasedDiagnosticsUpdated {},
1725                    ),
1726                ),
1727            },
1728        )?;
1729    }
1730
1731    Ok(())
1732}
1733
1734/// Leave someone elses shared project.
1735async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1736    let sender_id = session.connection_id;
1737    let project_id = ProjectId::from_proto(request.project_id);
1738    let db = session.db().await;
1739    if db.is_hosted_project(project_id).await? {
1740        let project = db.leave_hosted_project(project_id, sender_id).await?;
1741        project_left(&project, &session);
1742        return Ok(());
1743    }
1744
1745    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1746    tracing::info!(
1747        %project_id,
1748        host_user_id = ?project.host_user_id,
1749        host_connection_id = ?project.host_connection_id,
1750        "leave project"
1751    );
1752
1753    project_left(&project, &session);
1754    room_updated(&room, &session.peer);
1755
1756    Ok(())
1757}
1758
1759async fn join_hosted_project(
1760    request: proto::JoinHostedProject,
1761    response: Response<proto::JoinHostedProject>,
1762    session: Session,
1763) -> Result<()> {
1764    let (mut project, replica_id) = session
1765        .db()
1766        .await
1767        .join_hosted_project(
1768            HostedProjectId(request.id as i32),
1769            session.user_id,
1770            session.connection_id,
1771        )
1772        .await?;
1773
1774    join_project_internal(response, session, &mut project, &replica_id)
1775}
1776
1777/// Updates other participants with changes to the project
1778async fn update_project(
1779    request: proto::UpdateProject,
1780    response: Response<proto::UpdateProject>,
1781    session: Session,
1782) -> Result<()> {
1783    let project_id = ProjectId::from_proto(request.project_id);
1784    let (room, guest_connection_ids) = &*session
1785        .db()
1786        .await
1787        .update_project(project_id, session.connection_id, &request.worktrees)
1788        .await?;
1789    broadcast(
1790        Some(session.connection_id),
1791        guest_connection_ids.iter().copied(),
1792        |connection_id| {
1793            session
1794                .peer
1795                .forward_send(session.connection_id, connection_id, request.clone())
1796        },
1797    );
1798    room_updated(&room, &session.peer);
1799    response.send(proto::Ack {})?;
1800
1801    Ok(())
1802}
1803
1804/// Updates other participants with changes to the worktree
1805async fn update_worktree(
1806    request: proto::UpdateWorktree,
1807    response: Response<proto::UpdateWorktree>,
1808    session: Session,
1809) -> Result<()> {
1810    let guest_connection_ids = session
1811        .db()
1812        .await
1813        .update_worktree(&request, session.connection_id)
1814        .await?;
1815
1816    broadcast(
1817        Some(session.connection_id),
1818        guest_connection_ids.iter().copied(),
1819        |connection_id| {
1820            session
1821                .peer
1822                .forward_send(session.connection_id, connection_id, request.clone())
1823        },
1824    );
1825    response.send(proto::Ack {})?;
1826    Ok(())
1827}
1828
1829/// Updates other participants with changes to the diagnostics
1830async fn update_diagnostic_summary(
1831    message: proto::UpdateDiagnosticSummary,
1832    session: Session,
1833) -> Result<()> {
1834    let guest_connection_ids = session
1835        .db()
1836        .await
1837        .update_diagnostic_summary(&message, session.connection_id)
1838        .await?;
1839
1840    broadcast(
1841        Some(session.connection_id),
1842        guest_connection_ids.iter().copied(),
1843        |connection_id| {
1844            session
1845                .peer
1846                .forward_send(session.connection_id, connection_id, message.clone())
1847        },
1848    );
1849
1850    Ok(())
1851}
1852
1853/// Updates other participants with changes to the worktree settings
1854async fn update_worktree_settings(
1855    message: proto::UpdateWorktreeSettings,
1856    session: Session,
1857) -> Result<()> {
1858    let guest_connection_ids = session
1859        .db()
1860        .await
1861        .update_worktree_settings(&message, session.connection_id)
1862        .await?;
1863
1864    broadcast(
1865        Some(session.connection_id),
1866        guest_connection_ids.iter().copied(),
1867        |connection_id| {
1868            session
1869                .peer
1870                .forward_send(session.connection_id, connection_id, message.clone())
1871        },
1872    );
1873
1874    Ok(())
1875}
1876
1877/// Notify other participants that a  language server has started.
1878async fn start_language_server(
1879    request: proto::StartLanguageServer,
1880    session: Session,
1881) -> Result<()> {
1882    let guest_connection_ids = session
1883        .db()
1884        .await
1885        .start_language_server(&request, session.connection_id)
1886        .await?;
1887
1888    broadcast(
1889        Some(session.connection_id),
1890        guest_connection_ids.iter().copied(),
1891        |connection_id| {
1892            session
1893                .peer
1894                .forward_send(session.connection_id, connection_id, request.clone())
1895        },
1896    );
1897    Ok(())
1898}
1899
1900/// Notify other participants that a language server has changed.
1901async fn update_language_server(
1902    request: proto::UpdateLanguageServer,
1903    session: Session,
1904) -> Result<()> {
1905    let project_id = ProjectId::from_proto(request.project_id);
1906    let project_connection_ids = session
1907        .db()
1908        .await
1909        .project_connection_ids(project_id, session.connection_id)
1910        .await?;
1911    broadcast(
1912        Some(session.connection_id),
1913        project_connection_ids.iter().copied(),
1914        |connection_id| {
1915            session
1916                .peer
1917                .forward_send(session.connection_id, connection_id, request.clone())
1918        },
1919    );
1920    Ok(())
1921}
1922
1923/// forward a project request to the host. These requests should be read only
1924/// as guests are allowed to send them.
1925async fn forward_read_only_project_request<T>(
1926    request: T,
1927    response: Response<T>,
1928    session: Session,
1929) -> Result<()>
1930where
1931    T: EntityMessage + RequestMessage,
1932{
1933    let project_id = ProjectId::from_proto(request.remote_entity_id());
1934    let host_connection_id = session
1935        .db()
1936        .await
1937        .host_for_read_only_project_request(project_id, session.connection_id)
1938        .await?;
1939    let payload = session
1940        .peer
1941        .forward_request(session.connection_id, host_connection_id, request)
1942        .await?;
1943    response.send(payload)?;
1944    Ok(())
1945}
1946
1947/// forward a project request to the host. These requests are disallowed
1948/// for guests.
1949async fn forward_mutating_project_request<T>(
1950    request: T,
1951    response: Response<T>,
1952    session: Session,
1953) -> Result<()>
1954where
1955    T: EntityMessage + RequestMessage,
1956{
1957    let project_id = ProjectId::from_proto(request.remote_entity_id());
1958    let host_connection_id = session
1959        .db()
1960        .await
1961        .host_for_mutating_project_request(project_id, session.connection_id)
1962        .await?;
1963    let payload = session
1964        .peer
1965        .forward_request(session.connection_id, host_connection_id, request)
1966        .await?;
1967    response.send(payload)?;
1968    Ok(())
1969}
1970
1971/// Notify other participants that a new buffer has been created
1972async fn create_buffer_for_peer(
1973    request: proto::CreateBufferForPeer,
1974    session: Session,
1975) -> Result<()> {
1976    session
1977        .db()
1978        .await
1979        .check_user_is_project_host(
1980            ProjectId::from_proto(request.project_id),
1981            session.connection_id,
1982        )
1983        .await?;
1984    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1985    session
1986        .peer
1987        .forward_send(session.connection_id, peer_id.into(), request)?;
1988    Ok(())
1989}
1990
1991/// Notify other participants that a buffer has been updated. This is
1992/// allowed for guests as long as the update is limited to selections.
1993async fn update_buffer(
1994    request: proto::UpdateBuffer,
1995    response: Response<proto::UpdateBuffer>,
1996    session: Session,
1997) -> Result<()> {
1998    let project_id = ProjectId::from_proto(request.project_id);
1999    let mut guest_connection_ids;
2000    let mut host_connection_id = None;
2001
2002    let mut requires_write_permission = false;
2003
2004    for op in request.operations.iter() {
2005        match op.variant {
2006            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2007            Some(_) => requires_write_permission = true,
2008        }
2009    }
2010
2011    {
2012        let collaborators = session
2013            .db()
2014            .await
2015            .project_collaborators_for_buffer_update(
2016                project_id,
2017                session.connection_id,
2018                requires_write_permission,
2019            )
2020            .await?;
2021        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2022        for collaborator in collaborators.iter() {
2023            if collaborator.is_host {
2024                host_connection_id = Some(collaborator.connection_id);
2025            } else {
2026                guest_connection_ids.push(collaborator.connection_id);
2027            }
2028        }
2029    }
2030    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2031
2032    broadcast(
2033        Some(session.connection_id),
2034        guest_connection_ids,
2035        |connection_id| {
2036            session
2037                .peer
2038                .forward_send(session.connection_id, connection_id, request.clone())
2039        },
2040    );
2041    if host_connection_id != session.connection_id {
2042        session
2043            .peer
2044            .forward_request(session.connection_id, host_connection_id, request.clone())
2045            .await?;
2046    }
2047
2048    response.send(proto::Ack {})?;
2049    Ok(())
2050}
2051
2052/// Notify other participants that a project has been updated.
2053async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2054    request: T,
2055    session: Session,
2056) -> Result<()> {
2057    let project_id = ProjectId::from_proto(request.remote_entity_id());
2058    let project_connection_ids = session
2059        .db()
2060        .await
2061        .project_connection_ids(project_id, session.connection_id)
2062        .await?;
2063
2064    broadcast(
2065        Some(session.connection_id),
2066        project_connection_ids.iter().copied(),
2067        |connection_id| {
2068            session
2069                .peer
2070                .forward_send(session.connection_id, connection_id, request.clone())
2071        },
2072    );
2073    Ok(())
2074}
2075
2076/// Start following another user in a call.
2077async fn follow(
2078    request: proto::Follow,
2079    response: Response<proto::Follow>,
2080    session: Session,
2081) -> Result<()> {
2082    let room_id = RoomId::from_proto(request.room_id);
2083    let project_id = request.project_id.map(ProjectId::from_proto);
2084    let leader_id = request
2085        .leader_id
2086        .ok_or_else(|| anyhow!("invalid leader id"))?
2087        .into();
2088    let follower_id = session.connection_id;
2089
2090    session
2091        .db()
2092        .await
2093        .check_room_participants(room_id, leader_id, session.connection_id)
2094        .await?;
2095
2096    let response_payload = session
2097        .peer
2098        .forward_request(session.connection_id, leader_id, request)
2099        .await?;
2100    response.send(response_payload)?;
2101
2102    if let Some(project_id) = project_id {
2103        let room = session
2104            .db()
2105            .await
2106            .follow(room_id, project_id, leader_id, follower_id)
2107            .await?;
2108        room_updated(&room, &session.peer);
2109    }
2110
2111    Ok(())
2112}
2113
2114/// Stop following another user in a call.
2115async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2116    let room_id = RoomId::from_proto(request.room_id);
2117    let project_id = request.project_id.map(ProjectId::from_proto);
2118    let leader_id = request
2119        .leader_id
2120        .ok_or_else(|| anyhow!("invalid leader id"))?
2121        .into();
2122    let follower_id = session.connection_id;
2123
2124    session
2125        .db()
2126        .await
2127        .check_room_participants(room_id, leader_id, session.connection_id)
2128        .await?;
2129
2130    session
2131        .peer
2132        .forward_send(session.connection_id, leader_id, request)?;
2133
2134    if let Some(project_id) = project_id {
2135        let room = session
2136            .db()
2137            .await
2138            .unfollow(room_id, project_id, leader_id, follower_id)
2139            .await?;
2140        room_updated(&room, &session.peer);
2141    }
2142
2143    Ok(())
2144}
2145
2146/// Notify everyone following you of your current location.
2147async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2148    let room_id = RoomId::from_proto(request.room_id);
2149    let database = session.db.lock().await;
2150
2151    let connection_ids = if let Some(project_id) = request.project_id {
2152        let project_id = ProjectId::from_proto(project_id);
2153        database
2154            .project_connection_ids(project_id, session.connection_id)
2155            .await?
2156    } else {
2157        database
2158            .room_connection_ids(room_id, session.connection_id)
2159            .await?
2160    };
2161
2162    // For now, don't send view update messages back to that view's current leader.
2163    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2164        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2165        _ => None,
2166    });
2167
2168    for connection_id in connection_ids.iter().cloned() {
2169        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2170            session
2171                .peer
2172                .forward_send(session.connection_id, connection_id, request.clone())?;
2173        }
2174    }
2175    Ok(())
2176}
2177
2178/// Get public data about users.
2179async fn get_users(
2180    request: proto::GetUsers,
2181    response: Response<proto::GetUsers>,
2182    session: Session,
2183) -> Result<()> {
2184    let user_ids = request
2185        .user_ids
2186        .into_iter()
2187        .map(UserId::from_proto)
2188        .collect();
2189    let users = session
2190        .db()
2191        .await
2192        .get_users_by_ids(user_ids)
2193        .await?
2194        .into_iter()
2195        .map(|user| proto::User {
2196            id: user.id.to_proto(),
2197            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2198            github_login: user.github_login,
2199        })
2200        .collect();
2201    response.send(proto::UsersResponse { users })?;
2202    Ok(())
2203}
2204
2205/// Search for users (to invite) buy Github login
2206async fn fuzzy_search_users(
2207    request: proto::FuzzySearchUsers,
2208    response: Response<proto::FuzzySearchUsers>,
2209    session: Session,
2210) -> Result<()> {
2211    let query = request.query;
2212    let users = match query.len() {
2213        0 => vec![],
2214        1 | 2 => session
2215            .db()
2216            .await
2217            .get_user_by_github_login(&query)
2218            .await?
2219            .into_iter()
2220            .collect(),
2221        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2222    };
2223    let users = users
2224        .into_iter()
2225        .filter(|user| user.id != session.user_id)
2226        .map(|user| proto::User {
2227            id: user.id.to_proto(),
2228            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2229            github_login: user.github_login,
2230        })
2231        .collect();
2232    response.send(proto::UsersResponse { users })?;
2233    Ok(())
2234}
2235
2236/// Send a contact request to another user.
2237async fn request_contact(
2238    request: proto::RequestContact,
2239    response: Response<proto::RequestContact>,
2240    session: Session,
2241) -> Result<()> {
2242    let requester_id = session.user_id;
2243    let responder_id = UserId::from_proto(request.responder_id);
2244    if requester_id == responder_id {
2245        return Err(anyhow!("cannot add yourself as a contact"))?;
2246    }
2247
2248    let notifications = session
2249        .db()
2250        .await
2251        .send_contact_request(requester_id, responder_id)
2252        .await?;
2253
2254    // Update outgoing contact requests of requester
2255    let mut update = proto::UpdateContacts::default();
2256    update.outgoing_requests.push(responder_id.to_proto());
2257    for connection_id in session
2258        .connection_pool()
2259        .await
2260        .user_connection_ids(requester_id)
2261    {
2262        session.peer.send(connection_id, update.clone())?;
2263    }
2264
2265    // Update incoming contact requests of responder
2266    let mut update = proto::UpdateContacts::default();
2267    update
2268        .incoming_requests
2269        .push(proto::IncomingContactRequest {
2270            requester_id: requester_id.to_proto(),
2271        });
2272    let connection_pool = session.connection_pool().await;
2273    for connection_id in connection_pool.user_connection_ids(responder_id) {
2274        session.peer.send(connection_id, update.clone())?;
2275    }
2276
2277    send_notifications(&connection_pool, &session.peer, notifications);
2278
2279    response.send(proto::Ack {})?;
2280    Ok(())
2281}
2282
2283/// Accept or decline a contact request
2284async fn respond_to_contact_request(
2285    request: proto::RespondToContactRequest,
2286    response: Response<proto::RespondToContactRequest>,
2287    session: Session,
2288) -> Result<()> {
2289    let responder_id = session.user_id;
2290    let requester_id = UserId::from_proto(request.requester_id);
2291    let db = session.db().await;
2292    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2293        db.dismiss_contact_notification(responder_id, requester_id)
2294            .await?;
2295    } else {
2296        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2297
2298        let notifications = db
2299            .respond_to_contact_request(responder_id, requester_id, accept)
2300            .await?;
2301        let requester_busy = db.is_user_busy(requester_id).await?;
2302        let responder_busy = db.is_user_busy(responder_id).await?;
2303
2304        let pool = session.connection_pool().await;
2305        // Update responder with new contact
2306        let mut update = proto::UpdateContacts::default();
2307        if accept {
2308            update
2309                .contacts
2310                .push(contact_for_user(requester_id, requester_busy, &pool));
2311        }
2312        update
2313            .remove_incoming_requests
2314            .push(requester_id.to_proto());
2315        for connection_id in pool.user_connection_ids(responder_id) {
2316            session.peer.send(connection_id, update.clone())?;
2317        }
2318
2319        // Update requester with new contact
2320        let mut update = proto::UpdateContacts::default();
2321        if accept {
2322            update
2323                .contacts
2324                .push(contact_for_user(responder_id, responder_busy, &pool));
2325        }
2326        update
2327            .remove_outgoing_requests
2328            .push(responder_id.to_proto());
2329
2330        for connection_id in pool.user_connection_ids(requester_id) {
2331            session.peer.send(connection_id, update.clone())?;
2332        }
2333
2334        send_notifications(&pool, &session.peer, notifications);
2335    }
2336
2337    response.send(proto::Ack {})?;
2338    Ok(())
2339}
2340
2341/// Remove a contact.
2342async fn remove_contact(
2343    request: proto::RemoveContact,
2344    response: Response<proto::RemoveContact>,
2345    session: Session,
2346) -> Result<()> {
2347    let requester_id = session.user_id;
2348    let responder_id = UserId::from_proto(request.user_id);
2349    let db = session.db().await;
2350    let (contact_accepted, deleted_notification_id) =
2351        db.remove_contact(requester_id, responder_id).await?;
2352
2353    let pool = session.connection_pool().await;
2354    // Update outgoing contact requests of requester
2355    let mut update = proto::UpdateContacts::default();
2356    if contact_accepted {
2357        update.remove_contacts.push(responder_id.to_proto());
2358    } else {
2359        update
2360            .remove_outgoing_requests
2361            .push(responder_id.to_proto());
2362    }
2363    for connection_id in pool.user_connection_ids(requester_id) {
2364        session.peer.send(connection_id, update.clone())?;
2365    }
2366
2367    // Update incoming contact requests of responder
2368    let mut update = proto::UpdateContacts::default();
2369    if contact_accepted {
2370        update.remove_contacts.push(requester_id.to_proto());
2371    } else {
2372        update
2373            .remove_incoming_requests
2374            .push(requester_id.to_proto());
2375    }
2376    for connection_id in pool.user_connection_ids(responder_id) {
2377        session.peer.send(connection_id, update.clone())?;
2378        if let Some(notification_id) = deleted_notification_id {
2379            session.peer.send(
2380                connection_id,
2381                proto::DeleteNotification {
2382                    notification_id: notification_id.to_proto(),
2383                },
2384            )?;
2385        }
2386    }
2387
2388    response.send(proto::Ack {})?;
2389    Ok(())
2390}
2391
2392/// Creates a new channel.
2393async fn create_channel(
2394    request: proto::CreateChannel,
2395    response: Response<proto::CreateChannel>,
2396    session: Session,
2397) -> Result<()> {
2398    let db = session.db().await;
2399
2400    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2401    let (channel, owner, channel_members) = db
2402        .create_channel(&request.name, parent_id, session.user_id)
2403        .await?;
2404
2405    response.send(proto::CreateChannelResponse {
2406        channel: Some(channel.to_proto()),
2407        parent_id: request.parent_id,
2408    })?;
2409
2410    let connection_pool = session.connection_pool().await;
2411    if let Some(owner) = owner {
2412        let update = proto::UpdateUserChannels {
2413            channel_memberships: vec![proto::ChannelMembership {
2414                channel_id: owner.channel_id.to_proto(),
2415                role: owner.role.into(),
2416            }],
2417            ..Default::default()
2418        };
2419        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2420            session.peer.send(connection_id, update.clone())?;
2421        }
2422    }
2423
2424    for channel_member in channel_members {
2425        if !channel_member.role.can_see_channel(channel.visibility) {
2426            continue;
2427        }
2428
2429        let update = proto::UpdateChannels {
2430            channels: vec![channel.to_proto()],
2431            ..Default::default()
2432        };
2433        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2434            session.peer.send(connection_id, update.clone())?;
2435        }
2436    }
2437
2438    Ok(())
2439}
2440
2441/// Delete a channel
2442async fn delete_channel(
2443    request: proto::DeleteChannel,
2444    response: Response<proto::DeleteChannel>,
2445    session: Session,
2446) -> Result<()> {
2447    let db = session.db().await;
2448
2449    let channel_id = request.channel_id;
2450    let (removed_channels, member_ids) = db
2451        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2452        .await?;
2453    response.send(proto::Ack {})?;
2454
2455    // Notify members of removed channels
2456    let mut update = proto::UpdateChannels::default();
2457    update
2458        .delete_channels
2459        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2460
2461    let connection_pool = session.connection_pool().await;
2462    for member_id in member_ids {
2463        for connection_id in connection_pool.user_connection_ids(member_id) {
2464            session.peer.send(connection_id, update.clone())?;
2465        }
2466    }
2467
2468    Ok(())
2469}
2470
2471/// Invite someone to join a channel.
2472async fn invite_channel_member(
2473    request: proto::InviteChannelMember,
2474    response: Response<proto::InviteChannelMember>,
2475    session: Session,
2476) -> Result<()> {
2477    let db = session.db().await;
2478    let channel_id = ChannelId::from_proto(request.channel_id);
2479    let invitee_id = UserId::from_proto(request.user_id);
2480    let InviteMemberResult {
2481        channel,
2482        notifications,
2483    } = db
2484        .invite_channel_member(
2485            channel_id,
2486            invitee_id,
2487            session.user_id,
2488            request.role().into(),
2489        )
2490        .await?;
2491
2492    let update = proto::UpdateChannels {
2493        channel_invitations: vec![channel.to_proto()],
2494        ..Default::default()
2495    };
2496
2497    let connection_pool = session.connection_pool().await;
2498    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2499        session.peer.send(connection_id, update.clone())?;
2500    }
2501
2502    send_notifications(&connection_pool, &session.peer, notifications);
2503
2504    response.send(proto::Ack {})?;
2505    Ok(())
2506}
2507
2508/// remove someone from a channel
2509async fn remove_channel_member(
2510    request: proto::RemoveChannelMember,
2511    response: Response<proto::RemoveChannelMember>,
2512    session: Session,
2513) -> Result<()> {
2514    let db = session.db().await;
2515    let channel_id = ChannelId::from_proto(request.channel_id);
2516    let member_id = UserId::from_proto(request.user_id);
2517
2518    let RemoveChannelMemberResult {
2519        membership_update,
2520        notification_id,
2521    } = db
2522        .remove_channel_member(channel_id, member_id, session.user_id)
2523        .await?;
2524
2525    let connection_pool = &session.connection_pool().await;
2526    notify_membership_updated(
2527        &connection_pool,
2528        membership_update,
2529        member_id,
2530        &session.peer,
2531    );
2532    for connection_id in connection_pool.user_connection_ids(member_id) {
2533        if let Some(notification_id) = notification_id {
2534            session
2535                .peer
2536                .send(
2537                    connection_id,
2538                    proto::DeleteNotification {
2539                        notification_id: notification_id.to_proto(),
2540                    },
2541                )
2542                .trace_err();
2543        }
2544    }
2545
2546    response.send(proto::Ack {})?;
2547    Ok(())
2548}
2549
2550/// Toggle the channel between public and private.
2551/// Care is taken to maintain the invariant that public channels only descend from public channels,
2552/// (though members-only channels can appear at any point in the hierarchy).
2553async fn set_channel_visibility(
2554    request: proto::SetChannelVisibility,
2555    response: Response<proto::SetChannelVisibility>,
2556    session: Session,
2557) -> Result<()> {
2558    let db = session.db().await;
2559    let channel_id = ChannelId::from_proto(request.channel_id);
2560    let visibility = request.visibility().into();
2561
2562    let (channel, channel_members) = db
2563        .set_channel_visibility(channel_id, visibility, session.user_id)
2564        .await?;
2565
2566    let connection_pool = session.connection_pool().await;
2567    for member in channel_members {
2568        let update = if member.role.can_see_channel(channel.visibility) {
2569            proto::UpdateChannels {
2570                channels: vec![channel.to_proto()],
2571                ..Default::default()
2572            }
2573        } else {
2574            proto::UpdateChannels {
2575                delete_channels: vec![channel.id.to_proto()],
2576                ..Default::default()
2577            }
2578        };
2579
2580        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2581            session.peer.send(connection_id, update.clone())?;
2582        }
2583    }
2584
2585    response.send(proto::Ack {})?;
2586    Ok(())
2587}
2588
2589/// Alter the role for a user in the channel.
2590async fn set_channel_member_role(
2591    request: proto::SetChannelMemberRole,
2592    response: Response<proto::SetChannelMemberRole>,
2593    session: Session,
2594) -> Result<()> {
2595    let db = session.db().await;
2596    let channel_id = ChannelId::from_proto(request.channel_id);
2597    let member_id = UserId::from_proto(request.user_id);
2598    let result = db
2599        .set_channel_member_role(
2600            channel_id,
2601            session.user_id,
2602            member_id,
2603            request.role().into(),
2604        )
2605        .await?;
2606
2607    match result {
2608        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2609            let connection_pool = session.connection_pool().await;
2610            notify_membership_updated(
2611                &connection_pool,
2612                membership_update,
2613                member_id,
2614                &session.peer,
2615            )
2616        }
2617        db::SetMemberRoleResult::InviteUpdated(channel) => {
2618            let update = proto::UpdateChannels {
2619                channel_invitations: vec![channel.to_proto()],
2620                ..Default::default()
2621            };
2622
2623            for connection_id in session
2624                .connection_pool()
2625                .await
2626                .user_connection_ids(member_id)
2627            {
2628                session.peer.send(connection_id, update.clone())?;
2629            }
2630        }
2631    }
2632
2633    response.send(proto::Ack {})?;
2634    Ok(())
2635}
2636
2637/// Change the name of a channel
2638async fn rename_channel(
2639    request: proto::RenameChannel,
2640    response: Response<proto::RenameChannel>,
2641    session: Session,
2642) -> Result<()> {
2643    let db = session.db().await;
2644    let channel_id = ChannelId::from_proto(request.channel_id);
2645    let (channel, channel_members) = db
2646        .rename_channel(channel_id, session.user_id, &request.name)
2647        .await?;
2648
2649    response.send(proto::RenameChannelResponse {
2650        channel: Some(channel.to_proto()),
2651    })?;
2652
2653    let connection_pool = session.connection_pool().await;
2654    for channel_member in channel_members {
2655        if !channel_member.role.can_see_channel(channel.visibility) {
2656            continue;
2657        }
2658        let update = proto::UpdateChannels {
2659            channels: vec![channel.to_proto()],
2660            ..Default::default()
2661        };
2662        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2663            session.peer.send(connection_id, update.clone())?;
2664        }
2665    }
2666
2667    Ok(())
2668}
2669
2670/// Move a channel to a new parent.
2671async fn move_channel(
2672    request: proto::MoveChannel,
2673    response: Response<proto::MoveChannel>,
2674    session: Session,
2675) -> Result<()> {
2676    let channel_id = ChannelId::from_proto(request.channel_id);
2677    let to = ChannelId::from_proto(request.to);
2678
2679    let (channels, channel_members) = session
2680        .db()
2681        .await
2682        .move_channel(channel_id, to, session.user_id)
2683        .await?;
2684
2685    let connection_pool = session.connection_pool().await;
2686    for member in channel_members {
2687        let channels = channels
2688            .iter()
2689            .filter_map(|channel| {
2690                if member.role.can_see_channel(channel.visibility) {
2691                    Some(channel.to_proto())
2692                } else {
2693                    None
2694                }
2695            })
2696            .collect::<Vec<_>>();
2697        if channels.is_empty() {
2698            continue;
2699        }
2700
2701        let update = proto::UpdateChannels {
2702            channels,
2703            ..Default::default()
2704        };
2705
2706        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2707            session.peer.send(connection_id, update.clone())?;
2708        }
2709    }
2710
2711    response.send(Ack {})?;
2712    Ok(())
2713}
2714
2715/// Get the list of channel members
2716async fn get_channel_members(
2717    request: proto::GetChannelMembers,
2718    response: Response<proto::GetChannelMembers>,
2719    session: Session,
2720) -> Result<()> {
2721    let db = session.db().await;
2722    let channel_id = ChannelId::from_proto(request.channel_id);
2723    let members = db
2724        .get_channel_participant_details(channel_id, session.user_id)
2725        .await?;
2726    response.send(proto::GetChannelMembersResponse { members })?;
2727    Ok(())
2728}
2729
2730/// Accept or decline a channel invitation.
2731async fn respond_to_channel_invite(
2732    request: proto::RespondToChannelInvite,
2733    response: Response<proto::RespondToChannelInvite>,
2734    session: Session,
2735) -> Result<()> {
2736    let db = session.db().await;
2737    let channel_id = ChannelId::from_proto(request.channel_id);
2738    let RespondToChannelInvite {
2739        membership_update,
2740        notifications,
2741    } = db
2742        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2743        .await?;
2744
2745    let connection_pool = session.connection_pool().await;
2746    if let Some(membership_update) = membership_update {
2747        notify_membership_updated(
2748            &connection_pool,
2749            membership_update,
2750            session.user_id,
2751            &session.peer,
2752        );
2753    } else {
2754        let update = proto::UpdateChannels {
2755            remove_channel_invitations: vec![channel_id.to_proto()],
2756            ..Default::default()
2757        };
2758
2759        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2760            session.peer.send(connection_id, update.clone())?;
2761        }
2762    };
2763
2764    send_notifications(&connection_pool, &session.peer, notifications);
2765
2766    response.send(proto::Ack {})?;
2767
2768    Ok(())
2769}
2770
2771/// Join the channels' room
2772async fn join_channel(
2773    request: proto::JoinChannel,
2774    response: Response<proto::JoinChannel>,
2775    session: Session,
2776) -> Result<()> {
2777    let channel_id = ChannelId::from_proto(request.channel_id);
2778    join_channel_internal(channel_id, Box::new(response), session).await
2779}
2780
2781trait JoinChannelInternalResponse {
2782    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2783}
2784impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2785    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2786        Response::<proto::JoinChannel>::send(self, result)
2787    }
2788}
2789impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2790    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2791        Response::<proto::JoinRoom>::send(self, result)
2792    }
2793}
2794
2795async fn join_channel_internal(
2796    channel_id: ChannelId,
2797    response: Box<impl JoinChannelInternalResponse>,
2798    session: Session,
2799) -> Result<()> {
2800    let joined_room = {
2801        leave_room_for_session(&session).await?;
2802        let db = session.db().await;
2803
2804        let (joined_room, membership_updated, role) = db
2805            .join_channel(channel_id, session.user_id, session.connection_id)
2806            .await?;
2807
2808        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2809            let (can_publish, token) = if role == ChannelRole::Guest {
2810                (
2811                    false,
2812                    live_kit
2813                        .guest_token(
2814                            &joined_room.room.live_kit_room,
2815                            &session.user_id.to_string(),
2816                        )
2817                        .trace_err()?,
2818                )
2819            } else {
2820                (
2821                    true,
2822                    live_kit
2823                        .room_token(
2824                            &joined_room.room.live_kit_room,
2825                            &session.user_id.to_string(),
2826                        )
2827                        .trace_err()?,
2828                )
2829            };
2830
2831            Some(LiveKitConnectionInfo {
2832                server_url: live_kit.url().into(),
2833                token,
2834                can_publish,
2835            })
2836        });
2837
2838        response.send(proto::JoinRoomResponse {
2839            room: Some(joined_room.room.clone()),
2840            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2841            live_kit_connection_info,
2842        })?;
2843
2844        let connection_pool = session.connection_pool().await;
2845        if let Some(membership_updated) = membership_updated {
2846            notify_membership_updated(
2847                &connection_pool,
2848                membership_updated,
2849                session.user_id,
2850                &session.peer,
2851            );
2852        }
2853
2854        room_updated(&joined_room.room, &session.peer);
2855
2856        joined_room
2857    };
2858
2859    channel_updated(
2860        channel_id,
2861        &joined_room.room,
2862        &joined_room.channel_members,
2863        &session.peer,
2864        &*session.connection_pool().await,
2865    );
2866
2867    update_user_contacts(session.user_id, &session).await?;
2868    Ok(())
2869}
2870
2871/// Start editing the channel notes
2872async fn join_channel_buffer(
2873    request: proto::JoinChannelBuffer,
2874    response: Response<proto::JoinChannelBuffer>,
2875    session: Session,
2876) -> Result<()> {
2877    let db = session.db().await;
2878    let channel_id = ChannelId::from_proto(request.channel_id);
2879
2880    let open_response = db
2881        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2882        .await?;
2883
2884    let collaborators = open_response.collaborators.clone();
2885    response.send(open_response)?;
2886
2887    let update = UpdateChannelBufferCollaborators {
2888        channel_id: channel_id.to_proto(),
2889        collaborators: collaborators.clone(),
2890    };
2891    channel_buffer_updated(
2892        session.connection_id,
2893        collaborators
2894            .iter()
2895            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2896        &update,
2897        &session.peer,
2898    );
2899
2900    Ok(())
2901}
2902
2903/// Edit the channel notes
2904async fn update_channel_buffer(
2905    request: proto::UpdateChannelBuffer,
2906    session: Session,
2907) -> Result<()> {
2908    let db = session.db().await;
2909    let channel_id = ChannelId::from_proto(request.channel_id);
2910
2911    let (collaborators, non_collaborators, epoch, version) = db
2912        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2913        .await?;
2914
2915    channel_buffer_updated(
2916        session.connection_id,
2917        collaborators,
2918        &proto::UpdateChannelBuffer {
2919            channel_id: channel_id.to_proto(),
2920            operations: request.operations,
2921        },
2922        &session.peer,
2923    );
2924
2925    let pool = &*session.connection_pool().await;
2926
2927    broadcast(
2928        None,
2929        non_collaborators
2930            .iter()
2931            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2932        |peer_id| {
2933            session.peer.send(
2934                peer_id,
2935                proto::UpdateChannels {
2936                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2937                        channel_id: channel_id.to_proto(),
2938                        epoch: epoch as u64,
2939                        version: version.clone(),
2940                    }],
2941                    ..Default::default()
2942                },
2943            )
2944        },
2945    );
2946
2947    Ok(())
2948}
2949
2950/// Rejoin the channel notes after a connection blip
2951async fn rejoin_channel_buffers(
2952    request: proto::RejoinChannelBuffers,
2953    response: Response<proto::RejoinChannelBuffers>,
2954    session: Session,
2955) -> Result<()> {
2956    let db = session.db().await;
2957    let buffers = db
2958        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2959        .await?;
2960
2961    for rejoined_buffer in &buffers {
2962        let collaborators_to_notify = rejoined_buffer
2963            .buffer
2964            .collaborators
2965            .iter()
2966            .filter_map(|c| Some(c.peer_id?.into()));
2967        channel_buffer_updated(
2968            session.connection_id,
2969            collaborators_to_notify,
2970            &proto::UpdateChannelBufferCollaborators {
2971                channel_id: rejoined_buffer.buffer.channel_id,
2972                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2973            },
2974            &session.peer,
2975        );
2976    }
2977
2978    response.send(proto::RejoinChannelBuffersResponse {
2979        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2980    })?;
2981
2982    Ok(())
2983}
2984
2985/// Stop editing the channel notes
2986async fn leave_channel_buffer(
2987    request: proto::LeaveChannelBuffer,
2988    response: Response<proto::LeaveChannelBuffer>,
2989    session: Session,
2990) -> Result<()> {
2991    let db = session.db().await;
2992    let channel_id = ChannelId::from_proto(request.channel_id);
2993
2994    let left_buffer = db
2995        .leave_channel_buffer(channel_id, session.connection_id)
2996        .await?;
2997
2998    response.send(Ack {})?;
2999
3000    channel_buffer_updated(
3001        session.connection_id,
3002        left_buffer.connections,
3003        &proto::UpdateChannelBufferCollaborators {
3004            channel_id: channel_id.to_proto(),
3005            collaborators: left_buffer.collaborators,
3006        },
3007        &session.peer,
3008    );
3009
3010    Ok(())
3011}
3012
3013fn channel_buffer_updated<T: EnvelopedMessage>(
3014    sender_id: ConnectionId,
3015    collaborators: impl IntoIterator<Item = ConnectionId>,
3016    message: &T,
3017    peer: &Peer,
3018) {
3019    broadcast(Some(sender_id), collaborators, |peer_id| {
3020        peer.send(peer_id, message.clone())
3021    });
3022}
3023
3024fn send_notifications(
3025    connection_pool: &ConnectionPool,
3026    peer: &Peer,
3027    notifications: db::NotificationBatch,
3028) {
3029    for (user_id, notification) in notifications {
3030        for connection_id in connection_pool.user_connection_ids(user_id) {
3031            if let Err(error) = peer.send(
3032                connection_id,
3033                proto::AddNotification {
3034                    notification: Some(notification.clone()),
3035                },
3036            ) {
3037                tracing::error!(
3038                    "failed to send notification to {:?} {}",
3039                    connection_id,
3040                    error
3041                );
3042            }
3043        }
3044    }
3045}
3046
3047/// Send a message to the channel
3048async fn send_channel_message(
3049    request: proto::SendChannelMessage,
3050    response: Response<proto::SendChannelMessage>,
3051    session: Session,
3052) -> Result<()> {
3053    // Validate the message body.
3054    let body = request.body.trim().to_string();
3055    if body.len() > MAX_MESSAGE_LEN {
3056        return Err(anyhow!("message is too long"))?;
3057    }
3058    if body.is_empty() {
3059        return Err(anyhow!("message can't be blank"))?;
3060    }
3061
3062    // TODO: adjust mentions if body is trimmed
3063
3064    let timestamp = OffsetDateTime::now_utc();
3065    let nonce = request
3066        .nonce
3067        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3068
3069    let channel_id = ChannelId::from_proto(request.channel_id);
3070    let CreatedChannelMessage {
3071        message_id,
3072        participant_connection_ids,
3073        channel_members,
3074        notifications,
3075    } = session
3076        .db()
3077        .await
3078        .create_channel_message(
3079            channel_id,
3080            session.user_id,
3081            &body,
3082            &request.mentions,
3083            timestamp,
3084            nonce.clone().into(),
3085            match request.reply_to_message_id {
3086                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3087                None => None,
3088            },
3089        )
3090        .await?;
3091    let message = proto::ChannelMessage {
3092        sender_id: session.user_id.to_proto(),
3093        id: message_id.to_proto(),
3094        body,
3095        mentions: request.mentions,
3096        timestamp: timestamp.unix_timestamp() as u64,
3097        nonce: Some(nonce),
3098        reply_to_message_id: request.reply_to_message_id,
3099    };
3100    broadcast(
3101        Some(session.connection_id),
3102        participant_connection_ids,
3103        |connection| {
3104            session.peer.send(
3105                connection,
3106                proto::ChannelMessageSent {
3107                    channel_id: channel_id.to_proto(),
3108                    message: Some(message.clone()),
3109                },
3110            )
3111        },
3112    );
3113    response.send(proto::SendChannelMessageResponse {
3114        message: Some(message),
3115    })?;
3116
3117    let pool = &*session.connection_pool().await;
3118    broadcast(
3119        None,
3120        channel_members
3121            .iter()
3122            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3123        |peer_id| {
3124            session.peer.send(
3125                peer_id,
3126                proto::UpdateChannels {
3127                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3128                        channel_id: channel_id.to_proto(),
3129                        message_id: message_id.to_proto(),
3130                    }],
3131                    ..Default::default()
3132                },
3133            )
3134        },
3135    );
3136    send_notifications(pool, &session.peer, notifications);
3137
3138    Ok(())
3139}
3140
3141/// Delete a channel message
3142async fn remove_channel_message(
3143    request: proto::RemoveChannelMessage,
3144    response: Response<proto::RemoveChannelMessage>,
3145    session: Session,
3146) -> Result<()> {
3147    let channel_id = ChannelId::from_proto(request.channel_id);
3148    let message_id = MessageId::from_proto(request.message_id);
3149    let connection_ids = session
3150        .db()
3151        .await
3152        .remove_channel_message(channel_id, message_id, session.user_id)
3153        .await?;
3154    broadcast(Some(session.connection_id), connection_ids, |connection| {
3155        session.peer.send(connection, request.clone())
3156    });
3157    response.send(proto::Ack {})?;
3158    Ok(())
3159}
3160
3161/// Mark a channel message as read
3162async fn acknowledge_channel_message(
3163    request: proto::AckChannelMessage,
3164    session: Session,
3165) -> Result<()> {
3166    let channel_id = ChannelId::from_proto(request.channel_id);
3167    let message_id = MessageId::from_proto(request.message_id);
3168    let notifications = session
3169        .db()
3170        .await
3171        .observe_channel_message(channel_id, session.user_id, message_id)
3172        .await?;
3173    send_notifications(
3174        &*session.connection_pool().await,
3175        &session.peer,
3176        notifications,
3177    );
3178    Ok(())
3179}
3180
3181/// Mark a buffer version as synced
3182async fn acknowledge_buffer_version(
3183    request: proto::AckBufferOperation,
3184    session: Session,
3185) -> Result<()> {
3186    let buffer_id = BufferId::from_proto(request.buffer_id);
3187    session
3188        .db()
3189        .await
3190        .observe_buffer_version(
3191            buffer_id,
3192            session.user_id,
3193            request.epoch as i32,
3194            &request.version,
3195        )
3196        .await?;
3197    Ok(())
3198}
3199
3200/// Start receiving chat updates for a channel
3201async fn join_channel_chat(
3202    request: proto::JoinChannelChat,
3203    response: Response<proto::JoinChannelChat>,
3204    session: Session,
3205) -> Result<()> {
3206    let channel_id = ChannelId::from_proto(request.channel_id);
3207
3208    let db = session.db().await;
3209    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3210        .await?;
3211    let messages = db
3212        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3213        .await?;
3214    response.send(proto::JoinChannelChatResponse {
3215        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3216        messages,
3217    })?;
3218    Ok(())
3219}
3220
3221/// Stop receiving chat updates for a channel
3222async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3223    let channel_id = ChannelId::from_proto(request.channel_id);
3224    session
3225        .db()
3226        .await
3227        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3228        .await?;
3229    Ok(())
3230}
3231
3232/// Retrieve the chat history for a channel
3233async fn get_channel_messages(
3234    request: proto::GetChannelMessages,
3235    response: Response<proto::GetChannelMessages>,
3236    session: Session,
3237) -> Result<()> {
3238    let channel_id = ChannelId::from_proto(request.channel_id);
3239    let messages = session
3240        .db()
3241        .await
3242        .get_channel_messages(
3243            channel_id,
3244            session.user_id,
3245            MESSAGE_COUNT_PER_PAGE,
3246            Some(MessageId::from_proto(request.before_message_id)),
3247        )
3248        .await?;
3249    response.send(proto::GetChannelMessagesResponse {
3250        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3251        messages,
3252    })?;
3253    Ok(())
3254}
3255
3256/// Retrieve specific chat messages
3257async fn get_channel_messages_by_id(
3258    request: proto::GetChannelMessagesById,
3259    response: Response<proto::GetChannelMessagesById>,
3260    session: Session,
3261) -> Result<()> {
3262    let message_ids = request
3263        .message_ids
3264        .iter()
3265        .map(|id| MessageId::from_proto(*id))
3266        .collect::<Vec<_>>();
3267    let messages = session
3268        .db()
3269        .await
3270        .get_channel_messages_by_id(session.user_id, &message_ids)
3271        .await?;
3272    response.send(proto::GetChannelMessagesResponse {
3273        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3274        messages,
3275    })?;
3276    Ok(())
3277}
3278
3279/// Retrieve the current users notifications
3280async fn get_notifications(
3281    request: proto::GetNotifications,
3282    response: Response<proto::GetNotifications>,
3283    session: Session,
3284) -> Result<()> {
3285    let notifications = session
3286        .db()
3287        .await
3288        .get_notifications(
3289            session.user_id,
3290            NOTIFICATION_COUNT_PER_PAGE,
3291            request
3292                .before_id
3293                .map(|id| db::NotificationId::from_proto(id)),
3294        )
3295        .await?;
3296    response.send(proto::GetNotificationsResponse {
3297        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3298        notifications,
3299    })?;
3300    Ok(())
3301}
3302
3303/// Mark notifications as read
3304async fn mark_notification_as_read(
3305    request: proto::MarkNotificationRead,
3306    response: Response<proto::MarkNotificationRead>,
3307    session: Session,
3308) -> Result<()> {
3309    let database = &session.db().await;
3310    let notifications = database
3311        .mark_notification_as_read_by_id(
3312            session.user_id,
3313            NotificationId::from_proto(request.notification_id),
3314        )
3315        .await?;
3316    send_notifications(
3317        &*session.connection_pool().await,
3318        &session.peer,
3319        notifications,
3320    );
3321    response.send(proto::Ack {})?;
3322    Ok(())
3323}
3324
3325/// Get the current users information
3326async fn get_private_user_info(
3327    _request: proto::GetPrivateUserInfo,
3328    response: Response<proto::GetPrivateUserInfo>,
3329    session: Session,
3330) -> Result<()> {
3331    let db = session.db().await;
3332
3333    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3334    let user = db
3335        .get_user_by_id(session.user_id)
3336        .await?
3337        .ok_or_else(|| anyhow!("user not found"))?;
3338    let flags = db.get_user_flags(session.user_id).await?;
3339
3340    response.send(proto::GetPrivateUserInfoResponse {
3341        metrics_id,
3342        staff: user.admin,
3343        flags,
3344    })?;
3345    Ok(())
3346}
3347
3348fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3349    match message {
3350        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3351        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3352        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3353        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3354        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3355            code: frame.code.into(),
3356            reason: frame.reason,
3357        })),
3358    }
3359}
3360
3361fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3362    match message {
3363        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3364        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3365        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3366        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3367        AxumMessage::Close(frame) => {
3368            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3369                code: frame.code.into(),
3370                reason: frame.reason,
3371            }))
3372        }
3373    }
3374}
3375
3376fn notify_membership_updated(
3377    connection_pool: &ConnectionPool,
3378    result: MembershipUpdated,
3379    user_id: UserId,
3380    peer: &Peer,
3381) {
3382    let user_channels_update = proto::UpdateUserChannels {
3383        channel_memberships: result
3384            .new_channels
3385            .channel_memberships
3386            .iter()
3387            .map(|cm| proto::ChannelMembership {
3388                channel_id: cm.channel_id.to_proto(),
3389                role: cm.role.into(),
3390            })
3391            .collect(),
3392        ..Default::default()
3393    };
3394    let mut update = build_channels_update(result.new_channels, vec![]);
3395    update.delete_channels = result
3396        .removed_channels
3397        .into_iter()
3398        .map(|id| id.to_proto())
3399        .collect();
3400    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3401
3402    for connection_id in connection_pool.user_connection_ids(user_id) {
3403        peer.send(connection_id, user_channels_update.clone())
3404            .trace_err();
3405        peer.send(connection_id, update.clone()).trace_err();
3406    }
3407}
3408
3409fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3410    proto::UpdateUserChannels {
3411        channel_memberships: channels
3412            .channel_memberships
3413            .iter()
3414            .map(|m| proto::ChannelMembership {
3415                channel_id: m.channel_id.to_proto(),
3416                role: m.role.into(),
3417            })
3418            .collect(),
3419        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3420        observed_channel_message_id: channels.observed_channel_messages.clone(),
3421    }
3422}
3423
3424fn build_channels_update(
3425    channels: ChannelsForUser,
3426    channel_invites: Vec<db::Channel>,
3427) -> proto::UpdateChannels {
3428    let mut update = proto::UpdateChannels::default();
3429
3430    for channel in channels.channels {
3431        update.channels.push(channel.to_proto());
3432    }
3433
3434    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3435    update.latest_channel_message_ids = channels.latest_channel_messages;
3436
3437    for (channel_id, participants) in channels.channel_participants {
3438        update
3439            .channel_participants
3440            .push(proto::ChannelParticipants {
3441                channel_id: channel_id.to_proto(),
3442                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3443            });
3444    }
3445
3446    for channel in channel_invites {
3447        update.channel_invitations.push(channel.to_proto());
3448    }
3449    for project in channels.hosted_projects {
3450        update.hosted_projects.push(project);
3451    }
3452
3453    update
3454}
3455
3456fn build_initial_contacts_update(
3457    contacts: Vec<db::Contact>,
3458    pool: &ConnectionPool,
3459) -> proto::UpdateContacts {
3460    let mut update = proto::UpdateContacts::default();
3461
3462    for contact in contacts {
3463        match contact {
3464            db::Contact::Accepted { user_id, busy } => {
3465                update.contacts.push(contact_for_user(user_id, busy, &pool));
3466            }
3467            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3468            db::Contact::Incoming { user_id } => {
3469                update
3470                    .incoming_requests
3471                    .push(proto::IncomingContactRequest {
3472                        requester_id: user_id.to_proto(),
3473                    })
3474            }
3475        }
3476    }
3477
3478    update
3479}
3480
3481fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3482    proto::Contact {
3483        user_id: user_id.to_proto(),
3484        online: pool.is_user_online(user_id),
3485        busy,
3486    }
3487}
3488
3489fn room_updated(room: &proto::Room, peer: &Peer) {
3490    broadcast(
3491        None,
3492        room.participants
3493            .iter()
3494            .filter_map(|participant| Some(participant.peer_id?.into())),
3495        |peer_id| {
3496            peer.send(
3497                peer_id,
3498                proto::RoomUpdated {
3499                    room: Some(room.clone()),
3500                },
3501            )
3502        },
3503    );
3504}
3505
3506fn channel_updated(
3507    channel_id: ChannelId,
3508    room: &proto::Room,
3509    channel_members: &[UserId],
3510    peer: &Peer,
3511    pool: &ConnectionPool,
3512) {
3513    let participants = room
3514        .participants
3515        .iter()
3516        .map(|p| p.user_id)
3517        .collect::<Vec<_>>();
3518
3519    broadcast(
3520        None,
3521        channel_members
3522            .iter()
3523            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3524        |peer_id| {
3525            peer.send(
3526                peer_id,
3527                proto::UpdateChannels {
3528                    channel_participants: vec![proto::ChannelParticipants {
3529                        channel_id: channel_id.to_proto(),
3530                        participant_user_ids: participants.clone(),
3531                    }],
3532                    ..Default::default()
3533                },
3534            )
3535        },
3536    );
3537}
3538
3539async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3540    let db = session.db().await;
3541
3542    let contacts = db.get_contacts(user_id).await?;
3543    let busy = db.is_user_busy(user_id).await?;
3544
3545    let pool = session.connection_pool().await;
3546    let updated_contact = contact_for_user(user_id, busy, &pool);
3547    for contact in contacts {
3548        if let db::Contact::Accepted {
3549            user_id: contact_user_id,
3550            ..
3551        } = contact
3552        {
3553            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3554                session
3555                    .peer
3556                    .send(
3557                        contact_conn_id,
3558                        proto::UpdateContacts {
3559                            contacts: vec![updated_contact.clone()],
3560                            remove_contacts: Default::default(),
3561                            incoming_requests: Default::default(),
3562                            remove_incoming_requests: Default::default(),
3563                            outgoing_requests: Default::default(),
3564                            remove_outgoing_requests: Default::default(),
3565                        },
3566                    )
3567                    .trace_err();
3568            }
3569        }
3570    }
3571    Ok(())
3572}
3573
3574async fn leave_room_for_session(session: &Session) -> Result<()> {
3575    let mut contacts_to_update = HashSet::default();
3576
3577    let room_id;
3578    let canceled_calls_to_user_ids;
3579    let live_kit_room;
3580    let delete_live_kit_room;
3581    let room;
3582    let channel_members;
3583    let channel_id;
3584
3585    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3586        contacts_to_update.insert(session.user_id);
3587
3588        for project in left_room.left_projects.values() {
3589            project_left(project, session);
3590        }
3591
3592        room_id = RoomId::from_proto(left_room.room.id);
3593        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3594        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3595        delete_live_kit_room = left_room.deleted;
3596        room = mem::take(&mut left_room.room);
3597        channel_members = mem::take(&mut left_room.channel_members);
3598        channel_id = left_room.channel_id;
3599
3600        room_updated(&room, &session.peer);
3601    } else {
3602        return Ok(());
3603    }
3604
3605    if let Some(channel_id) = channel_id {
3606        channel_updated(
3607            channel_id,
3608            &room,
3609            &channel_members,
3610            &session.peer,
3611            &*session.connection_pool().await,
3612        );
3613    }
3614
3615    {
3616        let pool = session.connection_pool().await;
3617        for canceled_user_id in canceled_calls_to_user_ids {
3618            for connection_id in pool.user_connection_ids(canceled_user_id) {
3619                session
3620                    .peer
3621                    .send(
3622                        connection_id,
3623                        proto::CallCanceled {
3624                            room_id: room_id.to_proto(),
3625                        },
3626                    )
3627                    .trace_err();
3628            }
3629            contacts_to_update.insert(canceled_user_id);
3630        }
3631    }
3632
3633    for contact_user_id in contacts_to_update {
3634        update_user_contacts(contact_user_id, &session).await?;
3635    }
3636
3637    if let Some(live_kit) = session.live_kit_client.as_ref() {
3638        live_kit
3639            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3640            .await
3641            .trace_err();
3642
3643        if delete_live_kit_room {
3644            live_kit.delete_room(live_kit_room).await.trace_err();
3645        }
3646    }
3647
3648    Ok(())
3649}
3650
3651async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3652    let left_channel_buffers = session
3653        .db()
3654        .await
3655        .leave_channel_buffers(session.connection_id)
3656        .await?;
3657
3658    for left_buffer in left_channel_buffers {
3659        channel_buffer_updated(
3660            session.connection_id,
3661            left_buffer.connections,
3662            &proto::UpdateChannelBufferCollaborators {
3663                channel_id: left_buffer.channel_id.to_proto(),
3664                collaborators: left_buffer.collaborators,
3665            },
3666            &session.peer,
3667        );
3668    }
3669
3670    Ok(())
3671}
3672
3673fn project_left(project: &db::LeftProject, session: &Session) {
3674    for connection_id in &project.connection_ids {
3675        if project.host_user_id == Some(session.user_id) {
3676            session
3677                .peer
3678                .send(
3679                    *connection_id,
3680                    proto::UnshareProject {
3681                        project_id: project.id.to_proto(),
3682                    },
3683                )
3684                .trace_err();
3685        } else {
3686            session
3687                .peer
3688                .send(
3689                    *connection_id,
3690                    proto::RemoveProjectCollaborator {
3691                        project_id: project.id.to_proto(),
3692                        peer_id: Some(session.connection_id.into()),
3693                    },
3694                )
3695                .trace_err();
3696        }
3697    }
3698}
3699
3700pub trait ResultExt {
3701    type Ok;
3702
3703    fn trace_err(self) -> Option<Self::Ok>;
3704}
3705
3706impl<T, E> ResultExt for Result<T, E>
3707where
3708    E: std::fmt::Debug,
3709{
3710    type Ok = T;
3711
3712    fn trace_err(self) -> Option<T> {
3713        match self {
3714            Ok(value) => Some(value),
3715            Err(error) => {
3716                tracing::error!("{:?}", error);
3717                None
3718            }
3719        }
3720    }
3721}