rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::{ConnectionPool, ZedVersion};
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  42        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc, OnceLock,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{field, info_span, instrument, Instrument};
  66use util::SemanticVersion;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  74
  75type MessageHandler =
  76    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  77
  78struct Response<R> {
  79    peer: Arc<Peer>,
  80    receipt: Receipt<R>,
  81    responded: Arc<AtomicBool>,
  82}
  83
  84impl<R: RequestMessage> Response<R> {
  85    fn send(self, payload: R::Response) -> Result<()> {
  86        self.responded.store(true, SeqCst);
  87        self.peer.respond(self.receipt, payload)?;
  88        Ok(())
  89    }
  90}
  91
  92#[derive(Clone)]
  93struct Session {
  94    user_id: UserId,
  95    connection_id: ConnectionId,
  96    db: Arc<tokio::sync::Mutex<DbHandle>>,
  97    peer: Arc<Peer>,
  98    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  99    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 100    _executor: Executor,
 101}
 102
 103impl Session {
 104    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        let guard = self.db.lock().await;
 108        #[cfg(test)]
 109        tokio::task::yield_now().await;
 110        guard
 111    }
 112
 113    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.connection_pool.lock();
 117        ConnectionPoolGuard {
 118            guard,
 119            _not_send: PhantomData,
 120        }
 121    }
 122}
 123
 124impl fmt::Debug for Session {
 125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 126        f.debug_struct("Session")
 127            .field("user_id", &self.user_id)
 128            .field("connection_id", &self.connection_id)
 129            .finish()
 130    }
 131}
 132
 133struct DbHandle(Arc<Database>);
 134
 135impl Deref for DbHandle {
 136    type Target = Database;
 137
 138    fn deref(&self) -> &Self::Target {
 139        self.0.as_ref()
 140    }
 141}
 142
 143pub struct Server {
 144    id: parking_lot::Mutex<ServerId>,
 145    peer: Arc<Peer>,
 146    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 147    app_state: Arc<AppState>,
 148    executor: Executor,
 149    handlers: HashMap<TypeId, MessageHandler>,
 150    teardown: watch::Sender<bool>,
 151}
 152
 153pub(crate) struct ConnectionPoolGuard<'a> {
 154    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 155    _not_send: PhantomData<Rc<()>>,
 156}
 157
 158#[derive(Serialize)]
 159pub struct ServerSnapshot<'a> {
 160    peer: &'a Peer,
 161    #[serde(serialize_with = "serialize_deref")]
 162    connection_pool: ConnectionPoolGuard<'a>,
 163}
 164
 165pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 166where
 167    S: Serializer,
 168    T: Deref<Target = U>,
 169    U: Serialize,
 170{
 171    Serialize::serialize(value.deref(), serializer)
 172}
 173
 174impl Server {
 175    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 176        let mut server = Self {
 177            id: parking_lot::Mutex::new(id),
 178            peer: Peer::new(id.0 as u32),
 179            app_state,
 180            executor,
 181            connection_pool: Default::default(),
 182            handlers: Default::default(),
 183            teardown: watch::channel(false).0,
 184        };
 185
 186        server
 187            .add_request_handler(ping)
 188            .add_request_handler(create_room)
 189            .add_request_handler(join_room)
 190            .add_request_handler(rejoin_room)
 191            .add_request_handler(leave_room)
 192            .add_request_handler(set_room_participant_role)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 208            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 209            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 210            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 211            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 212            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 213            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 214            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 215            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 216            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 217            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 218            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 219            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 220            .add_request_handler(
 221                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 222            )
 223            .add_request_handler(
 224                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 225            )
 226            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 227            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 228            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 229            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 230            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 231            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 232            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 233            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 234            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 235            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 236            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 237            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 238            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 239            .add_message_handler(create_buffer_for_peer)
 240            .add_request_handler(update_buffer)
 241            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 242            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 243            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 244            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 245            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 246            .add_request_handler(get_users)
 247            .add_request_handler(fuzzy_search_users)
 248            .add_request_handler(request_contact)
 249            .add_request_handler(remove_contact)
 250            .add_request_handler(respond_to_contact_request)
 251            .add_request_handler(create_channel)
 252            .add_request_handler(delete_channel)
 253            .add_request_handler(invite_channel_member)
 254            .add_request_handler(remove_channel_member)
 255            .add_request_handler(set_channel_member_role)
 256            .add_request_handler(set_channel_visibility)
 257            .add_request_handler(rename_channel)
 258            .add_request_handler(join_channel_buffer)
 259            .add_request_handler(leave_channel_buffer)
 260            .add_message_handler(update_channel_buffer)
 261            .add_request_handler(rejoin_channel_buffers)
 262            .add_request_handler(get_channel_members)
 263            .add_request_handler(respond_to_channel_invite)
 264            .add_request_handler(join_channel)
 265            .add_request_handler(join_channel_chat)
 266            .add_message_handler(leave_channel_chat)
 267            .add_request_handler(send_channel_message)
 268            .add_request_handler(remove_channel_message)
 269            .add_request_handler(get_channel_messages)
 270            .add_request_handler(get_channel_messages_by_id)
 271            .add_request_handler(get_notifications)
 272            .add_request_handler(mark_notification_as_read)
 273            .add_request_handler(move_channel)
 274            .add_request_handler(follow)
 275            .add_message_handler(unfollow)
 276            .add_message_handler(update_followers)
 277            .add_request_handler(get_private_user_info)
 278            .add_message_handler(acknowledge_channel_message)
 279            .add_message_handler(acknowledge_buffer_version);
 280
 281        Arc::new(server)
 282    }
 283
 284    pub async fn start(&self) -> Result<()> {
 285        let server_id = *self.id.lock();
 286        let app_state = self.app_state.clone();
 287        let peer = self.peer.clone();
 288        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 289        let pool = self.connection_pool.clone();
 290        let live_kit_client = self.app_state.live_kit_client.clone();
 291
 292        let span = info_span!("start server");
 293        self.executor.spawn_detached(
 294            async move {
 295                tracing::info!("waiting for cleanup timeout");
 296                timeout.await;
 297                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 298                if let Some((room_ids, channel_ids)) = app_state
 299                    .db
 300                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 301                    .await
 302                    .trace_err()
 303                {
 304                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 305                    tracing::info!(
 306                        stale_channel_buffer_count = channel_ids.len(),
 307                        "retrieved stale channel buffers"
 308                    );
 309
 310                    for channel_id in channel_ids {
 311                        if let Some(refreshed_channel_buffer) = app_state
 312                            .db
 313                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 314                            .await
 315                            .trace_err()
 316                        {
 317                            for connection_id in refreshed_channel_buffer.connection_ids {
 318                                peer.send(
 319                                    connection_id,
 320                                    proto::UpdateChannelBufferCollaborators {
 321                                        channel_id: channel_id.to_proto(),
 322                                        collaborators: refreshed_channel_buffer
 323                                            .collaborators
 324                                            .clone(),
 325                                    },
 326                                )
 327                                .trace_err();
 328                            }
 329                        }
 330                    }
 331
 332                    for room_id in room_ids {
 333                        let mut contacts_to_update = HashSet::default();
 334                        let mut canceled_calls_to_user_ids = Vec::new();
 335                        let mut live_kit_room = String::new();
 336                        let mut delete_live_kit_room = false;
 337
 338                        if let Some(mut refreshed_room) = app_state
 339                            .db
 340                            .clear_stale_room_participants(room_id, server_id)
 341                            .await
 342                            .trace_err()
 343                        {
 344                            tracing::info!(
 345                                room_id = room_id.0,
 346                                new_participant_count = refreshed_room.room.participants.len(),
 347                                "refreshed room"
 348                            );
 349                            room_updated(&refreshed_room.room, &peer);
 350                            if let Some(channel_id) = refreshed_room.channel_id {
 351                                channel_updated(
 352                                    channel_id,
 353                                    &refreshed_room.room,
 354                                    &refreshed_room.channel_members,
 355                                    &peer,
 356                                    &pool.lock(),
 357                                );
 358                            }
 359                            contacts_to_update
 360                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 361                            contacts_to_update
 362                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 363                            canceled_calls_to_user_ids =
 364                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 365                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 366                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 367                        }
 368
 369                        {
 370                            let pool = pool.lock();
 371                            for canceled_user_id in canceled_calls_to_user_ids {
 372                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 373                                    peer.send(
 374                                        connection_id,
 375                                        proto::CallCanceled {
 376                                            room_id: room_id.to_proto(),
 377                                        },
 378                                    )
 379                                    .trace_err();
 380                                }
 381                            }
 382                        }
 383
 384                        for user_id in contacts_to_update {
 385                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 386                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 387                            if let Some((busy, contacts)) = busy.zip(contacts) {
 388                                let pool = pool.lock();
 389                                let updated_contact = contact_for_user(user_id, busy, &pool);
 390                                for contact in contacts {
 391                                    if let db::Contact::Accepted {
 392                                        user_id: contact_user_id,
 393                                        ..
 394                                    } = contact
 395                                    {
 396                                        for contact_conn_id in
 397                                            pool.user_connection_ids(contact_user_id)
 398                                        {
 399                                            peer.send(
 400                                                contact_conn_id,
 401                                                proto::UpdateContacts {
 402                                                    contacts: vec![updated_contact.clone()],
 403                                                    remove_contacts: Default::default(),
 404                                                    incoming_requests: Default::default(),
 405                                                    remove_incoming_requests: Default::default(),
 406                                                    outgoing_requests: Default::default(),
 407                                                    remove_outgoing_requests: Default::default(),
 408                                                },
 409                                            )
 410                                            .trace_err();
 411                                        }
 412                                    }
 413                                }
 414                            }
 415                        }
 416
 417                        if let Some(live_kit) = live_kit_client.as_ref() {
 418                            if delete_live_kit_room {
 419                                live_kit.delete_room(live_kit_room).await.trace_err();
 420                            }
 421                        }
 422                    }
 423                }
 424
 425                app_state
 426                    .db
 427                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 428                    .await
 429                    .trace_err();
 430            }
 431            .instrument(span),
 432        );
 433        Ok(())
 434    }
 435
 436    pub fn teardown(&self) {
 437        self.peer.teardown();
 438        self.connection_pool.lock().reset();
 439        let _ = self.teardown.send(true);
 440    }
 441
 442    #[cfg(test)]
 443    pub fn reset(&self, id: ServerId) {
 444        self.teardown();
 445        *self.id.lock() = id;
 446        self.peer.reset(id.0 as u32);
 447        let _ = self.teardown.send(false);
 448    }
 449
 450    #[cfg(test)]
 451    pub fn id(&self) -> ServerId {
 452        *self.id.lock()
 453    }
 454
 455    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 456    where
 457        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 458        Fut: 'static + Send + Future<Output = Result<()>>,
 459        M: EnvelopedMessage,
 460    {
 461        let prev_handler = self.handlers.insert(
 462            TypeId::of::<M>(),
 463            Box::new(move |envelope, session| {
 464                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 465                let span = info_span!(
 466                    "handle message",
 467                    payload_type = envelope.payload_type_name()
 468                );
 469                span.in_scope(|| {
 470                    tracing::info!(
 471                        payload_type = envelope.payload_type_name(),
 472                        "message received"
 473                    );
 474                });
 475                let start_time = Instant::now();
 476                let future = (handler)(*envelope, session);
 477                async move {
 478                    let result = future.await;
 479                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 480                    match result {
 481                        Err(error) => {
 482                            tracing::error!(%error, ?duration_ms, "error handling message")
 483                        }
 484                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 485                    }
 486                }
 487                .instrument(span)
 488                .boxed()
 489            }),
 490        );
 491        if prev_handler.is_some() {
 492            panic!("registered a handler for the same message twice");
 493        }
 494        self
 495    }
 496
 497    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 498    where
 499        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 500        Fut: 'static + Send + Future<Output = Result<()>>,
 501        M: EnvelopedMessage,
 502    {
 503        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 504        self
 505    }
 506
 507    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 510        Fut: Send + Future<Output = Result<()>>,
 511        M: RequestMessage,
 512    {
 513        let handler = Arc::new(handler);
 514        self.add_handler(move |envelope, session| {
 515            let receipt = envelope.receipt();
 516            let handler = handler.clone();
 517            async move {
 518                let peer = session.peer.clone();
 519                let responded = Arc::new(AtomicBool::default());
 520                let response = Response {
 521                    peer: peer.clone(),
 522                    responded: responded.clone(),
 523                    receipt,
 524                };
 525                match (handler)(envelope.payload, response, session).await {
 526                    Ok(()) => {
 527                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 528                            Ok(())
 529                        } else {
 530                            Err(anyhow!("handler did not send a response"))?
 531                        }
 532                    }
 533                    Err(error) => {
 534                        let proto_err = match &error {
 535                            Error::Internal(err) => err.to_proto(),
 536                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 537                        };
 538                        peer.respond_with_error(receipt, proto_err)?;
 539                        Err(error)
 540                    }
 541                }
 542            }
 543        })
 544    }
 545
 546    #[allow(clippy::too_many_arguments)]
 547    pub fn handle_connection(
 548        self: &Arc<Self>,
 549        connection: Connection,
 550        address: String,
 551        user: User,
 552        zed_version: ZedVersion,
 553        impersonator: Option<User>,
 554        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 555        executor: Executor,
 556    ) -> impl Future<Output = Result<()>> {
 557        let this = self.clone();
 558        let user_id = user.id;
 559        let login = user.github_login;
 560        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 561        if let Some(impersonator) = impersonator {
 562            span.record("impersonator", &impersonator.github_login);
 563        }
 564        let mut teardown = self.teardown.subscribe();
 565        async move {
 566            if *teardown.borrow() {
 567                return Err(anyhow!("server is tearing down"))?;
 568            }
 569            let (connection_id, handle_io, mut incoming_rx) = this
 570                .peer
 571                .add_connection(connection, {
 572                    let executor = executor.clone();
 573                    move |duration| executor.sleep(duration)
 574                });
 575
 576            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 577            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 578            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 579
 580            if let Some(send_connection_id) = send_connection_id.take() {
 581                let _ = send_connection_id.send(connection_id);
 582            }
 583
 584            if !user.connected_once {
 585                this.peer.send(connection_id, proto::ShowContacts {})?;
 586                this.app_state.db.set_user_connected_once(user_id, true).await?;
 587            }
 588
 589            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 590                this.app_state.db.get_contacts(user_id),
 591                this.app_state.db.get_channels_for_user(user_id),
 592                this.app_state.db.get_channel_invites_for_user(user_id),
 593            ).await?;
 594
 595            {
 596                let mut pool = this.connection_pool.lock();
 597                pool.add_connection(connection_id, user_id, user.admin, zed_version);
 598                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 599                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 600                this.peer.send(connection_id, build_channels_update(
 601                    channels_for_user,
 602                    channel_invites
 603                ))?;
 604            }
 605
 606            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 607                this.peer.send(connection_id, incoming_call)?;
 608            }
 609
 610            let session = Session {
 611                user_id,
 612                connection_id,
 613                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 614                peer: this.peer.clone(),
 615                connection_pool: this.connection_pool.clone(),
 616                live_kit_client: this.app_state.live_kit_client.clone(),
 617                _executor: executor.clone()
 618            };
 619            update_user_contacts(user_id, &session).await?;
 620
 621            let handle_io = handle_io.fuse();
 622            futures::pin_mut!(handle_io);
 623
 624            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 625            // This prevents deadlocks when e.g., client A performs a request to client B and
 626            // client B performs a request to client A. If both clients stop processing further
 627            // messages until their respective request completes, they won't have a chance to
 628            // respond to the other client's request and cause a deadlock.
 629            //
 630            // This arrangement ensures we will attempt to process earlier messages first, but fall
 631            // back to processing messages arrived later in the spirit of making progress.
 632            let mut foreground_message_handlers = FuturesUnordered::new();
 633            let concurrent_handlers = Arc::new(Semaphore::new(256));
 634            loop {
 635                let next_message = async {
 636                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 637                    let message = incoming_rx.next().await;
 638                    (permit, message)
 639                }.fuse();
 640                futures::pin_mut!(next_message);
 641                futures::select_biased! {
 642                    _ = teardown.changed().fuse() => return Ok(()),
 643                    result = handle_io => {
 644                        if let Err(error) = result {
 645                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 646                        }
 647                        break;
 648                    }
 649                    _ = foreground_message_handlers.next() => {}
 650                    next_message = next_message => {
 651                        let (permit, message) = next_message;
 652                        if let Some(message) = message {
 653                            let type_name = message.payload_type_name();
 654                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 655                            let span_enter = span.enter();
 656                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 657                                let is_background = message.is_background();
 658                                let handle_message = (handler)(message, session.clone());
 659                                drop(span_enter);
 660
 661                                let handle_message = async move {
 662                                    handle_message.await;
 663                                    drop(permit);
 664                                }.instrument(span);
 665                                if is_background {
 666                                    executor.spawn_detached(handle_message);
 667                                } else {
 668                                    foreground_message_handlers.push(handle_message);
 669                                }
 670                            } else {
 671                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 672                            }
 673                        } else {
 674                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 675                            break;
 676                        }
 677                    }
 678                }
 679            }
 680
 681            drop(foreground_message_handlers);
 682            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 683            if let Err(error) = connection_lost(session, teardown, executor).await {
 684                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 685            }
 686
 687            Ok(())
 688        }.instrument(span)
 689    }
 690
 691    pub async fn invite_code_redeemed(
 692        self: &Arc<Self>,
 693        inviter_id: UserId,
 694        invitee_id: UserId,
 695    ) -> Result<()> {
 696        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 697            if let Some(code) = &user.invite_code {
 698                let pool = self.connection_pool.lock();
 699                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 700                for connection_id in pool.user_connection_ids(inviter_id) {
 701                    self.peer.send(
 702                        connection_id,
 703                        proto::UpdateContacts {
 704                            contacts: vec![invitee_contact.clone()],
 705                            ..Default::default()
 706                        },
 707                    )?;
 708                    self.peer.send(
 709                        connection_id,
 710                        proto::UpdateInviteInfo {
 711                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 712                            count: user.invite_count as u32,
 713                        },
 714                    )?;
 715                }
 716            }
 717        }
 718        Ok(())
 719    }
 720
 721    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 722        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 723            if let Some(invite_code) = &user.invite_code {
 724                let pool = self.connection_pool.lock();
 725                for connection_id in pool.user_connection_ids(user_id) {
 726                    self.peer.send(
 727                        connection_id,
 728                        proto::UpdateInviteInfo {
 729                            url: format!(
 730                                "{}{}",
 731                                self.app_state.config.invite_link_prefix, invite_code
 732                            ),
 733                            count: user.invite_count as u32,
 734                        },
 735                    )?;
 736                }
 737            }
 738        }
 739        Ok(())
 740    }
 741
 742    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 743        ServerSnapshot {
 744            connection_pool: ConnectionPoolGuard {
 745                guard: self.connection_pool.lock(),
 746                _not_send: PhantomData,
 747            },
 748            peer: &self.peer,
 749        }
 750    }
 751}
 752
 753impl<'a> Deref for ConnectionPoolGuard<'a> {
 754    type Target = ConnectionPool;
 755
 756    fn deref(&self) -> &Self::Target {
 757        &self.guard
 758    }
 759}
 760
 761impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 762    fn deref_mut(&mut self) -> &mut Self::Target {
 763        &mut self.guard
 764    }
 765}
 766
 767impl<'a> Drop for ConnectionPoolGuard<'a> {
 768    fn drop(&mut self) {
 769        #[cfg(test)]
 770        self.check_invariants();
 771    }
 772}
 773
 774fn broadcast<F>(
 775    sender_id: Option<ConnectionId>,
 776    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 777    mut f: F,
 778) where
 779    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 780{
 781    for receiver_id in receiver_ids {
 782        if Some(receiver_id) != sender_id {
 783            if let Err(error) = f(receiver_id) {
 784                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 785            }
 786        }
 787    }
 788}
 789
 790pub struct ProtocolVersion(u32);
 791
 792impl Header for ProtocolVersion {
 793    fn name() -> &'static HeaderName {
 794        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 795        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 796    }
 797
 798    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 799    where
 800        Self: Sized,
 801        I: Iterator<Item = &'i axum::http::HeaderValue>,
 802    {
 803        let version = values
 804            .next()
 805            .ok_or_else(axum::headers::Error::invalid)?
 806            .to_str()
 807            .map_err(|_| axum::headers::Error::invalid())?
 808            .parse()
 809            .map_err(|_| axum::headers::Error::invalid())?;
 810        Ok(Self(version))
 811    }
 812
 813    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 814        values.extend([self.0.to_string().parse().unwrap()]);
 815    }
 816}
 817
 818pub struct AppVersionHeader(SemanticVersion);
 819impl Header for AppVersionHeader {
 820    fn name() -> &'static HeaderName {
 821        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 822        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 823    }
 824
 825    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 826    where
 827        Self: Sized,
 828        I: Iterator<Item = &'i axum::http::HeaderValue>,
 829    {
 830        let version = values
 831            .next()
 832            .ok_or_else(axum::headers::Error::invalid)?
 833            .to_str()
 834            .map_err(|_| axum::headers::Error::invalid())?
 835            .parse()
 836            .map_err(|_| axum::headers::Error::invalid())?;
 837        Ok(Self(version))
 838    }
 839
 840    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 841        values.extend([self.0.to_string().parse().unwrap()]);
 842    }
 843}
 844
 845pub fn routes(server: Arc<Server>) -> Router<Body> {
 846    Router::new()
 847        .route("/rpc", get(handle_websocket_request))
 848        .layer(
 849            ServiceBuilder::new()
 850                .layer(Extension(server.app_state.clone()))
 851                .layer(middleware::from_fn(auth::validate_header)),
 852        )
 853        .route("/metrics", get(handle_metrics))
 854        .layer(Extension(server))
 855}
 856
 857pub async fn handle_websocket_request(
 858    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 859    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 860    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 861    Extension(server): Extension<Arc<Server>>,
 862    Extension(user): Extension<User>,
 863    Extension(impersonator): Extension<Impersonator>,
 864    ws: WebSocketUpgrade,
 865) -> axum::response::Response {
 866    if protocol_version != rpc::PROTOCOL_VERSION {
 867        return (
 868            StatusCode::UPGRADE_REQUIRED,
 869            "client must be upgraded".to_string(),
 870        )
 871            .into_response();
 872    }
 873
 874    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 875        return (
 876            StatusCode::UPGRADE_REQUIRED,
 877            "no version header found".to_string(),
 878        )
 879            .into_response();
 880    };
 881
 882    if !version.is_supported() {
 883        return (
 884            StatusCode::UPGRADE_REQUIRED,
 885            "client must be upgraded".to_string(),
 886        )
 887            .into_response();
 888    }
 889
 890    let socket_address = socket_address.to_string();
 891    ws.on_upgrade(move |socket| {
 892        use util::ResultExt;
 893        let socket = socket
 894            .map_ok(to_tungstenite_message)
 895            .err_into()
 896            .with(|message| async move { Ok(to_axum_message(message)) });
 897        let connection = Connection::new(Box::pin(socket));
 898        async move {
 899            server
 900                .handle_connection(
 901                    connection,
 902                    socket_address,
 903                    user,
 904                    version,
 905                    impersonator.0,
 906                    None,
 907                    Executor::Production,
 908                )
 909                .await
 910                .log_err();
 911        }
 912    })
 913}
 914
 915pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 916    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 917    let connections_metric = CONNECTIONS_METRIC
 918        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
 919
 920    let connections = server
 921        .connection_pool
 922        .lock()
 923        .connections()
 924        .filter(|connection| !connection.admin)
 925        .count();
 926    connections_metric.set(connections as _);
 927
 928    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 929    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
 930        register_int_gauge!(
 931            "shared_projects",
 932            "number of open projects with one or more guests"
 933        )
 934        .unwrap()
 935    });
 936
 937    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 938    shared_projects_metric.set(shared_projects as _);
 939
 940    let encoder = prometheus::TextEncoder::new();
 941    let metric_families = prometheus::gather();
 942    let encoded_metrics = encoder
 943        .encode_to_string(&metric_families)
 944        .map_err(|err| anyhow!("{}", err))?;
 945    Ok(encoded_metrics)
 946}
 947
 948#[instrument(err, skip(executor))]
 949async fn connection_lost(
 950    session: Session,
 951    mut teardown: watch::Receiver<bool>,
 952    executor: Executor,
 953) -> Result<()> {
 954    session.peer.disconnect(session.connection_id);
 955    session
 956        .connection_pool()
 957        .await
 958        .remove_connection(session.connection_id)?;
 959
 960    session
 961        .db()
 962        .await
 963        .connection_lost(session.connection_id)
 964        .await
 965        .trace_err();
 966
 967    futures::select_biased! {
 968        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 969            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 970            leave_room_for_session(&session).await.trace_err();
 971            leave_channel_buffers_for_session(&session)
 972                .await
 973                .trace_err();
 974
 975            if !session
 976                .connection_pool()
 977                .await
 978                .is_user_online(session.user_id)
 979            {
 980                let db = session.db().await;
 981                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 982                    room_updated(&room, &session.peer);
 983                }
 984            }
 985
 986            update_user_contacts(session.user_id, &session).await?;
 987        }
 988        _ = teardown.changed().fuse() => {}
 989    }
 990
 991    Ok(())
 992}
 993
 994/// Acknowledges a ping from a client, used to keep the connection alive.
 995async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 996    response.send(proto::Ack {})?;
 997    Ok(())
 998}
 999
1000/// Creates a new room for calling (outside of channels)
1001async fn create_room(
1002    _request: proto::CreateRoom,
1003    response: Response<proto::CreateRoom>,
1004    session: Session,
1005) -> Result<()> {
1006    let live_kit_room = nanoid::nanoid!(30);
1007
1008    let live_kit_connection_info = {
1009        let live_kit_room = live_kit_room.clone();
1010        let live_kit = session.live_kit_client.as_ref();
1011
1012        util::async_maybe!({
1013            let live_kit = live_kit?;
1014
1015            let token = live_kit
1016                .room_token(&live_kit_room, &session.user_id.to_string())
1017                .trace_err()?;
1018
1019            Some(proto::LiveKitConnectionInfo {
1020                server_url: live_kit.url().into(),
1021                token,
1022                can_publish: true,
1023            })
1024        })
1025    }
1026    .await;
1027
1028    let room = session
1029        .db()
1030        .await
1031        .create_room(session.user_id, session.connection_id, &live_kit_room)
1032        .await?;
1033
1034    response.send(proto::CreateRoomResponse {
1035        room: Some(room.clone()),
1036        live_kit_connection_info,
1037    })?;
1038
1039    update_user_contacts(session.user_id, &session).await?;
1040    Ok(())
1041}
1042
1043/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1044async fn join_room(
1045    request: proto::JoinRoom,
1046    response: Response<proto::JoinRoom>,
1047    session: Session,
1048) -> Result<()> {
1049    let room_id = RoomId::from_proto(request.id);
1050
1051    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1052
1053    if let Some(channel_id) = channel_id {
1054        return join_channel_internal(channel_id, Box::new(response), session).await;
1055    }
1056
1057    let joined_room = {
1058        let room = session
1059            .db()
1060            .await
1061            .join_room(room_id, session.user_id, session.connection_id)
1062            .await?;
1063        room_updated(&room.room, &session.peer);
1064        room.into_inner()
1065    };
1066
1067    for connection_id in session
1068        .connection_pool()
1069        .await
1070        .user_connection_ids(session.user_id)
1071    {
1072        session
1073            .peer
1074            .send(
1075                connection_id,
1076                proto::CallCanceled {
1077                    room_id: room_id.to_proto(),
1078                },
1079            )
1080            .trace_err();
1081    }
1082
1083    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1084        if let Some(token) = live_kit
1085            .room_token(
1086                &joined_room.room.live_kit_room,
1087                &session.user_id.to_string(),
1088            )
1089            .trace_err()
1090        {
1091            Some(proto::LiveKitConnectionInfo {
1092                server_url: live_kit.url().into(),
1093                token,
1094                can_publish: true,
1095            })
1096        } else {
1097            None
1098        }
1099    } else {
1100        None
1101    };
1102
1103    response.send(proto::JoinRoomResponse {
1104        room: Some(joined_room.room),
1105        channel_id: None,
1106        live_kit_connection_info,
1107    })?;
1108
1109    update_user_contacts(session.user_id, &session).await?;
1110    Ok(())
1111}
1112
1113/// Rejoin room is used to reconnect to a room after connection errors.
1114async fn rejoin_room(
1115    request: proto::RejoinRoom,
1116    response: Response<proto::RejoinRoom>,
1117    session: Session,
1118) -> Result<()> {
1119    let room;
1120    let channel_id;
1121    let channel_members;
1122    {
1123        let mut rejoined_room = session
1124            .db()
1125            .await
1126            .rejoin_room(request, session.user_id, session.connection_id)
1127            .await?;
1128
1129        response.send(proto::RejoinRoomResponse {
1130            room: Some(rejoined_room.room.clone()),
1131            reshared_projects: rejoined_room
1132                .reshared_projects
1133                .iter()
1134                .map(|project| proto::ResharedProject {
1135                    id: project.id.to_proto(),
1136                    collaborators: project
1137                        .collaborators
1138                        .iter()
1139                        .map(|collaborator| collaborator.to_proto())
1140                        .collect(),
1141                })
1142                .collect(),
1143            rejoined_projects: rejoined_room
1144                .rejoined_projects
1145                .iter()
1146                .map(|rejoined_project| proto::RejoinedProject {
1147                    id: rejoined_project.id.to_proto(),
1148                    worktrees: rejoined_project
1149                        .worktrees
1150                        .iter()
1151                        .map(|worktree| proto::WorktreeMetadata {
1152                            id: worktree.id,
1153                            root_name: worktree.root_name.clone(),
1154                            visible: worktree.visible,
1155                            abs_path: worktree.abs_path.clone(),
1156                        })
1157                        .collect(),
1158                    collaborators: rejoined_project
1159                        .collaborators
1160                        .iter()
1161                        .map(|collaborator| collaborator.to_proto())
1162                        .collect(),
1163                    language_servers: rejoined_project.language_servers.clone(),
1164                })
1165                .collect(),
1166        })?;
1167        room_updated(&rejoined_room.room, &session.peer);
1168
1169        for project in &rejoined_room.reshared_projects {
1170            for collaborator in &project.collaborators {
1171                session
1172                    .peer
1173                    .send(
1174                        collaborator.connection_id,
1175                        proto::UpdateProjectCollaborator {
1176                            project_id: project.id.to_proto(),
1177                            old_peer_id: Some(project.old_connection_id.into()),
1178                            new_peer_id: Some(session.connection_id.into()),
1179                        },
1180                    )
1181                    .trace_err();
1182            }
1183
1184            broadcast(
1185                Some(session.connection_id),
1186                project
1187                    .collaborators
1188                    .iter()
1189                    .map(|collaborator| collaborator.connection_id),
1190                |connection_id| {
1191                    session.peer.forward_send(
1192                        session.connection_id,
1193                        connection_id,
1194                        proto::UpdateProject {
1195                            project_id: project.id.to_proto(),
1196                            worktrees: project.worktrees.clone(),
1197                        },
1198                    )
1199                },
1200            );
1201        }
1202
1203        for project in &rejoined_room.rejoined_projects {
1204            for collaborator in &project.collaborators {
1205                session
1206                    .peer
1207                    .send(
1208                        collaborator.connection_id,
1209                        proto::UpdateProjectCollaborator {
1210                            project_id: project.id.to_proto(),
1211                            old_peer_id: Some(project.old_connection_id.into()),
1212                            new_peer_id: Some(session.connection_id.into()),
1213                        },
1214                    )
1215                    .trace_err();
1216            }
1217        }
1218
1219        for project in &mut rejoined_room.rejoined_projects {
1220            for worktree in mem::take(&mut project.worktrees) {
1221                #[cfg(any(test, feature = "test-support"))]
1222                const MAX_CHUNK_SIZE: usize = 2;
1223                #[cfg(not(any(test, feature = "test-support")))]
1224                const MAX_CHUNK_SIZE: usize = 256;
1225
1226                // Stream this worktree's entries.
1227                let message = proto::UpdateWorktree {
1228                    project_id: project.id.to_proto(),
1229                    worktree_id: worktree.id,
1230                    abs_path: worktree.abs_path.clone(),
1231                    root_name: worktree.root_name,
1232                    updated_entries: worktree.updated_entries,
1233                    removed_entries: worktree.removed_entries,
1234                    scan_id: worktree.scan_id,
1235                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1236                    updated_repositories: worktree.updated_repositories,
1237                    removed_repositories: worktree.removed_repositories,
1238                };
1239                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1240                    session.peer.send(session.connection_id, update.clone())?;
1241                }
1242
1243                // Stream this worktree's diagnostics.
1244                for summary in worktree.diagnostic_summaries {
1245                    session.peer.send(
1246                        session.connection_id,
1247                        proto::UpdateDiagnosticSummary {
1248                            project_id: project.id.to_proto(),
1249                            worktree_id: worktree.id,
1250                            summary: Some(summary),
1251                        },
1252                    )?;
1253                }
1254
1255                for settings_file in worktree.settings_files {
1256                    session.peer.send(
1257                        session.connection_id,
1258                        proto::UpdateWorktreeSettings {
1259                            project_id: project.id.to_proto(),
1260                            worktree_id: worktree.id,
1261                            path: settings_file.path,
1262                            content: Some(settings_file.content),
1263                        },
1264                    )?;
1265                }
1266            }
1267
1268            for language_server in &project.language_servers {
1269                session.peer.send(
1270                    session.connection_id,
1271                    proto::UpdateLanguageServer {
1272                        project_id: project.id.to_proto(),
1273                        language_server_id: language_server.id,
1274                        variant: Some(
1275                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1276                                proto::LspDiskBasedDiagnosticsUpdated {},
1277                            ),
1278                        ),
1279                    },
1280                )?;
1281            }
1282        }
1283
1284        let rejoined_room = rejoined_room.into_inner();
1285
1286        room = rejoined_room.room;
1287        channel_id = rejoined_room.channel_id;
1288        channel_members = rejoined_room.channel_members;
1289    }
1290
1291    if let Some(channel_id) = channel_id {
1292        channel_updated(
1293            channel_id,
1294            &room,
1295            &channel_members,
1296            &session.peer,
1297            &*session.connection_pool().await,
1298        );
1299    }
1300
1301    update_user_contacts(session.user_id, &session).await?;
1302    Ok(())
1303}
1304
1305/// leave room disconnects from the room.
1306async fn leave_room(
1307    _: proto::LeaveRoom,
1308    response: Response<proto::LeaveRoom>,
1309    session: Session,
1310) -> Result<()> {
1311    leave_room_for_session(&session).await?;
1312    response.send(proto::Ack {})?;
1313    Ok(())
1314}
1315
1316/// Updates the permissions of someone else in the room.
1317async fn set_room_participant_role(
1318    request: proto::SetRoomParticipantRole,
1319    response: Response<proto::SetRoomParticipantRole>,
1320    session: Session,
1321) -> Result<()> {
1322    let user_id = UserId::from_proto(request.user_id);
1323    let role = ChannelRole::from(request.role());
1324
1325    if role == ChannelRole::Talker {
1326        let pool = session.connection_pool().await;
1327
1328        for connection in pool.user_connections(user_id) {
1329            if !connection.zed_version.supports_talker_role() {
1330                Err(anyhow!(
1331                    "This user is on zed {} which does not support unmute",
1332                    connection.zed_version
1333                ))?;
1334            }
1335        }
1336    }
1337
1338    let (live_kit_room, can_publish) = {
1339        let room = session
1340            .db()
1341            .await
1342            .set_room_participant_role(
1343                session.user_id,
1344                RoomId::from_proto(request.room_id),
1345                user_id,
1346                role,
1347            )
1348            .await?;
1349
1350        let live_kit_room = room.live_kit_room.clone();
1351        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1352        room_updated(&room, &session.peer);
1353        (live_kit_room, can_publish)
1354    };
1355
1356    if let Some(live_kit) = session.live_kit_client.as_ref() {
1357        live_kit
1358            .update_participant(
1359                live_kit_room.clone(),
1360                request.user_id.to_string(),
1361                live_kit_server::proto::ParticipantPermission {
1362                    can_subscribe: true,
1363                    can_publish,
1364                    can_publish_data: can_publish,
1365                    hidden: false,
1366                    recorder: false,
1367                },
1368            )
1369            .await
1370            .trace_err();
1371    }
1372
1373    response.send(proto::Ack {})?;
1374    Ok(())
1375}
1376
1377/// Call someone else into the current room
1378async fn call(
1379    request: proto::Call,
1380    response: Response<proto::Call>,
1381    session: Session,
1382) -> Result<()> {
1383    let room_id = RoomId::from_proto(request.room_id);
1384    let calling_user_id = session.user_id;
1385    let calling_connection_id = session.connection_id;
1386    let called_user_id = UserId::from_proto(request.called_user_id);
1387    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1388    if !session
1389        .db()
1390        .await
1391        .has_contact(calling_user_id, called_user_id)
1392        .await?
1393    {
1394        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1395    }
1396
1397    let incoming_call = {
1398        let (room, incoming_call) = &mut *session
1399            .db()
1400            .await
1401            .call(
1402                room_id,
1403                calling_user_id,
1404                calling_connection_id,
1405                called_user_id,
1406                initial_project_id,
1407            )
1408            .await?;
1409        room_updated(&room, &session.peer);
1410        mem::take(incoming_call)
1411    };
1412    update_user_contacts(called_user_id, &session).await?;
1413
1414    let mut calls = session
1415        .connection_pool()
1416        .await
1417        .user_connection_ids(called_user_id)
1418        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1419        .collect::<FuturesUnordered<_>>();
1420
1421    while let Some(call_response) = calls.next().await {
1422        match call_response.as_ref() {
1423            Ok(_) => {
1424                response.send(proto::Ack {})?;
1425                return Ok(());
1426            }
1427            Err(_) => {
1428                call_response.trace_err();
1429            }
1430        }
1431    }
1432
1433    {
1434        let room = session
1435            .db()
1436            .await
1437            .call_failed(room_id, called_user_id)
1438            .await?;
1439        room_updated(&room, &session.peer);
1440    }
1441    update_user_contacts(called_user_id, &session).await?;
1442
1443    Err(anyhow!("failed to ring user"))?
1444}
1445
1446/// Cancel an outgoing call.
1447async fn cancel_call(
1448    request: proto::CancelCall,
1449    response: Response<proto::CancelCall>,
1450    session: Session,
1451) -> Result<()> {
1452    let called_user_id = UserId::from_proto(request.called_user_id);
1453    let room_id = RoomId::from_proto(request.room_id);
1454    {
1455        let room = session
1456            .db()
1457            .await
1458            .cancel_call(room_id, session.connection_id, called_user_id)
1459            .await?;
1460        room_updated(&room, &session.peer);
1461    }
1462
1463    for connection_id in session
1464        .connection_pool()
1465        .await
1466        .user_connection_ids(called_user_id)
1467    {
1468        session
1469            .peer
1470            .send(
1471                connection_id,
1472                proto::CallCanceled {
1473                    room_id: room_id.to_proto(),
1474                },
1475            )
1476            .trace_err();
1477    }
1478    response.send(proto::Ack {})?;
1479
1480    update_user_contacts(called_user_id, &session).await?;
1481    Ok(())
1482}
1483
1484/// Decline an incoming call.
1485async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1486    let room_id = RoomId::from_proto(message.room_id);
1487    {
1488        let room = session
1489            .db()
1490            .await
1491            .decline_call(Some(room_id), session.user_id)
1492            .await?
1493            .ok_or_else(|| anyhow!("failed to decline call"))?;
1494        room_updated(&room, &session.peer);
1495    }
1496
1497    for connection_id in session
1498        .connection_pool()
1499        .await
1500        .user_connection_ids(session.user_id)
1501    {
1502        session
1503            .peer
1504            .send(
1505                connection_id,
1506                proto::CallCanceled {
1507                    room_id: room_id.to_proto(),
1508                },
1509            )
1510            .trace_err();
1511    }
1512    update_user_contacts(session.user_id, &session).await?;
1513    Ok(())
1514}
1515
1516/// Updates other participants in the room with your current location.
1517async fn update_participant_location(
1518    request: proto::UpdateParticipantLocation,
1519    response: Response<proto::UpdateParticipantLocation>,
1520    session: Session,
1521) -> Result<()> {
1522    let room_id = RoomId::from_proto(request.room_id);
1523    let location = request
1524        .location
1525        .ok_or_else(|| anyhow!("invalid location"))?;
1526
1527    let db = session.db().await;
1528    let room = db
1529        .update_room_participant_location(room_id, session.connection_id, location)
1530        .await?;
1531
1532    room_updated(&room, &session.peer);
1533    response.send(proto::Ack {})?;
1534    Ok(())
1535}
1536
1537/// Share a project into the room.
1538async fn share_project(
1539    request: proto::ShareProject,
1540    response: Response<proto::ShareProject>,
1541    session: Session,
1542) -> Result<()> {
1543    let (project_id, room) = &*session
1544        .db()
1545        .await
1546        .share_project(
1547            RoomId::from_proto(request.room_id),
1548            session.connection_id,
1549            &request.worktrees,
1550        )
1551        .await?;
1552    response.send(proto::ShareProjectResponse {
1553        project_id: project_id.to_proto(),
1554    })?;
1555    room_updated(&room, &session.peer);
1556
1557    Ok(())
1558}
1559
1560/// Unshare a project from the room.
1561async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1562    let project_id = ProjectId::from_proto(message.project_id);
1563
1564    let (room, guest_connection_ids) = &*session
1565        .db()
1566        .await
1567        .unshare_project(project_id, session.connection_id)
1568        .await?;
1569
1570    broadcast(
1571        Some(session.connection_id),
1572        guest_connection_ids.iter().copied(),
1573        |conn_id| session.peer.send(conn_id, message.clone()),
1574    );
1575    room_updated(&room, &session.peer);
1576
1577    Ok(())
1578}
1579
1580/// Join someone elses shared project.
1581async fn join_project(
1582    request: proto::JoinProject,
1583    response: Response<proto::JoinProject>,
1584    session: Session,
1585) -> Result<()> {
1586    let project_id = ProjectId::from_proto(request.project_id);
1587    let guest_user_id = session.user_id;
1588
1589    tracing::info!(%project_id, "join project");
1590
1591    let (project, replica_id) = &mut *session
1592        .db()
1593        .await
1594        .join_project(project_id, session.connection_id)
1595        .await?;
1596
1597    let collaborators = project
1598        .collaborators
1599        .iter()
1600        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1601        .map(|collaborator| collaborator.to_proto())
1602        .collect::<Vec<_>>();
1603
1604    let worktrees = project
1605        .worktrees
1606        .iter()
1607        .map(|(id, worktree)| proto::WorktreeMetadata {
1608            id: *id,
1609            root_name: worktree.root_name.clone(),
1610            visible: worktree.visible,
1611            abs_path: worktree.abs_path.clone(),
1612        })
1613        .collect::<Vec<_>>();
1614
1615    for collaborator in &collaborators {
1616        session
1617            .peer
1618            .send(
1619                collaborator.peer_id.unwrap().into(),
1620                proto::AddProjectCollaborator {
1621                    project_id: project_id.to_proto(),
1622                    collaborator: Some(proto::Collaborator {
1623                        peer_id: Some(session.connection_id.into()),
1624                        replica_id: replica_id.0 as u32,
1625                        user_id: guest_user_id.to_proto(),
1626                    }),
1627                },
1628            )
1629            .trace_err();
1630    }
1631
1632    // First, we send the metadata associated with each worktree.
1633    response.send(proto::JoinProjectResponse {
1634        worktrees: worktrees.clone(),
1635        replica_id: replica_id.0 as u32,
1636        collaborators: collaborators.clone(),
1637        language_servers: project.language_servers.clone(),
1638    })?;
1639
1640    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1641        #[cfg(any(test, feature = "test-support"))]
1642        const MAX_CHUNK_SIZE: usize = 2;
1643        #[cfg(not(any(test, feature = "test-support")))]
1644        const MAX_CHUNK_SIZE: usize = 256;
1645
1646        // Stream this worktree's entries.
1647        let message = proto::UpdateWorktree {
1648            project_id: project_id.to_proto(),
1649            worktree_id,
1650            abs_path: worktree.abs_path.clone(),
1651            root_name: worktree.root_name,
1652            updated_entries: worktree.entries,
1653            removed_entries: Default::default(),
1654            scan_id: worktree.scan_id,
1655            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1656            updated_repositories: worktree.repository_entries.into_values().collect(),
1657            removed_repositories: Default::default(),
1658        };
1659        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1660            session.peer.send(session.connection_id, update.clone())?;
1661        }
1662
1663        // Stream this worktree's diagnostics.
1664        for summary in worktree.diagnostic_summaries {
1665            session.peer.send(
1666                session.connection_id,
1667                proto::UpdateDiagnosticSummary {
1668                    project_id: project_id.to_proto(),
1669                    worktree_id: worktree.id,
1670                    summary: Some(summary),
1671                },
1672            )?;
1673        }
1674
1675        for settings_file in worktree.settings_files {
1676            session.peer.send(
1677                session.connection_id,
1678                proto::UpdateWorktreeSettings {
1679                    project_id: project_id.to_proto(),
1680                    worktree_id: worktree.id,
1681                    path: settings_file.path,
1682                    content: Some(settings_file.content),
1683                },
1684            )?;
1685        }
1686    }
1687
1688    for language_server in &project.language_servers {
1689        session.peer.send(
1690            session.connection_id,
1691            proto::UpdateLanguageServer {
1692                project_id: project_id.to_proto(),
1693                language_server_id: language_server.id,
1694                variant: Some(
1695                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1696                        proto::LspDiskBasedDiagnosticsUpdated {},
1697                    ),
1698                ),
1699            },
1700        )?;
1701    }
1702
1703    Ok(())
1704}
1705
1706/// Leave someone elses shared project.
1707async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1708    let sender_id = session.connection_id;
1709    let project_id = ProjectId::from_proto(request.project_id);
1710
1711    let (room, project) = &*session
1712        .db()
1713        .await
1714        .leave_project(project_id, sender_id)
1715        .await?;
1716    tracing::info!(
1717        %project_id,
1718        host_user_id = %project.host_user_id,
1719        host_connection_id = ?project.host_connection_id,
1720        "leave project"
1721    );
1722
1723    project_left(&project, &session);
1724    room_updated(&room, &session.peer);
1725
1726    Ok(())
1727}
1728
1729/// Updates other participants with changes to the project
1730async fn update_project(
1731    request: proto::UpdateProject,
1732    response: Response<proto::UpdateProject>,
1733    session: Session,
1734) -> Result<()> {
1735    let project_id = ProjectId::from_proto(request.project_id);
1736    let (room, guest_connection_ids) = &*session
1737        .db()
1738        .await
1739        .update_project(project_id, session.connection_id, &request.worktrees)
1740        .await?;
1741    broadcast(
1742        Some(session.connection_id),
1743        guest_connection_ids.iter().copied(),
1744        |connection_id| {
1745            session
1746                .peer
1747                .forward_send(session.connection_id, connection_id, request.clone())
1748        },
1749    );
1750    room_updated(&room, &session.peer);
1751    response.send(proto::Ack {})?;
1752
1753    Ok(())
1754}
1755
1756/// Updates other participants with changes to the worktree
1757async fn update_worktree(
1758    request: proto::UpdateWorktree,
1759    response: Response<proto::UpdateWorktree>,
1760    session: Session,
1761) -> Result<()> {
1762    let guest_connection_ids = session
1763        .db()
1764        .await
1765        .update_worktree(&request, session.connection_id)
1766        .await?;
1767
1768    broadcast(
1769        Some(session.connection_id),
1770        guest_connection_ids.iter().copied(),
1771        |connection_id| {
1772            session
1773                .peer
1774                .forward_send(session.connection_id, connection_id, request.clone())
1775        },
1776    );
1777    response.send(proto::Ack {})?;
1778    Ok(())
1779}
1780
1781/// Updates other participants with changes to the diagnostics
1782async fn update_diagnostic_summary(
1783    message: proto::UpdateDiagnosticSummary,
1784    session: Session,
1785) -> Result<()> {
1786    let guest_connection_ids = session
1787        .db()
1788        .await
1789        .update_diagnostic_summary(&message, session.connection_id)
1790        .await?;
1791
1792    broadcast(
1793        Some(session.connection_id),
1794        guest_connection_ids.iter().copied(),
1795        |connection_id| {
1796            session
1797                .peer
1798                .forward_send(session.connection_id, connection_id, message.clone())
1799        },
1800    );
1801
1802    Ok(())
1803}
1804
1805/// Updates other participants with changes to the worktree settings
1806async fn update_worktree_settings(
1807    message: proto::UpdateWorktreeSettings,
1808    session: Session,
1809) -> Result<()> {
1810    let guest_connection_ids = session
1811        .db()
1812        .await
1813        .update_worktree_settings(&message, session.connection_id)
1814        .await?;
1815
1816    broadcast(
1817        Some(session.connection_id),
1818        guest_connection_ids.iter().copied(),
1819        |connection_id| {
1820            session
1821                .peer
1822                .forward_send(session.connection_id, connection_id, message.clone())
1823        },
1824    );
1825
1826    Ok(())
1827}
1828
1829/// Notify other participants that a  language server has started.
1830async fn start_language_server(
1831    request: proto::StartLanguageServer,
1832    session: Session,
1833) -> Result<()> {
1834    let guest_connection_ids = session
1835        .db()
1836        .await
1837        .start_language_server(&request, session.connection_id)
1838        .await?;
1839
1840    broadcast(
1841        Some(session.connection_id),
1842        guest_connection_ids.iter().copied(),
1843        |connection_id| {
1844            session
1845                .peer
1846                .forward_send(session.connection_id, connection_id, request.clone())
1847        },
1848    );
1849    Ok(())
1850}
1851
1852/// Notify other participants that a language server has changed.
1853async fn update_language_server(
1854    request: proto::UpdateLanguageServer,
1855    session: Session,
1856) -> Result<()> {
1857    let project_id = ProjectId::from_proto(request.project_id);
1858    let project_connection_ids = session
1859        .db()
1860        .await
1861        .project_connection_ids(project_id, session.connection_id)
1862        .await?;
1863    broadcast(
1864        Some(session.connection_id),
1865        project_connection_ids.iter().copied(),
1866        |connection_id| {
1867            session
1868                .peer
1869                .forward_send(session.connection_id, connection_id, request.clone())
1870        },
1871    );
1872    Ok(())
1873}
1874
1875/// forward a project request to the host. These requests should be read only
1876/// as guests are allowed to send them.
1877async fn forward_read_only_project_request<T>(
1878    request: T,
1879    response: Response<T>,
1880    session: Session,
1881) -> Result<()>
1882where
1883    T: EntityMessage + RequestMessage,
1884{
1885    let project_id = ProjectId::from_proto(request.remote_entity_id());
1886    let host_connection_id = session
1887        .db()
1888        .await
1889        .host_for_read_only_project_request(project_id, session.connection_id)
1890        .await?;
1891    let payload = session
1892        .peer
1893        .forward_request(session.connection_id, host_connection_id, request)
1894        .await?;
1895    response.send(payload)?;
1896    Ok(())
1897}
1898
1899/// forward a project request to the host. These requests are disallowed
1900/// for guests.
1901async fn forward_mutating_project_request<T>(
1902    request: T,
1903    response: Response<T>,
1904    session: Session,
1905) -> Result<()>
1906where
1907    T: EntityMessage + RequestMessage,
1908{
1909    let project_id = ProjectId::from_proto(request.remote_entity_id());
1910    let host_connection_id = session
1911        .db()
1912        .await
1913        .host_for_mutating_project_request(project_id, session.connection_id)
1914        .await?;
1915    let payload = session
1916        .peer
1917        .forward_request(session.connection_id, host_connection_id, request)
1918        .await?;
1919    response.send(payload)?;
1920    Ok(())
1921}
1922
1923/// Notify other participants that a new buffer has been created
1924async fn create_buffer_for_peer(
1925    request: proto::CreateBufferForPeer,
1926    session: Session,
1927) -> Result<()> {
1928    session
1929        .db()
1930        .await
1931        .check_user_is_project_host(
1932            ProjectId::from_proto(request.project_id),
1933            session.connection_id,
1934        )
1935        .await?;
1936    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1937    session
1938        .peer
1939        .forward_send(session.connection_id, peer_id.into(), request)?;
1940    Ok(())
1941}
1942
1943/// Notify other participants that a buffer has been updated. This is
1944/// allowed for guests as long as the update is limited to selections.
1945async fn update_buffer(
1946    request: proto::UpdateBuffer,
1947    response: Response<proto::UpdateBuffer>,
1948    session: Session,
1949) -> Result<()> {
1950    let project_id = ProjectId::from_proto(request.project_id);
1951    let mut guest_connection_ids;
1952    let mut host_connection_id = None;
1953
1954    let mut requires_write_permission = false;
1955
1956    for op in request.operations.iter() {
1957        match op.variant {
1958            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1959            Some(_) => requires_write_permission = true,
1960        }
1961    }
1962
1963    {
1964        let collaborators = session
1965            .db()
1966            .await
1967            .project_collaborators_for_buffer_update(
1968                project_id,
1969                session.connection_id,
1970                requires_write_permission,
1971            )
1972            .await?;
1973        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1974        for collaborator in collaborators.iter() {
1975            if collaborator.is_host {
1976                host_connection_id = Some(collaborator.connection_id);
1977            } else {
1978                guest_connection_ids.push(collaborator.connection_id);
1979            }
1980        }
1981    }
1982    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1983
1984    broadcast(
1985        Some(session.connection_id),
1986        guest_connection_ids,
1987        |connection_id| {
1988            session
1989                .peer
1990                .forward_send(session.connection_id, connection_id, request.clone())
1991        },
1992    );
1993    if host_connection_id != session.connection_id {
1994        session
1995            .peer
1996            .forward_request(session.connection_id, host_connection_id, request.clone())
1997            .await?;
1998    }
1999
2000    response.send(proto::Ack {})?;
2001    Ok(())
2002}
2003
2004/// Notify other participants that a project has been updated.
2005async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2006    request: T,
2007    session: Session,
2008) -> Result<()> {
2009    let project_id = ProjectId::from_proto(request.remote_entity_id());
2010    let project_connection_ids = session
2011        .db()
2012        .await
2013        .project_connection_ids(project_id, session.connection_id)
2014        .await?;
2015
2016    broadcast(
2017        Some(session.connection_id),
2018        project_connection_ids.iter().copied(),
2019        |connection_id| {
2020            session
2021                .peer
2022                .forward_send(session.connection_id, connection_id, request.clone())
2023        },
2024    );
2025    Ok(())
2026}
2027
2028/// Start following another user in a call.
2029async fn follow(
2030    request: proto::Follow,
2031    response: Response<proto::Follow>,
2032    session: Session,
2033) -> Result<()> {
2034    let room_id = RoomId::from_proto(request.room_id);
2035    let project_id = request.project_id.map(ProjectId::from_proto);
2036    let leader_id = request
2037        .leader_id
2038        .ok_or_else(|| anyhow!("invalid leader id"))?
2039        .into();
2040    let follower_id = session.connection_id;
2041
2042    session
2043        .db()
2044        .await
2045        .check_room_participants(room_id, leader_id, session.connection_id)
2046        .await?;
2047
2048    let response_payload = session
2049        .peer
2050        .forward_request(session.connection_id, leader_id, request)
2051        .await?;
2052    response.send(response_payload)?;
2053
2054    if let Some(project_id) = project_id {
2055        let room = session
2056            .db()
2057            .await
2058            .follow(room_id, project_id, leader_id, follower_id)
2059            .await?;
2060        room_updated(&room, &session.peer);
2061    }
2062
2063    Ok(())
2064}
2065
2066/// Stop following another user in a call.
2067async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2068    let room_id = RoomId::from_proto(request.room_id);
2069    let project_id = request.project_id.map(ProjectId::from_proto);
2070    let leader_id = request
2071        .leader_id
2072        .ok_or_else(|| anyhow!("invalid leader id"))?
2073        .into();
2074    let follower_id = session.connection_id;
2075
2076    session
2077        .db()
2078        .await
2079        .check_room_participants(room_id, leader_id, session.connection_id)
2080        .await?;
2081
2082    session
2083        .peer
2084        .forward_send(session.connection_id, leader_id, request)?;
2085
2086    if let Some(project_id) = project_id {
2087        let room = session
2088            .db()
2089            .await
2090            .unfollow(room_id, project_id, leader_id, follower_id)
2091            .await?;
2092        room_updated(&room, &session.peer);
2093    }
2094
2095    Ok(())
2096}
2097
2098/// Notify everyone following you of your current location.
2099async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2100    let room_id = RoomId::from_proto(request.room_id);
2101    let database = session.db.lock().await;
2102
2103    let connection_ids = if let Some(project_id) = request.project_id {
2104        let project_id = ProjectId::from_proto(project_id);
2105        database
2106            .project_connection_ids(project_id, session.connection_id)
2107            .await?
2108    } else {
2109        database
2110            .room_connection_ids(room_id, session.connection_id)
2111            .await?
2112    };
2113
2114    // For now, don't send view update messages back to that view's current leader.
2115    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2116        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2117        _ => None,
2118    });
2119
2120    for connection_id in connection_ids.iter().cloned() {
2121        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2122            session
2123                .peer
2124                .forward_send(session.connection_id, connection_id, request.clone())?;
2125        }
2126    }
2127    Ok(())
2128}
2129
2130/// Get public data about users.
2131async fn get_users(
2132    request: proto::GetUsers,
2133    response: Response<proto::GetUsers>,
2134    session: Session,
2135) -> Result<()> {
2136    let user_ids = request
2137        .user_ids
2138        .into_iter()
2139        .map(UserId::from_proto)
2140        .collect();
2141    let users = session
2142        .db()
2143        .await
2144        .get_users_by_ids(user_ids)
2145        .await?
2146        .into_iter()
2147        .map(|user| proto::User {
2148            id: user.id.to_proto(),
2149            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2150            github_login: user.github_login,
2151        })
2152        .collect();
2153    response.send(proto::UsersResponse { users })?;
2154    Ok(())
2155}
2156
2157/// Search for users (to invite) buy Github login
2158async fn fuzzy_search_users(
2159    request: proto::FuzzySearchUsers,
2160    response: Response<proto::FuzzySearchUsers>,
2161    session: Session,
2162) -> Result<()> {
2163    let query = request.query;
2164    let users = match query.len() {
2165        0 => vec![],
2166        1 | 2 => session
2167            .db()
2168            .await
2169            .get_user_by_github_login(&query)
2170            .await?
2171            .into_iter()
2172            .collect(),
2173        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2174    };
2175    let users = users
2176        .into_iter()
2177        .filter(|user| user.id != session.user_id)
2178        .map(|user| proto::User {
2179            id: user.id.to_proto(),
2180            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2181            github_login: user.github_login,
2182        })
2183        .collect();
2184    response.send(proto::UsersResponse { users })?;
2185    Ok(())
2186}
2187
2188/// Send a contact request to another user.
2189async fn request_contact(
2190    request: proto::RequestContact,
2191    response: Response<proto::RequestContact>,
2192    session: Session,
2193) -> Result<()> {
2194    let requester_id = session.user_id;
2195    let responder_id = UserId::from_proto(request.responder_id);
2196    if requester_id == responder_id {
2197        return Err(anyhow!("cannot add yourself as a contact"))?;
2198    }
2199
2200    let notifications = session
2201        .db()
2202        .await
2203        .send_contact_request(requester_id, responder_id)
2204        .await?;
2205
2206    // Update outgoing contact requests of requester
2207    let mut update = proto::UpdateContacts::default();
2208    update.outgoing_requests.push(responder_id.to_proto());
2209    for connection_id in session
2210        .connection_pool()
2211        .await
2212        .user_connection_ids(requester_id)
2213    {
2214        session.peer.send(connection_id, update.clone())?;
2215    }
2216
2217    // Update incoming contact requests of responder
2218    let mut update = proto::UpdateContacts::default();
2219    update
2220        .incoming_requests
2221        .push(proto::IncomingContactRequest {
2222            requester_id: requester_id.to_proto(),
2223        });
2224    let connection_pool = session.connection_pool().await;
2225    for connection_id in connection_pool.user_connection_ids(responder_id) {
2226        session.peer.send(connection_id, update.clone())?;
2227    }
2228
2229    send_notifications(&connection_pool, &session.peer, notifications);
2230
2231    response.send(proto::Ack {})?;
2232    Ok(())
2233}
2234
2235/// Accept or decline a contact request
2236async fn respond_to_contact_request(
2237    request: proto::RespondToContactRequest,
2238    response: Response<proto::RespondToContactRequest>,
2239    session: Session,
2240) -> Result<()> {
2241    let responder_id = session.user_id;
2242    let requester_id = UserId::from_proto(request.requester_id);
2243    let db = session.db().await;
2244    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2245        db.dismiss_contact_notification(responder_id, requester_id)
2246            .await?;
2247    } else {
2248        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2249
2250        let notifications = db
2251            .respond_to_contact_request(responder_id, requester_id, accept)
2252            .await?;
2253        let requester_busy = db.is_user_busy(requester_id).await?;
2254        let responder_busy = db.is_user_busy(responder_id).await?;
2255
2256        let pool = session.connection_pool().await;
2257        // Update responder with new contact
2258        let mut update = proto::UpdateContacts::default();
2259        if accept {
2260            update
2261                .contacts
2262                .push(contact_for_user(requester_id, requester_busy, &pool));
2263        }
2264        update
2265            .remove_incoming_requests
2266            .push(requester_id.to_proto());
2267        for connection_id in pool.user_connection_ids(responder_id) {
2268            session.peer.send(connection_id, update.clone())?;
2269        }
2270
2271        // Update requester with new contact
2272        let mut update = proto::UpdateContacts::default();
2273        if accept {
2274            update
2275                .contacts
2276                .push(contact_for_user(responder_id, responder_busy, &pool));
2277        }
2278        update
2279            .remove_outgoing_requests
2280            .push(responder_id.to_proto());
2281
2282        for connection_id in pool.user_connection_ids(requester_id) {
2283            session.peer.send(connection_id, update.clone())?;
2284        }
2285
2286        send_notifications(&pool, &session.peer, notifications);
2287    }
2288
2289    response.send(proto::Ack {})?;
2290    Ok(())
2291}
2292
2293/// Remove a contact.
2294async fn remove_contact(
2295    request: proto::RemoveContact,
2296    response: Response<proto::RemoveContact>,
2297    session: Session,
2298) -> Result<()> {
2299    let requester_id = session.user_id;
2300    let responder_id = UserId::from_proto(request.user_id);
2301    let db = session.db().await;
2302    let (contact_accepted, deleted_notification_id) =
2303        db.remove_contact(requester_id, responder_id).await?;
2304
2305    let pool = session.connection_pool().await;
2306    // Update outgoing contact requests of requester
2307    let mut update = proto::UpdateContacts::default();
2308    if contact_accepted {
2309        update.remove_contacts.push(responder_id.to_proto());
2310    } else {
2311        update
2312            .remove_outgoing_requests
2313            .push(responder_id.to_proto());
2314    }
2315    for connection_id in pool.user_connection_ids(requester_id) {
2316        session.peer.send(connection_id, update.clone())?;
2317    }
2318
2319    // Update incoming contact requests of responder
2320    let mut update = proto::UpdateContacts::default();
2321    if contact_accepted {
2322        update.remove_contacts.push(requester_id.to_proto());
2323    } else {
2324        update
2325            .remove_incoming_requests
2326            .push(requester_id.to_proto());
2327    }
2328    for connection_id in pool.user_connection_ids(responder_id) {
2329        session.peer.send(connection_id, update.clone())?;
2330        if let Some(notification_id) = deleted_notification_id {
2331            session.peer.send(
2332                connection_id,
2333                proto::DeleteNotification {
2334                    notification_id: notification_id.to_proto(),
2335                },
2336            )?;
2337        }
2338    }
2339
2340    response.send(proto::Ack {})?;
2341    Ok(())
2342}
2343
2344/// Creates a new channel.
2345async fn create_channel(
2346    request: proto::CreateChannel,
2347    response: Response<proto::CreateChannel>,
2348    session: Session,
2349) -> Result<()> {
2350    let db = session.db().await;
2351
2352    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2353    let (channel, owner, channel_members) = db
2354        .create_channel(&request.name, parent_id, session.user_id)
2355        .await?;
2356
2357    response.send(proto::CreateChannelResponse {
2358        channel: Some(channel.to_proto()),
2359        parent_id: request.parent_id,
2360    })?;
2361
2362    let connection_pool = session.connection_pool().await;
2363    if let Some(owner) = owner {
2364        let update = proto::UpdateUserChannels {
2365            channel_memberships: vec![proto::ChannelMembership {
2366                channel_id: owner.channel_id.to_proto(),
2367                role: owner.role.into(),
2368            }],
2369            ..Default::default()
2370        };
2371        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2372            session.peer.send(connection_id, update.clone())?;
2373        }
2374    }
2375
2376    for channel_member in channel_members {
2377        if !channel_member.role.can_see_channel(channel.visibility) {
2378            continue;
2379        }
2380
2381        let update = proto::UpdateChannels {
2382            channels: vec![channel.to_proto()],
2383            ..Default::default()
2384        };
2385        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2386            session.peer.send(connection_id, update.clone())?;
2387        }
2388    }
2389
2390    Ok(())
2391}
2392
2393/// Delete a channel
2394async fn delete_channel(
2395    request: proto::DeleteChannel,
2396    response: Response<proto::DeleteChannel>,
2397    session: Session,
2398) -> Result<()> {
2399    let db = session.db().await;
2400
2401    let channel_id = request.channel_id;
2402    let (removed_channels, member_ids) = db
2403        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2404        .await?;
2405    response.send(proto::Ack {})?;
2406
2407    // Notify members of removed channels
2408    let mut update = proto::UpdateChannels::default();
2409    update
2410        .delete_channels
2411        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2412
2413    let connection_pool = session.connection_pool().await;
2414    for member_id in member_ids {
2415        for connection_id in connection_pool.user_connection_ids(member_id) {
2416            session.peer.send(connection_id, update.clone())?;
2417        }
2418    }
2419
2420    Ok(())
2421}
2422
2423/// Invite someone to join a channel.
2424async fn invite_channel_member(
2425    request: proto::InviteChannelMember,
2426    response: Response<proto::InviteChannelMember>,
2427    session: Session,
2428) -> Result<()> {
2429    let db = session.db().await;
2430    let channel_id = ChannelId::from_proto(request.channel_id);
2431    let invitee_id = UserId::from_proto(request.user_id);
2432    let InviteMemberResult {
2433        channel,
2434        notifications,
2435    } = db
2436        .invite_channel_member(
2437            channel_id,
2438            invitee_id,
2439            session.user_id,
2440            request.role().into(),
2441        )
2442        .await?;
2443
2444    let update = proto::UpdateChannels {
2445        channel_invitations: vec![channel.to_proto()],
2446        ..Default::default()
2447    };
2448
2449    let connection_pool = session.connection_pool().await;
2450    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2451        session.peer.send(connection_id, update.clone())?;
2452    }
2453
2454    send_notifications(&connection_pool, &session.peer, notifications);
2455
2456    response.send(proto::Ack {})?;
2457    Ok(())
2458}
2459
2460/// remove someone from a channel
2461async fn remove_channel_member(
2462    request: proto::RemoveChannelMember,
2463    response: Response<proto::RemoveChannelMember>,
2464    session: Session,
2465) -> Result<()> {
2466    let db = session.db().await;
2467    let channel_id = ChannelId::from_proto(request.channel_id);
2468    let member_id = UserId::from_proto(request.user_id);
2469
2470    let RemoveChannelMemberResult {
2471        membership_update,
2472        notification_id,
2473    } = db
2474        .remove_channel_member(channel_id, member_id, session.user_id)
2475        .await?;
2476
2477    let connection_pool = &session.connection_pool().await;
2478    notify_membership_updated(
2479        &connection_pool,
2480        membership_update,
2481        member_id,
2482        &session.peer,
2483    );
2484    for connection_id in connection_pool.user_connection_ids(member_id) {
2485        if let Some(notification_id) = notification_id {
2486            session
2487                .peer
2488                .send(
2489                    connection_id,
2490                    proto::DeleteNotification {
2491                        notification_id: notification_id.to_proto(),
2492                    },
2493                )
2494                .trace_err();
2495        }
2496    }
2497
2498    response.send(proto::Ack {})?;
2499    Ok(())
2500}
2501
2502/// Toggle the channel between public and private.
2503/// Care is taken to maintain the invariant that public channels only descend from public channels,
2504/// (though members-only channels can appear at any point in the hierarchy).
2505async fn set_channel_visibility(
2506    request: proto::SetChannelVisibility,
2507    response: Response<proto::SetChannelVisibility>,
2508    session: Session,
2509) -> Result<()> {
2510    let db = session.db().await;
2511    let channel_id = ChannelId::from_proto(request.channel_id);
2512    let visibility = request.visibility().into();
2513
2514    let (channel, channel_members) = db
2515        .set_channel_visibility(channel_id, visibility, session.user_id)
2516        .await?;
2517
2518    let connection_pool = session.connection_pool().await;
2519    for member in channel_members {
2520        let update = if member.role.can_see_channel(channel.visibility) {
2521            proto::UpdateChannels {
2522                channels: vec![channel.to_proto()],
2523                ..Default::default()
2524            }
2525        } else {
2526            proto::UpdateChannels {
2527                delete_channels: vec![channel.id.to_proto()],
2528                ..Default::default()
2529            }
2530        };
2531
2532        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2533            session.peer.send(connection_id, update.clone())?;
2534        }
2535    }
2536
2537    response.send(proto::Ack {})?;
2538    Ok(())
2539}
2540
2541/// Alter the role for a user in the channel.
2542async fn set_channel_member_role(
2543    request: proto::SetChannelMemberRole,
2544    response: Response<proto::SetChannelMemberRole>,
2545    session: Session,
2546) -> Result<()> {
2547    let db = session.db().await;
2548    let channel_id = ChannelId::from_proto(request.channel_id);
2549    let member_id = UserId::from_proto(request.user_id);
2550    let result = db
2551        .set_channel_member_role(
2552            channel_id,
2553            session.user_id,
2554            member_id,
2555            request.role().into(),
2556        )
2557        .await?;
2558
2559    match result {
2560        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2561            let connection_pool = session.connection_pool().await;
2562            notify_membership_updated(
2563                &connection_pool,
2564                membership_update,
2565                member_id,
2566                &session.peer,
2567            )
2568        }
2569        db::SetMemberRoleResult::InviteUpdated(channel) => {
2570            let update = proto::UpdateChannels {
2571                channel_invitations: vec![channel.to_proto()],
2572                ..Default::default()
2573            };
2574
2575            for connection_id in session
2576                .connection_pool()
2577                .await
2578                .user_connection_ids(member_id)
2579            {
2580                session.peer.send(connection_id, update.clone())?;
2581            }
2582        }
2583    }
2584
2585    response.send(proto::Ack {})?;
2586    Ok(())
2587}
2588
2589/// Change the name of a channel
2590async fn rename_channel(
2591    request: proto::RenameChannel,
2592    response: Response<proto::RenameChannel>,
2593    session: Session,
2594) -> Result<()> {
2595    let db = session.db().await;
2596    let channel_id = ChannelId::from_proto(request.channel_id);
2597    let (channel, channel_members) = db
2598        .rename_channel(channel_id, session.user_id, &request.name)
2599        .await?;
2600
2601    response.send(proto::RenameChannelResponse {
2602        channel: Some(channel.to_proto()),
2603    })?;
2604
2605    let connection_pool = session.connection_pool().await;
2606    for channel_member in channel_members {
2607        if !channel_member.role.can_see_channel(channel.visibility) {
2608            continue;
2609        }
2610        let update = proto::UpdateChannels {
2611            channels: vec![channel.to_proto()],
2612            ..Default::default()
2613        };
2614        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2615            session.peer.send(connection_id, update.clone())?;
2616        }
2617    }
2618
2619    Ok(())
2620}
2621
2622/// Move a channel to a new parent.
2623async fn move_channel(
2624    request: proto::MoveChannel,
2625    response: Response<proto::MoveChannel>,
2626    session: Session,
2627) -> Result<()> {
2628    let channel_id = ChannelId::from_proto(request.channel_id);
2629    let to = ChannelId::from_proto(request.to);
2630
2631    let (channels, channel_members) = session
2632        .db()
2633        .await
2634        .move_channel(channel_id, to, session.user_id)
2635        .await?;
2636
2637    let connection_pool = session.connection_pool().await;
2638    for member in channel_members {
2639        let channels = channels
2640            .iter()
2641            .filter_map(|channel| {
2642                if member.role.can_see_channel(channel.visibility) {
2643                    Some(channel.to_proto())
2644                } else {
2645                    None
2646                }
2647            })
2648            .collect::<Vec<_>>();
2649        if channels.is_empty() {
2650            continue;
2651        }
2652
2653        let update = proto::UpdateChannels {
2654            channels,
2655            ..Default::default()
2656        };
2657
2658        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2659            session.peer.send(connection_id, update.clone())?;
2660        }
2661    }
2662
2663    response.send(Ack {})?;
2664    Ok(())
2665}
2666
2667/// Get the list of channel members
2668async fn get_channel_members(
2669    request: proto::GetChannelMembers,
2670    response: Response<proto::GetChannelMembers>,
2671    session: Session,
2672) -> Result<()> {
2673    let db = session.db().await;
2674    let channel_id = ChannelId::from_proto(request.channel_id);
2675    let members = db
2676        .get_channel_participant_details(channel_id, session.user_id)
2677        .await?;
2678    response.send(proto::GetChannelMembersResponse { members })?;
2679    Ok(())
2680}
2681
2682/// Accept or decline a channel invitation.
2683async fn respond_to_channel_invite(
2684    request: proto::RespondToChannelInvite,
2685    response: Response<proto::RespondToChannelInvite>,
2686    session: Session,
2687) -> Result<()> {
2688    let db = session.db().await;
2689    let channel_id = ChannelId::from_proto(request.channel_id);
2690    let RespondToChannelInvite {
2691        membership_update,
2692        notifications,
2693    } = db
2694        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2695        .await?;
2696
2697    let connection_pool = session.connection_pool().await;
2698    if let Some(membership_update) = membership_update {
2699        notify_membership_updated(
2700            &connection_pool,
2701            membership_update,
2702            session.user_id,
2703            &session.peer,
2704        );
2705    } else {
2706        let update = proto::UpdateChannels {
2707            remove_channel_invitations: vec![channel_id.to_proto()],
2708            ..Default::default()
2709        };
2710
2711        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2712            session.peer.send(connection_id, update.clone())?;
2713        }
2714    };
2715
2716    send_notifications(&connection_pool, &session.peer, notifications);
2717
2718    response.send(proto::Ack {})?;
2719
2720    Ok(())
2721}
2722
2723/// Join the channels' room
2724async fn join_channel(
2725    request: proto::JoinChannel,
2726    response: Response<proto::JoinChannel>,
2727    session: Session,
2728) -> Result<()> {
2729    let channel_id = ChannelId::from_proto(request.channel_id);
2730    join_channel_internal(channel_id, Box::new(response), session).await
2731}
2732
2733trait JoinChannelInternalResponse {
2734    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2735}
2736impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2737    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2738        Response::<proto::JoinChannel>::send(self, result)
2739    }
2740}
2741impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2742    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2743        Response::<proto::JoinRoom>::send(self, result)
2744    }
2745}
2746
2747async fn join_channel_internal(
2748    channel_id: ChannelId,
2749    response: Box<impl JoinChannelInternalResponse>,
2750    session: Session,
2751) -> Result<()> {
2752    let joined_room = {
2753        leave_room_for_session(&session).await?;
2754        let db = session.db().await;
2755
2756        let (joined_room, membership_updated, role) = db
2757            .join_channel(channel_id, session.user_id, session.connection_id)
2758            .await?;
2759
2760        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2761            let (can_publish, token) = if role == ChannelRole::Guest {
2762                (
2763                    false,
2764                    live_kit
2765                        .guest_token(
2766                            &joined_room.room.live_kit_room,
2767                            &session.user_id.to_string(),
2768                        )
2769                        .trace_err()?,
2770                )
2771            } else {
2772                (
2773                    true,
2774                    live_kit
2775                        .room_token(
2776                            &joined_room.room.live_kit_room,
2777                            &session.user_id.to_string(),
2778                        )
2779                        .trace_err()?,
2780                )
2781            };
2782
2783            Some(LiveKitConnectionInfo {
2784                server_url: live_kit.url().into(),
2785                token,
2786                can_publish,
2787            })
2788        });
2789
2790        response.send(proto::JoinRoomResponse {
2791            room: Some(joined_room.room.clone()),
2792            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2793            live_kit_connection_info,
2794        })?;
2795
2796        let connection_pool = session.connection_pool().await;
2797        if let Some(membership_updated) = membership_updated {
2798            notify_membership_updated(
2799                &connection_pool,
2800                membership_updated,
2801                session.user_id,
2802                &session.peer,
2803            );
2804        }
2805
2806        room_updated(&joined_room.room, &session.peer);
2807
2808        joined_room
2809    };
2810
2811    channel_updated(
2812        channel_id,
2813        &joined_room.room,
2814        &joined_room.channel_members,
2815        &session.peer,
2816        &*session.connection_pool().await,
2817    );
2818
2819    update_user_contacts(session.user_id, &session).await?;
2820    Ok(())
2821}
2822
2823/// Start editing the channel notes
2824async fn join_channel_buffer(
2825    request: proto::JoinChannelBuffer,
2826    response: Response<proto::JoinChannelBuffer>,
2827    session: Session,
2828) -> Result<()> {
2829    let db = session.db().await;
2830    let channel_id = ChannelId::from_proto(request.channel_id);
2831
2832    let open_response = db
2833        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2834        .await?;
2835
2836    let collaborators = open_response.collaborators.clone();
2837    response.send(open_response)?;
2838
2839    let update = UpdateChannelBufferCollaborators {
2840        channel_id: channel_id.to_proto(),
2841        collaborators: collaborators.clone(),
2842    };
2843    channel_buffer_updated(
2844        session.connection_id,
2845        collaborators
2846            .iter()
2847            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2848        &update,
2849        &session.peer,
2850    );
2851
2852    Ok(())
2853}
2854
2855/// Edit the channel notes
2856async fn update_channel_buffer(
2857    request: proto::UpdateChannelBuffer,
2858    session: Session,
2859) -> Result<()> {
2860    let db = session.db().await;
2861    let channel_id = ChannelId::from_proto(request.channel_id);
2862
2863    let (collaborators, non_collaborators, epoch, version) = db
2864        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2865        .await?;
2866
2867    channel_buffer_updated(
2868        session.connection_id,
2869        collaborators,
2870        &proto::UpdateChannelBuffer {
2871            channel_id: channel_id.to_proto(),
2872            operations: request.operations,
2873        },
2874        &session.peer,
2875    );
2876
2877    let pool = &*session.connection_pool().await;
2878
2879    broadcast(
2880        None,
2881        non_collaborators
2882            .iter()
2883            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2884        |peer_id| {
2885            session.peer.send(
2886                peer_id,
2887                proto::UpdateChannels {
2888                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2889                        channel_id: channel_id.to_proto(),
2890                        epoch: epoch as u64,
2891                        version: version.clone(),
2892                    }],
2893                    ..Default::default()
2894                },
2895            )
2896        },
2897    );
2898
2899    Ok(())
2900}
2901
2902/// Rejoin the channel notes after a connection blip
2903async fn rejoin_channel_buffers(
2904    request: proto::RejoinChannelBuffers,
2905    response: Response<proto::RejoinChannelBuffers>,
2906    session: Session,
2907) -> Result<()> {
2908    let db = session.db().await;
2909    let buffers = db
2910        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2911        .await?;
2912
2913    for rejoined_buffer in &buffers {
2914        let collaborators_to_notify = rejoined_buffer
2915            .buffer
2916            .collaborators
2917            .iter()
2918            .filter_map(|c| Some(c.peer_id?.into()));
2919        channel_buffer_updated(
2920            session.connection_id,
2921            collaborators_to_notify,
2922            &proto::UpdateChannelBufferCollaborators {
2923                channel_id: rejoined_buffer.buffer.channel_id,
2924                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2925            },
2926            &session.peer,
2927        );
2928    }
2929
2930    response.send(proto::RejoinChannelBuffersResponse {
2931        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2932    })?;
2933
2934    Ok(())
2935}
2936
2937/// Stop editing the channel notes
2938async fn leave_channel_buffer(
2939    request: proto::LeaveChannelBuffer,
2940    response: Response<proto::LeaveChannelBuffer>,
2941    session: Session,
2942) -> Result<()> {
2943    let db = session.db().await;
2944    let channel_id = ChannelId::from_proto(request.channel_id);
2945
2946    let left_buffer = db
2947        .leave_channel_buffer(channel_id, session.connection_id)
2948        .await?;
2949
2950    response.send(Ack {})?;
2951
2952    channel_buffer_updated(
2953        session.connection_id,
2954        left_buffer.connections,
2955        &proto::UpdateChannelBufferCollaborators {
2956            channel_id: channel_id.to_proto(),
2957            collaborators: left_buffer.collaborators,
2958        },
2959        &session.peer,
2960    );
2961
2962    Ok(())
2963}
2964
2965fn channel_buffer_updated<T: EnvelopedMessage>(
2966    sender_id: ConnectionId,
2967    collaborators: impl IntoIterator<Item = ConnectionId>,
2968    message: &T,
2969    peer: &Peer,
2970) {
2971    broadcast(Some(sender_id), collaborators, |peer_id| {
2972        peer.send(peer_id, message.clone())
2973    });
2974}
2975
2976fn send_notifications(
2977    connection_pool: &ConnectionPool,
2978    peer: &Peer,
2979    notifications: db::NotificationBatch,
2980) {
2981    for (user_id, notification) in notifications {
2982        for connection_id in connection_pool.user_connection_ids(user_id) {
2983            if let Err(error) = peer.send(
2984                connection_id,
2985                proto::AddNotification {
2986                    notification: Some(notification.clone()),
2987                },
2988            ) {
2989                tracing::error!(
2990                    "failed to send notification to {:?} {}",
2991                    connection_id,
2992                    error
2993                );
2994            }
2995        }
2996    }
2997}
2998
2999/// Send a message to the channel
3000async fn send_channel_message(
3001    request: proto::SendChannelMessage,
3002    response: Response<proto::SendChannelMessage>,
3003    session: Session,
3004) -> Result<()> {
3005    // Validate the message body.
3006    let body = request.body.trim().to_string();
3007    if body.len() > MAX_MESSAGE_LEN {
3008        return Err(anyhow!("message is too long"))?;
3009    }
3010    if body.is_empty() {
3011        return Err(anyhow!("message can't be blank"))?;
3012    }
3013
3014    // TODO: adjust mentions if body is trimmed
3015
3016    let timestamp = OffsetDateTime::now_utc();
3017    let nonce = request
3018        .nonce
3019        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3020
3021    let channel_id = ChannelId::from_proto(request.channel_id);
3022    let CreatedChannelMessage {
3023        message_id,
3024        participant_connection_ids,
3025        channel_members,
3026        notifications,
3027    } = session
3028        .db()
3029        .await
3030        .create_channel_message(
3031            channel_id,
3032            session.user_id,
3033            &body,
3034            &request.mentions,
3035            timestamp,
3036            nonce.clone().into(),
3037            match request.reply_to_message_id {
3038                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3039                None => None,
3040            },
3041        )
3042        .await?;
3043    let message = proto::ChannelMessage {
3044        sender_id: session.user_id.to_proto(),
3045        id: message_id.to_proto(),
3046        body,
3047        mentions: request.mentions,
3048        timestamp: timestamp.unix_timestamp() as u64,
3049        nonce: Some(nonce),
3050        reply_to_message_id: request.reply_to_message_id,
3051    };
3052    broadcast(
3053        Some(session.connection_id),
3054        participant_connection_ids,
3055        |connection| {
3056            session.peer.send(
3057                connection,
3058                proto::ChannelMessageSent {
3059                    channel_id: channel_id.to_proto(),
3060                    message: Some(message.clone()),
3061                },
3062            )
3063        },
3064    );
3065    response.send(proto::SendChannelMessageResponse {
3066        message: Some(message),
3067    })?;
3068
3069    let pool = &*session.connection_pool().await;
3070    broadcast(
3071        None,
3072        channel_members
3073            .iter()
3074            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3075        |peer_id| {
3076            session.peer.send(
3077                peer_id,
3078                proto::UpdateChannels {
3079                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3080                        channel_id: channel_id.to_proto(),
3081                        message_id: message_id.to_proto(),
3082                    }],
3083                    ..Default::default()
3084                },
3085            )
3086        },
3087    );
3088    send_notifications(pool, &session.peer, notifications);
3089
3090    Ok(())
3091}
3092
3093/// Delete a channel message
3094async fn remove_channel_message(
3095    request: proto::RemoveChannelMessage,
3096    response: Response<proto::RemoveChannelMessage>,
3097    session: Session,
3098) -> Result<()> {
3099    let channel_id = ChannelId::from_proto(request.channel_id);
3100    let message_id = MessageId::from_proto(request.message_id);
3101    let connection_ids = session
3102        .db()
3103        .await
3104        .remove_channel_message(channel_id, message_id, session.user_id)
3105        .await?;
3106    broadcast(Some(session.connection_id), connection_ids, |connection| {
3107        session.peer.send(connection, request.clone())
3108    });
3109    response.send(proto::Ack {})?;
3110    Ok(())
3111}
3112
3113/// Mark a channel message as read
3114async fn acknowledge_channel_message(
3115    request: proto::AckChannelMessage,
3116    session: Session,
3117) -> Result<()> {
3118    let channel_id = ChannelId::from_proto(request.channel_id);
3119    let message_id = MessageId::from_proto(request.message_id);
3120    let notifications = session
3121        .db()
3122        .await
3123        .observe_channel_message(channel_id, session.user_id, message_id)
3124        .await?;
3125    send_notifications(
3126        &*session.connection_pool().await,
3127        &session.peer,
3128        notifications,
3129    );
3130    Ok(())
3131}
3132
3133/// Mark a buffer version as synced
3134async fn acknowledge_buffer_version(
3135    request: proto::AckBufferOperation,
3136    session: Session,
3137) -> Result<()> {
3138    let buffer_id = BufferId::from_proto(request.buffer_id);
3139    session
3140        .db()
3141        .await
3142        .observe_buffer_version(
3143            buffer_id,
3144            session.user_id,
3145            request.epoch as i32,
3146            &request.version,
3147        )
3148        .await?;
3149    Ok(())
3150}
3151
3152/// Start receiving chat updates for a channel
3153async fn join_channel_chat(
3154    request: proto::JoinChannelChat,
3155    response: Response<proto::JoinChannelChat>,
3156    session: Session,
3157) -> Result<()> {
3158    let channel_id = ChannelId::from_proto(request.channel_id);
3159
3160    let db = session.db().await;
3161    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3162        .await?;
3163    let messages = db
3164        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3165        .await?;
3166    response.send(proto::JoinChannelChatResponse {
3167        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3168        messages,
3169    })?;
3170    Ok(())
3171}
3172
3173/// Stop receiving chat updates for a channel
3174async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3175    let channel_id = ChannelId::from_proto(request.channel_id);
3176    session
3177        .db()
3178        .await
3179        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3180        .await?;
3181    Ok(())
3182}
3183
3184/// Retrieve the chat history for a channel
3185async fn get_channel_messages(
3186    request: proto::GetChannelMessages,
3187    response: Response<proto::GetChannelMessages>,
3188    session: Session,
3189) -> Result<()> {
3190    let channel_id = ChannelId::from_proto(request.channel_id);
3191    let messages = session
3192        .db()
3193        .await
3194        .get_channel_messages(
3195            channel_id,
3196            session.user_id,
3197            MESSAGE_COUNT_PER_PAGE,
3198            Some(MessageId::from_proto(request.before_message_id)),
3199        )
3200        .await?;
3201    response.send(proto::GetChannelMessagesResponse {
3202        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3203        messages,
3204    })?;
3205    Ok(())
3206}
3207
3208/// Retrieve specific chat messages
3209async fn get_channel_messages_by_id(
3210    request: proto::GetChannelMessagesById,
3211    response: Response<proto::GetChannelMessagesById>,
3212    session: Session,
3213) -> Result<()> {
3214    let message_ids = request
3215        .message_ids
3216        .iter()
3217        .map(|id| MessageId::from_proto(*id))
3218        .collect::<Vec<_>>();
3219    let messages = session
3220        .db()
3221        .await
3222        .get_channel_messages_by_id(session.user_id, &message_ids)
3223        .await?;
3224    response.send(proto::GetChannelMessagesResponse {
3225        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3226        messages,
3227    })?;
3228    Ok(())
3229}
3230
3231/// Retrieve the current users notifications
3232async fn get_notifications(
3233    request: proto::GetNotifications,
3234    response: Response<proto::GetNotifications>,
3235    session: Session,
3236) -> Result<()> {
3237    let notifications = session
3238        .db()
3239        .await
3240        .get_notifications(
3241            session.user_id,
3242            NOTIFICATION_COUNT_PER_PAGE,
3243            request
3244                .before_id
3245                .map(|id| db::NotificationId::from_proto(id)),
3246        )
3247        .await?;
3248    response.send(proto::GetNotificationsResponse {
3249        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3250        notifications,
3251    })?;
3252    Ok(())
3253}
3254
3255/// Mark notifications as read
3256async fn mark_notification_as_read(
3257    request: proto::MarkNotificationRead,
3258    response: Response<proto::MarkNotificationRead>,
3259    session: Session,
3260) -> Result<()> {
3261    let database = &session.db().await;
3262    let notifications = database
3263        .mark_notification_as_read_by_id(
3264            session.user_id,
3265            NotificationId::from_proto(request.notification_id),
3266        )
3267        .await?;
3268    send_notifications(
3269        &*session.connection_pool().await,
3270        &session.peer,
3271        notifications,
3272    );
3273    response.send(proto::Ack {})?;
3274    Ok(())
3275}
3276
3277/// Get the current users information
3278async fn get_private_user_info(
3279    _request: proto::GetPrivateUserInfo,
3280    response: Response<proto::GetPrivateUserInfo>,
3281    session: Session,
3282) -> Result<()> {
3283    let db = session.db().await;
3284
3285    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3286    let user = db
3287        .get_user_by_id(session.user_id)
3288        .await?
3289        .ok_or_else(|| anyhow!("user not found"))?;
3290    let flags = db.get_user_flags(session.user_id).await?;
3291
3292    response.send(proto::GetPrivateUserInfoResponse {
3293        metrics_id,
3294        staff: user.admin,
3295        flags,
3296    })?;
3297    Ok(())
3298}
3299
3300fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3301    match message {
3302        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3303        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3304        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3305        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3306        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3307            code: frame.code.into(),
3308            reason: frame.reason,
3309        })),
3310    }
3311}
3312
3313fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3314    match message {
3315        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3316        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3317        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3318        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3319        AxumMessage::Close(frame) => {
3320            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3321                code: frame.code.into(),
3322                reason: frame.reason,
3323            }))
3324        }
3325    }
3326}
3327
3328fn notify_membership_updated(
3329    connection_pool: &ConnectionPool,
3330    result: MembershipUpdated,
3331    user_id: UserId,
3332    peer: &Peer,
3333) {
3334    let user_channels_update = proto::UpdateUserChannels {
3335        channel_memberships: result
3336            .new_channels
3337            .channel_memberships
3338            .iter()
3339            .map(|cm| proto::ChannelMembership {
3340                channel_id: cm.channel_id.to_proto(),
3341                role: cm.role.into(),
3342            })
3343            .collect(),
3344        ..Default::default()
3345    };
3346    let mut update = build_channels_update(result.new_channels, vec![]);
3347    update.delete_channels = result
3348        .removed_channels
3349        .into_iter()
3350        .map(|id| id.to_proto())
3351        .collect();
3352    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3353
3354    for connection_id in connection_pool.user_connection_ids(user_id) {
3355        peer.send(connection_id, user_channels_update.clone())
3356            .trace_err();
3357        peer.send(connection_id, update.clone()).trace_err();
3358    }
3359}
3360
3361fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3362    proto::UpdateUserChannels {
3363        channel_memberships: channels
3364            .channel_memberships
3365            .iter()
3366            .map(|m| proto::ChannelMembership {
3367                channel_id: m.channel_id.to_proto(),
3368                role: m.role.into(),
3369            })
3370            .collect(),
3371        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3372        observed_channel_message_id: channels.observed_channel_messages.clone(),
3373        ..Default::default()
3374    }
3375}
3376
3377fn build_channels_update(
3378    channels: ChannelsForUser,
3379    channel_invites: Vec<db::Channel>,
3380) -> proto::UpdateChannels {
3381    let mut update = proto::UpdateChannels::default();
3382
3383    for channel in channels.channels {
3384        update.channels.push(channel.to_proto());
3385    }
3386
3387    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3388    update.latest_channel_message_ids = channels.latest_channel_messages;
3389
3390    for (channel_id, participants) in channels.channel_participants {
3391        update
3392            .channel_participants
3393            .push(proto::ChannelParticipants {
3394                channel_id: channel_id.to_proto(),
3395                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3396            });
3397    }
3398
3399    for channel in channel_invites {
3400        update.channel_invitations.push(channel.to_proto());
3401    }
3402    for project in channels.hosted_projects {
3403        update.hosted_projects.push(project);
3404    }
3405
3406    update
3407}
3408
3409fn build_initial_contacts_update(
3410    contacts: Vec<db::Contact>,
3411    pool: &ConnectionPool,
3412) -> proto::UpdateContacts {
3413    let mut update = proto::UpdateContacts::default();
3414
3415    for contact in contacts {
3416        match contact {
3417            db::Contact::Accepted { user_id, busy } => {
3418                update.contacts.push(contact_for_user(user_id, busy, &pool));
3419            }
3420            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3421            db::Contact::Incoming { user_id } => {
3422                update
3423                    .incoming_requests
3424                    .push(proto::IncomingContactRequest {
3425                        requester_id: user_id.to_proto(),
3426                    })
3427            }
3428        }
3429    }
3430
3431    update
3432}
3433
3434fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3435    proto::Contact {
3436        user_id: user_id.to_proto(),
3437        online: pool.is_user_online(user_id),
3438        busy,
3439    }
3440}
3441
3442fn room_updated(room: &proto::Room, peer: &Peer) {
3443    broadcast(
3444        None,
3445        room.participants
3446            .iter()
3447            .filter_map(|participant| Some(participant.peer_id?.into())),
3448        |peer_id| {
3449            peer.send(
3450                peer_id,
3451                proto::RoomUpdated {
3452                    room: Some(room.clone()),
3453                },
3454            )
3455        },
3456    );
3457}
3458
3459fn channel_updated(
3460    channel_id: ChannelId,
3461    room: &proto::Room,
3462    channel_members: &[UserId],
3463    peer: &Peer,
3464    pool: &ConnectionPool,
3465) {
3466    let participants = room
3467        .participants
3468        .iter()
3469        .map(|p| p.user_id)
3470        .collect::<Vec<_>>();
3471
3472    broadcast(
3473        None,
3474        channel_members
3475            .iter()
3476            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3477        |peer_id| {
3478            peer.send(
3479                peer_id,
3480                proto::UpdateChannels {
3481                    channel_participants: vec![proto::ChannelParticipants {
3482                        channel_id: channel_id.to_proto(),
3483                        participant_user_ids: participants.clone(),
3484                    }],
3485                    ..Default::default()
3486                },
3487            )
3488        },
3489    );
3490}
3491
3492async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3493    let db = session.db().await;
3494
3495    let contacts = db.get_contacts(user_id).await?;
3496    let busy = db.is_user_busy(user_id).await?;
3497
3498    let pool = session.connection_pool().await;
3499    let updated_contact = contact_for_user(user_id, busy, &pool);
3500    for contact in contacts {
3501        if let db::Contact::Accepted {
3502            user_id: contact_user_id,
3503            ..
3504        } = contact
3505        {
3506            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3507                session
3508                    .peer
3509                    .send(
3510                        contact_conn_id,
3511                        proto::UpdateContacts {
3512                            contacts: vec![updated_contact.clone()],
3513                            remove_contacts: Default::default(),
3514                            incoming_requests: Default::default(),
3515                            remove_incoming_requests: Default::default(),
3516                            outgoing_requests: Default::default(),
3517                            remove_outgoing_requests: Default::default(),
3518                        },
3519                    )
3520                    .trace_err();
3521            }
3522        }
3523    }
3524    Ok(())
3525}
3526
3527async fn leave_room_for_session(session: &Session) -> Result<()> {
3528    let mut contacts_to_update = HashSet::default();
3529
3530    let room_id;
3531    let canceled_calls_to_user_ids;
3532    let live_kit_room;
3533    let delete_live_kit_room;
3534    let room;
3535    let channel_members;
3536    let channel_id;
3537
3538    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3539        contacts_to_update.insert(session.user_id);
3540
3541        for project in left_room.left_projects.values() {
3542            project_left(project, session);
3543        }
3544
3545        room_id = RoomId::from_proto(left_room.room.id);
3546        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3547        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3548        delete_live_kit_room = left_room.deleted;
3549        room = mem::take(&mut left_room.room);
3550        channel_members = mem::take(&mut left_room.channel_members);
3551        channel_id = left_room.channel_id;
3552
3553        room_updated(&room, &session.peer);
3554    } else {
3555        return Ok(());
3556    }
3557
3558    if let Some(channel_id) = channel_id {
3559        channel_updated(
3560            channel_id,
3561            &room,
3562            &channel_members,
3563            &session.peer,
3564            &*session.connection_pool().await,
3565        );
3566    }
3567
3568    {
3569        let pool = session.connection_pool().await;
3570        for canceled_user_id in canceled_calls_to_user_ids {
3571            for connection_id in pool.user_connection_ids(canceled_user_id) {
3572                session
3573                    .peer
3574                    .send(
3575                        connection_id,
3576                        proto::CallCanceled {
3577                            room_id: room_id.to_proto(),
3578                        },
3579                    )
3580                    .trace_err();
3581            }
3582            contacts_to_update.insert(canceled_user_id);
3583        }
3584    }
3585
3586    for contact_user_id in contacts_to_update {
3587        update_user_contacts(contact_user_id, &session).await?;
3588    }
3589
3590    if let Some(live_kit) = session.live_kit_client.as_ref() {
3591        live_kit
3592            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3593            .await
3594            .trace_err();
3595
3596        if delete_live_kit_room {
3597            live_kit.delete_room(live_kit_room).await.trace_err();
3598        }
3599    }
3600
3601    Ok(())
3602}
3603
3604async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3605    let left_channel_buffers = session
3606        .db()
3607        .await
3608        .leave_channel_buffers(session.connection_id)
3609        .await?;
3610
3611    for left_buffer in left_channel_buffers {
3612        channel_buffer_updated(
3613            session.connection_id,
3614            left_buffer.connections,
3615            &proto::UpdateChannelBufferCollaborators {
3616                channel_id: left_buffer.channel_id.to_proto(),
3617                collaborators: left_buffer.collaborators,
3618            },
3619            &session.peer,
3620        );
3621    }
3622
3623    Ok(())
3624}
3625
3626fn project_left(project: &db::LeftProject, session: &Session) {
3627    for connection_id in &project.connection_ids {
3628        if project.host_user_id == session.user_id {
3629            session
3630                .peer
3631                .send(
3632                    *connection_id,
3633                    proto::UnshareProject {
3634                        project_id: project.id.to_proto(),
3635                    },
3636                )
3637                .trace_err();
3638        } else {
3639            session
3640                .peer
3641                .send(
3642                    *connection_id,
3643                    proto::RemoveProjectCollaborator {
3644                        project_id: project.id.to_proto(),
3645                        peer_id: Some(session.connection_id.into()),
3646                    },
3647                )
3648                .trace_err();
3649        }
3650    }
3651}
3652
3653pub trait ResultExt {
3654    type Ok;
3655
3656    fn trace_err(self) -> Option<Self::Ok>;
3657}
3658
3659impl<T, E> ResultExt for Result<T, E>
3660where
3661    E: std::fmt::Debug,
3662{
3663    type Ok = T;
3664
3665    fn trace_err(self) -> Option<T> {
3666        match self {
3667            Ok(value) => Some(value),
3668            Err(error) => {
3669                tracing::error!("{:?}", error);
3670                None
3671            }
3672        }
3673    }
3674}