rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        HostedProjectId, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
   8        ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
   9        User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, Result,
  13};
  14use anyhow::anyhow;
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use futures::{
  34    channel::oneshot,
  35    future::{self, BoxFuture},
  36    stream::FuturesUnordered,
  37    FutureExt, SinkExt, StreamExt, TryStreamExt,
  38};
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc, OnceLock,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67use util::SemanticVersion;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    peer: Arc<Peer>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91}
  92
  93#[derive(Clone)]
  94struct Session {
  95    user_id: UserId,
  96    connection_id: ConnectionId,
  97    db: Arc<tokio::sync::Mutex<DbHandle>>,
  98    peer: Arc<Peer>,
  99    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 100    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 101    _executor: Executor,
 102}
 103
 104impl Session {
 105    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        let guard = self.db.lock().await;
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        guard
 112    }
 113
 114    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.connection_pool.lock();
 118        ConnectionPoolGuard {
 119            guard,
 120            _not_send: PhantomData,
 121        }
 122    }
 123}
 124
 125impl fmt::Debug for Session {
 126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 127        f.debug_struct("Session")
 128            .field("user_id", &self.user_id)
 129            .field("connection_id", &self.connection_id)
 130            .finish()
 131    }
 132}
 133
 134struct DbHandle(Arc<Database>);
 135
 136impl Deref for DbHandle {
 137    type Target = Database;
 138
 139    fn deref(&self) -> &Self::Target {
 140        self.0.as_ref()
 141    }
 142}
 143
 144pub struct Server {
 145    id: parking_lot::Mutex<ServerId>,
 146    peer: Arc<Peer>,
 147    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    app_state: Arc<AppState>,
 149    executor: Executor,
 150    handlers: HashMap<TypeId, MessageHandler>,
 151    teardown: watch::Sender<bool>,
 152}
 153
 154pub(crate) struct ConnectionPoolGuard<'a> {
 155    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 156    _not_send: PhantomData<Rc<()>>,
 157}
 158
 159#[derive(Serialize)]
 160pub struct ServerSnapshot<'a> {
 161    peer: &'a Peer,
 162    #[serde(serialize_with = "serialize_deref")]
 163    connection_pool: ConnectionPoolGuard<'a>,
 164}
 165
 166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 167where
 168    S: Serializer,
 169    T: Deref<Target = U>,
 170    U: Serialize,
 171{
 172    Serialize::serialize(value.deref(), serializer)
 173}
 174
 175impl Server {
 176    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 177        let mut server = Self {
 178            id: parking_lot::Mutex::new(id),
 179            peer: Peer::new(id.0 as u32),
 180            app_state,
 181            executor,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184            teardown: watch::channel(false).0,
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_request_handler(rejoin_room)
 192            .add_request_handler(leave_room)
 193            .add_request_handler(set_room_participant_role)
 194            .add_request_handler(call)
 195            .add_request_handler(cancel_call)
 196            .add_message_handler(decline_call)
 197            .add_request_handler(update_participant_location)
 198            .add_request_handler(share_project)
 199            .add_message_handler(unshare_project)
 200            .add_request_handler(join_project)
 201            .add_request_handler(join_hosted_project)
 202            .add_message_handler(leave_project)
 203            .add_request_handler(update_project)
 204            .add_request_handler(update_worktree)
 205            .add_message_handler(start_language_server)
 206            .add_message_handler(update_language_server)
 207            .add_message_handler(update_diagnostic_summary)
 208            .add_message_handler(update_worktree_settings)
 209            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 210            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 211            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 212            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 213            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 214            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 215            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 216            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 217            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 218            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 219            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 220            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 221            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 222            .add_request_handler(
 223                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 224            )
 225            .add_request_handler(
 226                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 227            )
 228            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 229            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 230            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 231            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 232            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 233            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 234            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 235            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 236            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 237            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 238            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 239            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 240            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 241            .add_message_handler(create_buffer_for_peer)
 242            .add_request_handler(update_buffer)
 243            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 244            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 245            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 246            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 247            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 248            .add_request_handler(get_users)
 249            .add_request_handler(fuzzy_search_users)
 250            .add_request_handler(request_contact)
 251            .add_request_handler(remove_contact)
 252            .add_request_handler(respond_to_contact_request)
 253            .add_request_handler(create_channel)
 254            .add_request_handler(delete_channel)
 255            .add_request_handler(invite_channel_member)
 256            .add_request_handler(remove_channel_member)
 257            .add_request_handler(set_channel_member_role)
 258            .add_request_handler(set_channel_visibility)
 259            .add_request_handler(rename_channel)
 260            .add_request_handler(join_channel_buffer)
 261            .add_request_handler(leave_channel_buffer)
 262            .add_message_handler(update_channel_buffer)
 263            .add_request_handler(rejoin_channel_buffers)
 264            .add_request_handler(get_channel_members)
 265            .add_request_handler(respond_to_channel_invite)
 266            .add_request_handler(join_channel)
 267            .add_request_handler(join_channel_chat)
 268            .add_message_handler(leave_channel_chat)
 269            .add_request_handler(send_channel_message)
 270            .add_request_handler(remove_channel_message)
 271            .add_request_handler(get_channel_messages)
 272            .add_request_handler(get_channel_messages_by_id)
 273            .add_request_handler(get_notifications)
 274            .add_request_handler(mark_notification_as_read)
 275            .add_request_handler(move_channel)
 276            .add_request_handler(follow)
 277            .add_message_handler(unfollow)
 278            .add_message_handler(update_followers)
 279            .add_request_handler(get_private_user_info)
 280            .add_message_handler(acknowledge_channel_message)
 281            .add_message_handler(acknowledge_buffer_version);
 282
 283        Arc::new(server)
 284    }
 285
 286    pub async fn start(&self) -> Result<()> {
 287        let server_id = *self.id.lock();
 288        let app_state = self.app_state.clone();
 289        let peer = self.peer.clone();
 290        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 291        let pool = self.connection_pool.clone();
 292        let live_kit_client = self.app_state.live_kit_client.clone();
 293
 294        let span = info_span!("start server");
 295        self.executor.spawn_detached(
 296            async move {
 297                tracing::info!("waiting for cleanup timeout");
 298                timeout.await;
 299                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 300                if let Some((room_ids, channel_ids)) = app_state
 301                    .db
 302                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 303                    .await
 304                    .trace_err()
 305                {
 306                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 307                    tracing::info!(
 308                        stale_channel_buffer_count = channel_ids.len(),
 309                        "retrieved stale channel buffers"
 310                    );
 311
 312                    for channel_id in channel_ids {
 313                        if let Some(refreshed_channel_buffer) = app_state
 314                            .db
 315                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 316                            .await
 317                            .trace_err()
 318                        {
 319                            for connection_id in refreshed_channel_buffer.connection_ids {
 320                                peer.send(
 321                                    connection_id,
 322                                    proto::UpdateChannelBufferCollaborators {
 323                                        channel_id: channel_id.to_proto(),
 324                                        collaborators: refreshed_channel_buffer
 325                                            .collaborators
 326                                            .clone(),
 327                                    },
 328                                )
 329                                .trace_err();
 330                            }
 331                        }
 332                    }
 333
 334                    for room_id in room_ids {
 335                        let mut contacts_to_update = HashSet::default();
 336                        let mut canceled_calls_to_user_ids = Vec::new();
 337                        let mut live_kit_room = String::new();
 338                        let mut delete_live_kit_room = false;
 339
 340                        if let Some(mut refreshed_room) = app_state
 341                            .db
 342                            .clear_stale_room_participants(room_id, server_id)
 343                            .await
 344                            .trace_err()
 345                        {
 346                            tracing::info!(
 347                                room_id = room_id.0,
 348                                new_participant_count = refreshed_room.room.participants.len(),
 349                                "refreshed room"
 350                            );
 351                            room_updated(&refreshed_room.room, &peer);
 352                            if let Some(channel_id) = refreshed_room.channel_id {
 353                                channel_updated(
 354                                    channel_id,
 355                                    &refreshed_room.room,
 356                                    &refreshed_room.channel_members,
 357                                    &peer,
 358                                    &pool.lock(),
 359                                );
 360                            }
 361                            contacts_to_update
 362                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 363                            contacts_to_update
 364                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 365                            canceled_calls_to_user_ids =
 366                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 367                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 368                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 369                        }
 370
 371                        {
 372                            let pool = pool.lock();
 373                            for canceled_user_id in canceled_calls_to_user_ids {
 374                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 375                                    peer.send(
 376                                        connection_id,
 377                                        proto::CallCanceled {
 378                                            room_id: room_id.to_proto(),
 379                                        },
 380                                    )
 381                                    .trace_err();
 382                                }
 383                            }
 384                        }
 385
 386                        for user_id in contacts_to_update {
 387                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 388                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 389                            if let Some((busy, contacts)) = busy.zip(contacts) {
 390                                let pool = pool.lock();
 391                                let updated_contact = contact_for_user(user_id, busy, &pool);
 392                                for contact in contacts {
 393                                    if let db::Contact::Accepted {
 394                                        user_id: contact_user_id,
 395                                        ..
 396                                    } = contact
 397                                    {
 398                                        for contact_conn_id in
 399                                            pool.user_connection_ids(contact_user_id)
 400                                        {
 401                                            peer.send(
 402                                                contact_conn_id,
 403                                                proto::UpdateContacts {
 404                                                    contacts: vec![updated_contact.clone()],
 405                                                    remove_contacts: Default::default(),
 406                                                    incoming_requests: Default::default(),
 407                                                    remove_incoming_requests: Default::default(),
 408                                                    outgoing_requests: Default::default(),
 409                                                    remove_outgoing_requests: Default::default(),
 410                                                },
 411                                            )
 412                                            .trace_err();
 413                                        }
 414                                    }
 415                                }
 416                            }
 417                        }
 418
 419                        if let Some(live_kit) = live_kit_client.as_ref() {
 420                            if delete_live_kit_room {
 421                                live_kit.delete_room(live_kit_room).await.trace_err();
 422                            }
 423                        }
 424                    }
 425                }
 426
 427                app_state
 428                    .db
 429                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 430                    .await
 431                    .trace_err();
 432            }
 433            .instrument(span),
 434        );
 435        Ok(())
 436    }
 437
 438    pub fn teardown(&self) {
 439        self.peer.teardown();
 440        self.connection_pool.lock().reset();
 441        let _ = self.teardown.send(true);
 442    }
 443
 444    #[cfg(test)]
 445    pub fn reset(&self, id: ServerId) {
 446        self.teardown();
 447        *self.id.lock() = id;
 448        self.peer.reset(id.0 as u32);
 449        let _ = self.teardown.send(false);
 450    }
 451
 452    #[cfg(test)]
 453    pub fn id(&self) -> ServerId {
 454        *self.id.lock()
 455    }
 456
 457    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 458    where
 459        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 460        Fut: 'static + Send + Future<Output = Result<()>>,
 461        M: EnvelopedMessage,
 462    {
 463        let prev_handler = self.handlers.insert(
 464            TypeId::of::<M>(),
 465            Box::new(move |envelope, session| {
 466                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 467                let received_at = envelope.received_at;
 468                let span = info_span!(
 469                    "handle message",
 470                    payload_type = envelope.payload_type_name()
 471                );
 472                span.in_scope(|| {
 473                    tracing::info!(
 474                        payload_type = envelope.payload_type_name(),
 475                        "message received"
 476                    );
 477                });
 478                let start_time = Instant::now();
 479                let future = (handler)(*envelope, session);
 480                async move {
 481                    let result = future.await;
 482                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 483                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 484                    let queue_duration_ms = processing_duration_ms - total_duration_ms;
 485                    match result {
 486                        Err(error) => {
 487                            tracing::error!(%error, ?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "error handling message")
 488                        }
 489                        Ok(()) => tracing::info!(?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "finished handling message"),
 490                    }
 491                }
 492                .instrument(span)
 493                .boxed()
 494            }),
 495        );
 496        if prev_handler.is_some() {
 497            panic!("registered a handler for the same message twice");
 498        }
 499        self
 500    }
 501
 502    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 503    where
 504        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 505        Fut: 'static + Send + Future<Output = Result<()>>,
 506        M: EnvelopedMessage,
 507    {
 508        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 509        self
 510    }
 511
 512    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 513    where
 514        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 515        Fut: Send + Future<Output = Result<()>>,
 516        M: RequestMessage,
 517    {
 518        let handler = Arc::new(handler);
 519        self.add_handler(move |envelope, session| {
 520            let receipt = envelope.receipt();
 521            let handler = handler.clone();
 522            async move {
 523                let peer = session.peer.clone();
 524                let responded = Arc::new(AtomicBool::default());
 525                let response = Response {
 526                    peer: peer.clone(),
 527                    responded: responded.clone(),
 528                    receipt,
 529                };
 530                match (handler)(envelope.payload, response, session).await {
 531                    Ok(()) => {
 532                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 533                            Ok(())
 534                        } else {
 535                            Err(anyhow!("handler did not send a response"))?
 536                        }
 537                    }
 538                    Err(error) => {
 539                        let proto_err = match &error {
 540                            Error::Internal(err) => err.to_proto(),
 541                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 542                        };
 543                        peer.respond_with_error(receipt, proto_err)?;
 544                        Err(error)
 545                    }
 546                }
 547            }
 548        })
 549    }
 550
 551    #[allow(clippy::too_many_arguments)]
 552    pub fn handle_connection(
 553        self: &Arc<Self>,
 554        connection: Connection,
 555        address: String,
 556        user: User,
 557        zed_version: ZedVersion,
 558        impersonator: Option<User>,
 559        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 560        executor: Executor,
 561    ) -> impl Future<Output = Result<()>> {
 562        let this = self.clone();
 563        let user_id = user.id;
 564        let login = user.github_login;
 565        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 566        if let Some(impersonator) = impersonator {
 567            span.record("impersonator", &impersonator.github_login);
 568        }
 569        let mut teardown = self.teardown.subscribe();
 570        async move {
 571            if *teardown.borrow() {
 572                return Err(anyhow!("server is tearing down"))?;
 573            }
 574            let (connection_id, handle_io, mut incoming_rx) = this
 575                .peer
 576                .add_connection(connection, {
 577                    let executor = executor.clone();
 578                    move |duration| executor.sleep(duration)
 579                });
 580
 581            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 582            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 583            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 584
 585            if let Some(send_connection_id) = send_connection_id.take() {
 586                let _ = send_connection_id.send(connection_id);
 587            }
 588
 589            if !user.connected_once {
 590                this.peer.send(connection_id, proto::ShowContacts {})?;
 591                this.app_state.db.set_user_connected_once(user_id, true).await?;
 592            }
 593
 594            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 595                this.app_state.db.get_contacts(user_id),
 596                this.app_state.db.get_channels_for_user(user_id),
 597                this.app_state.db.get_channel_invites_for_user(user_id),
 598            ).await?;
 599
 600            {
 601                let mut pool = this.connection_pool.lock();
 602                pool.add_connection(connection_id, user_id, user.admin, zed_version);
 603                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 604                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 605                this.peer.send(connection_id, build_channels_update(
 606                    channels_for_user,
 607                    channel_invites
 608                ))?;
 609            }
 610
 611            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 612                this.peer.send(connection_id, incoming_call)?;
 613            }
 614
 615            let session = Session {
 616                user_id,
 617                connection_id,
 618                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 619                peer: this.peer.clone(),
 620                connection_pool: this.connection_pool.clone(),
 621                live_kit_client: this.app_state.live_kit_client.clone(),
 622                _executor: executor.clone()
 623            };
 624            update_user_contacts(user_id, &session).await?;
 625
 626            let handle_io = handle_io.fuse();
 627            futures::pin_mut!(handle_io);
 628
 629            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 630            // This prevents deadlocks when e.g., client A performs a request to client B and
 631            // client B performs a request to client A. If both clients stop processing further
 632            // messages until their respective request completes, they won't have a chance to
 633            // respond to the other client's request and cause a deadlock.
 634            //
 635            // This arrangement ensures we will attempt to process earlier messages first, but fall
 636            // back to processing messages arrived later in the spirit of making progress.
 637            let mut foreground_message_handlers = FuturesUnordered::new();
 638            let concurrent_handlers = Arc::new(Semaphore::new(256));
 639            loop {
 640                let next_message = async {
 641                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 642                    let message = incoming_rx.next().await;
 643                    (permit, message)
 644                }.fuse();
 645                futures::pin_mut!(next_message);
 646                futures::select_biased! {
 647                    _ = teardown.changed().fuse() => return Ok(()),
 648                    result = handle_io => {
 649                        if let Err(error) = result {
 650                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 651                        }
 652                        break;
 653                    }
 654                    _ = foreground_message_handlers.next() => {}
 655                    next_message = next_message => {
 656                        let (permit, message) = next_message;
 657                        if let Some(message) = message {
 658                            let type_name = message.payload_type_name();
 659                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 660                            let span_enter = span.enter();
 661                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 662                                let is_background = message.is_background();
 663                                let handle_message = (handler)(message, session.clone());
 664                                drop(span_enter);
 665
 666                                let handle_message = async move {
 667                                    handle_message.await;
 668                                    drop(permit);
 669                                }.instrument(span);
 670                                if is_background {
 671                                    executor.spawn_detached(handle_message);
 672                                } else {
 673                                    foreground_message_handlers.push(handle_message);
 674                                }
 675                            } else {
 676                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 677                            }
 678                        } else {
 679                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 680                            break;
 681                        }
 682                    }
 683                }
 684            }
 685
 686            drop(foreground_message_handlers);
 687            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 688            if let Err(error) = connection_lost(session, teardown, executor).await {
 689                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 690            }
 691
 692            Ok(())
 693        }.instrument(span)
 694    }
 695
 696    pub async fn invite_code_redeemed(
 697        self: &Arc<Self>,
 698        inviter_id: UserId,
 699        invitee_id: UserId,
 700    ) -> Result<()> {
 701        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 702            if let Some(code) = &user.invite_code {
 703                let pool = self.connection_pool.lock();
 704                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 705                for connection_id in pool.user_connection_ids(inviter_id) {
 706                    self.peer.send(
 707                        connection_id,
 708                        proto::UpdateContacts {
 709                            contacts: vec![invitee_contact.clone()],
 710                            ..Default::default()
 711                        },
 712                    )?;
 713                    self.peer.send(
 714                        connection_id,
 715                        proto::UpdateInviteInfo {
 716                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 717                            count: user.invite_count as u32,
 718                        },
 719                    )?;
 720                }
 721            }
 722        }
 723        Ok(())
 724    }
 725
 726    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 727        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 728            if let Some(invite_code) = &user.invite_code {
 729                let pool = self.connection_pool.lock();
 730                for connection_id in pool.user_connection_ids(user_id) {
 731                    self.peer.send(
 732                        connection_id,
 733                        proto::UpdateInviteInfo {
 734                            url: format!(
 735                                "{}{}",
 736                                self.app_state.config.invite_link_prefix, invite_code
 737                            ),
 738                            count: user.invite_count as u32,
 739                        },
 740                    )?;
 741                }
 742            }
 743        }
 744        Ok(())
 745    }
 746
 747    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 748        ServerSnapshot {
 749            connection_pool: ConnectionPoolGuard {
 750                guard: self.connection_pool.lock(),
 751                _not_send: PhantomData,
 752            },
 753            peer: &self.peer,
 754        }
 755    }
 756}
 757
 758impl<'a> Deref for ConnectionPoolGuard<'a> {
 759    type Target = ConnectionPool;
 760
 761    fn deref(&self) -> &Self::Target {
 762        &self.guard
 763    }
 764}
 765
 766impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 767    fn deref_mut(&mut self) -> &mut Self::Target {
 768        &mut self.guard
 769    }
 770}
 771
 772impl<'a> Drop for ConnectionPoolGuard<'a> {
 773    fn drop(&mut self) {
 774        #[cfg(test)]
 775        self.check_invariants();
 776    }
 777}
 778
 779fn broadcast<F>(
 780    sender_id: Option<ConnectionId>,
 781    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 782    mut f: F,
 783) where
 784    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 785{
 786    for receiver_id in receiver_ids {
 787        if Some(receiver_id) != sender_id {
 788            if let Err(error) = f(receiver_id) {
 789                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 790            }
 791        }
 792    }
 793}
 794
 795pub struct ProtocolVersion(u32);
 796
 797impl Header for ProtocolVersion {
 798    fn name() -> &'static HeaderName {
 799        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 800        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 801    }
 802
 803    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 804    where
 805        Self: Sized,
 806        I: Iterator<Item = &'i axum::http::HeaderValue>,
 807    {
 808        let version = values
 809            .next()
 810            .ok_or_else(axum::headers::Error::invalid)?
 811            .to_str()
 812            .map_err(|_| axum::headers::Error::invalid())?
 813            .parse()
 814            .map_err(|_| axum::headers::Error::invalid())?;
 815        Ok(Self(version))
 816    }
 817
 818    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 819        values.extend([self.0.to_string().parse().unwrap()]);
 820    }
 821}
 822
 823pub struct AppVersionHeader(SemanticVersion);
 824impl Header for AppVersionHeader {
 825    fn name() -> &'static HeaderName {
 826        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 827        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 828    }
 829
 830    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 831    where
 832        Self: Sized,
 833        I: Iterator<Item = &'i axum::http::HeaderValue>,
 834    {
 835        let version = values
 836            .next()
 837            .ok_or_else(axum::headers::Error::invalid)?
 838            .to_str()
 839            .map_err(|_| axum::headers::Error::invalid())?
 840            .parse()
 841            .map_err(|_| axum::headers::Error::invalid())?;
 842        Ok(Self(version))
 843    }
 844
 845    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 846        values.extend([self.0.to_string().parse().unwrap()]);
 847    }
 848}
 849
 850pub fn routes(server: Arc<Server>) -> Router<(), Body> {
 851    Router::new()
 852        .route("/rpc", get(handle_websocket_request))
 853        .layer(
 854            ServiceBuilder::new()
 855                .layer(Extension(server.app_state.clone()))
 856                .layer(middleware::from_fn(auth::validate_header)),
 857        )
 858        .route("/metrics", get(handle_metrics))
 859        .layer(Extension(server))
 860}
 861
 862pub async fn handle_websocket_request(
 863    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 864    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 865    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 866    Extension(server): Extension<Arc<Server>>,
 867    Extension(user): Extension<User>,
 868    Extension(impersonator): Extension<Impersonator>,
 869    ws: WebSocketUpgrade,
 870) -> axum::response::Response {
 871    if protocol_version != rpc::PROTOCOL_VERSION {
 872        return (
 873            StatusCode::UPGRADE_REQUIRED,
 874            "client must be upgraded".to_string(),
 875        )
 876            .into_response();
 877    }
 878
 879    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 880        return (
 881            StatusCode::UPGRADE_REQUIRED,
 882            "no version header found".to_string(),
 883        )
 884            .into_response();
 885    };
 886
 887    if !version.is_supported() {
 888        return (
 889            StatusCode::UPGRADE_REQUIRED,
 890            "client must be upgraded".to_string(),
 891        )
 892            .into_response();
 893    }
 894
 895    let socket_address = socket_address.to_string();
 896    ws.on_upgrade(move |socket| {
 897        use util::ResultExt;
 898        let socket = socket
 899            .map_ok(to_tungstenite_message)
 900            .err_into()
 901            .with(|message| async move { Ok(to_axum_message(message)) });
 902        let connection = Connection::new(Box::pin(socket));
 903        async move {
 904            server
 905                .handle_connection(
 906                    connection,
 907                    socket_address,
 908                    user,
 909                    version,
 910                    impersonator.0,
 911                    None,
 912                    Executor::Production,
 913                )
 914                .await
 915                .log_err();
 916        }
 917    })
 918}
 919
 920pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 921    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 922    let connections_metric = CONNECTIONS_METRIC
 923        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
 924
 925    let connections = server
 926        .connection_pool
 927        .lock()
 928        .connections()
 929        .filter(|connection| !connection.admin)
 930        .count();
 931    connections_metric.set(connections as _);
 932
 933    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 934    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
 935        register_int_gauge!(
 936            "shared_projects",
 937            "number of open projects with one or more guests"
 938        )
 939        .unwrap()
 940    });
 941
 942    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 943    shared_projects_metric.set(shared_projects as _);
 944
 945    let encoder = prometheus::TextEncoder::new();
 946    let metric_families = prometheus::gather();
 947    let encoded_metrics = encoder
 948        .encode_to_string(&metric_families)
 949        .map_err(|err| anyhow!("{}", err))?;
 950    Ok(encoded_metrics)
 951}
 952
 953#[instrument(err, skip(executor))]
 954async fn connection_lost(
 955    session: Session,
 956    mut teardown: watch::Receiver<bool>,
 957    executor: Executor,
 958) -> Result<()> {
 959    session.peer.disconnect(session.connection_id);
 960    session
 961        .connection_pool()
 962        .await
 963        .remove_connection(session.connection_id)?;
 964
 965    session
 966        .db()
 967        .await
 968        .connection_lost(session.connection_id)
 969        .await
 970        .trace_err();
 971
 972    futures::select_biased! {
 973        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 974            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 975            leave_room_for_session(&session).await.trace_err();
 976            leave_channel_buffers_for_session(&session)
 977                .await
 978                .trace_err();
 979
 980            if !session
 981                .connection_pool()
 982                .await
 983                .is_user_online(session.user_id)
 984            {
 985                let db = session.db().await;
 986                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 987                    room_updated(&room, &session.peer);
 988                }
 989            }
 990
 991            update_user_contacts(session.user_id, &session).await?;
 992        }
 993        _ = teardown.changed().fuse() => {}
 994    }
 995
 996    Ok(())
 997}
 998
 999/// Acknowledges a ping from a client, used to keep the connection alive.
1000async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1001    response.send(proto::Ack {})?;
1002    Ok(())
1003}
1004
1005/// Creates a new room for calling (outside of channels)
1006async fn create_room(
1007    _request: proto::CreateRoom,
1008    response: Response<proto::CreateRoom>,
1009    session: Session,
1010) -> Result<()> {
1011    let live_kit_room = nanoid::nanoid!(30);
1012
1013    let live_kit_connection_info = {
1014        let live_kit_room = live_kit_room.clone();
1015        let live_kit = session.live_kit_client.as_ref();
1016
1017        util::async_maybe!({
1018            let live_kit = live_kit?;
1019
1020            let token = live_kit
1021                .room_token(&live_kit_room, &session.user_id.to_string())
1022                .trace_err()?;
1023
1024            Some(proto::LiveKitConnectionInfo {
1025                server_url: live_kit.url().into(),
1026                token,
1027                can_publish: true,
1028            })
1029        })
1030    }
1031    .await;
1032
1033    let room = session
1034        .db()
1035        .await
1036        .create_room(session.user_id, session.connection_id, &live_kit_room)
1037        .await?;
1038
1039    response.send(proto::CreateRoomResponse {
1040        room: Some(room.clone()),
1041        live_kit_connection_info,
1042    })?;
1043
1044    update_user_contacts(session.user_id, &session).await?;
1045    Ok(())
1046}
1047
1048/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1049async fn join_room(
1050    request: proto::JoinRoom,
1051    response: Response<proto::JoinRoom>,
1052    session: Session,
1053) -> Result<()> {
1054    let room_id = RoomId::from_proto(request.id);
1055
1056    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1057
1058    if let Some(channel_id) = channel_id {
1059        return join_channel_internal(channel_id, Box::new(response), session).await;
1060    }
1061
1062    let joined_room = {
1063        let room = session
1064            .db()
1065            .await
1066            .join_room(room_id, session.user_id, session.connection_id)
1067            .await?;
1068        room_updated(&room.room, &session.peer);
1069        room.into_inner()
1070    };
1071
1072    for connection_id in session
1073        .connection_pool()
1074        .await
1075        .user_connection_ids(session.user_id)
1076    {
1077        session
1078            .peer
1079            .send(
1080                connection_id,
1081                proto::CallCanceled {
1082                    room_id: room_id.to_proto(),
1083                },
1084            )
1085            .trace_err();
1086    }
1087
1088    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1089        if let Some(token) = live_kit
1090            .room_token(
1091                &joined_room.room.live_kit_room,
1092                &session.user_id.to_string(),
1093            )
1094            .trace_err()
1095        {
1096            Some(proto::LiveKitConnectionInfo {
1097                server_url: live_kit.url().into(),
1098                token,
1099                can_publish: true,
1100            })
1101        } else {
1102            None
1103        }
1104    } else {
1105        None
1106    };
1107
1108    response.send(proto::JoinRoomResponse {
1109        room: Some(joined_room.room),
1110        channel_id: None,
1111        live_kit_connection_info,
1112    })?;
1113
1114    update_user_contacts(session.user_id, &session).await?;
1115    Ok(())
1116}
1117
1118/// Rejoin room is used to reconnect to a room after connection errors.
1119async fn rejoin_room(
1120    request: proto::RejoinRoom,
1121    response: Response<proto::RejoinRoom>,
1122    session: Session,
1123) -> Result<()> {
1124    let room;
1125    let channel_id;
1126    let channel_members;
1127    {
1128        let mut rejoined_room = session
1129            .db()
1130            .await
1131            .rejoin_room(request, session.user_id, session.connection_id)
1132            .await?;
1133
1134        response.send(proto::RejoinRoomResponse {
1135            room: Some(rejoined_room.room.clone()),
1136            reshared_projects: rejoined_room
1137                .reshared_projects
1138                .iter()
1139                .map(|project| proto::ResharedProject {
1140                    id: project.id.to_proto(),
1141                    collaborators: project
1142                        .collaborators
1143                        .iter()
1144                        .map(|collaborator| collaborator.to_proto())
1145                        .collect(),
1146                })
1147                .collect(),
1148            rejoined_projects: rejoined_room
1149                .rejoined_projects
1150                .iter()
1151                .map(|rejoined_project| proto::RejoinedProject {
1152                    id: rejoined_project.id.to_proto(),
1153                    worktrees: rejoined_project
1154                        .worktrees
1155                        .iter()
1156                        .map(|worktree| proto::WorktreeMetadata {
1157                            id: worktree.id,
1158                            root_name: worktree.root_name.clone(),
1159                            visible: worktree.visible,
1160                            abs_path: worktree.abs_path.clone(),
1161                        })
1162                        .collect(),
1163                    collaborators: rejoined_project
1164                        .collaborators
1165                        .iter()
1166                        .map(|collaborator| collaborator.to_proto())
1167                        .collect(),
1168                    language_servers: rejoined_project.language_servers.clone(),
1169                })
1170                .collect(),
1171        })?;
1172        room_updated(&rejoined_room.room, &session.peer);
1173
1174        for project in &rejoined_room.reshared_projects {
1175            for collaborator in &project.collaborators {
1176                session
1177                    .peer
1178                    .send(
1179                        collaborator.connection_id,
1180                        proto::UpdateProjectCollaborator {
1181                            project_id: project.id.to_proto(),
1182                            old_peer_id: Some(project.old_connection_id.into()),
1183                            new_peer_id: Some(session.connection_id.into()),
1184                        },
1185                    )
1186                    .trace_err();
1187            }
1188
1189            broadcast(
1190                Some(session.connection_id),
1191                project
1192                    .collaborators
1193                    .iter()
1194                    .map(|collaborator| collaborator.connection_id),
1195                |connection_id| {
1196                    session.peer.forward_send(
1197                        session.connection_id,
1198                        connection_id,
1199                        proto::UpdateProject {
1200                            project_id: project.id.to_proto(),
1201                            worktrees: project.worktrees.clone(),
1202                        },
1203                    )
1204                },
1205            );
1206        }
1207
1208        for project in &rejoined_room.rejoined_projects {
1209            for collaborator in &project.collaborators {
1210                session
1211                    .peer
1212                    .send(
1213                        collaborator.connection_id,
1214                        proto::UpdateProjectCollaborator {
1215                            project_id: project.id.to_proto(),
1216                            old_peer_id: Some(project.old_connection_id.into()),
1217                            new_peer_id: Some(session.connection_id.into()),
1218                        },
1219                    )
1220                    .trace_err();
1221            }
1222        }
1223
1224        for project in &mut rejoined_room.rejoined_projects {
1225            for worktree in mem::take(&mut project.worktrees) {
1226                #[cfg(any(test, feature = "test-support"))]
1227                const MAX_CHUNK_SIZE: usize = 2;
1228                #[cfg(not(any(test, feature = "test-support")))]
1229                const MAX_CHUNK_SIZE: usize = 256;
1230
1231                // Stream this worktree's entries.
1232                let message = proto::UpdateWorktree {
1233                    project_id: project.id.to_proto(),
1234                    worktree_id: worktree.id,
1235                    abs_path: worktree.abs_path.clone(),
1236                    root_name: worktree.root_name,
1237                    updated_entries: worktree.updated_entries,
1238                    removed_entries: worktree.removed_entries,
1239                    scan_id: worktree.scan_id,
1240                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1241                    updated_repositories: worktree.updated_repositories,
1242                    removed_repositories: worktree.removed_repositories,
1243                };
1244                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1245                    session.peer.send(session.connection_id, update.clone())?;
1246                }
1247
1248                // Stream this worktree's diagnostics.
1249                for summary in worktree.diagnostic_summaries {
1250                    session.peer.send(
1251                        session.connection_id,
1252                        proto::UpdateDiagnosticSummary {
1253                            project_id: project.id.to_proto(),
1254                            worktree_id: worktree.id,
1255                            summary: Some(summary),
1256                        },
1257                    )?;
1258                }
1259
1260                for settings_file in worktree.settings_files {
1261                    session.peer.send(
1262                        session.connection_id,
1263                        proto::UpdateWorktreeSettings {
1264                            project_id: project.id.to_proto(),
1265                            worktree_id: worktree.id,
1266                            path: settings_file.path,
1267                            content: Some(settings_file.content),
1268                        },
1269                    )?;
1270                }
1271            }
1272
1273            for language_server in &project.language_servers {
1274                session.peer.send(
1275                    session.connection_id,
1276                    proto::UpdateLanguageServer {
1277                        project_id: project.id.to_proto(),
1278                        language_server_id: language_server.id,
1279                        variant: Some(
1280                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1281                                proto::LspDiskBasedDiagnosticsUpdated {},
1282                            ),
1283                        ),
1284                    },
1285                )?;
1286            }
1287        }
1288
1289        let rejoined_room = rejoined_room.into_inner();
1290
1291        room = rejoined_room.room;
1292        channel_id = rejoined_room.channel_id;
1293        channel_members = rejoined_room.channel_members;
1294    }
1295
1296    if let Some(channel_id) = channel_id {
1297        channel_updated(
1298            channel_id,
1299            &room,
1300            &channel_members,
1301            &session.peer,
1302            &*session.connection_pool().await,
1303        );
1304    }
1305
1306    update_user_contacts(session.user_id, &session).await?;
1307    Ok(())
1308}
1309
1310/// leave room disconnects from the room.
1311async fn leave_room(
1312    _: proto::LeaveRoom,
1313    response: Response<proto::LeaveRoom>,
1314    session: Session,
1315) -> Result<()> {
1316    leave_room_for_session(&session).await?;
1317    response.send(proto::Ack {})?;
1318    Ok(())
1319}
1320
1321/// Updates the permissions of someone else in the room.
1322async fn set_room_participant_role(
1323    request: proto::SetRoomParticipantRole,
1324    response: Response<proto::SetRoomParticipantRole>,
1325    session: Session,
1326) -> Result<()> {
1327    let user_id = UserId::from_proto(request.user_id);
1328    let role = ChannelRole::from(request.role());
1329
1330    if role == ChannelRole::Talker {
1331        let pool = session.connection_pool().await;
1332
1333        for connection in pool.user_connections(user_id) {
1334            if !connection.zed_version.supports_talker_role() {
1335                Err(anyhow!(
1336                    "This user is on zed {} which does not support unmute",
1337                    connection.zed_version
1338                ))?;
1339            }
1340        }
1341    }
1342
1343    let (live_kit_room, can_publish) = {
1344        let room = session
1345            .db()
1346            .await
1347            .set_room_participant_role(
1348                session.user_id,
1349                RoomId::from_proto(request.room_id),
1350                user_id,
1351                role,
1352            )
1353            .await?;
1354
1355        let live_kit_room = room.live_kit_room.clone();
1356        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1357        room_updated(&room, &session.peer);
1358        (live_kit_room, can_publish)
1359    };
1360
1361    if let Some(live_kit) = session.live_kit_client.as_ref() {
1362        live_kit
1363            .update_participant(
1364                live_kit_room.clone(),
1365                request.user_id.to_string(),
1366                live_kit_server::proto::ParticipantPermission {
1367                    can_subscribe: true,
1368                    can_publish,
1369                    can_publish_data: can_publish,
1370                    hidden: false,
1371                    recorder: false,
1372                },
1373            )
1374            .await
1375            .trace_err();
1376    }
1377
1378    response.send(proto::Ack {})?;
1379    Ok(())
1380}
1381
1382/// Call someone else into the current room
1383async fn call(
1384    request: proto::Call,
1385    response: Response<proto::Call>,
1386    session: Session,
1387) -> Result<()> {
1388    let room_id = RoomId::from_proto(request.room_id);
1389    let calling_user_id = session.user_id;
1390    let calling_connection_id = session.connection_id;
1391    let called_user_id = UserId::from_proto(request.called_user_id);
1392    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1393    if !session
1394        .db()
1395        .await
1396        .has_contact(calling_user_id, called_user_id)
1397        .await?
1398    {
1399        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1400    }
1401
1402    let incoming_call = {
1403        let (room, incoming_call) = &mut *session
1404            .db()
1405            .await
1406            .call(
1407                room_id,
1408                calling_user_id,
1409                calling_connection_id,
1410                called_user_id,
1411                initial_project_id,
1412            )
1413            .await?;
1414        room_updated(&room, &session.peer);
1415        mem::take(incoming_call)
1416    };
1417    update_user_contacts(called_user_id, &session).await?;
1418
1419    let mut calls = session
1420        .connection_pool()
1421        .await
1422        .user_connection_ids(called_user_id)
1423        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1424        .collect::<FuturesUnordered<_>>();
1425
1426    while let Some(call_response) = calls.next().await {
1427        match call_response.as_ref() {
1428            Ok(_) => {
1429                response.send(proto::Ack {})?;
1430                return Ok(());
1431            }
1432            Err(_) => {
1433                call_response.trace_err();
1434            }
1435        }
1436    }
1437
1438    {
1439        let room = session
1440            .db()
1441            .await
1442            .call_failed(room_id, called_user_id)
1443            .await?;
1444        room_updated(&room, &session.peer);
1445    }
1446    update_user_contacts(called_user_id, &session).await?;
1447
1448    Err(anyhow!("failed to ring user"))?
1449}
1450
1451/// Cancel an outgoing call.
1452async fn cancel_call(
1453    request: proto::CancelCall,
1454    response: Response<proto::CancelCall>,
1455    session: Session,
1456) -> Result<()> {
1457    let called_user_id = UserId::from_proto(request.called_user_id);
1458    let room_id = RoomId::from_proto(request.room_id);
1459    {
1460        let room = session
1461            .db()
1462            .await
1463            .cancel_call(room_id, session.connection_id, called_user_id)
1464            .await?;
1465        room_updated(&room, &session.peer);
1466    }
1467
1468    for connection_id in session
1469        .connection_pool()
1470        .await
1471        .user_connection_ids(called_user_id)
1472    {
1473        session
1474            .peer
1475            .send(
1476                connection_id,
1477                proto::CallCanceled {
1478                    room_id: room_id.to_proto(),
1479                },
1480            )
1481            .trace_err();
1482    }
1483    response.send(proto::Ack {})?;
1484
1485    update_user_contacts(called_user_id, &session).await?;
1486    Ok(())
1487}
1488
1489/// Decline an incoming call.
1490async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1491    let room_id = RoomId::from_proto(message.room_id);
1492    {
1493        let room = session
1494            .db()
1495            .await
1496            .decline_call(Some(room_id), session.user_id)
1497            .await?
1498            .ok_or_else(|| anyhow!("failed to decline call"))?;
1499        room_updated(&room, &session.peer);
1500    }
1501
1502    for connection_id in session
1503        .connection_pool()
1504        .await
1505        .user_connection_ids(session.user_id)
1506    {
1507        session
1508            .peer
1509            .send(
1510                connection_id,
1511                proto::CallCanceled {
1512                    room_id: room_id.to_proto(),
1513                },
1514            )
1515            .trace_err();
1516    }
1517    update_user_contacts(session.user_id, &session).await?;
1518    Ok(())
1519}
1520
1521/// Updates other participants in the room with your current location.
1522async fn update_participant_location(
1523    request: proto::UpdateParticipantLocation,
1524    response: Response<proto::UpdateParticipantLocation>,
1525    session: Session,
1526) -> Result<()> {
1527    let room_id = RoomId::from_proto(request.room_id);
1528    let location = request
1529        .location
1530        .ok_or_else(|| anyhow!("invalid location"))?;
1531
1532    let db = session.db().await;
1533    let room = db
1534        .update_room_participant_location(room_id, session.connection_id, location)
1535        .await?;
1536
1537    room_updated(&room, &session.peer);
1538    response.send(proto::Ack {})?;
1539    Ok(())
1540}
1541
1542/// Share a project into the room.
1543async fn share_project(
1544    request: proto::ShareProject,
1545    response: Response<proto::ShareProject>,
1546    session: Session,
1547) -> Result<()> {
1548    let (project_id, room) = &*session
1549        .db()
1550        .await
1551        .share_project(
1552            RoomId::from_proto(request.room_id),
1553            session.connection_id,
1554            &request.worktrees,
1555        )
1556        .await?;
1557    response.send(proto::ShareProjectResponse {
1558        project_id: project_id.to_proto(),
1559    })?;
1560    room_updated(&room, &session.peer);
1561
1562    Ok(())
1563}
1564
1565/// Unshare a project from the room.
1566async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1567    let project_id = ProjectId::from_proto(message.project_id);
1568
1569    let (room, guest_connection_ids) = &*session
1570        .db()
1571        .await
1572        .unshare_project(project_id, session.connection_id)
1573        .await?;
1574
1575    broadcast(
1576        Some(session.connection_id),
1577        guest_connection_ids.iter().copied(),
1578        |conn_id| session.peer.send(conn_id, message.clone()),
1579    );
1580    room_updated(&room, &session.peer);
1581
1582    Ok(())
1583}
1584
1585/// Join someone elses shared project.
1586async fn join_project(
1587    request: proto::JoinProject,
1588    response: Response<proto::JoinProject>,
1589    session: Session,
1590) -> Result<()> {
1591    let project_id = ProjectId::from_proto(request.project_id);
1592
1593    tracing::info!(%project_id, "join project");
1594
1595    let (project, replica_id) = &mut *session
1596        .db()
1597        .await
1598        .join_project_in_room(project_id, session.connection_id)
1599        .await?;
1600
1601    join_project_internal(response, session, project, replica_id)
1602}
1603
1604trait JoinProjectInternalResponse {
1605    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1606}
1607impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1608    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1609        Response::<proto::JoinProject>::send(self, result)
1610    }
1611}
1612impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1613    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1614        Response::<proto::JoinHostedProject>::send(self, result)
1615    }
1616}
1617
1618fn join_project_internal(
1619    response: impl JoinProjectInternalResponse,
1620    session: Session,
1621    project: &mut Project,
1622    replica_id: &ReplicaId,
1623) -> Result<()> {
1624    let collaborators = project
1625        .collaborators
1626        .iter()
1627        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1628        .map(|collaborator| collaborator.to_proto())
1629        .collect::<Vec<_>>();
1630    let project_id = project.id;
1631    let guest_user_id = session.user_id;
1632
1633    let worktrees = project
1634        .worktrees
1635        .iter()
1636        .map(|(id, worktree)| proto::WorktreeMetadata {
1637            id: *id,
1638            root_name: worktree.root_name.clone(),
1639            visible: worktree.visible,
1640            abs_path: worktree.abs_path.clone(),
1641        })
1642        .collect::<Vec<_>>();
1643
1644    for collaborator in &collaborators {
1645        session
1646            .peer
1647            .send(
1648                collaborator.peer_id.unwrap().into(),
1649                proto::AddProjectCollaborator {
1650                    project_id: project_id.to_proto(),
1651                    collaborator: Some(proto::Collaborator {
1652                        peer_id: Some(session.connection_id.into()),
1653                        replica_id: replica_id.0 as u32,
1654                        user_id: guest_user_id.to_proto(),
1655                    }),
1656                },
1657            )
1658            .trace_err();
1659    }
1660
1661    // First, we send the metadata associated with each worktree.
1662    response.send(proto::JoinProjectResponse {
1663        project_id: project.id.0 as u64,
1664        worktrees: worktrees.clone(),
1665        replica_id: replica_id.0 as u32,
1666        collaborators: collaborators.clone(),
1667        language_servers: project.language_servers.clone(),
1668        role: project.role.into(), // todo
1669    })?;
1670
1671    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1672        #[cfg(any(test, feature = "test-support"))]
1673        const MAX_CHUNK_SIZE: usize = 2;
1674        #[cfg(not(any(test, feature = "test-support")))]
1675        const MAX_CHUNK_SIZE: usize = 256;
1676
1677        // Stream this worktree's entries.
1678        let message = proto::UpdateWorktree {
1679            project_id: project_id.to_proto(),
1680            worktree_id,
1681            abs_path: worktree.abs_path.clone(),
1682            root_name: worktree.root_name,
1683            updated_entries: worktree.entries,
1684            removed_entries: Default::default(),
1685            scan_id: worktree.scan_id,
1686            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1687            updated_repositories: worktree.repository_entries.into_values().collect(),
1688            removed_repositories: Default::default(),
1689        };
1690        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1691            session.peer.send(session.connection_id, update.clone())?;
1692        }
1693
1694        // Stream this worktree's diagnostics.
1695        for summary in worktree.diagnostic_summaries {
1696            session.peer.send(
1697                session.connection_id,
1698                proto::UpdateDiagnosticSummary {
1699                    project_id: project_id.to_proto(),
1700                    worktree_id: worktree.id,
1701                    summary: Some(summary),
1702                },
1703            )?;
1704        }
1705
1706        for settings_file in worktree.settings_files {
1707            session.peer.send(
1708                session.connection_id,
1709                proto::UpdateWorktreeSettings {
1710                    project_id: project_id.to_proto(),
1711                    worktree_id: worktree.id,
1712                    path: settings_file.path,
1713                    content: Some(settings_file.content),
1714                },
1715            )?;
1716        }
1717    }
1718
1719    for language_server in &project.language_servers {
1720        session.peer.send(
1721            session.connection_id,
1722            proto::UpdateLanguageServer {
1723                project_id: project_id.to_proto(),
1724                language_server_id: language_server.id,
1725                variant: Some(
1726                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1727                        proto::LspDiskBasedDiagnosticsUpdated {},
1728                    ),
1729                ),
1730            },
1731        )?;
1732    }
1733
1734    Ok(())
1735}
1736
1737/// Leave someone elses shared project.
1738async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1739    let sender_id = session.connection_id;
1740    let project_id = ProjectId::from_proto(request.project_id);
1741    let db = session.db().await;
1742    if db.is_hosted_project(project_id).await? {
1743        let project = db.leave_hosted_project(project_id, sender_id).await?;
1744        project_left(&project, &session);
1745        return Ok(());
1746    }
1747
1748    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1749    tracing::info!(
1750        %project_id,
1751        host_user_id = ?project.host_user_id,
1752        host_connection_id = ?project.host_connection_id,
1753        "leave project"
1754    );
1755
1756    project_left(&project, &session);
1757    room_updated(&room, &session.peer);
1758
1759    Ok(())
1760}
1761
1762async fn join_hosted_project(
1763    request: proto::JoinHostedProject,
1764    response: Response<proto::JoinHostedProject>,
1765    session: Session,
1766) -> Result<()> {
1767    let (mut project, replica_id) = session
1768        .db()
1769        .await
1770        .join_hosted_project(
1771            HostedProjectId(request.id as i32),
1772            session.user_id,
1773            session.connection_id,
1774        )
1775        .await?;
1776
1777    join_project_internal(response, session, &mut project, &replica_id)
1778}
1779
1780/// Updates other participants with changes to the project
1781async fn update_project(
1782    request: proto::UpdateProject,
1783    response: Response<proto::UpdateProject>,
1784    session: Session,
1785) -> Result<()> {
1786    let project_id = ProjectId::from_proto(request.project_id);
1787    let (room, guest_connection_ids) = &*session
1788        .db()
1789        .await
1790        .update_project(project_id, session.connection_id, &request.worktrees)
1791        .await?;
1792    broadcast(
1793        Some(session.connection_id),
1794        guest_connection_ids.iter().copied(),
1795        |connection_id| {
1796            session
1797                .peer
1798                .forward_send(session.connection_id, connection_id, request.clone())
1799        },
1800    );
1801    room_updated(&room, &session.peer);
1802    response.send(proto::Ack {})?;
1803
1804    Ok(())
1805}
1806
1807/// Updates other participants with changes to the worktree
1808async fn update_worktree(
1809    request: proto::UpdateWorktree,
1810    response: Response<proto::UpdateWorktree>,
1811    session: Session,
1812) -> Result<()> {
1813    let guest_connection_ids = session
1814        .db()
1815        .await
1816        .update_worktree(&request, session.connection_id)
1817        .await?;
1818
1819    broadcast(
1820        Some(session.connection_id),
1821        guest_connection_ids.iter().copied(),
1822        |connection_id| {
1823            session
1824                .peer
1825                .forward_send(session.connection_id, connection_id, request.clone())
1826        },
1827    );
1828    response.send(proto::Ack {})?;
1829    Ok(())
1830}
1831
1832/// Updates other participants with changes to the diagnostics
1833async fn update_diagnostic_summary(
1834    message: proto::UpdateDiagnosticSummary,
1835    session: Session,
1836) -> Result<()> {
1837    let guest_connection_ids = session
1838        .db()
1839        .await
1840        .update_diagnostic_summary(&message, session.connection_id)
1841        .await?;
1842
1843    broadcast(
1844        Some(session.connection_id),
1845        guest_connection_ids.iter().copied(),
1846        |connection_id| {
1847            session
1848                .peer
1849                .forward_send(session.connection_id, connection_id, message.clone())
1850        },
1851    );
1852
1853    Ok(())
1854}
1855
1856/// Updates other participants with changes to the worktree settings
1857async fn update_worktree_settings(
1858    message: proto::UpdateWorktreeSettings,
1859    session: Session,
1860) -> Result<()> {
1861    let guest_connection_ids = session
1862        .db()
1863        .await
1864        .update_worktree_settings(&message, session.connection_id)
1865        .await?;
1866
1867    broadcast(
1868        Some(session.connection_id),
1869        guest_connection_ids.iter().copied(),
1870        |connection_id| {
1871            session
1872                .peer
1873                .forward_send(session.connection_id, connection_id, message.clone())
1874        },
1875    );
1876
1877    Ok(())
1878}
1879
1880/// Notify other participants that a  language server has started.
1881async fn start_language_server(
1882    request: proto::StartLanguageServer,
1883    session: Session,
1884) -> Result<()> {
1885    let guest_connection_ids = session
1886        .db()
1887        .await
1888        .start_language_server(&request, session.connection_id)
1889        .await?;
1890
1891    broadcast(
1892        Some(session.connection_id),
1893        guest_connection_ids.iter().copied(),
1894        |connection_id| {
1895            session
1896                .peer
1897                .forward_send(session.connection_id, connection_id, request.clone())
1898        },
1899    );
1900    Ok(())
1901}
1902
1903/// Notify other participants that a language server has changed.
1904async fn update_language_server(
1905    request: proto::UpdateLanguageServer,
1906    session: Session,
1907) -> Result<()> {
1908    let project_id = ProjectId::from_proto(request.project_id);
1909    let project_connection_ids = session
1910        .db()
1911        .await
1912        .project_connection_ids(project_id, session.connection_id)
1913        .await?;
1914    broadcast(
1915        Some(session.connection_id),
1916        project_connection_ids.iter().copied(),
1917        |connection_id| {
1918            session
1919                .peer
1920                .forward_send(session.connection_id, connection_id, request.clone())
1921        },
1922    );
1923    Ok(())
1924}
1925
1926/// forward a project request to the host. These requests should be read only
1927/// as guests are allowed to send them.
1928async fn forward_read_only_project_request<T>(
1929    request: T,
1930    response: Response<T>,
1931    session: Session,
1932) -> Result<()>
1933where
1934    T: EntityMessage + RequestMessage,
1935{
1936    let project_id = ProjectId::from_proto(request.remote_entity_id());
1937    let host_connection_id = session
1938        .db()
1939        .await
1940        .host_for_read_only_project_request(project_id, session.connection_id)
1941        .await?;
1942    let payload = session
1943        .peer
1944        .forward_request(session.connection_id, host_connection_id, request)
1945        .await?;
1946    response.send(payload)?;
1947    Ok(())
1948}
1949
1950/// forward a project request to the host. These requests are disallowed
1951/// for guests.
1952async fn forward_mutating_project_request<T>(
1953    request: T,
1954    response: Response<T>,
1955    session: Session,
1956) -> Result<()>
1957where
1958    T: EntityMessage + RequestMessage,
1959{
1960    let project_id = ProjectId::from_proto(request.remote_entity_id());
1961    let host_connection_id = session
1962        .db()
1963        .await
1964        .host_for_mutating_project_request(project_id, session.connection_id)
1965        .await?;
1966    let payload = session
1967        .peer
1968        .forward_request(session.connection_id, host_connection_id, request)
1969        .await?;
1970    response.send(payload)?;
1971    Ok(())
1972}
1973
1974/// Notify other participants that a new buffer has been created
1975async fn create_buffer_for_peer(
1976    request: proto::CreateBufferForPeer,
1977    session: Session,
1978) -> Result<()> {
1979    session
1980        .db()
1981        .await
1982        .check_user_is_project_host(
1983            ProjectId::from_proto(request.project_id),
1984            session.connection_id,
1985        )
1986        .await?;
1987    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1988    session
1989        .peer
1990        .forward_send(session.connection_id, peer_id.into(), request)?;
1991    Ok(())
1992}
1993
1994/// Notify other participants that a buffer has been updated. This is
1995/// allowed for guests as long as the update is limited to selections.
1996async fn update_buffer(
1997    request: proto::UpdateBuffer,
1998    response: Response<proto::UpdateBuffer>,
1999    session: Session,
2000) -> Result<()> {
2001    let project_id = ProjectId::from_proto(request.project_id);
2002    let mut guest_connection_ids;
2003    let mut host_connection_id = None;
2004
2005    let mut requires_write_permission = false;
2006
2007    for op in request.operations.iter() {
2008        match op.variant {
2009            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2010            Some(_) => requires_write_permission = true,
2011        }
2012    }
2013
2014    {
2015        let collaborators = session
2016            .db()
2017            .await
2018            .project_collaborators_for_buffer_update(
2019                project_id,
2020                session.connection_id,
2021                requires_write_permission,
2022            )
2023            .await?;
2024        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2025        for collaborator in collaborators.iter() {
2026            if collaborator.is_host {
2027                host_connection_id = Some(collaborator.connection_id);
2028            } else {
2029                guest_connection_ids.push(collaborator.connection_id);
2030            }
2031        }
2032    }
2033    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2034
2035    broadcast(
2036        Some(session.connection_id),
2037        guest_connection_ids,
2038        |connection_id| {
2039            session
2040                .peer
2041                .forward_send(session.connection_id, connection_id, request.clone())
2042        },
2043    );
2044    if host_connection_id != session.connection_id {
2045        session
2046            .peer
2047            .forward_request(session.connection_id, host_connection_id, request.clone())
2048            .await?;
2049    }
2050
2051    response.send(proto::Ack {})?;
2052    Ok(())
2053}
2054
2055/// Notify other participants that a project has been updated.
2056async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2057    request: T,
2058    session: Session,
2059) -> Result<()> {
2060    let project_id = ProjectId::from_proto(request.remote_entity_id());
2061    let project_connection_ids = session
2062        .db()
2063        .await
2064        .project_connection_ids(project_id, session.connection_id)
2065        .await?;
2066
2067    broadcast(
2068        Some(session.connection_id),
2069        project_connection_ids.iter().copied(),
2070        |connection_id| {
2071            session
2072                .peer
2073                .forward_send(session.connection_id, connection_id, request.clone())
2074        },
2075    );
2076    Ok(())
2077}
2078
2079/// Start following another user in a call.
2080async fn follow(
2081    request: proto::Follow,
2082    response: Response<proto::Follow>,
2083    session: Session,
2084) -> Result<()> {
2085    let room_id = RoomId::from_proto(request.room_id);
2086    let project_id = request.project_id.map(ProjectId::from_proto);
2087    let leader_id = request
2088        .leader_id
2089        .ok_or_else(|| anyhow!("invalid leader id"))?
2090        .into();
2091    let follower_id = session.connection_id;
2092
2093    session
2094        .db()
2095        .await
2096        .check_room_participants(room_id, leader_id, session.connection_id)
2097        .await?;
2098
2099    let response_payload = session
2100        .peer
2101        .forward_request(session.connection_id, leader_id, request)
2102        .await?;
2103    response.send(response_payload)?;
2104
2105    if let Some(project_id) = project_id {
2106        let room = session
2107            .db()
2108            .await
2109            .follow(room_id, project_id, leader_id, follower_id)
2110            .await?;
2111        room_updated(&room, &session.peer);
2112    }
2113
2114    Ok(())
2115}
2116
2117/// Stop following another user in a call.
2118async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2119    let room_id = RoomId::from_proto(request.room_id);
2120    let project_id = request.project_id.map(ProjectId::from_proto);
2121    let leader_id = request
2122        .leader_id
2123        .ok_or_else(|| anyhow!("invalid leader id"))?
2124        .into();
2125    let follower_id = session.connection_id;
2126
2127    session
2128        .db()
2129        .await
2130        .check_room_participants(room_id, leader_id, session.connection_id)
2131        .await?;
2132
2133    session
2134        .peer
2135        .forward_send(session.connection_id, leader_id, request)?;
2136
2137    if let Some(project_id) = project_id {
2138        let room = session
2139            .db()
2140            .await
2141            .unfollow(room_id, project_id, leader_id, follower_id)
2142            .await?;
2143        room_updated(&room, &session.peer);
2144    }
2145
2146    Ok(())
2147}
2148
2149/// Notify everyone following you of your current location.
2150async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2151    let room_id = RoomId::from_proto(request.room_id);
2152    let database = session.db.lock().await;
2153
2154    let connection_ids = if let Some(project_id) = request.project_id {
2155        let project_id = ProjectId::from_proto(project_id);
2156        database
2157            .project_connection_ids(project_id, session.connection_id)
2158            .await?
2159    } else {
2160        database
2161            .room_connection_ids(room_id, session.connection_id)
2162            .await?
2163    };
2164
2165    // For now, don't send view update messages back to that view's current leader.
2166    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2167        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2168        _ => None,
2169    });
2170
2171    for connection_id in connection_ids.iter().cloned() {
2172        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2173            session
2174                .peer
2175                .forward_send(session.connection_id, connection_id, request.clone())?;
2176        }
2177    }
2178    Ok(())
2179}
2180
2181/// Get public data about users.
2182async fn get_users(
2183    request: proto::GetUsers,
2184    response: Response<proto::GetUsers>,
2185    session: Session,
2186) -> Result<()> {
2187    let user_ids = request
2188        .user_ids
2189        .into_iter()
2190        .map(UserId::from_proto)
2191        .collect();
2192    let users = session
2193        .db()
2194        .await
2195        .get_users_by_ids(user_ids)
2196        .await?
2197        .into_iter()
2198        .map(|user| proto::User {
2199            id: user.id.to_proto(),
2200            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2201            github_login: user.github_login,
2202        })
2203        .collect();
2204    response.send(proto::UsersResponse { users })?;
2205    Ok(())
2206}
2207
2208/// Search for users (to invite) buy Github login
2209async fn fuzzy_search_users(
2210    request: proto::FuzzySearchUsers,
2211    response: Response<proto::FuzzySearchUsers>,
2212    session: Session,
2213) -> Result<()> {
2214    let query = request.query;
2215    let users = match query.len() {
2216        0 => vec![],
2217        1 | 2 => session
2218            .db()
2219            .await
2220            .get_user_by_github_login(&query)
2221            .await?
2222            .into_iter()
2223            .collect(),
2224        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2225    };
2226    let users = users
2227        .into_iter()
2228        .filter(|user| user.id != session.user_id)
2229        .map(|user| proto::User {
2230            id: user.id.to_proto(),
2231            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2232            github_login: user.github_login,
2233        })
2234        .collect();
2235    response.send(proto::UsersResponse { users })?;
2236    Ok(())
2237}
2238
2239/// Send a contact request to another user.
2240async fn request_contact(
2241    request: proto::RequestContact,
2242    response: Response<proto::RequestContact>,
2243    session: Session,
2244) -> Result<()> {
2245    let requester_id = session.user_id;
2246    let responder_id = UserId::from_proto(request.responder_id);
2247    if requester_id == responder_id {
2248        return Err(anyhow!("cannot add yourself as a contact"))?;
2249    }
2250
2251    let notifications = session
2252        .db()
2253        .await
2254        .send_contact_request(requester_id, responder_id)
2255        .await?;
2256
2257    // Update outgoing contact requests of requester
2258    let mut update = proto::UpdateContacts::default();
2259    update.outgoing_requests.push(responder_id.to_proto());
2260    for connection_id in session
2261        .connection_pool()
2262        .await
2263        .user_connection_ids(requester_id)
2264    {
2265        session.peer.send(connection_id, update.clone())?;
2266    }
2267
2268    // Update incoming contact requests of responder
2269    let mut update = proto::UpdateContacts::default();
2270    update
2271        .incoming_requests
2272        .push(proto::IncomingContactRequest {
2273            requester_id: requester_id.to_proto(),
2274        });
2275    let connection_pool = session.connection_pool().await;
2276    for connection_id in connection_pool.user_connection_ids(responder_id) {
2277        session.peer.send(connection_id, update.clone())?;
2278    }
2279
2280    send_notifications(&connection_pool, &session.peer, notifications);
2281
2282    response.send(proto::Ack {})?;
2283    Ok(())
2284}
2285
2286/// Accept or decline a contact request
2287async fn respond_to_contact_request(
2288    request: proto::RespondToContactRequest,
2289    response: Response<proto::RespondToContactRequest>,
2290    session: Session,
2291) -> Result<()> {
2292    let responder_id = session.user_id;
2293    let requester_id = UserId::from_proto(request.requester_id);
2294    let db = session.db().await;
2295    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2296        db.dismiss_contact_notification(responder_id, requester_id)
2297            .await?;
2298    } else {
2299        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2300
2301        let notifications = db
2302            .respond_to_contact_request(responder_id, requester_id, accept)
2303            .await?;
2304        let requester_busy = db.is_user_busy(requester_id).await?;
2305        let responder_busy = db.is_user_busy(responder_id).await?;
2306
2307        let pool = session.connection_pool().await;
2308        // Update responder with new contact
2309        let mut update = proto::UpdateContacts::default();
2310        if accept {
2311            update
2312                .contacts
2313                .push(contact_for_user(requester_id, requester_busy, &pool));
2314        }
2315        update
2316            .remove_incoming_requests
2317            .push(requester_id.to_proto());
2318        for connection_id in pool.user_connection_ids(responder_id) {
2319            session.peer.send(connection_id, update.clone())?;
2320        }
2321
2322        // Update requester with new contact
2323        let mut update = proto::UpdateContacts::default();
2324        if accept {
2325            update
2326                .contacts
2327                .push(contact_for_user(responder_id, responder_busy, &pool));
2328        }
2329        update
2330            .remove_outgoing_requests
2331            .push(responder_id.to_proto());
2332
2333        for connection_id in pool.user_connection_ids(requester_id) {
2334            session.peer.send(connection_id, update.clone())?;
2335        }
2336
2337        send_notifications(&pool, &session.peer, notifications);
2338    }
2339
2340    response.send(proto::Ack {})?;
2341    Ok(())
2342}
2343
2344/// Remove a contact.
2345async fn remove_contact(
2346    request: proto::RemoveContact,
2347    response: Response<proto::RemoveContact>,
2348    session: Session,
2349) -> Result<()> {
2350    let requester_id = session.user_id;
2351    let responder_id = UserId::from_proto(request.user_id);
2352    let db = session.db().await;
2353    let (contact_accepted, deleted_notification_id) =
2354        db.remove_contact(requester_id, responder_id).await?;
2355
2356    let pool = session.connection_pool().await;
2357    // Update outgoing contact requests of requester
2358    let mut update = proto::UpdateContacts::default();
2359    if contact_accepted {
2360        update.remove_contacts.push(responder_id.to_proto());
2361    } else {
2362        update
2363            .remove_outgoing_requests
2364            .push(responder_id.to_proto());
2365    }
2366    for connection_id in pool.user_connection_ids(requester_id) {
2367        session.peer.send(connection_id, update.clone())?;
2368    }
2369
2370    // Update incoming contact requests of responder
2371    let mut update = proto::UpdateContacts::default();
2372    if contact_accepted {
2373        update.remove_contacts.push(requester_id.to_proto());
2374    } else {
2375        update
2376            .remove_incoming_requests
2377            .push(requester_id.to_proto());
2378    }
2379    for connection_id in pool.user_connection_ids(responder_id) {
2380        session.peer.send(connection_id, update.clone())?;
2381        if let Some(notification_id) = deleted_notification_id {
2382            session.peer.send(
2383                connection_id,
2384                proto::DeleteNotification {
2385                    notification_id: notification_id.to_proto(),
2386                },
2387            )?;
2388        }
2389    }
2390
2391    response.send(proto::Ack {})?;
2392    Ok(())
2393}
2394
2395/// Creates a new channel.
2396async fn create_channel(
2397    request: proto::CreateChannel,
2398    response: Response<proto::CreateChannel>,
2399    session: Session,
2400) -> Result<()> {
2401    let db = session.db().await;
2402
2403    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2404    let (channel, owner, channel_members) = db
2405        .create_channel(&request.name, parent_id, session.user_id)
2406        .await?;
2407
2408    response.send(proto::CreateChannelResponse {
2409        channel: Some(channel.to_proto()),
2410        parent_id: request.parent_id,
2411    })?;
2412
2413    let connection_pool = session.connection_pool().await;
2414    if let Some(owner) = owner {
2415        let update = proto::UpdateUserChannels {
2416            channel_memberships: vec![proto::ChannelMembership {
2417                channel_id: owner.channel_id.to_proto(),
2418                role: owner.role.into(),
2419            }],
2420            ..Default::default()
2421        };
2422        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2423            session.peer.send(connection_id, update.clone())?;
2424        }
2425    }
2426
2427    for channel_member in channel_members {
2428        if !channel_member.role.can_see_channel(channel.visibility) {
2429            continue;
2430        }
2431
2432        let update = proto::UpdateChannels {
2433            channels: vec![channel.to_proto()],
2434            ..Default::default()
2435        };
2436        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2437            session.peer.send(connection_id, update.clone())?;
2438        }
2439    }
2440
2441    Ok(())
2442}
2443
2444/// Delete a channel
2445async fn delete_channel(
2446    request: proto::DeleteChannel,
2447    response: Response<proto::DeleteChannel>,
2448    session: Session,
2449) -> Result<()> {
2450    let db = session.db().await;
2451
2452    let channel_id = request.channel_id;
2453    let (removed_channels, member_ids) = db
2454        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2455        .await?;
2456    response.send(proto::Ack {})?;
2457
2458    // Notify members of removed channels
2459    let mut update = proto::UpdateChannels::default();
2460    update
2461        .delete_channels
2462        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2463
2464    let connection_pool = session.connection_pool().await;
2465    for member_id in member_ids {
2466        for connection_id in connection_pool.user_connection_ids(member_id) {
2467            session.peer.send(connection_id, update.clone())?;
2468        }
2469    }
2470
2471    Ok(())
2472}
2473
2474/// Invite someone to join a channel.
2475async fn invite_channel_member(
2476    request: proto::InviteChannelMember,
2477    response: Response<proto::InviteChannelMember>,
2478    session: Session,
2479) -> Result<()> {
2480    let db = session.db().await;
2481    let channel_id = ChannelId::from_proto(request.channel_id);
2482    let invitee_id = UserId::from_proto(request.user_id);
2483    let InviteMemberResult {
2484        channel,
2485        notifications,
2486    } = db
2487        .invite_channel_member(
2488            channel_id,
2489            invitee_id,
2490            session.user_id,
2491            request.role().into(),
2492        )
2493        .await?;
2494
2495    let update = proto::UpdateChannels {
2496        channel_invitations: vec![channel.to_proto()],
2497        ..Default::default()
2498    };
2499
2500    let connection_pool = session.connection_pool().await;
2501    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2502        session.peer.send(connection_id, update.clone())?;
2503    }
2504
2505    send_notifications(&connection_pool, &session.peer, notifications);
2506
2507    response.send(proto::Ack {})?;
2508    Ok(())
2509}
2510
2511/// remove someone from a channel
2512async fn remove_channel_member(
2513    request: proto::RemoveChannelMember,
2514    response: Response<proto::RemoveChannelMember>,
2515    session: Session,
2516) -> Result<()> {
2517    let db = session.db().await;
2518    let channel_id = ChannelId::from_proto(request.channel_id);
2519    let member_id = UserId::from_proto(request.user_id);
2520
2521    let RemoveChannelMemberResult {
2522        membership_update,
2523        notification_id,
2524    } = db
2525        .remove_channel_member(channel_id, member_id, session.user_id)
2526        .await?;
2527
2528    let connection_pool = &session.connection_pool().await;
2529    notify_membership_updated(
2530        &connection_pool,
2531        membership_update,
2532        member_id,
2533        &session.peer,
2534    );
2535    for connection_id in connection_pool.user_connection_ids(member_id) {
2536        if let Some(notification_id) = notification_id {
2537            session
2538                .peer
2539                .send(
2540                    connection_id,
2541                    proto::DeleteNotification {
2542                        notification_id: notification_id.to_proto(),
2543                    },
2544                )
2545                .trace_err();
2546        }
2547    }
2548
2549    response.send(proto::Ack {})?;
2550    Ok(())
2551}
2552
2553/// Toggle the channel between public and private.
2554/// Care is taken to maintain the invariant that public channels only descend from public channels,
2555/// (though members-only channels can appear at any point in the hierarchy).
2556async fn set_channel_visibility(
2557    request: proto::SetChannelVisibility,
2558    response: Response<proto::SetChannelVisibility>,
2559    session: Session,
2560) -> Result<()> {
2561    let db = session.db().await;
2562    let channel_id = ChannelId::from_proto(request.channel_id);
2563    let visibility = request.visibility().into();
2564
2565    let (channel, channel_members) = db
2566        .set_channel_visibility(channel_id, visibility, session.user_id)
2567        .await?;
2568
2569    let connection_pool = session.connection_pool().await;
2570    for member in channel_members {
2571        let update = if member.role.can_see_channel(channel.visibility) {
2572            proto::UpdateChannels {
2573                channels: vec![channel.to_proto()],
2574                ..Default::default()
2575            }
2576        } else {
2577            proto::UpdateChannels {
2578                delete_channels: vec![channel.id.to_proto()],
2579                ..Default::default()
2580            }
2581        };
2582
2583        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2584            session.peer.send(connection_id, update.clone())?;
2585        }
2586    }
2587
2588    response.send(proto::Ack {})?;
2589    Ok(())
2590}
2591
2592/// Alter the role for a user in the channel.
2593async fn set_channel_member_role(
2594    request: proto::SetChannelMemberRole,
2595    response: Response<proto::SetChannelMemberRole>,
2596    session: Session,
2597) -> Result<()> {
2598    let db = session.db().await;
2599    let channel_id = ChannelId::from_proto(request.channel_id);
2600    let member_id = UserId::from_proto(request.user_id);
2601    let result = db
2602        .set_channel_member_role(
2603            channel_id,
2604            session.user_id,
2605            member_id,
2606            request.role().into(),
2607        )
2608        .await?;
2609
2610    match result {
2611        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2612            let connection_pool = session.connection_pool().await;
2613            notify_membership_updated(
2614                &connection_pool,
2615                membership_update,
2616                member_id,
2617                &session.peer,
2618            )
2619        }
2620        db::SetMemberRoleResult::InviteUpdated(channel) => {
2621            let update = proto::UpdateChannels {
2622                channel_invitations: vec![channel.to_proto()],
2623                ..Default::default()
2624            };
2625
2626            for connection_id in session
2627                .connection_pool()
2628                .await
2629                .user_connection_ids(member_id)
2630            {
2631                session.peer.send(connection_id, update.clone())?;
2632            }
2633        }
2634    }
2635
2636    response.send(proto::Ack {})?;
2637    Ok(())
2638}
2639
2640/// Change the name of a channel
2641async fn rename_channel(
2642    request: proto::RenameChannel,
2643    response: Response<proto::RenameChannel>,
2644    session: Session,
2645) -> Result<()> {
2646    let db = session.db().await;
2647    let channel_id = ChannelId::from_proto(request.channel_id);
2648    let (channel, channel_members) = db
2649        .rename_channel(channel_id, session.user_id, &request.name)
2650        .await?;
2651
2652    response.send(proto::RenameChannelResponse {
2653        channel: Some(channel.to_proto()),
2654    })?;
2655
2656    let connection_pool = session.connection_pool().await;
2657    for channel_member in channel_members {
2658        if !channel_member.role.can_see_channel(channel.visibility) {
2659            continue;
2660        }
2661        let update = proto::UpdateChannels {
2662            channels: vec![channel.to_proto()],
2663            ..Default::default()
2664        };
2665        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2666            session.peer.send(connection_id, update.clone())?;
2667        }
2668    }
2669
2670    Ok(())
2671}
2672
2673/// Move a channel to a new parent.
2674async fn move_channel(
2675    request: proto::MoveChannel,
2676    response: Response<proto::MoveChannel>,
2677    session: Session,
2678) -> Result<()> {
2679    let channel_id = ChannelId::from_proto(request.channel_id);
2680    let to = ChannelId::from_proto(request.to);
2681
2682    let (channels, channel_members) = session
2683        .db()
2684        .await
2685        .move_channel(channel_id, to, session.user_id)
2686        .await?;
2687
2688    let connection_pool = session.connection_pool().await;
2689    for member in channel_members {
2690        let channels = channels
2691            .iter()
2692            .filter_map(|channel| {
2693                if member.role.can_see_channel(channel.visibility) {
2694                    Some(channel.to_proto())
2695                } else {
2696                    None
2697                }
2698            })
2699            .collect::<Vec<_>>();
2700        if channels.is_empty() {
2701            continue;
2702        }
2703
2704        let update = proto::UpdateChannels {
2705            channels,
2706            ..Default::default()
2707        };
2708
2709        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2710            session.peer.send(connection_id, update.clone())?;
2711        }
2712    }
2713
2714    response.send(Ack {})?;
2715    Ok(())
2716}
2717
2718/// Get the list of channel members
2719async fn get_channel_members(
2720    request: proto::GetChannelMembers,
2721    response: Response<proto::GetChannelMembers>,
2722    session: Session,
2723) -> Result<()> {
2724    let db = session.db().await;
2725    let channel_id = ChannelId::from_proto(request.channel_id);
2726    let members = db
2727        .get_channel_participant_details(channel_id, session.user_id)
2728        .await?;
2729    response.send(proto::GetChannelMembersResponse { members })?;
2730    Ok(())
2731}
2732
2733/// Accept or decline a channel invitation.
2734async fn respond_to_channel_invite(
2735    request: proto::RespondToChannelInvite,
2736    response: Response<proto::RespondToChannelInvite>,
2737    session: Session,
2738) -> Result<()> {
2739    let db = session.db().await;
2740    let channel_id = ChannelId::from_proto(request.channel_id);
2741    let RespondToChannelInvite {
2742        membership_update,
2743        notifications,
2744    } = db
2745        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2746        .await?;
2747
2748    let connection_pool = session.connection_pool().await;
2749    if let Some(membership_update) = membership_update {
2750        notify_membership_updated(
2751            &connection_pool,
2752            membership_update,
2753            session.user_id,
2754            &session.peer,
2755        );
2756    } else {
2757        let update = proto::UpdateChannels {
2758            remove_channel_invitations: vec![channel_id.to_proto()],
2759            ..Default::default()
2760        };
2761
2762        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2763            session.peer.send(connection_id, update.clone())?;
2764        }
2765    };
2766
2767    send_notifications(&connection_pool, &session.peer, notifications);
2768
2769    response.send(proto::Ack {})?;
2770
2771    Ok(())
2772}
2773
2774/// Join the channels' room
2775async fn join_channel(
2776    request: proto::JoinChannel,
2777    response: Response<proto::JoinChannel>,
2778    session: Session,
2779) -> Result<()> {
2780    let channel_id = ChannelId::from_proto(request.channel_id);
2781    join_channel_internal(channel_id, Box::new(response), session).await
2782}
2783
2784trait JoinChannelInternalResponse {
2785    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2786}
2787impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2788    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2789        Response::<proto::JoinChannel>::send(self, result)
2790    }
2791}
2792impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2793    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2794        Response::<proto::JoinRoom>::send(self, result)
2795    }
2796}
2797
2798async fn join_channel_internal(
2799    channel_id: ChannelId,
2800    response: Box<impl JoinChannelInternalResponse>,
2801    session: Session,
2802) -> Result<()> {
2803    let joined_room = {
2804        leave_room_for_session(&session).await?;
2805        let db = session.db().await;
2806
2807        let (joined_room, membership_updated, role) = db
2808            .join_channel(channel_id, session.user_id, session.connection_id)
2809            .await?;
2810
2811        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2812            let (can_publish, token) = if role == ChannelRole::Guest {
2813                (
2814                    false,
2815                    live_kit
2816                        .guest_token(
2817                            &joined_room.room.live_kit_room,
2818                            &session.user_id.to_string(),
2819                        )
2820                        .trace_err()?,
2821                )
2822            } else {
2823                (
2824                    true,
2825                    live_kit
2826                        .room_token(
2827                            &joined_room.room.live_kit_room,
2828                            &session.user_id.to_string(),
2829                        )
2830                        .trace_err()?,
2831                )
2832            };
2833
2834            Some(LiveKitConnectionInfo {
2835                server_url: live_kit.url().into(),
2836                token,
2837                can_publish,
2838            })
2839        });
2840
2841        response.send(proto::JoinRoomResponse {
2842            room: Some(joined_room.room.clone()),
2843            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2844            live_kit_connection_info,
2845        })?;
2846
2847        let connection_pool = session.connection_pool().await;
2848        if let Some(membership_updated) = membership_updated {
2849            notify_membership_updated(
2850                &connection_pool,
2851                membership_updated,
2852                session.user_id,
2853                &session.peer,
2854            );
2855        }
2856
2857        room_updated(&joined_room.room, &session.peer);
2858
2859        joined_room
2860    };
2861
2862    channel_updated(
2863        channel_id,
2864        &joined_room.room,
2865        &joined_room.channel_members,
2866        &session.peer,
2867        &*session.connection_pool().await,
2868    );
2869
2870    update_user_contacts(session.user_id, &session).await?;
2871    Ok(())
2872}
2873
2874/// Start editing the channel notes
2875async fn join_channel_buffer(
2876    request: proto::JoinChannelBuffer,
2877    response: Response<proto::JoinChannelBuffer>,
2878    session: Session,
2879) -> Result<()> {
2880    let db = session.db().await;
2881    let channel_id = ChannelId::from_proto(request.channel_id);
2882
2883    let open_response = db
2884        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2885        .await?;
2886
2887    let collaborators = open_response.collaborators.clone();
2888    response.send(open_response)?;
2889
2890    let update = UpdateChannelBufferCollaborators {
2891        channel_id: channel_id.to_proto(),
2892        collaborators: collaborators.clone(),
2893    };
2894    channel_buffer_updated(
2895        session.connection_id,
2896        collaborators
2897            .iter()
2898            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2899        &update,
2900        &session.peer,
2901    );
2902
2903    Ok(())
2904}
2905
2906/// Edit the channel notes
2907async fn update_channel_buffer(
2908    request: proto::UpdateChannelBuffer,
2909    session: Session,
2910) -> Result<()> {
2911    let db = session.db().await;
2912    let channel_id = ChannelId::from_proto(request.channel_id);
2913
2914    let (collaborators, non_collaborators, epoch, version) = db
2915        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2916        .await?;
2917
2918    channel_buffer_updated(
2919        session.connection_id,
2920        collaborators,
2921        &proto::UpdateChannelBuffer {
2922            channel_id: channel_id.to_proto(),
2923            operations: request.operations,
2924        },
2925        &session.peer,
2926    );
2927
2928    let pool = &*session.connection_pool().await;
2929
2930    broadcast(
2931        None,
2932        non_collaborators
2933            .iter()
2934            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2935        |peer_id| {
2936            session.peer.send(
2937                peer_id,
2938                proto::UpdateChannels {
2939                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2940                        channel_id: channel_id.to_proto(),
2941                        epoch: epoch as u64,
2942                        version: version.clone(),
2943                    }],
2944                    ..Default::default()
2945                },
2946            )
2947        },
2948    );
2949
2950    Ok(())
2951}
2952
2953/// Rejoin the channel notes after a connection blip
2954async fn rejoin_channel_buffers(
2955    request: proto::RejoinChannelBuffers,
2956    response: Response<proto::RejoinChannelBuffers>,
2957    session: Session,
2958) -> Result<()> {
2959    let db = session.db().await;
2960    let buffers = db
2961        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2962        .await?;
2963
2964    for rejoined_buffer in &buffers {
2965        let collaborators_to_notify = rejoined_buffer
2966            .buffer
2967            .collaborators
2968            .iter()
2969            .filter_map(|c| Some(c.peer_id?.into()));
2970        channel_buffer_updated(
2971            session.connection_id,
2972            collaborators_to_notify,
2973            &proto::UpdateChannelBufferCollaborators {
2974                channel_id: rejoined_buffer.buffer.channel_id,
2975                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2976            },
2977            &session.peer,
2978        );
2979    }
2980
2981    response.send(proto::RejoinChannelBuffersResponse {
2982        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2983    })?;
2984
2985    Ok(())
2986}
2987
2988/// Stop editing the channel notes
2989async fn leave_channel_buffer(
2990    request: proto::LeaveChannelBuffer,
2991    response: Response<proto::LeaveChannelBuffer>,
2992    session: Session,
2993) -> Result<()> {
2994    let db = session.db().await;
2995    let channel_id = ChannelId::from_proto(request.channel_id);
2996
2997    let left_buffer = db
2998        .leave_channel_buffer(channel_id, session.connection_id)
2999        .await?;
3000
3001    response.send(Ack {})?;
3002
3003    channel_buffer_updated(
3004        session.connection_id,
3005        left_buffer.connections,
3006        &proto::UpdateChannelBufferCollaborators {
3007            channel_id: channel_id.to_proto(),
3008            collaborators: left_buffer.collaborators,
3009        },
3010        &session.peer,
3011    );
3012
3013    Ok(())
3014}
3015
3016fn channel_buffer_updated<T: EnvelopedMessage>(
3017    sender_id: ConnectionId,
3018    collaborators: impl IntoIterator<Item = ConnectionId>,
3019    message: &T,
3020    peer: &Peer,
3021) {
3022    broadcast(Some(sender_id), collaborators, |peer_id| {
3023        peer.send(peer_id, message.clone())
3024    });
3025}
3026
3027fn send_notifications(
3028    connection_pool: &ConnectionPool,
3029    peer: &Peer,
3030    notifications: db::NotificationBatch,
3031) {
3032    for (user_id, notification) in notifications {
3033        for connection_id in connection_pool.user_connection_ids(user_id) {
3034            if let Err(error) = peer.send(
3035                connection_id,
3036                proto::AddNotification {
3037                    notification: Some(notification.clone()),
3038                },
3039            ) {
3040                tracing::error!(
3041                    "failed to send notification to {:?} {}",
3042                    connection_id,
3043                    error
3044                );
3045            }
3046        }
3047    }
3048}
3049
3050/// Send a message to the channel
3051async fn send_channel_message(
3052    request: proto::SendChannelMessage,
3053    response: Response<proto::SendChannelMessage>,
3054    session: Session,
3055) -> Result<()> {
3056    // Validate the message body.
3057    let body = request.body.trim().to_string();
3058    if body.len() > MAX_MESSAGE_LEN {
3059        return Err(anyhow!("message is too long"))?;
3060    }
3061    if body.is_empty() {
3062        return Err(anyhow!("message can't be blank"))?;
3063    }
3064
3065    // TODO: adjust mentions if body is trimmed
3066
3067    let timestamp = OffsetDateTime::now_utc();
3068    let nonce = request
3069        .nonce
3070        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3071
3072    let channel_id = ChannelId::from_proto(request.channel_id);
3073    let CreatedChannelMessage {
3074        message_id,
3075        participant_connection_ids,
3076        channel_members,
3077        notifications,
3078    } = session
3079        .db()
3080        .await
3081        .create_channel_message(
3082            channel_id,
3083            session.user_id,
3084            &body,
3085            &request.mentions,
3086            timestamp,
3087            nonce.clone().into(),
3088            match request.reply_to_message_id {
3089                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3090                None => None,
3091            },
3092        )
3093        .await?;
3094    let message = proto::ChannelMessage {
3095        sender_id: session.user_id.to_proto(),
3096        id: message_id.to_proto(),
3097        body,
3098        mentions: request.mentions,
3099        timestamp: timestamp.unix_timestamp() as u64,
3100        nonce: Some(nonce),
3101        reply_to_message_id: request.reply_to_message_id,
3102    };
3103    broadcast(
3104        Some(session.connection_id),
3105        participant_connection_ids,
3106        |connection| {
3107            session.peer.send(
3108                connection,
3109                proto::ChannelMessageSent {
3110                    channel_id: channel_id.to_proto(),
3111                    message: Some(message.clone()),
3112                },
3113            )
3114        },
3115    );
3116    response.send(proto::SendChannelMessageResponse {
3117        message: Some(message),
3118    })?;
3119
3120    let pool = &*session.connection_pool().await;
3121    broadcast(
3122        None,
3123        channel_members
3124            .iter()
3125            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3126        |peer_id| {
3127            session.peer.send(
3128                peer_id,
3129                proto::UpdateChannels {
3130                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3131                        channel_id: channel_id.to_proto(),
3132                        message_id: message_id.to_proto(),
3133                    }],
3134                    ..Default::default()
3135                },
3136            )
3137        },
3138    );
3139    send_notifications(pool, &session.peer, notifications);
3140
3141    Ok(())
3142}
3143
3144/// Delete a channel message
3145async fn remove_channel_message(
3146    request: proto::RemoveChannelMessage,
3147    response: Response<proto::RemoveChannelMessage>,
3148    session: Session,
3149) -> Result<()> {
3150    let channel_id = ChannelId::from_proto(request.channel_id);
3151    let message_id = MessageId::from_proto(request.message_id);
3152    let connection_ids = session
3153        .db()
3154        .await
3155        .remove_channel_message(channel_id, message_id, session.user_id)
3156        .await?;
3157    broadcast(Some(session.connection_id), connection_ids, |connection| {
3158        session.peer.send(connection, request.clone())
3159    });
3160    response.send(proto::Ack {})?;
3161    Ok(())
3162}
3163
3164/// Mark a channel message as read
3165async fn acknowledge_channel_message(
3166    request: proto::AckChannelMessage,
3167    session: Session,
3168) -> Result<()> {
3169    let channel_id = ChannelId::from_proto(request.channel_id);
3170    let message_id = MessageId::from_proto(request.message_id);
3171    let notifications = session
3172        .db()
3173        .await
3174        .observe_channel_message(channel_id, session.user_id, message_id)
3175        .await?;
3176    send_notifications(
3177        &*session.connection_pool().await,
3178        &session.peer,
3179        notifications,
3180    );
3181    Ok(())
3182}
3183
3184/// Mark a buffer version as synced
3185async fn acknowledge_buffer_version(
3186    request: proto::AckBufferOperation,
3187    session: Session,
3188) -> Result<()> {
3189    let buffer_id = BufferId::from_proto(request.buffer_id);
3190    session
3191        .db()
3192        .await
3193        .observe_buffer_version(
3194            buffer_id,
3195            session.user_id,
3196            request.epoch as i32,
3197            &request.version,
3198        )
3199        .await?;
3200    Ok(())
3201}
3202
3203/// Start receiving chat updates for a channel
3204async fn join_channel_chat(
3205    request: proto::JoinChannelChat,
3206    response: Response<proto::JoinChannelChat>,
3207    session: Session,
3208) -> Result<()> {
3209    let channel_id = ChannelId::from_proto(request.channel_id);
3210
3211    let db = session.db().await;
3212    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3213        .await?;
3214    let messages = db
3215        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3216        .await?;
3217    response.send(proto::JoinChannelChatResponse {
3218        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3219        messages,
3220    })?;
3221    Ok(())
3222}
3223
3224/// Stop receiving chat updates for a channel
3225async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3226    let channel_id = ChannelId::from_proto(request.channel_id);
3227    session
3228        .db()
3229        .await
3230        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3231        .await?;
3232    Ok(())
3233}
3234
3235/// Retrieve the chat history for a channel
3236async fn get_channel_messages(
3237    request: proto::GetChannelMessages,
3238    response: Response<proto::GetChannelMessages>,
3239    session: Session,
3240) -> Result<()> {
3241    let channel_id = ChannelId::from_proto(request.channel_id);
3242    let messages = session
3243        .db()
3244        .await
3245        .get_channel_messages(
3246            channel_id,
3247            session.user_id,
3248            MESSAGE_COUNT_PER_PAGE,
3249            Some(MessageId::from_proto(request.before_message_id)),
3250        )
3251        .await?;
3252    response.send(proto::GetChannelMessagesResponse {
3253        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3254        messages,
3255    })?;
3256    Ok(())
3257}
3258
3259/// Retrieve specific chat messages
3260async fn get_channel_messages_by_id(
3261    request: proto::GetChannelMessagesById,
3262    response: Response<proto::GetChannelMessagesById>,
3263    session: Session,
3264) -> Result<()> {
3265    let message_ids = request
3266        .message_ids
3267        .iter()
3268        .map(|id| MessageId::from_proto(*id))
3269        .collect::<Vec<_>>();
3270    let messages = session
3271        .db()
3272        .await
3273        .get_channel_messages_by_id(session.user_id, &message_ids)
3274        .await?;
3275    response.send(proto::GetChannelMessagesResponse {
3276        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3277        messages,
3278    })?;
3279    Ok(())
3280}
3281
3282/// Retrieve the current users notifications
3283async fn get_notifications(
3284    request: proto::GetNotifications,
3285    response: Response<proto::GetNotifications>,
3286    session: Session,
3287) -> Result<()> {
3288    let notifications = session
3289        .db()
3290        .await
3291        .get_notifications(
3292            session.user_id,
3293            NOTIFICATION_COUNT_PER_PAGE,
3294            request
3295                .before_id
3296                .map(|id| db::NotificationId::from_proto(id)),
3297        )
3298        .await?;
3299    response.send(proto::GetNotificationsResponse {
3300        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3301        notifications,
3302    })?;
3303    Ok(())
3304}
3305
3306/// Mark notifications as read
3307async fn mark_notification_as_read(
3308    request: proto::MarkNotificationRead,
3309    response: Response<proto::MarkNotificationRead>,
3310    session: Session,
3311) -> Result<()> {
3312    let database = &session.db().await;
3313    let notifications = database
3314        .mark_notification_as_read_by_id(
3315            session.user_id,
3316            NotificationId::from_proto(request.notification_id),
3317        )
3318        .await?;
3319    send_notifications(
3320        &*session.connection_pool().await,
3321        &session.peer,
3322        notifications,
3323    );
3324    response.send(proto::Ack {})?;
3325    Ok(())
3326}
3327
3328/// Get the current users information
3329async fn get_private_user_info(
3330    _request: proto::GetPrivateUserInfo,
3331    response: Response<proto::GetPrivateUserInfo>,
3332    session: Session,
3333) -> Result<()> {
3334    let db = session.db().await;
3335
3336    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3337    let user = db
3338        .get_user_by_id(session.user_id)
3339        .await?
3340        .ok_or_else(|| anyhow!("user not found"))?;
3341    let flags = db.get_user_flags(session.user_id).await?;
3342
3343    response.send(proto::GetPrivateUserInfoResponse {
3344        metrics_id,
3345        staff: user.admin,
3346        flags,
3347    })?;
3348    Ok(())
3349}
3350
3351fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3352    match message {
3353        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3354        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3355        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3356        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3357        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3358            code: frame.code.into(),
3359            reason: frame.reason,
3360        })),
3361    }
3362}
3363
3364fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3365    match message {
3366        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3367        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3368        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3369        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3370        AxumMessage::Close(frame) => {
3371            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3372                code: frame.code.into(),
3373                reason: frame.reason,
3374            }))
3375        }
3376    }
3377}
3378
3379fn notify_membership_updated(
3380    connection_pool: &ConnectionPool,
3381    result: MembershipUpdated,
3382    user_id: UserId,
3383    peer: &Peer,
3384) {
3385    let user_channels_update = proto::UpdateUserChannels {
3386        channel_memberships: result
3387            .new_channels
3388            .channel_memberships
3389            .iter()
3390            .map(|cm| proto::ChannelMembership {
3391                channel_id: cm.channel_id.to_proto(),
3392                role: cm.role.into(),
3393            })
3394            .collect(),
3395        ..Default::default()
3396    };
3397    let mut update = build_channels_update(result.new_channels, vec![]);
3398    update.delete_channels = result
3399        .removed_channels
3400        .into_iter()
3401        .map(|id| id.to_proto())
3402        .collect();
3403    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3404
3405    for connection_id in connection_pool.user_connection_ids(user_id) {
3406        peer.send(connection_id, user_channels_update.clone())
3407            .trace_err();
3408        peer.send(connection_id, update.clone()).trace_err();
3409    }
3410}
3411
3412fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3413    proto::UpdateUserChannels {
3414        channel_memberships: channels
3415            .channel_memberships
3416            .iter()
3417            .map(|m| proto::ChannelMembership {
3418                channel_id: m.channel_id.to_proto(),
3419                role: m.role.into(),
3420            })
3421            .collect(),
3422        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3423        observed_channel_message_id: channels.observed_channel_messages.clone(),
3424    }
3425}
3426
3427fn build_channels_update(
3428    channels: ChannelsForUser,
3429    channel_invites: Vec<db::Channel>,
3430) -> proto::UpdateChannels {
3431    let mut update = proto::UpdateChannels::default();
3432
3433    for channel in channels.channels {
3434        update.channels.push(channel.to_proto());
3435    }
3436
3437    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3438    update.latest_channel_message_ids = channels.latest_channel_messages;
3439
3440    for (channel_id, participants) in channels.channel_participants {
3441        update
3442            .channel_participants
3443            .push(proto::ChannelParticipants {
3444                channel_id: channel_id.to_proto(),
3445                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3446            });
3447    }
3448
3449    for channel in channel_invites {
3450        update.channel_invitations.push(channel.to_proto());
3451    }
3452    for project in channels.hosted_projects {
3453        update.hosted_projects.push(project);
3454    }
3455
3456    update
3457}
3458
3459fn build_initial_contacts_update(
3460    contacts: Vec<db::Contact>,
3461    pool: &ConnectionPool,
3462) -> proto::UpdateContacts {
3463    let mut update = proto::UpdateContacts::default();
3464
3465    for contact in contacts {
3466        match contact {
3467            db::Contact::Accepted { user_id, busy } => {
3468                update.contacts.push(contact_for_user(user_id, busy, &pool));
3469            }
3470            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3471            db::Contact::Incoming { user_id } => {
3472                update
3473                    .incoming_requests
3474                    .push(proto::IncomingContactRequest {
3475                        requester_id: user_id.to_proto(),
3476                    })
3477            }
3478        }
3479    }
3480
3481    update
3482}
3483
3484fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3485    proto::Contact {
3486        user_id: user_id.to_proto(),
3487        online: pool.is_user_online(user_id),
3488        busy,
3489    }
3490}
3491
3492fn room_updated(room: &proto::Room, peer: &Peer) {
3493    broadcast(
3494        None,
3495        room.participants
3496            .iter()
3497            .filter_map(|participant| Some(participant.peer_id?.into())),
3498        |peer_id| {
3499            peer.send(
3500                peer_id,
3501                proto::RoomUpdated {
3502                    room: Some(room.clone()),
3503                },
3504            )
3505        },
3506    );
3507}
3508
3509fn channel_updated(
3510    channel_id: ChannelId,
3511    room: &proto::Room,
3512    channel_members: &[UserId],
3513    peer: &Peer,
3514    pool: &ConnectionPool,
3515) {
3516    let participants = room
3517        .participants
3518        .iter()
3519        .map(|p| p.user_id)
3520        .collect::<Vec<_>>();
3521
3522    broadcast(
3523        None,
3524        channel_members
3525            .iter()
3526            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3527        |peer_id| {
3528            peer.send(
3529                peer_id,
3530                proto::UpdateChannels {
3531                    channel_participants: vec![proto::ChannelParticipants {
3532                        channel_id: channel_id.to_proto(),
3533                        participant_user_ids: participants.clone(),
3534                    }],
3535                    ..Default::default()
3536                },
3537            )
3538        },
3539    );
3540}
3541
3542async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3543    let db = session.db().await;
3544
3545    let contacts = db.get_contacts(user_id).await?;
3546    let busy = db.is_user_busy(user_id).await?;
3547
3548    let pool = session.connection_pool().await;
3549    let updated_contact = contact_for_user(user_id, busy, &pool);
3550    for contact in contacts {
3551        if let db::Contact::Accepted {
3552            user_id: contact_user_id,
3553            ..
3554        } = contact
3555        {
3556            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3557                session
3558                    .peer
3559                    .send(
3560                        contact_conn_id,
3561                        proto::UpdateContacts {
3562                            contacts: vec![updated_contact.clone()],
3563                            remove_contacts: Default::default(),
3564                            incoming_requests: Default::default(),
3565                            remove_incoming_requests: Default::default(),
3566                            outgoing_requests: Default::default(),
3567                            remove_outgoing_requests: Default::default(),
3568                        },
3569                    )
3570                    .trace_err();
3571            }
3572        }
3573    }
3574    Ok(())
3575}
3576
3577async fn leave_room_for_session(session: &Session) -> Result<()> {
3578    let mut contacts_to_update = HashSet::default();
3579
3580    let room_id;
3581    let canceled_calls_to_user_ids;
3582    let live_kit_room;
3583    let delete_live_kit_room;
3584    let room;
3585    let channel_members;
3586    let channel_id;
3587
3588    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3589        contacts_to_update.insert(session.user_id);
3590
3591        for project in left_room.left_projects.values() {
3592            project_left(project, session);
3593        }
3594
3595        room_id = RoomId::from_proto(left_room.room.id);
3596        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3597        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3598        delete_live_kit_room = left_room.deleted;
3599        room = mem::take(&mut left_room.room);
3600        channel_members = mem::take(&mut left_room.channel_members);
3601        channel_id = left_room.channel_id;
3602
3603        room_updated(&room, &session.peer);
3604    } else {
3605        return Ok(());
3606    }
3607
3608    if let Some(channel_id) = channel_id {
3609        channel_updated(
3610            channel_id,
3611            &room,
3612            &channel_members,
3613            &session.peer,
3614            &*session.connection_pool().await,
3615        );
3616    }
3617
3618    {
3619        let pool = session.connection_pool().await;
3620        for canceled_user_id in canceled_calls_to_user_ids {
3621            for connection_id in pool.user_connection_ids(canceled_user_id) {
3622                session
3623                    .peer
3624                    .send(
3625                        connection_id,
3626                        proto::CallCanceled {
3627                            room_id: room_id.to_proto(),
3628                        },
3629                    )
3630                    .trace_err();
3631            }
3632            contacts_to_update.insert(canceled_user_id);
3633        }
3634    }
3635
3636    for contact_user_id in contacts_to_update {
3637        update_user_contacts(contact_user_id, &session).await?;
3638    }
3639
3640    if let Some(live_kit) = session.live_kit_client.as_ref() {
3641        live_kit
3642            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3643            .await
3644            .trace_err();
3645
3646        if delete_live_kit_room {
3647            live_kit.delete_room(live_kit_room).await.trace_err();
3648        }
3649    }
3650
3651    Ok(())
3652}
3653
3654async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3655    let left_channel_buffers = session
3656        .db()
3657        .await
3658        .leave_channel_buffers(session.connection_id)
3659        .await?;
3660
3661    for left_buffer in left_channel_buffers {
3662        channel_buffer_updated(
3663            session.connection_id,
3664            left_buffer.connections,
3665            &proto::UpdateChannelBufferCollaborators {
3666                channel_id: left_buffer.channel_id.to_proto(),
3667                collaborators: left_buffer.collaborators,
3668            },
3669            &session.peer,
3670        );
3671    }
3672
3673    Ok(())
3674}
3675
3676fn project_left(project: &db::LeftProject, session: &Session) {
3677    for connection_id in &project.connection_ids {
3678        if project.host_user_id == Some(session.user_id) {
3679            session
3680                .peer
3681                .send(
3682                    *connection_id,
3683                    proto::UnshareProject {
3684                        project_id: project.id.to_proto(),
3685                    },
3686                )
3687                .trace_err();
3688        } else {
3689            session
3690                .peer
3691                .send(
3692                    *connection_id,
3693                    proto::RemoveProjectCollaborator {
3694                        project_id: project.id.to_proto(),
3695                        peer_id: Some(session.connection_id.into()),
3696                    },
3697                )
3698                .trace_err();
3699        }
3700    }
3701}
3702
3703pub trait ResultExt {
3704    type Ok;
3705
3706    fn trace_err(self) -> Option<Self::Ok>;
3707}
3708
3709impl<T, E> ResultExt for Result<T, E>
3710where
3711    E: std::fmt::Debug,
3712{
3713    type Ok = T;
3714
3715    fn trace_err(self) -> Option<T> {
3716        match self {
3717            Ok(value) => Some(value),
3718            Err(error) => {
3719                tracing::error!("{:?}", error);
3720                None
3721            }
3722        }
3723    }
3724}