rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::{ConnectionPool, ZedVersion};
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use lazy_static::lazy_static;
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67use util::SemanticVersion;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  75
  76lazy_static! {
  77    static ref METRIC_CONNECTIONS: IntGauge =
  78        register_int_gauge!("connections", "number of connections").unwrap();
  79    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  80        "shared_projects",
  81        "number of open projects with one or more guests"
  82    )
  83    .unwrap();
  84}
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103#[derive(Clone)]
 104struct Session {
 105    user_id: UserId,
 106    connection_id: ConnectionId,
 107    db: Arc<tokio::sync::Mutex<DbHandle>>,
 108    peer: Arc<Peer>,
 109    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 110    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 111    _executor: Executor,
 112}
 113
 114impl Session {
 115    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        let guard = self.db.lock().await;
 119        #[cfg(test)]
 120        tokio::task::yield_now().await;
 121        guard
 122    }
 123
 124    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 125        #[cfg(test)]
 126        tokio::task::yield_now().await;
 127        let guard = self.connection_pool.lock();
 128        ConnectionPoolGuard {
 129            guard,
 130            _not_send: PhantomData,
 131        }
 132    }
 133}
 134
 135impl fmt::Debug for Session {
 136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 137        f.debug_struct("Session")
 138            .field("user_id", &self.user_id)
 139            .field("connection_id", &self.connection_id)
 140            .finish()
 141    }
 142}
 143
 144struct DbHandle(Arc<Database>);
 145
 146impl Deref for DbHandle {
 147    type Target = Database;
 148
 149    fn deref(&self) -> &Self::Target {
 150        self.0.as_ref()
 151    }
 152}
 153
 154pub struct Server {
 155    id: parking_lot::Mutex<ServerId>,
 156    peer: Arc<Peer>,
 157    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 158    app_state: Arc<AppState>,
 159    executor: Executor,
 160    handlers: HashMap<TypeId, MessageHandler>,
 161    teardown: watch::Sender<()>,
 162}
 163
 164pub(crate) struct ConnectionPoolGuard<'a> {
 165    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 166    _not_send: PhantomData<Rc<()>>,
 167}
 168
 169#[derive(Serialize)]
 170pub struct ServerSnapshot<'a> {
 171    peer: &'a Peer,
 172    #[serde(serialize_with = "serialize_deref")]
 173    connection_pool: ConnectionPoolGuard<'a>,
 174}
 175
 176pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 177where
 178    S: Serializer,
 179    T: Deref<Target = U>,
 180    U: Serialize,
 181{
 182    Serialize::serialize(value.deref(), serializer)
 183}
 184
 185impl Server {
 186    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 187        let mut server = Self {
 188            id: parking_lot::Mutex::new(id),
 189            peer: Peer::new(id.0 as u32),
 190            app_state,
 191            executor,
 192            connection_pool: Default::default(),
 193            handlers: Default::default(),
 194            teardown: watch::channel(()).0,
 195        };
 196
 197        server
 198            .add_request_handler(ping)
 199            .add_request_handler(create_room)
 200            .add_request_handler(join_room)
 201            .add_request_handler(rejoin_room)
 202            .add_request_handler(leave_room)
 203            .add_request_handler(set_room_participant_role)
 204            .add_request_handler(call)
 205            .add_request_handler(cancel_call)
 206            .add_message_handler(decline_call)
 207            .add_request_handler(update_participant_location)
 208            .add_request_handler(share_project)
 209            .add_message_handler(unshare_project)
 210            .add_request_handler(join_project)
 211            .add_message_handler(leave_project)
 212            .add_request_handler(update_project)
 213            .add_request_handler(update_worktree)
 214            .add_message_handler(start_language_server)
 215            .add_message_handler(update_language_server)
 216            .add_message_handler(update_diagnostic_summary)
 217            .add_message_handler(update_worktree_settings)
 218            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 219            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 222            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 225            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 226            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 227            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 228            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 229            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 230            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 231            .add_request_handler(
 232                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 233            )
 234            .add_request_handler(
 235                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 236            )
 237            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 238            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 239            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 240            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 241            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 242            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 243            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 244            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 249            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 250            .add_message_handler(create_buffer_for_peer)
 251            .add_request_handler(update_buffer)
 252            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 253            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 257            .add_request_handler(get_users)
 258            .add_request_handler(fuzzy_search_users)
 259            .add_request_handler(request_contact)
 260            .add_request_handler(remove_contact)
 261            .add_request_handler(respond_to_contact_request)
 262            .add_request_handler(create_channel)
 263            .add_request_handler(delete_channel)
 264            .add_request_handler(invite_channel_member)
 265            .add_request_handler(remove_channel_member)
 266            .add_request_handler(set_channel_member_role)
 267            .add_request_handler(set_channel_visibility)
 268            .add_request_handler(rename_channel)
 269            .add_request_handler(join_channel_buffer)
 270            .add_request_handler(leave_channel_buffer)
 271            .add_message_handler(update_channel_buffer)
 272            .add_request_handler(rejoin_channel_buffers)
 273            .add_request_handler(get_channel_members)
 274            .add_request_handler(respond_to_channel_invite)
 275            .add_request_handler(join_channel)
 276            .add_request_handler(join_channel_chat)
 277            .add_message_handler(leave_channel_chat)
 278            .add_request_handler(send_channel_message)
 279            .add_request_handler(remove_channel_message)
 280            .add_request_handler(get_channel_messages)
 281            .add_request_handler(get_channel_messages_by_id)
 282            .add_request_handler(get_notifications)
 283            .add_request_handler(mark_notification_as_read)
 284            .add_request_handler(move_channel)
 285            .add_request_handler(follow)
 286            .add_message_handler(unfollow)
 287            .add_message_handler(update_followers)
 288            .add_request_handler(get_private_user_info)
 289            .add_message_handler(acknowledge_channel_message)
 290            .add_message_handler(acknowledge_buffer_version);
 291
 292        Arc::new(server)
 293    }
 294
 295    pub async fn start(&self) -> Result<()> {
 296        let server_id = *self.id.lock();
 297        let app_state = self.app_state.clone();
 298        let peer = self.peer.clone();
 299        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 300        let pool = self.connection_pool.clone();
 301        let live_kit_client = self.app_state.live_kit_client.clone();
 302
 303        let span = info_span!("start server");
 304        self.executor.spawn_detached(
 305            async move {
 306                tracing::info!("waiting for cleanup timeout");
 307                timeout.await;
 308                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 309                if let Some((room_ids, channel_ids)) = app_state
 310                    .db
 311                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 312                    .await
 313                    .trace_err()
 314                {
 315                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 316                    tracing::info!(
 317                        stale_channel_buffer_count = channel_ids.len(),
 318                        "retrieved stale channel buffers"
 319                    );
 320
 321                    for channel_id in channel_ids {
 322                        if let Some(refreshed_channel_buffer) = app_state
 323                            .db
 324                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 325                            .await
 326                            .trace_err()
 327                        {
 328                            for connection_id in refreshed_channel_buffer.connection_ids {
 329                                peer.send(
 330                                    connection_id,
 331                                    proto::UpdateChannelBufferCollaborators {
 332                                        channel_id: channel_id.to_proto(),
 333                                        collaborators: refreshed_channel_buffer
 334                                            .collaborators
 335                                            .clone(),
 336                                    },
 337                                )
 338                                .trace_err();
 339                            }
 340                        }
 341                    }
 342
 343                    for room_id in room_ids {
 344                        let mut contacts_to_update = HashSet::default();
 345                        let mut canceled_calls_to_user_ids = Vec::new();
 346                        let mut live_kit_room = String::new();
 347                        let mut delete_live_kit_room = false;
 348
 349                        if let Some(mut refreshed_room) = app_state
 350                            .db
 351                            .clear_stale_room_participants(room_id, server_id)
 352                            .await
 353                            .trace_err()
 354                        {
 355                            tracing::info!(
 356                                room_id = room_id.0,
 357                                new_participant_count = refreshed_room.room.participants.len(),
 358                                "refreshed room"
 359                            );
 360                            room_updated(&refreshed_room.room, &peer);
 361                            if let Some(channel_id) = refreshed_room.channel_id {
 362                                channel_updated(
 363                                    channel_id,
 364                                    &refreshed_room.room,
 365                                    &refreshed_room.channel_members,
 366                                    &peer,
 367                                    &*pool.lock(),
 368                                );
 369                            }
 370                            contacts_to_update
 371                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 372                            contacts_to_update
 373                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 374                            canceled_calls_to_user_ids =
 375                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 376                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 377                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 378                        }
 379
 380                        {
 381                            let pool = pool.lock();
 382                            for canceled_user_id in canceled_calls_to_user_ids {
 383                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 384                                    peer.send(
 385                                        connection_id,
 386                                        proto::CallCanceled {
 387                                            room_id: room_id.to_proto(),
 388                                        },
 389                                    )
 390                                    .trace_err();
 391                                }
 392                            }
 393                        }
 394
 395                        for user_id in contacts_to_update {
 396                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 397                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 398                            if let Some((busy, contacts)) = busy.zip(contacts) {
 399                                let pool = pool.lock();
 400                                let updated_contact = contact_for_user(user_id, busy, &pool);
 401                                for contact in contacts {
 402                                    if let db::Contact::Accepted {
 403                                        user_id: contact_user_id,
 404                                        ..
 405                                    } = contact
 406                                    {
 407                                        for contact_conn_id in
 408                                            pool.user_connection_ids(contact_user_id)
 409                                        {
 410                                            peer.send(
 411                                                contact_conn_id,
 412                                                proto::UpdateContacts {
 413                                                    contacts: vec![updated_contact.clone()],
 414                                                    remove_contacts: Default::default(),
 415                                                    incoming_requests: Default::default(),
 416                                                    remove_incoming_requests: Default::default(),
 417                                                    outgoing_requests: Default::default(),
 418                                                    remove_outgoing_requests: Default::default(),
 419                                                },
 420                                            )
 421                                            .trace_err();
 422                                        }
 423                                    }
 424                                }
 425                            }
 426                        }
 427
 428                        if let Some(live_kit) = live_kit_client.as_ref() {
 429                            if delete_live_kit_room {
 430                                live_kit.delete_room(live_kit_room).await.trace_err();
 431                            }
 432                        }
 433                    }
 434                }
 435
 436                app_state
 437                    .db
 438                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 439                    .await
 440                    .trace_err();
 441            }
 442            .instrument(span),
 443        );
 444        Ok(())
 445    }
 446
 447    pub fn teardown(&self) {
 448        self.peer.teardown();
 449        self.connection_pool.lock().reset();
 450        let _ = self.teardown.send(());
 451    }
 452
 453    #[cfg(test)]
 454    pub fn reset(&self, id: ServerId) {
 455        self.teardown();
 456        *self.id.lock() = id;
 457        self.peer.reset(id.0 as u32);
 458    }
 459
 460    #[cfg(test)]
 461    pub fn id(&self) -> ServerId {
 462        *self.id.lock()
 463    }
 464
 465    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 466    where
 467        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 468        Fut: 'static + Send + Future<Output = Result<()>>,
 469        M: EnvelopedMessage,
 470    {
 471        let prev_handler = self.handlers.insert(
 472            TypeId::of::<M>(),
 473            Box::new(move |envelope, session| {
 474                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 475                let span = info_span!(
 476                    "handle message",
 477                    payload_type = envelope.payload_type_name()
 478                );
 479                span.in_scope(|| {
 480                    tracing::info!(
 481                        payload_type = envelope.payload_type_name(),
 482                        "message received"
 483                    );
 484                });
 485                let start_time = Instant::now();
 486                let future = (handler)(*envelope, session);
 487                async move {
 488                    let result = future.await;
 489                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 490                    match result {
 491                        Err(error) => {
 492                            tracing::error!(%error, ?duration_ms, "error handling message")
 493                        }
 494                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 495                    }
 496                }
 497                .instrument(span)
 498                .boxed()
 499            }),
 500        );
 501        if prev_handler.is_some() {
 502            panic!("registered a handler for the same message twice");
 503        }
 504        self
 505    }
 506
 507    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 510        Fut: 'static + Send + Future<Output = Result<()>>,
 511        M: EnvelopedMessage,
 512    {
 513        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 514        self
 515    }
 516
 517    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 518    where
 519        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 520        Fut: Send + Future<Output = Result<()>>,
 521        M: RequestMessage,
 522    {
 523        let handler = Arc::new(handler);
 524        self.add_handler(move |envelope, session| {
 525            let receipt = envelope.receipt();
 526            let handler = handler.clone();
 527            async move {
 528                let peer = session.peer.clone();
 529                let responded = Arc::new(AtomicBool::default());
 530                let response = Response {
 531                    peer: peer.clone(),
 532                    responded: responded.clone(),
 533                    receipt,
 534                };
 535                match (handler)(envelope.payload, response, session).await {
 536                    Ok(()) => {
 537                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 538                            Ok(())
 539                        } else {
 540                            Err(anyhow!("handler did not send a response"))?
 541                        }
 542                    }
 543                    Err(error) => {
 544                        let proto_err = match &error {
 545                            Error::Internal(err) => err.to_proto(),
 546                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 547                        };
 548                        peer.respond_with_error(receipt, proto_err)?;
 549                        Err(error)
 550                    }
 551                }
 552            }
 553        })
 554    }
 555
 556    pub fn handle_connection(
 557        self: &Arc<Self>,
 558        connection: Connection,
 559        address: String,
 560        user: User,
 561        zed_version: ZedVersion,
 562        impersonator: Option<User>,
 563        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 564        executor: Executor,
 565    ) -> impl Future<Output = Result<()>> {
 566        let this = self.clone();
 567        let user_id = user.id;
 568        let login = user.github_login;
 569        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 570        if let Some(impersonator) = impersonator {
 571            span.record("impersonator", &impersonator.github_login);
 572        }
 573        let mut teardown = self.teardown.subscribe();
 574        async move {
 575            let (connection_id, handle_io, mut incoming_rx) = this
 576                .peer
 577                .add_connection(connection, {
 578                    let executor = executor.clone();
 579                    move |duration| executor.sleep(duration)
 580                });
 581
 582            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 583            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 584            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 585
 586            if let Some(send_connection_id) = send_connection_id.take() {
 587                let _ = send_connection_id.send(connection_id);
 588            }
 589
 590            if !user.connected_once {
 591                this.peer.send(connection_id, proto::ShowContacts {})?;
 592                this.app_state.db.set_user_connected_once(user_id, true).await?;
 593            }
 594
 595            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 596                this.app_state.db.get_contacts(user_id),
 597                this.app_state.db.get_channels_for_user(user_id),
 598                this.app_state.db.get_channel_invites_for_user(user_id),
 599            ).await?;
 600
 601            {
 602                let mut pool = this.connection_pool.lock();
 603                pool.add_connection(connection_id, user_id, user.admin, zed_version);
 604                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 605                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 606                this.peer.send(connection_id, build_channels_update(
 607                    channels_for_user,
 608                    channel_invites
 609                ))?;
 610            }
 611
 612            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 613                this.peer.send(connection_id, incoming_call)?;
 614            }
 615
 616            let session = Session {
 617                user_id,
 618                connection_id,
 619                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 620                peer: this.peer.clone(),
 621                connection_pool: this.connection_pool.clone(),
 622                live_kit_client: this.app_state.live_kit_client.clone(),
 623                _executor: executor.clone()
 624            };
 625            update_user_contacts(user_id, &session).await?;
 626
 627            let handle_io = handle_io.fuse();
 628            futures::pin_mut!(handle_io);
 629
 630            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 631            // This prevents deadlocks when e.g., client A performs a request to client B and
 632            // client B performs a request to client A. If both clients stop processing further
 633            // messages until their respective request completes, they won't have a chance to
 634            // respond to the other client's request and cause a deadlock.
 635            //
 636            // This arrangement ensures we will attempt to process earlier messages first, but fall
 637            // back to processing messages arrived later in the spirit of making progress.
 638            let mut foreground_message_handlers = FuturesUnordered::new();
 639            let concurrent_handlers = Arc::new(Semaphore::new(256));
 640            loop {
 641                let next_message = async {
 642                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 643                    let message = incoming_rx.next().await;
 644                    (permit, message)
 645                }.fuse();
 646                futures::pin_mut!(next_message);
 647                futures::select_biased! {
 648                    _ = teardown.changed().fuse() => return Ok(()),
 649                    result = handle_io => {
 650                        if let Err(error) = result {
 651                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 652                        }
 653                        break;
 654                    }
 655                    _ = foreground_message_handlers.next() => {}
 656                    next_message = next_message => {
 657                        let (permit, message) = next_message;
 658                        if let Some(message) = message {
 659                            let type_name = message.payload_type_name();
 660                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 661                            let span_enter = span.enter();
 662                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 663                                let is_background = message.is_background();
 664                                let handle_message = (handler)(message, session.clone());
 665                                drop(span_enter);
 666
 667                                let handle_message = async move {
 668                                    handle_message.await;
 669                                    drop(permit);
 670                                }.instrument(span);
 671                                if is_background {
 672                                    executor.spawn_detached(handle_message);
 673                                } else {
 674                                    foreground_message_handlers.push(handle_message);
 675                                }
 676                            } else {
 677                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 678                            }
 679                        } else {
 680                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 681                            break;
 682                        }
 683                    }
 684                }
 685            }
 686
 687            drop(foreground_message_handlers);
 688            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 689            if let Err(error) = connection_lost(session, teardown, executor).await {
 690                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 691            }
 692
 693            Ok(())
 694        }.instrument(span)
 695    }
 696
 697    pub async fn invite_code_redeemed(
 698        self: &Arc<Self>,
 699        inviter_id: UserId,
 700        invitee_id: UserId,
 701    ) -> Result<()> {
 702        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 703            if let Some(code) = &user.invite_code {
 704                let pool = self.connection_pool.lock();
 705                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 706                for connection_id in pool.user_connection_ids(inviter_id) {
 707                    self.peer.send(
 708                        connection_id,
 709                        proto::UpdateContacts {
 710                            contacts: vec![invitee_contact.clone()],
 711                            ..Default::default()
 712                        },
 713                    )?;
 714                    self.peer.send(
 715                        connection_id,
 716                        proto::UpdateInviteInfo {
 717                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 718                            count: user.invite_count as u32,
 719                        },
 720                    )?;
 721                }
 722            }
 723        }
 724        Ok(())
 725    }
 726
 727    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 728        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 729            if let Some(invite_code) = &user.invite_code {
 730                let pool = self.connection_pool.lock();
 731                for connection_id in pool.user_connection_ids(user_id) {
 732                    self.peer.send(
 733                        connection_id,
 734                        proto::UpdateInviteInfo {
 735                            url: format!(
 736                                "{}{}",
 737                                self.app_state.config.invite_link_prefix, invite_code
 738                            ),
 739                            count: user.invite_count as u32,
 740                        },
 741                    )?;
 742                }
 743            }
 744        }
 745        Ok(())
 746    }
 747
 748    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 749        ServerSnapshot {
 750            connection_pool: ConnectionPoolGuard {
 751                guard: self.connection_pool.lock(),
 752                _not_send: PhantomData,
 753            },
 754            peer: &self.peer,
 755        }
 756    }
 757}
 758
 759impl<'a> Deref for ConnectionPoolGuard<'a> {
 760    type Target = ConnectionPool;
 761
 762    fn deref(&self) -> &Self::Target {
 763        &*self.guard
 764    }
 765}
 766
 767impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 768    fn deref_mut(&mut self) -> &mut Self::Target {
 769        &mut *self.guard
 770    }
 771}
 772
 773impl<'a> Drop for ConnectionPoolGuard<'a> {
 774    fn drop(&mut self) {
 775        #[cfg(test)]
 776        self.check_invariants();
 777    }
 778}
 779
 780fn broadcast<F>(
 781    sender_id: Option<ConnectionId>,
 782    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 783    mut f: F,
 784) where
 785    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 786{
 787    for receiver_id in receiver_ids {
 788        if Some(receiver_id) != sender_id {
 789            if let Err(error) = f(receiver_id) {
 790                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 791            }
 792        }
 793    }
 794}
 795
 796lazy_static! {
 797    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 798    static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
 799}
 800
 801pub struct ProtocolVersion(u32);
 802
 803impl Header for ProtocolVersion {
 804    fn name() -> &'static HeaderName {
 805        &ZED_PROTOCOL_VERSION
 806    }
 807
 808    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 809    where
 810        Self: Sized,
 811        I: Iterator<Item = &'i axum::http::HeaderValue>,
 812    {
 813        let version = values
 814            .next()
 815            .ok_or_else(axum::headers::Error::invalid)?
 816            .to_str()
 817            .map_err(|_| axum::headers::Error::invalid())?
 818            .parse()
 819            .map_err(|_| axum::headers::Error::invalid())?;
 820        Ok(Self(version))
 821    }
 822
 823    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 824        values.extend([self.0.to_string().parse().unwrap()]);
 825    }
 826}
 827
 828pub struct AppVersionHeader(SemanticVersion);
 829impl Header for AppVersionHeader {
 830    fn name() -> &'static HeaderName {
 831        &ZED_APP_VERSION
 832    }
 833
 834    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 835    where
 836        Self: Sized,
 837        I: Iterator<Item = &'i axum::http::HeaderValue>,
 838    {
 839        let version = values
 840            .next()
 841            .ok_or_else(axum::headers::Error::invalid)?
 842            .to_str()
 843            .map_err(|_| axum::headers::Error::invalid())?
 844            .parse()
 845            .map_err(|_| axum::headers::Error::invalid())?;
 846        Ok(Self(version))
 847    }
 848
 849    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 850        values.extend([self.0.to_string().parse().unwrap()]);
 851    }
 852}
 853
 854pub fn routes(server: Arc<Server>) -> Router<Body> {
 855    Router::new()
 856        .route("/rpc", get(handle_websocket_request))
 857        .layer(
 858            ServiceBuilder::new()
 859                .layer(Extension(server.app_state.clone()))
 860                .layer(middleware::from_fn(auth::validate_header)),
 861        )
 862        .route("/metrics", get(handle_metrics))
 863        .layer(Extension(server))
 864}
 865
 866pub async fn handle_websocket_request(
 867    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 868    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 869    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 870    Extension(server): Extension<Arc<Server>>,
 871    Extension(user): Extension<User>,
 872    Extension(impersonator): Extension<Impersonator>,
 873    ws: WebSocketUpgrade,
 874) -> axum::response::Response {
 875    if protocol_version != rpc::PROTOCOL_VERSION {
 876        return (
 877            StatusCode::UPGRADE_REQUIRED,
 878            "client must be upgraded".to_string(),
 879        )
 880            .into_response();
 881    }
 882
 883    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 884        return (
 885            StatusCode::UPGRADE_REQUIRED,
 886            "no version header found".to_string(),
 887        )
 888            .into_response();
 889    };
 890
 891    if !version.is_supported() {
 892        return (
 893            StatusCode::UPGRADE_REQUIRED,
 894            "client must be upgraded".to_string(),
 895        )
 896            .into_response();
 897    }
 898
 899    let socket_address = socket_address.to_string();
 900    ws.on_upgrade(move |socket| {
 901        use util::ResultExt;
 902        let socket = socket
 903            .map_ok(to_tungstenite_message)
 904            .err_into()
 905            .with(|message| async move { Ok(to_axum_message(message)) });
 906        let connection = Connection::new(Box::pin(socket));
 907        async move {
 908            server
 909                .handle_connection(
 910                    connection,
 911                    socket_address,
 912                    user,
 913                    version,
 914                    impersonator.0,
 915                    None,
 916                    Executor::Production,
 917                )
 918                .await
 919                .log_err();
 920        }
 921    })
 922}
 923
 924pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 925    let connections = server
 926        .connection_pool
 927        .lock()
 928        .connections()
 929        .filter(|connection| !connection.admin)
 930        .count();
 931
 932    METRIC_CONNECTIONS.set(connections as _);
 933
 934    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 935    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 936
 937    let encoder = prometheus::TextEncoder::new();
 938    let metric_families = prometheus::gather();
 939    let encoded_metrics = encoder
 940        .encode_to_string(&metric_families)
 941        .map_err(|err| anyhow!("{}", err))?;
 942    Ok(encoded_metrics)
 943}
 944
 945#[instrument(err, skip(executor))]
 946async fn connection_lost(
 947    session: Session,
 948    mut teardown: watch::Receiver<()>,
 949    executor: Executor,
 950) -> Result<()> {
 951    session.peer.disconnect(session.connection_id);
 952    session
 953        .connection_pool()
 954        .await
 955        .remove_connection(session.connection_id)?;
 956
 957    session
 958        .db()
 959        .await
 960        .connection_lost(session.connection_id)
 961        .await
 962        .trace_err();
 963
 964    futures::select_biased! {
 965        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 966            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 967            leave_room_for_session(&session).await.trace_err();
 968            leave_channel_buffers_for_session(&session)
 969                .await
 970                .trace_err();
 971
 972            if !session
 973                .connection_pool()
 974                .await
 975                .is_user_online(session.user_id)
 976            {
 977                let db = session.db().await;
 978                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 979                    room_updated(&room, &session.peer);
 980                }
 981            }
 982
 983            update_user_contacts(session.user_id, &session).await?;
 984        }
 985        _ = teardown.changed().fuse() => {}
 986    }
 987
 988    Ok(())
 989}
 990
 991/// Acknowledges a ping from a client, used to keep the connection alive.
 992async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 993    response.send(proto::Ack {})?;
 994    Ok(())
 995}
 996
 997/// Creates a new room for calling (outside of channels)
 998async fn create_room(
 999    _request: proto::CreateRoom,
1000    response: Response<proto::CreateRoom>,
1001    session: Session,
1002) -> Result<()> {
1003    let live_kit_room = nanoid::nanoid!(30);
1004
1005    let live_kit_connection_info = {
1006        let live_kit_room = live_kit_room.clone();
1007        let live_kit = session.live_kit_client.as_ref();
1008
1009        util::async_maybe!({
1010            let live_kit = live_kit?;
1011
1012            let token = live_kit
1013                .room_token(&live_kit_room, &session.user_id.to_string())
1014                .trace_err()?;
1015
1016            Some(proto::LiveKitConnectionInfo {
1017                server_url: live_kit.url().into(),
1018                token,
1019                can_publish: true,
1020            })
1021        })
1022    }
1023    .await;
1024
1025    let room = session
1026        .db()
1027        .await
1028        .create_room(session.user_id, session.connection_id, &live_kit_room)
1029        .await?;
1030
1031    response.send(proto::CreateRoomResponse {
1032        room: Some(room.clone()),
1033        live_kit_connection_info,
1034    })?;
1035
1036    update_user_contacts(session.user_id, &session).await?;
1037    Ok(())
1038}
1039
1040/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1041async fn join_room(
1042    request: proto::JoinRoom,
1043    response: Response<proto::JoinRoom>,
1044    session: Session,
1045) -> Result<()> {
1046    let room_id = RoomId::from_proto(request.id);
1047
1048    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1049
1050    if let Some(channel_id) = channel_id {
1051        return join_channel_internal(channel_id, Box::new(response), session).await;
1052    }
1053
1054    let joined_room = {
1055        let room = session
1056            .db()
1057            .await
1058            .join_room(room_id, session.user_id, session.connection_id)
1059            .await?;
1060        room_updated(&room.room, &session.peer);
1061        room.into_inner()
1062    };
1063
1064    for connection_id in session
1065        .connection_pool()
1066        .await
1067        .user_connection_ids(session.user_id)
1068    {
1069        session
1070            .peer
1071            .send(
1072                connection_id,
1073                proto::CallCanceled {
1074                    room_id: room_id.to_proto(),
1075                },
1076            )
1077            .trace_err();
1078    }
1079
1080    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1081        if let Some(token) = live_kit
1082            .room_token(
1083                &joined_room.room.live_kit_room,
1084                &session.user_id.to_string(),
1085            )
1086            .trace_err()
1087        {
1088            Some(proto::LiveKitConnectionInfo {
1089                server_url: live_kit.url().into(),
1090                token,
1091                can_publish: true,
1092            })
1093        } else {
1094            None
1095        }
1096    } else {
1097        None
1098    };
1099
1100    response.send(proto::JoinRoomResponse {
1101        room: Some(joined_room.room),
1102        channel_id: None,
1103        live_kit_connection_info,
1104    })?;
1105
1106    update_user_contacts(session.user_id, &session).await?;
1107    Ok(())
1108}
1109
1110/// Rejoin room is used to reconnect to a room after connection errors.
1111async fn rejoin_room(
1112    request: proto::RejoinRoom,
1113    response: Response<proto::RejoinRoom>,
1114    session: Session,
1115) -> Result<()> {
1116    let room;
1117    let channel_id;
1118    let channel_members;
1119    {
1120        let mut rejoined_room = session
1121            .db()
1122            .await
1123            .rejoin_room(request, session.user_id, session.connection_id)
1124            .await?;
1125
1126        response.send(proto::RejoinRoomResponse {
1127            room: Some(rejoined_room.room.clone()),
1128            reshared_projects: rejoined_room
1129                .reshared_projects
1130                .iter()
1131                .map(|project| proto::ResharedProject {
1132                    id: project.id.to_proto(),
1133                    collaborators: project
1134                        .collaborators
1135                        .iter()
1136                        .map(|collaborator| collaborator.to_proto())
1137                        .collect(),
1138                })
1139                .collect(),
1140            rejoined_projects: rejoined_room
1141                .rejoined_projects
1142                .iter()
1143                .map(|rejoined_project| proto::RejoinedProject {
1144                    id: rejoined_project.id.to_proto(),
1145                    worktrees: rejoined_project
1146                        .worktrees
1147                        .iter()
1148                        .map(|worktree| proto::WorktreeMetadata {
1149                            id: worktree.id,
1150                            root_name: worktree.root_name.clone(),
1151                            visible: worktree.visible,
1152                            abs_path: worktree.abs_path.clone(),
1153                        })
1154                        .collect(),
1155                    collaborators: rejoined_project
1156                        .collaborators
1157                        .iter()
1158                        .map(|collaborator| collaborator.to_proto())
1159                        .collect(),
1160                    language_servers: rejoined_project.language_servers.clone(),
1161                })
1162                .collect(),
1163        })?;
1164        room_updated(&rejoined_room.room, &session.peer);
1165
1166        for project in &rejoined_room.reshared_projects {
1167            for collaborator in &project.collaborators {
1168                session
1169                    .peer
1170                    .send(
1171                        collaborator.connection_id,
1172                        proto::UpdateProjectCollaborator {
1173                            project_id: project.id.to_proto(),
1174                            old_peer_id: Some(project.old_connection_id.into()),
1175                            new_peer_id: Some(session.connection_id.into()),
1176                        },
1177                    )
1178                    .trace_err();
1179            }
1180
1181            broadcast(
1182                Some(session.connection_id),
1183                project
1184                    .collaborators
1185                    .iter()
1186                    .map(|collaborator| collaborator.connection_id),
1187                |connection_id| {
1188                    session.peer.forward_send(
1189                        session.connection_id,
1190                        connection_id,
1191                        proto::UpdateProject {
1192                            project_id: project.id.to_proto(),
1193                            worktrees: project.worktrees.clone(),
1194                        },
1195                    )
1196                },
1197            );
1198        }
1199
1200        for project in &rejoined_room.rejoined_projects {
1201            for collaborator in &project.collaborators {
1202                session
1203                    .peer
1204                    .send(
1205                        collaborator.connection_id,
1206                        proto::UpdateProjectCollaborator {
1207                            project_id: project.id.to_proto(),
1208                            old_peer_id: Some(project.old_connection_id.into()),
1209                            new_peer_id: Some(session.connection_id.into()),
1210                        },
1211                    )
1212                    .trace_err();
1213            }
1214        }
1215
1216        for project in &mut rejoined_room.rejoined_projects {
1217            for worktree in mem::take(&mut project.worktrees) {
1218                #[cfg(any(test, feature = "test-support"))]
1219                const MAX_CHUNK_SIZE: usize = 2;
1220                #[cfg(not(any(test, feature = "test-support")))]
1221                const MAX_CHUNK_SIZE: usize = 256;
1222
1223                // Stream this worktree's entries.
1224                let message = proto::UpdateWorktree {
1225                    project_id: project.id.to_proto(),
1226                    worktree_id: worktree.id,
1227                    abs_path: worktree.abs_path.clone(),
1228                    root_name: worktree.root_name,
1229                    updated_entries: worktree.updated_entries,
1230                    removed_entries: worktree.removed_entries,
1231                    scan_id: worktree.scan_id,
1232                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1233                    updated_repositories: worktree.updated_repositories,
1234                    removed_repositories: worktree.removed_repositories,
1235                };
1236                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1237                    session.peer.send(session.connection_id, update.clone())?;
1238                }
1239
1240                // Stream this worktree's diagnostics.
1241                for summary in worktree.diagnostic_summaries {
1242                    session.peer.send(
1243                        session.connection_id,
1244                        proto::UpdateDiagnosticSummary {
1245                            project_id: project.id.to_proto(),
1246                            worktree_id: worktree.id,
1247                            summary: Some(summary),
1248                        },
1249                    )?;
1250                }
1251
1252                for settings_file in worktree.settings_files {
1253                    session.peer.send(
1254                        session.connection_id,
1255                        proto::UpdateWorktreeSettings {
1256                            project_id: project.id.to_proto(),
1257                            worktree_id: worktree.id,
1258                            path: settings_file.path,
1259                            content: Some(settings_file.content),
1260                        },
1261                    )?;
1262                }
1263            }
1264
1265            for language_server in &project.language_servers {
1266                session.peer.send(
1267                    session.connection_id,
1268                    proto::UpdateLanguageServer {
1269                        project_id: project.id.to_proto(),
1270                        language_server_id: language_server.id,
1271                        variant: Some(
1272                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1273                                proto::LspDiskBasedDiagnosticsUpdated {},
1274                            ),
1275                        ),
1276                    },
1277                )?;
1278            }
1279        }
1280
1281        let rejoined_room = rejoined_room.into_inner();
1282
1283        room = rejoined_room.room;
1284        channel_id = rejoined_room.channel_id;
1285        channel_members = rejoined_room.channel_members;
1286    }
1287
1288    if let Some(channel_id) = channel_id {
1289        channel_updated(
1290            channel_id,
1291            &room,
1292            &channel_members,
1293            &session.peer,
1294            &*session.connection_pool().await,
1295        );
1296    }
1297
1298    update_user_contacts(session.user_id, &session).await?;
1299    Ok(())
1300}
1301
1302/// leave room disconnects from the room.
1303async fn leave_room(
1304    _: proto::LeaveRoom,
1305    response: Response<proto::LeaveRoom>,
1306    session: Session,
1307) -> Result<()> {
1308    leave_room_for_session(&session).await?;
1309    response.send(proto::Ack {})?;
1310    Ok(())
1311}
1312
1313/// Updates the permissions of someone else in the room.
1314async fn set_room_participant_role(
1315    request: proto::SetRoomParticipantRole,
1316    response: Response<proto::SetRoomParticipantRole>,
1317    session: Session,
1318) -> Result<()> {
1319    let user_id = UserId::from_proto(request.user_id);
1320    let role = ChannelRole::from(request.role());
1321
1322    if role == ChannelRole::Talker {
1323        let pool = session.connection_pool().await;
1324
1325        for connection in pool.user_connections(user_id) {
1326            if !connection.zed_version.supports_talker_role() {
1327                Err(anyhow!(
1328                    "This user is on zed {} which does not support unmute",
1329                    connection.zed_version
1330                ))?;
1331            }
1332        }
1333    }
1334
1335    let (live_kit_room, can_publish) = {
1336        let room = session
1337            .db()
1338            .await
1339            .set_room_participant_role(
1340                session.user_id,
1341                RoomId::from_proto(request.room_id),
1342                user_id,
1343                role,
1344            )
1345            .await?;
1346
1347        let live_kit_room = room.live_kit_room.clone();
1348        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1349        room_updated(&room, &session.peer);
1350        (live_kit_room, can_publish)
1351    };
1352
1353    if let Some(live_kit) = session.live_kit_client.as_ref() {
1354        live_kit
1355            .update_participant(
1356                live_kit_room.clone(),
1357                request.user_id.to_string(),
1358                live_kit_server::proto::ParticipantPermission {
1359                    can_subscribe: true,
1360                    can_publish,
1361                    can_publish_data: can_publish,
1362                    hidden: false,
1363                    recorder: false,
1364                },
1365            )
1366            .await
1367            .trace_err();
1368    }
1369
1370    response.send(proto::Ack {})?;
1371    Ok(())
1372}
1373
1374/// Call someone else into the current room
1375async fn call(
1376    request: proto::Call,
1377    response: Response<proto::Call>,
1378    session: Session,
1379) -> Result<()> {
1380    let room_id = RoomId::from_proto(request.room_id);
1381    let calling_user_id = session.user_id;
1382    let calling_connection_id = session.connection_id;
1383    let called_user_id = UserId::from_proto(request.called_user_id);
1384    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1385    if !session
1386        .db()
1387        .await
1388        .has_contact(calling_user_id, called_user_id)
1389        .await?
1390    {
1391        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1392    }
1393
1394    let incoming_call = {
1395        let (room, incoming_call) = &mut *session
1396            .db()
1397            .await
1398            .call(
1399                room_id,
1400                calling_user_id,
1401                calling_connection_id,
1402                called_user_id,
1403                initial_project_id,
1404            )
1405            .await?;
1406        room_updated(&room, &session.peer);
1407        mem::take(incoming_call)
1408    };
1409    update_user_contacts(called_user_id, &session).await?;
1410
1411    let mut calls = session
1412        .connection_pool()
1413        .await
1414        .user_connection_ids(called_user_id)
1415        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1416        .collect::<FuturesUnordered<_>>();
1417
1418    while let Some(call_response) = calls.next().await {
1419        match call_response.as_ref() {
1420            Ok(_) => {
1421                response.send(proto::Ack {})?;
1422                return Ok(());
1423            }
1424            Err(_) => {
1425                call_response.trace_err();
1426            }
1427        }
1428    }
1429
1430    {
1431        let room = session
1432            .db()
1433            .await
1434            .call_failed(room_id, called_user_id)
1435            .await?;
1436        room_updated(&room, &session.peer);
1437    }
1438    update_user_contacts(called_user_id, &session).await?;
1439
1440    Err(anyhow!("failed to ring user"))?
1441}
1442
1443/// Cancel an outgoing call.
1444async fn cancel_call(
1445    request: proto::CancelCall,
1446    response: Response<proto::CancelCall>,
1447    session: Session,
1448) -> Result<()> {
1449    let called_user_id = UserId::from_proto(request.called_user_id);
1450    let room_id = RoomId::from_proto(request.room_id);
1451    {
1452        let room = session
1453            .db()
1454            .await
1455            .cancel_call(room_id, session.connection_id, called_user_id)
1456            .await?;
1457        room_updated(&room, &session.peer);
1458    }
1459
1460    for connection_id in session
1461        .connection_pool()
1462        .await
1463        .user_connection_ids(called_user_id)
1464    {
1465        session
1466            .peer
1467            .send(
1468                connection_id,
1469                proto::CallCanceled {
1470                    room_id: room_id.to_proto(),
1471                },
1472            )
1473            .trace_err();
1474    }
1475    response.send(proto::Ack {})?;
1476
1477    update_user_contacts(called_user_id, &session).await?;
1478    Ok(())
1479}
1480
1481/// Decline an incoming call.
1482async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1483    let room_id = RoomId::from_proto(message.room_id);
1484    {
1485        let room = session
1486            .db()
1487            .await
1488            .decline_call(Some(room_id), session.user_id)
1489            .await?
1490            .ok_or_else(|| anyhow!("failed to decline call"))?;
1491        room_updated(&room, &session.peer);
1492    }
1493
1494    for connection_id in session
1495        .connection_pool()
1496        .await
1497        .user_connection_ids(session.user_id)
1498    {
1499        session
1500            .peer
1501            .send(
1502                connection_id,
1503                proto::CallCanceled {
1504                    room_id: room_id.to_proto(),
1505                },
1506            )
1507            .trace_err();
1508    }
1509    update_user_contacts(session.user_id, &session).await?;
1510    Ok(())
1511}
1512
1513/// Updates other participants in the room with your current location.
1514async fn update_participant_location(
1515    request: proto::UpdateParticipantLocation,
1516    response: Response<proto::UpdateParticipantLocation>,
1517    session: Session,
1518) -> Result<()> {
1519    let room_id = RoomId::from_proto(request.room_id);
1520    let location = request
1521        .location
1522        .ok_or_else(|| anyhow!("invalid location"))?;
1523
1524    let db = session.db().await;
1525    let room = db
1526        .update_room_participant_location(room_id, session.connection_id, location)
1527        .await?;
1528
1529    room_updated(&room, &session.peer);
1530    response.send(proto::Ack {})?;
1531    Ok(())
1532}
1533
1534/// Share a project into the room.
1535async fn share_project(
1536    request: proto::ShareProject,
1537    response: Response<proto::ShareProject>,
1538    session: Session,
1539) -> Result<()> {
1540    let (project_id, room) = &*session
1541        .db()
1542        .await
1543        .share_project(
1544            RoomId::from_proto(request.room_id),
1545            session.connection_id,
1546            &request.worktrees,
1547        )
1548        .await?;
1549    response.send(proto::ShareProjectResponse {
1550        project_id: project_id.to_proto(),
1551    })?;
1552    room_updated(&room, &session.peer);
1553
1554    Ok(())
1555}
1556
1557/// Unshare a project from the room.
1558async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1559    let project_id = ProjectId::from_proto(message.project_id);
1560
1561    let (room, guest_connection_ids) = &*session
1562        .db()
1563        .await
1564        .unshare_project(project_id, session.connection_id)
1565        .await?;
1566
1567    broadcast(
1568        Some(session.connection_id),
1569        guest_connection_ids.iter().copied(),
1570        |conn_id| session.peer.send(conn_id, message.clone()),
1571    );
1572    room_updated(&room, &session.peer);
1573
1574    Ok(())
1575}
1576
1577/// Join someone elses shared project.
1578async fn join_project(
1579    request: proto::JoinProject,
1580    response: Response<proto::JoinProject>,
1581    session: Session,
1582) -> Result<()> {
1583    let project_id = ProjectId::from_proto(request.project_id);
1584    let guest_user_id = session.user_id;
1585
1586    tracing::info!(%project_id, "join project");
1587
1588    let (project, replica_id) = &mut *session
1589        .db()
1590        .await
1591        .join_project(project_id, session.connection_id)
1592        .await?;
1593
1594    let collaborators = project
1595        .collaborators
1596        .iter()
1597        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1598        .map(|collaborator| collaborator.to_proto())
1599        .collect::<Vec<_>>();
1600
1601    let worktrees = project
1602        .worktrees
1603        .iter()
1604        .map(|(id, worktree)| proto::WorktreeMetadata {
1605            id: *id,
1606            root_name: worktree.root_name.clone(),
1607            visible: worktree.visible,
1608            abs_path: worktree.abs_path.clone(),
1609        })
1610        .collect::<Vec<_>>();
1611
1612    for collaborator in &collaborators {
1613        session
1614            .peer
1615            .send(
1616                collaborator.peer_id.unwrap().into(),
1617                proto::AddProjectCollaborator {
1618                    project_id: project_id.to_proto(),
1619                    collaborator: Some(proto::Collaborator {
1620                        peer_id: Some(session.connection_id.into()),
1621                        replica_id: replica_id.0 as u32,
1622                        user_id: guest_user_id.to_proto(),
1623                    }),
1624                },
1625            )
1626            .trace_err();
1627    }
1628
1629    // First, we send the metadata associated with each worktree.
1630    response.send(proto::JoinProjectResponse {
1631        worktrees: worktrees.clone(),
1632        replica_id: replica_id.0 as u32,
1633        collaborators: collaborators.clone(),
1634        language_servers: project.language_servers.clone(),
1635    })?;
1636
1637    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1638        #[cfg(any(test, feature = "test-support"))]
1639        const MAX_CHUNK_SIZE: usize = 2;
1640        #[cfg(not(any(test, feature = "test-support")))]
1641        const MAX_CHUNK_SIZE: usize = 256;
1642
1643        // Stream this worktree's entries.
1644        let message = proto::UpdateWorktree {
1645            project_id: project_id.to_proto(),
1646            worktree_id,
1647            abs_path: worktree.abs_path.clone(),
1648            root_name: worktree.root_name,
1649            updated_entries: worktree.entries,
1650            removed_entries: Default::default(),
1651            scan_id: worktree.scan_id,
1652            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1653            updated_repositories: worktree.repository_entries.into_values().collect(),
1654            removed_repositories: Default::default(),
1655        };
1656        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1657            session.peer.send(session.connection_id, update.clone())?;
1658        }
1659
1660        // Stream this worktree's diagnostics.
1661        for summary in worktree.diagnostic_summaries {
1662            session.peer.send(
1663                session.connection_id,
1664                proto::UpdateDiagnosticSummary {
1665                    project_id: project_id.to_proto(),
1666                    worktree_id: worktree.id,
1667                    summary: Some(summary),
1668                },
1669            )?;
1670        }
1671
1672        for settings_file in worktree.settings_files {
1673            session.peer.send(
1674                session.connection_id,
1675                proto::UpdateWorktreeSettings {
1676                    project_id: project_id.to_proto(),
1677                    worktree_id: worktree.id,
1678                    path: settings_file.path,
1679                    content: Some(settings_file.content),
1680                },
1681            )?;
1682        }
1683    }
1684
1685    for language_server in &project.language_servers {
1686        session.peer.send(
1687            session.connection_id,
1688            proto::UpdateLanguageServer {
1689                project_id: project_id.to_proto(),
1690                language_server_id: language_server.id,
1691                variant: Some(
1692                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1693                        proto::LspDiskBasedDiagnosticsUpdated {},
1694                    ),
1695                ),
1696            },
1697        )?;
1698    }
1699
1700    Ok(())
1701}
1702
1703/// Leave someone elses shared project.
1704async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1705    let sender_id = session.connection_id;
1706    let project_id = ProjectId::from_proto(request.project_id);
1707
1708    let (room, project) = &*session
1709        .db()
1710        .await
1711        .leave_project(project_id, sender_id)
1712        .await?;
1713    tracing::info!(
1714        %project_id,
1715        host_user_id = %project.host_user_id,
1716        host_connection_id = ?project.host_connection_id,
1717        "leave project"
1718    );
1719
1720    project_left(&project, &session);
1721    room_updated(&room, &session.peer);
1722
1723    Ok(())
1724}
1725
1726/// Updates other participants with changes to the project
1727async fn update_project(
1728    request: proto::UpdateProject,
1729    response: Response<proto::UpdateProject>,
1730    session: Session,
1731) -> Result<()> {
1732    let project_id = ProjectId::from_proto(request.project_id);
1733    let (room, guest_connection_ids) = &*session
1734        .db()
1735        .await
1736        .update_project(project_id, session.connection_id, &request.worktrees)
1737        .await?;
1738    broadcast(
1739        Some(session.connection_id),
1740        guest_connection_ids.iter().copied(),
1741        |connection_id| {
1742            session
1743                .peer
1744                .forward_send(session.connection_id, connection_id, request.clone())
1745        },
1746    );
1747    room_updated(&room, &session.peer);
1748    response.send(proto::Ack {})?;
1749
1750    Ok(())
1751}
1752
1753/// Updates other participants with changes to the worktree
1754async fn update_worktree(
1755    request: proto::UpdateWorktree,
1756    response: Response<proto::UpdateWorktree>,
1757    session: Session,
1758) -> Result<()> {
1759    let guest_connection_ids = session
1760        .db()
1761        .await
1762        .update_worktree(&request, session.connection_id)
1763        .await?;
1764
1765    broadcast(
1766        Some(session.connection_id),
1767        guest_connection_ids.iter().copied(),
1768        |connection_id| {
1769            session
1770                .peer
1771                .forward_send(session.connection_id, connection_id, request.clone())
1772        },
1773    );
1774    response.send(proto::Ack {})?;
1775    Ok(())
1776}
1777
1778/// Updates other participants with changes to the diagnostics
1779async fn update_diagnostic_summary(
1780    message: proto::UpdateDiagnosticSummary,
1781    session: Session,
1782) -> Result<()> {
1783    let guest_connection_ids = session
1784        .db()
1785        .await
1786        .update_diagnostic_summary(&message, session.connection_id)
1787        .await?;
1788
1789    broadcast(
1790        Some(session.connection_id),
1791        guest_connection_ids.iter().copied(),
1792        |connection_id| {
1793            session
1794                .peer
1795                .forward_send(session.connection_id, connection_id, message.clone())
1796        },
1797    );
1798
1799    Ok(())
1800}
1801
1802/// Updates other participants with changes to the worktree settings
1803async fn update_worktree_settings(
1804    message: proto::UpdateWorktreeSettings,
1805    session: Session,
1806) -> Result<()> {
1807    let guest_connection_ids = session
1808        .db()
1809        .await
1810        .update_worktree_settings(&message, session.connection_id)
1811        .await?;
1812
1813    broadcast(
1814        Some(session.connection_id),
1815        guest_connection_ids.iter().copied(),
1816        |connection_id| {
1817            session
1818                .peer
1819                .forward_send(session.connection_id, connection_id, message.clone())
1820        },
1821    );
1822
1823    Ok(())
1824}
1825
1826/// Notify other participants that a  language server has started.
1827async fn start_language_server(
1828    request: proto::StartLanguageServer,
1829    session: Session,
1830) -> Result<()> {
1831    let guest_connection_ids = session
1832        .db()
1833        .await
1834        .start_language_server(&request, session.connection_id)
1835        .await?;
1836
1837    broadcast(
1838        Some(session.connection_id),
1839        guest_connection_ids.iter().copied(),
1840        |connection_id| {
1841            session
1842                .peer
1843                .forward_send(session.connection_id, connection_id, request.clone())
1844        },
1845    );
1846    Ok(())
1847}
1848
1849/// Notify other participants that a language server has changed.
1850async fn update_language_server(
1851    request: proto::UpdateLanguageServer,
1852    session: Session,
1853) -> Result<()> {
1854    let project_id = ProjectId::from_proto(request.project_id);
1855    let project_connection_ids = session
1856        .db()
1857        .await
1858        .project_connection_ids(project_id, session.connection_id)
1859        .await?;
1860    broadcast(
1861        Some(session.connection_id),
1862        project_connection_ids.iter().copied(),
1863        |connection_id| {
1864            session
1865                .peer
1866                .forward_send(session.connection_id, connection_id, request.clone())
1867        },
1868    );
1869    Ok(())
1870}
1871
1872/// forward a project request to the host. These requests should be read only
1873/// as guests are allowed to send them.
1874async fn forward_read_only_project_request<T>(
1875    request: T,
1876    response: Response<T>,
1877    session: Session,
1878) -> Result<()>
1879where
1880    T: EntityMessage + RequestMessage,
1881{
1882    let project_id = ProjectId::from_proto(request.remote_entity_id());
1883    let host_connection_id = session
1884        .db()
1885        .await
1886        .host_for_read_only_project_request(project_id, session.connection_id)
1887        .await?;
1888    let payload = session
1889        .peer
1890        .forward_request(session.connection_id, host_connection_id, request)
1891        .await?;
1892    response.send(payload)?;
1893    Ok(())
1894}
1895
1896/// forward a project request to the host. These requests are disallowed
1897/// for guests.
1898async fn forward_mutating_project_request<T>(
1899    request: T,
1900    response: Response<T>,
1901    session: Session,
1902) -> Result<()>
1903where
1904    T: EntityMessage + RequestMessage,
1905{
1906    let project_id = ProjectId::from_proto(request.remote_entity_id());
1907    let host_connection_id = session
1908        .db()
1909        .await
1910        .host_for_mutating_project_request(project_id, session.connection_id)
1911        .await?;
1912    let payload = session
1913        .peer
1914        .forward_request(session.connection_id, host_connection_id, request)
1915        .await?;
1916    response.send(payload)?;
1917    Ok(())
1918}
1919
1920/// Notify other participants that a new buffer has been created
1921async fn create_buffer_for_peer(
1922    request: proto::CreateBufferForPeer,
1923    session: Session,
1924) -> Result<()> {
1925    session
1926        .db()
1927        .await
1928        .check_user_is_project_host(
1929            ProjectId::from_proto(request.project_id),
1930            session.connection_id,
1931        )
1932        .await?;
1933    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1934    session
1935        .peer
1936        .forward_send(session.connection_id, peer_id.into(), request)?;
1937    Ok(())
1938}
1939
1940/// Notify other participants that a buffer has been updated. This is
1941/// allowed for guests as long as the update is limited to selections.
1942async fn update_buffer(
1943    request: proto::UpdateBuffer,
1944    response: Response<proto::UpdateBuffer>,
1945    session: Session,
1946) -> Result<()> {
1947    let project_id = ProjectId::from_proto(request.project_id);
1948    let mut guest_connection_ids;
1949    let mut host_connection_id = None;
1950
1951    let mut requires_write_permission = false;
1952
1953    for op in request.operations.iter() {
1954        match op.variant {
1955            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1956            Some(_) => requires_write_permission = true,
1957        }
1958    }
1959
1960    {
1961        let collaborators = session
1962            .db()
1963            .await
1964            .project_collaborators_for_buffer_update(
1965                project_id,
1966                session.connection_id,
1967                requires_write_permission,
1968            )
1969            .await?;
1970        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1971        for collaborator in collaborators.iter() {
1972            if collaborator.is_host {
1973                host_connection_id = Some(collaborator.connection_id);
1974            } else {
1975                guest_connection_ids.push(collaborator.connection_id);
1976            }
1977        }
1978    }
1979    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1980
1981    broadcast(
1982        Some(session.connection_id),
1983        guest_connection_ids,
1984        |connection_id| {
1985            session
1986                .peer
1987                .forward_send(session.connection_id, connection_id, request.clone())
1988        },
1989    );
1990    if host_connection_id != session.connection_id {
1991        session
1992            .peer
1993            .forward_request(session.connection_id, host_connection_id, request.clone())
1994            .await?;
1995    }
1996
1997    response.send(proto::Ack {})?;
1998    Ok(())
1999}
2000
2001/// Notify other participants that a project has been updated.
2002async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2003    request: T,
2004    session: Session,
2005) -> Result<()> {
2006    let project_id = ProjectId::from_proto(request.remote_entity_id());
2007    let project_connection_ids = session
2008        .db()
2009        .await
2010        .project_connection_ids(project_id, session.connection_id)
2011        .await?;
2012
2013    broadcast(
2014        Some(session.connection_id),
2015        project_connection_ids.iter().copied(),
2016        |connection_id| {
2017            session
2018                .peer
2019                .forward_send(session.connection_id, connection_id, request.clone())
2020        },
2021    );
2022    Ok(())
2023}
2024
2025/// Start following another user in a call.
2026async fn follow(
2027    request: proto::Follow,
2028    response: Response<proto::Follow>,
2029    session: Session,
2030) -> Result<()> {
2031    let room_id = RoomId::from_proto(request.room_id);
2032    let project_id = request.project_id.map(ProjectId::from_proto);
2033    let leader_id = request
2034        .leader_id
2035        .ok_or_else(|| anyhow!("invalid leader id"))?
2036        .into();
2037    let follower_id = session.connection_id;
2038
2039    session
2040        .db()
2041        .await
2042        .check_room_participants(room_id, leader_id, session.connection_id)
2043        .await?;
2044
2045    let response_payload = session
2046        .peer
2047        .forward_request(session.connection_id, leader_id, request)
2048        .await?;
2049    response.send(response_payload)?;
2050
2051    if let Some(project_id) = project_id {
2052        let room = session
2053            .db()
2054            .await
2055            .follow(room_id, project_id, leader_id, follower_id)
2056            .await?;
2057        room_updated(&room, &session.peer);
2058    }
2059
2060    Ok(())
2061}
2062
2063/// Stop following another user in a call.
2064async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2065    let room_id = RoomId::from_proto(request.room_id);
2066    let project_id = request.project_id.map(ProjectId::from_proto);
2067    let leader_id = request
2068        .leader_id
2069        .ok_or_else(|| anyhow!("invalid leader id"))?
2070        .into();
2071    let follower_id = session.connection_id;
2072
2073    session
2074        .db()
2075        .await
2076        .check_room_participants(room_id, leader_id, session.connection_id)
2077        .await?;
2078
2079    session
2080        .peer
2081        .forward_send(session.connection_id, leader_id, request)?;
2082
2083    if let Some(project_id) = project_id {
2084        let room = session
2085            .db()
2086            .await
2087            .unfollow(room_id, project_id, leader_id, follower_id)
2088            .await?;
2089        room_updated(&room, &session.peer);
2090    }
2091
2092    Ok(())
2093}
2094
2095/// Notify everyone following you of your current location.
2096async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2097    let room_id = RoomId::from_proto(request.room_id);
2098    let database = session.db.lock().await;
2099
2100    let connection_ids = if let Some(project_id) = request.project_id {
2101        let project_id = ProjectId::from_proto(project_id);
2102        database
2103            .project_connection_ids(project_id, session.connection_id)
2104            .await?
2105    } else {
2106        database
2107            .room_connection_ids(room_id, session.connection_id)
2108            .await?
2109    };
2110
2111    // For now, don't send view update messages back to that view's current leader.
2112    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2113        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2114        _ => None,
2115    });
2116
2117    for connection_id in connection_ids.iter().cloned() {
2118        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2119            session
2120                .peer
2121                .forward_send(session.connection_id, connection_id, request.clone())?;
2122        }
2123    }
2124    Ok(())
2125}
2126
2127/// Get public data about users.
2128async fn get_users(
2129    request: proto::GetUsers,
2130    response: Response<proto::GetUsers>,
2131    session: Session,
2132) -> Result<()> {
2133    let user_ids = request
2134        .user_ids
2135        .into_iter()
2136        .map(UserId::from_proto)
2137        .collect();
2138    let users = session
2139        .db()
2140        .await
2141        .get_users_by_ids(user_ids)
2142        .await?
2143        .into_iter()
2144        .map(|user| proto::User {
2145            id: user.id.to_proto(),
2146            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2147            github_login: user.github_login,
2148        })
2149        .collect();
2150    response.send(proto::UsersResponse { users })?;
2151    Ok(())
2152}
2153
2154/// Search for users (to invite) buy Github login
2155async fn fuzzy_search_users(
2156    request: proto::FuzzySearchUsers,
2157    response: Response<proto::FuzzySearchUsers>,
2158    session: Session,
2159) -> Result<()> {
2160    let query = request.query;
2161    let users = match query.len() {
2162        0 => vec![],
2163        1 | 2 => session
2164            .db()
2165            .await
2166            .get_user_by_github_login(&query)
2167            .await?
2168            .into_iter()
2169            .collect(),
2170        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2171    };
2172    let users = users
2173        .into_iter()
2174        .filter(|user| user.id != session.user_id)
2175        .map(|user| proto::User {
2176            id: user.id.to_proto(),
2177            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2178            github_login: user.github_login,
2179        })
2180        .collect();
2181    response.send(proto::UsersResponse { users })?;
2182    Ok(())
2183}
2184
2185/// Send a contact request to another user.
2186async fn request_contact(
2187    request: proto::RequestContact,
2188    response: Response<proto::RequestContact>,
2189    session: Session,
2190) -> Result<()> {
2191    let requester_id = session.user_id;
2192    let responder_id = UserId::from_proto(request.responder_id);
2193    if requester_id == responder_id {
2194        return Err(anyhow!("cannot add yourself as a contact"))?;
2195    }
2196
2197    let notifications = session
2198        .db()
2199        .await
2200        .send_contact_request(requester_id, responder_id)
2201        .await?;
2202
2203    // Update outgoing contact requests of requester
2204    let mut update = proto::UpdateContacts::default();
2205    update.outgoing_requests.push(responder_id.to_proto());
2206    for connection_id in session
2207        .connection_pool()
2208        .await
2209        .user_connection_ids(requester_id)
2210    {
2211        session.peer.send(connection_id, update.clone())?;
2212    }
2213
2214    // Update incoming contact requests of responder
2215    let mut update = proto::UpdateContacts::default();
2216    update
2217        .incoming_requests
2218        .push(proto::IncomingContactRequest {
2219            requester_id: requester_id.to_proto(),
2220        });
2221    let connection_pool = session.connection_pool().await;
2222    for connection_id in connection_pool.user_connection_ids(responder_id) {
2223        session.peer.send(connection_id, update.clone())?;
2224    }
2225
2226    send_notifications(&*connection_pool, &session.peer, notifications);
2227
2228    response.send(proto::Ack {})?;
2229    Ok(())
2230}
2231
2232/// Accept or decline a contact request
2233async fn respond_to_contact_request(
2234    request: proto::RespondToContactRequest,
2235    response: Response<proto::RespondToContactRequest>,
2236    session: Session,
2237) -> Result<()> {
2238    let responder_id = session.user_id;
2239    let requester_id = UserId::from_proto(request.requester_id);
2240    let db = session.db().await;
2241    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2242        db.dismiss_contact_notification(responder_id, requester_id)
2243            .await?;
2244    } else {
2245        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2246
2247        let notifications = db
2248            .respond_to_contact_request(responder_id, requester_id, accept)
2249            .await?;
2250        let requester_busy = db.is_user_busy(requester_id).await?;
2251        let responder_busy = db.is_user_busy(responder_id).await?;
2252
2253        let pool = session.connection_pool().await;
2254        // Update responder with new contact
2255        let mut update = proto::UpdateContacts::default();
2256        if accept {
2257            update
2258                .contacts
2259                .push(contact_for_user(requester_id, requester_busy, &pool));
2260        }
2261        update
2262            .remove_incoming_requests
2263            .push(requester_id.to_proto());
2264        for connection_id in pool.user_connection_ids(responder_id) {
2265            session.peer.send(connection_id, update.clone())?;
2266        }
2267
2268        // Update requester with new contact
2269        let mut update = proto::UpdateContacts::default();
2270        if accept {
2271            update
2272                .contacts
2273                .push(contact_for_user(responder_id, responder_busy, &pool));
2274        }
2275        update
2276            .remove_outgoing_requests
2277            .push(responder_id.to_proto());
2278
2279        for connection_id in pool.user_connection_ids(requester_id) {
2280            session.peer.send(connection_id, update.clone())?;
2281        }
2282
2283        send_notifications(&*pool, &session.peer, notifications);
2284    }
2285
2286    response.send(proto::Ack {})?;
2287    Ok(())
2288}
2289
2290/// Remove a contact.
2291async fn remove_contact(
2292    request: proto::RemoveContact,
2293    response: Response<proto::RemoveContact>,
2294    session: Session,
2295) -> Result<()> {
2296    let requester_id = session.user_id;
2297    let responder_id = UserId::from_proto(request.user_id);
2298    let db = session.db().await;
2299    let (contact_accepted, deleted_notification_id) =
2300        db.remove_contact(requester_id, responder_id).await?;
2301
2302    let pool = session.connection_pool().await;
2303    // Update outgoing contact requests of requester
2304    let mut update = proto::UpdateContacts::default();
2305    if contact_accepted {
2306        update.remove_contacts.push(responder_id.to_proto());
2307    } else {
2308        update
2309            .remove_outgoing_requests
2310            .push(responder_id.to_proto());
2311    }
2312    for connection_id in pool.user_connection_ids(requester_id) {
2313        session.peer.send(connection_id, update.clone())?;
2314    }
2315
2316    // Update incoming contact requests of responder
2317    let mut update = proto::UpdateContacts::default();
2318    if contact_accepted {
2319        update.remove_contacts.push(requester_id.to_proto());
2320    } else {
2321        update
2322            .remove_incoming_requests
2323            .push(requester_id.to_proto());
2324    }
2325    for connection_id in pool.user_connection_ids(responder_id) {
2326        session.peer.send(connection_id, update.clone())?;
2327        if let Some(notification_id) = deleted_notification_id {
2328            session.peer.send(
2329                connection_id,
2330                proto::DeleteNotification {
2331                    notification_id: notification_id.to_proto(),
2332                },
2333            )?;
2334        }
2335    }
2336
2337    response.send(proto::Ack {})?;
2338    Ok(())
2339}
2340
2341/// Creates a new channel.
2342async fn create_channel(
2343    request: proto::CreateChannel,
2344    response: Response<proto::CreateChannel>,
2345    session: Session,
2346) -> Result<()> {
2347    let db = session.db().await;
2348
2349    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2350    let (channel, owner, channel_members) = db
2351        .create_channel(&request.name, parent_id, session.user_id)
2352        .await?;
2353
2354    response.send(proto::CreateChannelResponse {
2355        channel: Some(channel.to_proto()),
2356        parent_id: request.parent_id,
2357    })?;
2358
2359    let connection_pool = session.connection_pool().await;
2360    if let Some(owner) = owner {
2361        let update = proto::UpdateUserChannels {
2362            channel_memberships: vec![proto::ChannelMembership {
2363                channel_id: owner.channel_id.to_proto(),
2364                role: owner.role.into(),
2365            }],
2366            ..Default::default()
2367        };
2368        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2369            session.peer.send(connection_id, update.clone())?;
2370        }
2371    }
2372
2373    for channel_member in channel_members {
2374        if !channel_member.role.can_see_channel(channel.visibility) {
2375            continue;
2376        }
2377
2378        let update = proto::UpdateChannels {
2379            channels: vec![channel.to_proto()],
2380            ..Default::default()
2381        };
2382        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2383            session.peer.send(connection_id, update.clone())?;
2384        }
2385    }
2386
2387    Ok(())
2388}
2389
2390/// Delete a channel
2391async fn delete_channel(
2392    request: proto::DeleteChannel,
2393    response: Response<proto::DeleteChannel>,
2394    session: Session,
2395) -> Result<()> {
2396    let db = session.db().await;
2397
2398    let channel_id = request.channel_id;
2399    let (removed_channels, member_ids) = db
2400        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2401        .await?;
2402    response.send(proto::Ack {})?;
2403
2404    // Notify members of removed channels
2405    let mut update = proto::UpdateChannels::default();
2406    update
2407        .delete_channels
2408        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2409
2410    let connection_pool = session.connection_pool().await;
2411    for member_id in member_ids {
2412        for connection_id in connection_pool.user_connection_ids(member_id) {
2413            session.peer.send(connection_id, update.clone())?;
2414        }
2415    }
2416
2417    Ok(())
2418}
2419
2420/// Invite someone to join a channel.
2421async fn invite_channel_member(
2422    request: proto::InviteChannelMember,
2423    response: Response<proto::InviteChannelMember>,
2424    session: Session,
2425) -> Result<()> {
2426    let db = session.db().await;
2427    let channel_id = ChannelId::from_proto(request.channel_id);
2428    let invitee_id = UserId::from_proto(request.user_id);
2429    let InviteMemberResult {
2430        channel,
2431        notifications,
2432    } = db
2433        .invite_channel_member(
2434            channel_id,
2435            invitee_id,
2436            session.user_id,
2437            request.role().into(),
2438        )
2439        .await?;
2440
2441    let update = proto::UpdateChannels {
2442        channel_invitations: vec![channel.to_proto()],
2443        ..Default::default()
2444    };
2445
2446    let connection_pool = session.connection_pool().await;
2447    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2448        session.peer.send(connection_id, update.clone())?;
2449    }
2450
2451    send_notifications(&*connection_pool, &session.peer, notifications);
2452
2453    response.send(proto::Ack {})?;
2454    Ok(())
2455}
2456
2457/// remove someone from a channel
2458async fn remove_channel_member(
2459    request: proto::RemoveChannelMember,
2460    response: Response<proto::RemoveChannelMember>,
2461    session: Session,
2462) -> Result<()> {
2463    let db = session.db().await;
2464    let channel_id = ChannelId::from_proto(request.channel_id);
2465    let member_id = UserId::from_proto(request.user_id);
2466
2467    let RemoveChannelMemberResult {
2468        membership_update,
2469        notification_id,
2470    } = db
2471        .remove_channel_member(channel_id, member_id, session.user_id)
2472        .await?;
2473
2474    let connection_pool = &session.connection_pool().await;
2475    notify_membership_updated(
2476        &connection_pool,
2477        membership_update,
2478        member_id,
2479        &session.peer,
2480    );
2481    for connection_id in connection_pool.user_connection_ids(member_id) {
2482        if let Some(notification_id) = notification_id {
2483            session
2484                .peer
2485                .send(
2486                    connection_id,
2487                    proto::DeleteNotification {
2488                        notification_id: notification_id.to_proto(),
2489                    },
2490                )
2491                .trace_err();
2492        }
2493    }
2494
2495    response.send(proto::Ack {})?;
2496    Ok(())
2497}
2498
2499/// Toggle the channel between public and private.
2500/// Care is taken to maintain the invariant that public channels only descend from public channels,
2501/// (though members-only channels can appear at any point in the hierarchy).
2502async fn set_channel_visibility(
2503    request: proto::SetChannelVisibility,
2504    response: Response<proto::SetChannelVisibility>,
2505    session: Session,
2506) -> Result<()> {
2507    let db = session.db().await;
2508    let channel_id = ChannelId::from_proto(request.channel_id);
2509    let visibility = request.visibility().into();
2510
2511    let (channel, channel_members) = db
2512        .set_channel_visibility(channel_id, visibility, session.user_id)
2513        .await?;
2514
2515    let connection_pool = session.connection_pool().await;
2516    for member in channel_members {
2517        let update = if member.role.can_see_channel(channel.visibility) {
2518            proto::UpdateChannels {
2519                channels: vec![channel.to_proto()],
2520                ..Default::default()
2521            }
2522        } else {
2523            proto::UpdateChannels {
2524                delete_channels: vec![channel.id.to_proto()],
2525                ..Default::default()
2526            }
2527        };
2528
2529        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2530            session.peer.send(connection_id, update.clone())?;
2531        }
2532    }
2533
2534    response.send(proto::Ack {})?;
2535    Ok(())
2536}
2537
2538/// Alter the role for a user in the channel.
2539async fn set_channel_member_role(
2540    request: proto::SetChannelMemberRole,
2541    response: Response<proto::SetChannelMemberRole>,
2542    session: Session,
2543) -> Result<()> {
2544    let db = session.db().await;
2545    let channel_id = ChannelId::from_proto(request.channel_id);
2546    let member_id = UserId::from_proto(request.user_id);
2547    let result = db
2548        .set_channel_member_role(
2549            channel_id,
2550            session.user_id,
2551            member_id,
2552            request.role().into(),
2553        )
2554        .await?;
2555
2556    match result {
2557        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2558            let connection_pool = session.connection_pool().await;
2559            notify_membership_updated(
2560                &connection_pool,
2561                membership_update,
2562                member_id,
2563                &session.peer,
2564            )
2565        }
2566        db::SetMemberRoleResult::InviteUpdated(channel) => {
2567            let update = proto::UpdateChannels {
2568                channel_invitations: vec![channel.to_proto()],
2569                ..Default::default()
2570            };
2571
2572            for connection_id in session
2573                .connection_pool()
2574                .await
2575                .user_connection_ids(member_id)
2576            {
2577                session.peer.send(connection_id, update.clone())?;
2578            }
2579        }
2580    }
2581
2582    response.send(proto::Ack {})?;
2583    Ok(())
2584}
2585
2586/// Change the name of a channel
2587async fn rename_channel(
2588    request: proto::RenameChannel,
2589    response: Response<proto::RenameChannel>,
2590    session: Session,
2591) -> Result<()> {
2592    let db = session.db().await;
2593    let channel_id = ChannelId::from_proto(request.channel_id);
2594    let (channel, channel_members) = db
2595        .rename_channel(channel_id, session.user_id, &request.name)
2596        .await?;
2597
2598    response.send(proto::RenameChannelResponse {
2599        channel: Some(channel.to_proto()),
2600    })?;
2601
2602    let connection_pool = session.connection_pool().await;
2603    for channel_member in channel_members {
2604        if !channel_member.role.can_see_channel(channel.visibility) {
2605            continue;
2606        }
2607        let update = proto::UpdateChannels {
2608            channels: vec![channel.to_proto()],
2609            ..Default::default()
2610        };
2611        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2612            session.peer.send(connection_id, update.clone())?;
2613        }
2614    }
2615
2616    Ok(())
2617}
2618
2619/// Move a channel to a new parent.
2620async fn move_channel(
2621    request: proto::MoveChannel,
2622    response: Response<proto::MoveChannel>,
2623    session: Session,
2624) -> Result<()> {
2625    let channel_id = ChannelId::from_proto(request.channel_id);
2626    let to = ChannelId::from_proto(request.to);
2627
2628    let (channels, channel_members) = session
2629        .db()
2630        .await
2631        .move_channel(channel_id, to, session.user_id)
2632        .await?;
2633
2634    let connection_pool = session.connection_pool().await;
2635    for member in channel_members {
2636        let channels = channels
2637            .iter()
2638            .filter_map(|channel| {
2639                if member.role.can_see_channel(channel.visibility) {
2640                    Some(channel.to_proto())
2641                } else {
2642                    None
2643                }
2644            })
2645            .collect::<Vec<_>>();
2646        if channels.is_empty() {
2647            continue;
2648        }
2649
2650        let update = proto::UpdateChannels {
2651            channels,
2652            ..Default::default()
2653        };
2654
2655        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2656            session.peer.send(connection_id, update.clone())?;
2657        }
2658    }
2659
2660    response.send(Ack {})?;
2661    Ok(())
2662}
2663
2664/// Get the list of channel members
2665async fn get_channel_members(
2666    request: proto::GetChannelMembers,
2667    response: Response<proto::GetChannelMembers>,
2668    session: Session,
2669) -> Result<()> {
2670    let db = session.db().await;
2671    let channel_id = ChannelId::from_proto(request.channel_id);
2672    let members = db
2673        .get_channel_participant_details(channel_id, session.user_id)
2674        .await?;
2675    response.send(proto::GetChannelMembersResponse { members })?;
2676    Ok(())
2677}
2678
2679/// Accept or decline a channel invitation.
2680async fn respond_to_channel_invite(
2681    request: proto::RespondToChannelInvite,
2682    response: Response<proto::RespondToChannelInvite>,
2683    session: Session,
2684) -> Result<()> {
2685    let db = session.db().await;
2686    let channel_id = ChannelId::from_proto(request.channel_id);
2687    let RespondToChannelInvite {
2688        membership_update,
2689        notifications,
2690    } = db
2691        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2692        .await?;
2693
2694    let connection_pool = session.connection_pool().await;
2695    if let Some(membership_update) = membership_update {
2696        notify_membership_updated(
2697            &connection_pool,
2698            membership_update,
2699            session.user_id,
2700            &session.peer,
2701        );
2702    } else {
2703        let update = proto::UpdateChannels {
2704            remove_channel_invitations: vec![channel_id.to_proto()],
2705            ..Default::default()
2706        };
2707
2708        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2709            session.peer.send(connection_id, update.clone())?;
2710        }
2711    };
2712
2713    send_notifications(&*connection_pool, &session.peer, notifications);
2714
2715    response.send(proto::Ack {})?;
2716
2717    Ok(())
2718}
2719
2720/// Join the channels' room
2721async fn join_channel(
2722    request: proto::JoinChannel,
2723    response: Response<proto::JoinChannel>,
2724    session: Session,
2725) -> Result<()> {
2726    let channel_id = ChannelId::from_proto(request.channel_id);
2727    join_channel_internal(channel_id, Box::new(response), session).await
2728}
2729
2730trait JoinChannelInternalResponse {
2731    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2732}
2733impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2734    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2735        Response::<proto::JoinChannel>::send(self, result)
2736    }
2737}
2738impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2739    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2740        Response::<proto::JoinRoom>::send(self, result)
2741    }
2742}
2743
2744async fn join_channel_internal(
2745    channel_id: ChannelId,
2746    response: Box<impl JoinChannelInternalResponse>,
2747    session: Session,
2748) -> Result<()> {
2749    let joined_room = {
2750        leave_room_for_session(&session).await?;
2751        let db = session.db().await;
2752
2753        let (joined_room, membership_updated, role) = db
2754            .join_channel(channel_id, session.user_id, session.connection_id)
2755            .await?;
2756
2757        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2758            let (can_publish, token) = if role == ChannelRole::Guest {
2759                (
2760                    false,
2761                    live_kit
2762                        .guest_token(
2763                            &joined_room.room.live_kit_room,
2764                            &session.user_id.to_string(),
2765                        )
2766                        .trace_err()?,
2767                )
2768            } else {
2769                (
2770                    true,
2771                    live_kit
2772                        .room_token(
2773                            &joined_room.room.live_kit_room,
2774                            &session.user_id.to_string(),
2775                        )
2776                        .trace_err()?,
2777                )
2778            };
2779
2780            Some(LiveKitConnectionInfo {
2781                server_url: live_kit.url().into(),
2782                token,
2783                can_publish,
2784            })
2785        });
2786
2787        response.send(proto::JoinRoomResponse {
2788            room: Some(joined_room.room.clone()),
2789            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2790            live_kit_connection_info,
2791        })?;
2792
2793        let connection_pool = session.connection_pool().await;
2794        if let Some(membership_updated) = membership_updated {
2795            notify_membership_updated(
2796                &connection_pool,
2797                membership_updated,
2798                session.user_id,
2799                &session.peer,
2800            );
2801        }
2802
2803        room_updated(&joined_room.room, &session.peer);
2804
2805        joined_room
2806    };
2807
2808    channel_updated(
2809        channel_id,
2810        &joined_room.room,
2811        &joined_room.channel_members,
2812        &session.peer,
2813        &*session.connection_pool().await,
2814    );
2815
2816    update_user_contacts(session.user_id, &session).await?;
2817    Ok(())
2818}
2819
2820/// Start editing the channel notes
2821async fn join_channel_buffer(
2822    request: proto::JoinChannelBuffer,
2823    response: Response<proto::JoinChannelBuffer>,
2824    session: Session,
2825) -> Result<()> {
2826    let db = session.db().await;
2827    let channel_id = ChannelId::from_proto(request.channel_id);
2828
2829    let open_response = db
2830        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2831        .await?;
2832
2833    let collaborators = open_response.collaborators.clone();
2834    response.send(open_response)?;
2835
2836    let update = UpdateChannelBufferCollaborators {
2837        channel_id: channel_id.to_proto(),
2838        collaborators: collaborators.clone(),
2839    };
2840    channel_buffer_updated(
2841        session.connection_id,
2842        collaborators
2843            .iter()
2844            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2845        &update,
2846        &session.peer,
2847    );
2848
2849    Ok(())
2850}
2851
2852/// Edit the channel notes
2853async fn update_channel_buffer(
2854    request: proto::UpdateChannelBuffer,
2855    session: Session,
2856) -> Result<()> {
2857    let db = session.db().await;
2858    let channel_id = ChannelId::from_proto(request.channel_id);
2859
2860    let (collaborators, non_collaborators, epoch, version) = db
2861        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2862        .await?;
2863
2864    channel_buffer_updated(
2865        session.connection_id,
2866        collaborators,
2867        &proto::UpdateChannelBuffer {
2868            channel_id: channel_id.to_proto(),
2869            operations: request.operations,
2870        },
2871        &session.peer,
2872    );
2873
2874    let pool = &*session.connection_pool().await;
2875
2876    broadcast(
2877        None,
2878        non_collaborators
2879            .iter()
2880            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2881        |peer_id| {
2882            session.peer.send(
2883                peer_id.into(),
2884                proto::UpdateChannels {
2885                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2886                        channel_id: channel_id.to_proto(),
2887                        epoch: epoch as u64,
2888                        version: version.clone(),
2889                    }],
2890                    ..Default::default()
2891                },
2892            )
2893        },
2894    );
2895
2896    Ok(())
2897}
2898
2899/// Rejoin the channel notes after a connection blip
2900async fn rejoin_channel_buffers(
2901    request: proto::RejoinChannelBuffers,
2902    response: Response<proto::RejoinChannelBuffers>,
2903    session: Session,
2904) -> Result<()> {
2905    let db = session.db().await;
2906    let buffers = db
2907        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2908        .await?;
2909
2910    for rejoined_buffer in &buffers {
2911        let collaborators_to_notify = rejoined_buffer
2912            .buffer
2913            .collaborators
2914            .iter()
2915            .filter_map(|c| Some(c.peer_id?.into()));
2916        channel_buffer_updated(
2917            session.connection_id,
2918            collaborators_to_notify,
2919            &proto::UpdateChannelBufferCollaborators {
2920                channel_id: rejoined_buffer.buffer.channel_id,
2921                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2922            },
2923            &session.peer,
2924        );
2925    }
2926
2927    response.send(proto::RejoinChannelBuffersResponse {
2928        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2929    })?;
2930
2931    Ok(())
2932}
2933
2934/// Stop editing the channel notes
2935async fn leave_channel_buffer(
2936    request: proto::LeaveChannelBuffer,
2937    response: Response<proto::LeaveChannelBuffer>,
2938    session: Session,
2939) -> Result<()> {
2940    let db = session.db().await;
2941    let channel_id = ChannelId::from_proto(request.channel_id);
2942
2943    let left_buffer = db
2944        .leave_channel_buffer(channel_id, session.connection_id)
2945        .await?;
2946
2947    response.send(Ack {})?;
2948
2949    channel_buffer_updated(
2950        session.connection_id,
2951        left_buffer.connections,
2952        &proto::UpdateChannelBufferCollaborators {
2953            channel_id: channel_id.to_proto(),
2954            collaborators: left_buffer.collaborators,
2955        },
2956        &session.peer,
2957    );
2958
2959    Ok(())
2960}
2961
2962fn channel_buffer_updated<T: EnvelopedMessage>(
2963    sender_id: ConnectionId,
2964    collaborators: impl IntoIterator<Item = ConnectionId>,
2965    message: &T,
2966    peer: &Peer,
2967) {
2968    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2969        peer.send(peer_id.into(), message.clone())
2970    });
2971}
2972
2973fn send_notifications(
2974    connection_pool: &ConnectionPool,
2975    peer: &Peer,
2976    notifications: db::NotificationBatch,
2977) {
2978    for (user_id, notification) in notifications {
2979        for connection_id in connection_pool.user_connection_ids(user_id) {
2980            if let Err(error) = peer.send(
2981                connection_id,
2982                proto::AddNotification {
2983                    notification: Some(notification.clone()),
2984                },
2985            ) {
2986                tracing::error!(
2987                    "failed to send notification to {:?} {}",
2988                    connection_id,
2989                    error
2990                );
2991            }
2992        }
2993    }
2994}
2995
2996/// Send a message to the channel
2997async fn send_channel_message(
2998    request: proto::SendChannelMessage,
2999    response: Response<proto::SendChannelMessage>,
3000    session: Session,
3001) -> Result<()> {
3002    // Validate the message body.
3003    let body = request.body.trim().to_string();
3004    if body.len() > MAX_MESSAGE_LEN {
3005        return Err(anyhow!("message is too long"))?;
3006    }
3007    if body.is_empty() {
3008        return Err(anyhow!("message can't be blank"))?;
3009    }
3010
3011    // TODO: adjust mentions if body is trimmed
3012
3013    let timestamp = OffsetDateTime::now_utc();
3014    let nonce = request
3015        .nonce
3016        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3017
3018    let channel_id = ChannelId::from_proto(request.channel_id);
3019    let CreatedChannelMessage {
3020        message_id,
3021        participant_connection_ids,
3022        channel_members,
3023        notifications,
3024    } = session
3025        .db()
3026        .await
3027        .create_channel_message(
3028            channel_id,
3029            session.user_id,
3030            &body,
3031            &request.mentions,
3032            timestamp,
3033            nonce.clone().into(),
3034            match request.reply_to_message_id {
3035                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3036                None => None,
3037            },
3038        )
3039        .await?;
3040    let message = proto::ChannelMessage {
3041        sender_id: session.user_id.to_proto(),
3042        id: message_id.to_proto(),
3043        body,
3044        mentions: request.mentions,
3045        timestamp: timestamp.unix_timestamp() as u64,
3046        nonce: Some(nonce),
3047        reply_to_message_id: request.reply_to_message_id,
3048    };
3049    broadcast(
3050        Some(session.connection_id),
3051        participant_connection_ids,
3052        |connection| {
3053            session.peer.send(
3054                connection,
3055                proto::ChannelMessageSent {
3056                    channel_id: channel_id.to_proto(),
3057                    message: Some(message.clone()),
3058                },
3059            )
3060        },
3061    );
3062    response.send(proto::SendChannelMessageResponse {
3063        message: Some(message),
3064    })?;
3065
3066    let pool = &*session.connection_pool().await;
3067    broadcast(
3068        None,
3069        channel_members
3070            .iter()
3071            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3072        |peer_id| {
3073            session.peer.send(
3074                peer_id.into(),
3075                proto::UpdateChannels {
3076                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3077                        channel_id: channel_id.to_proto(),
3078                        message_id: message_id.to_proto(),
3079                    }],
3080                    ..Default::default()
3081                },
3082            )
3083        },
3084    );
3085    send_notifications(pool, &session.peer, notifications);
3086
3087    Ok(())
3088}
3089
3090/// Delete a channel message
3091async fn remove_channel_message(
3092    request: proto::RemoveChannelMessage,
3093    response: Response<proto::RemoveChannelMessage>,
3094    session: Session,
3095) -> Result<()> {
3096    let channel_id = ChannelId::from_proto(request.channel_id);
3097    let message_id = MessageId::from_proto(request.message_id);
3098    let connection_ids = session
3099        .db()
3100        .await
3101        .remove_channel_message(channel_id, message_id, session.user_id)
3102        .await?;
3103    broadcast(Some(session.connection_id), connection_ids, |connection| {
3104        session.peer.send(connection, request.clone())
3105    });
3106    response.send(proto::Ack {})?;
3107    Ok(())
3108}
3109
3110/// Mark a channel message as read
3111async fn acknowledge_channel_message(
3112    request: proto::AckChannelMessage,
3113    session: Session,
3114) -> Result<()> {
3115    let channel_id = ChannelId::from_proto(request.channel_id);
3116    let message_id = MessageId::from_proto(request.message_id);
3117    let notifications = session
3118        .db()
3119        .await
3120        .observe_channel_message(channel_id, session.user_id, message_id)
3121        .await?;
3122    send_notifications(
3123        &*session.connection_pool().await,
3124        &session.peer,
3125        notifications,
3126    );
3127    Ok(())
3128}
3129
3130/// Mark a buffer version as synced
3131async fn acknowledge_buffer_version(
3132    request: proto::AckBufferOperation,
3133    session: Session,
3134) -> Result<()> {
3135    let buffer_id = BufferId::from_proto(request.buffer_id);
3136    session
3137        .db()
3138        .await
3139        .observe_buffer_version(
3140            buffer_id,
3141            session.user_id,
3142            request.epoch as i32,
3143            &request.version,
3144        )
3145        .await?;
3146    Ok(())
3147}
3148
3149/// Start receiving chat updates for a channel
3150async fn join_channel_chat(
3151    request: proto::JoinChannelChat,
3152    response: Response<proto::JoinChannelChat>,
3153    session: Session,
3154) -> Result<()> {
3155    let channel_id = ChannelId::from_proto(request.channel_id);
3156
3157    let db = session.db().await;
3158    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3159        .await?;
3160    let messages = db
3161        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3162        .await?;
3163    response.send(proto::JoinChannelChatResponse {
3164        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3165        messages,
3166    })?;
3167    Ok(())
3168}
3169
3170/// Stop receiving chat updates for a channel
3171async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3172    let channel_id = ChannelId::from_proto(request.channel_id);
3173    session
3174        .db()
3175        .await
3176        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3177        .await?;
3178    Ok(())
3179}
3180
3181/// Retrieve the chat history for a channel
3182async fn get_channel_messages(
3183    request: proto::GetChannelMessages,
3184    response: Response<proto::GetChannelMessages>,
3185    session: Session,
3186) -> Result<()> {
3187    let channel_id = ChannelId::from_proto(request.channel_id);
3188    let messages = session
3189        .db()
3190        .await
3191        .get_channel_messages(
3192            channel_id,
3193            session.user_id,
3194            MESSAGE_COUNT_PER_PAGE,
3195            Some(MessageId::from_proto(request.before_message_id)),
3196        )
3197        .await?;
3198    response.send(proto::GetChannelMessagesResponse {
3199        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3200        messages,
3201    })?;
3202    Ok(())
3203}
3204
3205/// Retrieve specific chat messages
3206async fn get_channel_messages_by_id(
3207    request: proto::GetChannelMessagesById,
3208    response: Response<proto::GetChannelMessagesById>,
3209    session: Session,
3210) -> Result<()> {
3211    let message_ids = request
3212        .message_ids
3213        .iter()
3214        .map(|id| MessageId::from_proto(*id))
3215        .collect::<Vec<_>>();
3216    let messages = session
3217        .db()
3218        .await
3219        .get_channel_messages_by_id(session.user_id, &message_ids)
3220        .await?;
3221    response.send(proto::GetChannelMessagesResponse {
3222        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3223        messages,
3224    })?;
3225    Ok(())
3226}
3227
3228/// Retrieve the current users notifications
3229async fn get_notifications(
3230    request: proto::GetNotifications,
3231    response: Response<proto::GetNotifications>,
3232    session: Session,
3233) -> Result<()> {
3234    let notifications = session
3235        .db()
3236        .await
3237        .get_notifications(
3238            session.user_id,
3239            NOTIFICATION_COUNT_PER_PAGE,
3240            request
3241                .before_id
3242                .map(|id| db::NotificationId::from_proto(id)),
3243        )
3244        .await?;
3245    response.send(proto::GetNotificationsResponse {
3246        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3247        notifications,
3248    })?;
3249    Ok(())
3250}
3251
3252/// Mark notifications as read
3253async fn mark_notification_as_read(
3254    request: proto::MarkNotificationRead,
3255    response: Response<proto::MarkNotificationRead>,
3256    session: Session,
3257) -> Result<()> {
3258    let database = &session.db().await;
3259    let notifications = database
3260        .mark_notification_as_read_by_id(
3261            session.user_id,
3262            NotificationId::from_proto(request.notification_id),
3263        )
3264        .await?;
3265    send_notifications(
3266        &*session.connection_pool().await,
3267        &session.peer,
3268        notifications,
3269    );
3270    response.send(proto::Ack {})?;
3271    Ok(())
3272}
3273
3274/// Get the current users information
3275async fn get_private_user_info(
3276    _request: proto::GetPrivateUserInfo,
3277    response: Response<proto::GetPrivateUserInfo>,
3278    session: Session,
3279) -> Result<()> {
3280    let db = session.db().await;
3281
3282    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3283    let user = db
3284        .get_user_by_id(session.user_id)
3285        .await?
3286        .ok_or_else(|| anyhow!("user not found"))?;
3287    let flags = db.get_user_flags(session.user_id).await?;
3288
3289    response.send(proto::GetPrivateUserInfoResponse {
3290        metrics_id,
3291        staff: user.admin,
3292        flags,
3293    })?;
3294    Ok(())
3295}
3296
3297fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3298    match message {
3299        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3300        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3301        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3302        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3303        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3304            code: frame.code.into(),
3305            reason: frame.reason,
3306        })),
3307    }
3308}
3309
3310fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3311    match message {
3312        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3313        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3314        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3315        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3316        AxumMessage::Close(frame) => {
3317            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3318                code: frame.code.into(),
3319                reason: frame.reason,
3320            }))
3321        }
3322    }
3323}
3324
3325fn notify_membership_updated(
3326    connection_pool: &ConnectionPool,
3327    result: MembershipUpdated,
3328    user_id: UserId,
3329    peer: &Peer,
3330) {
3331    let user_channels_update = proto::UpdateUserChannels {
3332        channel_memberships: result
3333            .new_channels
3334            .channel_memberships
3335            .iter()
3336            .map(|cm| proto::ChannelMembership {
3337                channel_id: cm.channel_id.to_proto(),
3338                role: cm.role.into(),
3339            })
3340            .collect(),
3341        ..Default::default()
3342    };
3343    let mut update = build_channels_update(result.new_channels, vec![]);
3344    update.delete_channels = result
3345        .removed_channels
3346        .into_iter()
3347        .map(|id| id.to_proto())
3348        .collect();
3349    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3350
3351    for connection_id in connection_pool.user_connection_ids(user_id) {
3352        peer.send(connection_id, user_channels_update.clone())
3353            .trace_err();
3354        peer.send(connection_id, update.clone()).trace_err();
3355    }
3356}
3357
3358fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3359    proto::UpdateUserChannels {
3360        channel_memberships: channels
3361            .channel_memberships
3362            .iter()
3363            .map(|m| proto::ChannelMembership {
3364                channel_id: m.channel_id.to_proto(),
3365                role: m.role.into(),
3366            })
3367            .collect(),
3368        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3369        observed_channel_message_id: channels.observed_channel_messages.clone(),
3370        ..Default::default()
3371    }
3372}
3373
3374fn build_channels_update(
3375    channels: ChannelsForUser,
3376    channel_invites: Vec<db::Channel>,
3377) -> proto::UpdateChannels {
3378    let mut update = proto::UpdateChannels::default();
3379
3380    for channel in channels.channels {
3381        update.channels.push(channel.to_proto());
3382    }
3383
3384    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3385    update.latest_channel_message_ids = channels.latest_channel_messages;
3386
3387    for (channel_id, participants) in channels.channel_participants {
3388        update
3389            .channel_participants
3390            .push(proto::ChannelParticipants {
3391                channel_id: channel_id.to_proto(),
3392                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3393            });
3394    }
3395
3396    for channel in channel_invites {
3397        update.channel_invitations.push(channel.to_proto());
3398    }
3399
3400    update
3401}
3402
3403fn build_initial_contacts_update(
3404    contacts: Vec<db::Contact>,
3405    pool: &ConnectionPool,
3406) -> proto::UpdateContacts {
3407    let mut update = proto::UpdateContacts::default();
3408
3409    for contact in contacts {
3410        match contact {
3411            db::Contact::Accepted { user_id, busy } => {
3412                update.contacts.push(contact_for_user(user_id, busy, &pool));
3413            }
3414            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3415            db::Contact::Incoming { user_id } => {
3416                update
3417                    .incoming_requests
3418                    .push(proto::IncomingContactRequest {
3419                        requester_id: user_id.to_proto(),
3420                    })
3421            }
3422        }
3423    }
3424
3425    update
3426}
3427
3428fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3429    proto::Contact {
3430        user_id: user_id.to_proto(),
3431        online: pool.is_user_online(user_id),
3432        busy,
3433    }
3434}
3435
3436fn room_updated(room: &proto::Room, peer: &Peer) {
3437    broadcast(
3438        None,
3439        room.participants
3440            .iter()
3441            .filter_map(|participant| Some(participant.peer_id?.into())),
3442        |peer_id| {
3443            peer.send(
3444                peer_id.into(),
3445                proto::RoomUpdated {
3446                    room: Some(room.clone()),
3447                },
3448            )
3449        },
3450    );
3451}
3452
3453fn channel_updated(
3454    channel_id: ChannelId,
3455    room: &proto::Room,
3456    channel_members: &[UserId],
3457    peer: &Peer,
3458    pool: &ConnectionPool,
3459) {
3460    let participants = room
3461        .participants
3462        .iter()
3463        .map(|p| p.user_id)
3464        .collect::<Vec<_>>();
3465
3466    broadcast(
3467        None,
3468        channel_members
3469            .iter()
3470            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3471        |peer_id| {
3472            peer.send(
3473                peer_id.into(),
3474                proto::UpdateChannels {
3475                    channel_participants: vec![proto::ChannelParticipants {
3476                        channel_id: channel_id.to_proto(),
3477                        participant_user_ids: participants.clone(),
3478                    }],
3479                    ..Default::default()
3480                },
3481            )
3482        },
3483    );
3484}
3485
3486async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3487    let db = session.db().await;
3488
3489    let contacts = db.get_contacts(user_id).await?;
3490    let busy = db.is_user_busy(user_id).await?;
3491
3492    let pool = session.connection_pool().await;
3493    let updated_contact = contact_for_user(user_id, busy, &pool);
3494    for contact in contacts {
3495        if let db::Contact::Accepted {
3496            user_id: contact_user_id,
3497            ..
3498        } = contact
3499        {
3500            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3501                session
3502                    .peer
3503                    .send(
3504                        contact_conn_id,
3505                        proto::UpdateContacts {
3506                            contacts: vec![updated_contact.clone()],
3507                            remove_contacts: Default::default(),
3508                            incoming_requests: Default::default(),
3509                            remove_incoming_requests: Default::default(),
3510                            outgoing_requests: Default::default(),
3511                            remove_outgoing_requests: Default::default(),
3512                        },
3513                    )
3514                    .trace_err();
3515            }
3516        }
3517    }
3518    Ok(())
3519}
3520
3521async fn leave_room_for_session(session: &Session) -> Result<()> {
3522    let mut contacts_to_update = HashSet::default();
3523
3524    let room_id;
3525    let canceled_calls_to_user_ids;
3526    let live_kit_room;
3527    let delete_live_kit_room;
3528    let room;
3529    let channel_members;
3530    let channel_id;
3531
3532    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3533        contacts_to_update.insert(session.user_id);
3534
3535        for project in left_room.left_projects.values() {
3536            project_left(project, session);
3537        }
3538
3539        room_id = RoomId::from_proto(left_room.room.id);
3540        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3541        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3542        delete_live_kit_room = left_room.deleted;
3543        room = mem::take(&mut left_room.room);
3544        channel_members = mem::take(&mut left_room.channel_members);
3545        channel_id = left_room.channel_id;
3546
3547        room_updated(&room, &session.peer);
3548    } else {
3549        return Ok(());
3550    }
3551
3552    if let Some(channel_id) = channel_id {
3553        channel_updated(
3554            channel_id,
3555            &room,
3556            &channel_members,
3557            &session.peer,
3558            &*session.connection_pool().await,
3559        );
3560    }
3561
3562    {
3563        let pool = session.connection_pool().await;
3564        for canceled_user_id in canceled_calls_to_user_ids {
3565            for connection_id in pool.user_connection_ids(canceled_user_id) {
3566                session
3567                    .peer
3568                    .send(
3569                        connection_id,
3570                        proto::CallCanceled {
3571                            room_id: room_id.to_proto(),
3572                        },
3573                    )
3574                    .trace_err();
3575            }
3576            contacts_to_update.insert(canceled_user_id);
3577        }
3578    }
3579
3580    for contact_user_id in contacts_to_update {
3581        update_user_contacts(contact_user_id, &session).await?;
3582    }
3583
3584    if let Some(live_kit) = session.live_kit_client.as_ref() {
3585        live_kit
3586            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3587            .await
3588            .trace_err();
3589
3590        if delete_live_kit_room {
3591            live_kit.delete_room(live_kit_room).await.trace_err();
3592        }
3593    }
3594
3595    Ok(())
3596}
3597
3598async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3599    let left_channel_buffers = session
3600        .db()
3601        .await
3602        .leave_channel_buffers(session.connection_id)
3603        .await?;
3604
3605    for left_buffer in left_channel_buffers {
3606        channel_buffer_updated(
3607            session.connection_id,
3608            left_buffer.connections,
3609            &proto::UpdateChannelBufferCollaborators {
3610                channel_id: left_buffer.channel_id.to_proto(),
3611                collaborators: left_buffer.collaborators,
3612            },
3613            &session.peer,
3614        );
3615    }
3616
3617    Ok(())
3618}
3619
3620fn project_left(project: &db::LeftProject, session: &Session) {
3621    for connection_id in &project.connection_ids {
3622        if project.host_user_id == session.user_id {
3623            session
3624                .peer
3625                .send(
3626                    *connection_id,
3627                    proto::UnshareProject {
3628                        project_id: project.id.to_proto(),
3629                    },
3630                )
3631                .trace_err();
3632        } else {
3633            session
3634                .peer
3635                .send(
3636                    *connection_id,
3637                    proto::RemoveProjectCollaborator {
3638                        project_id: project.id.to_proto(),
3639                        peer_id: Some(session.connection_id.into()),
3640                    },
3641                )
3642                .trace_err();
3643        }
3644    }
3645}
3646
3647pub trait ResultExt {
3648    type Ok;
3649
3650    fn trace_err(self) -> Option<Self::Ok>;
3651}
3652
3653impl<T, E> ResultExt for Result<T, E>
3654where
3655    E: std::fmt::Debug,
3656{
3657    type Ok = T;
3658
3659    fn trace_err(self) -> Option<T> {
3660        match self {
3661            Ok(value) => Some(value),
3662            Err(error) => {
3663                tracing::error!("{:?}", error);
3664                None
3665            }
3666        }
3667    }
3668}