rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, CreatedChannelMessage,
   7        Database, MessageId, NotificationId, ProjectId, RoomId, ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  74
  75lazy_static! {
  76    static ref METRIC_CONNECTIONS: IntGauge =
  77        register_int_gauge!("connections", "number of connections").unwrap();
  78    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  79        "shared_projects",
  80        "number of open projects with one or more guests"
  81    )
  82    .unwrap();
  83}
  84
  85type MessageHandler =
  86    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  87
  88struct Response<R> {
  89    peer: Arc<Peer>,
  90    receipt: Receipt<R>,
  91    responded: Arc<AtomicBool>,
  92}
  93
  94impl<R: RequestMessage> Response<R> {
  95    fn send(self, payload: R::Response) -> Result<()> {
  96        self.responded.store(true, SeqCst);
  97        self.peer.respond(self.receipt, payload)?;
  98        Ok(())
  99    }
 100}
 101
 102#[derive(Clone)]
 103struct Session {
 104    user_id: UserId,
 105    connection_id: ConnectionId,
 106    db: Arc<tokio::sync::Mutex<DbHandle>>,
 107    peer: Arc<Peer>,
 108    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 109    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 110    executor: Executor,
 111}
 112
 113impl Session {
 114    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.db.lock().await;
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        guard
 121    }
 122
 123    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 124        #[cfg(test)]
 125        tokio::task::yield_now().await;
 126        let guard = self.connection_pool.lock();
 127        ConnectionPoolGuard {
 128            guard,
 129            _not_send: PhantomData,
 130        }
 131    }
 132}
 133
 134impl fmt::Debug for Session {
 135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 136        f.debug_struct("Session")
 137            .field("user_id", &self.user_id)
 138            .field("connection_id", &self.connection_id)
 139            .finish()
 140    }
 141}
 142
 143struct DbHandle(Arc<Database>);
 144
 145impl Deref for DbHandle {
 146    type Target = Database;
 147
 148    fn deref(&self) -> &Self::Target {
 149        self.0.as_ref()
 150    }
 151}
 152
 153pub struct Server {
 154    id: parking_lot::Mutex<ServerId>,
 155    peer: Arc<Peer>,
 156    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 157    app_state: Arc<AppState>,
 158    executor: Executor,
 159    handlers: HashMap<TypeId, MessageHandler>,
 160    teardown: watch::Sender<()>,
 161}
 162
 163pub(crate) struct ConnectionPoolGuard<'a> {
 164    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 165    _not_send: PhantomData<Rc<()>>,
 166}
 167
 168#[derive(Serialize)]
 169pub struct ServerSnapshot<'a> {
 170    peer: &'a Peer,
 171    #[serde(serialize_with = "serialize_deref")]
 172    connection_pool: ConnectionPoolGuard<'a>,
 173}
 174
 175pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 176where
 177    S: Serializer,
 178    T: Deref<Target = U>,
 179    U: Serialize,
 180{
 181    Serialize::serialize(value.deref(), serializer)
 182}
 183
 184impl Server {
 185    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 186        let mut server = Self {
 187            id: parking_lot::Mutex::new(id),
 188            peer: Peer::new(id.0 as u32),
 189            app_state,
 190            executor,
 191            connection_pool: Default::default(),
 192            handlers: Default::default(),
 193            teardown: watch::channel(()).0,
 194        };
 195
 196        server
 197            .add_request_handler(ping)
 198            .add_request_handler(create_room)
 199            .add_request_handler(join_room)
 200            .add_request_handler(rejoin_room)
 201            .add_request_handler(leave_room)
 202            .add_request_handler(call)
 203            .add_request_handler(cancel_call)
 204            .add_message_handler(decline_call)
 205            .add_request_handler(update_participant_location)
 206            .add_request_handler(share_project)
 207            .add_message_handler(unshare_project)
 208            .add_request_handler(join_project)
 209            .add_message_handler(leave_project)
 210            .add_request_handler(update_project)
 211            .add_request_handler(update_worktree)
 212            .add_message_handler(start_language_server)
 213            .add_message_handler(update_language_server)
 214            .add_message_handler(update_diagnostic_summary)
 215            .add_message_handler(update_worktree_settings)
 216            .add_message_handler(refresh_inlay_hints)
 217            .add_request_handler(forward_project_request::<proto::GetHover>)
 218            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 220            .add_request_handler(forward_project_request::<proto::GetReferences>)
 221            .add_request_handler(forward_project_request::<proto::SearchProject>)
 222            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 223            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 226            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 227            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 229            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 230            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 232            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 233            .add_request_handler(forward_project_request::<proto::PerformRename>)
 234            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 235            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 236            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 237            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 243            .add_request_handler(forward_project_request::<proto::InlayHints>)
 244            .add_message_handler(create_buffer_for_peer)
 245            .add_request_handler(update_buffer)
 246            .add_message_handler(update_buffer_file)
 247            .add_message_handler(buffer_reloaded)
 248            .add_message_handler(buffer_saved)
 249            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 250            .add_request_handler(get_users)
 251            .add_request_handler(fuzzy_search_users)
 252            .add_request_handler(request_contact)
 253            .add_request_handler(remove_contact)
 254            .add_request_handler(respond_to_contact_request)
 255            .add_request_handler(create_channel)
 256            .add_request_handler(delete_channel)
 257            .add_request_handler(invite_channel_member)
 258            .add_request_handler(remove_channel_member)
 259            .add_request_handler(set_channel_member_role)
 260            .add_request_handler(set_channel_visibility)
 261            .add_request_handler(rename_channel)
 262            .add_request_handler(join_channel_buffer)
 263            .add_request_handler(leave_channel_buffer)
 264            .add_message_handler(update_channel_buffer)
 265            .add_request_handler(rejoin_channel_buffers)
 266            .add_request_handler(get_channel_members)
 267            .add_request_handler(respond_to_channel_invite)
 268            .add_request_handler(join_channel)
 269            .add_request_handler(join_channel_chat)
 270            .add_message_handler(leave_channel_chat)
 271            .add_request_handler(send_channel_message)
 272            .add_request_handler(remove_channel_message)
 273            .add_request_handler(get_channel_messages)
 274            .add_request_handler(get_channel_messages_by_id)
 275            .add_request_handler(get_notifications)
 276            .add_request_handler(mark_notification_as_read)
 277            .add_request_handler(link_channel)
 278            .add_request_handler(unlink_channel)
 279            .add_request_handler(move_channel)
 280            .add_request_handler(follow)
 281            .add_message_handler(unfollow)
 282            .add_message_handler(update_followers)
 283            .add_message_handler(update_diff_base)
 284            .add_request_handler(get_private_user_info)
 285            .add_message_handler(acknowledge_channel_message)
 286            .add_message_handler(acknowledge_buffer_version);
 287
 288        Arc::new(server)
 289    }
 290
 291    pub async fn start(&self) -> Result<()> {
 292        let server_id = *self.id.lock();
 293        let app_state = self.app_state.clone();
 294        let peer = self.peer.clone();
 295        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 296        let pool = self.connection_pool.clone();
 297        let live_kit_client = self.app_state.live_kit_client.clone();
 298
 299        let span = info_span!("start server");
 300        self.executor.spawn_detached(
 301            async move {
 302                tracing::info!("waiting for cleanup timeout");
 303                timeout.await;
 304                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 305                if let Some((room_ids, channel_ids)) = app_state
 306                    .db
 307                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 308                    .await
 309                    .trace_err()
 310                {
 311                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 312                    tracing::info!(
 313                        stale_channel_buffer_count = channel_ids.len(),
 314                        "retrieved stale channel buffers"
 315                    );
 316
 317                    for channel_id in channel_ids {
 318                        if let Some(refreshed_channel_buffer) = app_state
 319                            .db
 320                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 321                            .await
 322                            .trace_err()
 323                        {
 324                            for connection_id in refreshed_channel_buffer.connection_ids {
 325                                peer.send(
 326                                    connection_id,
 327                                    proto::UpdateChannelBufferCollaborators {
 328                                        channel_id: channel_id.to_proto(),
 329                                        collaborators: refreshed_channel_buffer
 330                                            .collaborators
 331                                            .clone(),
 332                                    },
 333                                )
 334                                .trace_err();
 335                            }
 336                        }
 337                    }
 338
 339                    for room_id in room_ids {
 340                        let mut contacts_to_update = HashSet::default();
 341                        let mut canceled_calls_to_user_ids = Vec::new();
 342                        let mut live_kit_room = String::new();
 343                        let mut delete_live_kit_room = false;
 344
 345                        if let Some(mut refreshed_room) = app_state
 346                            .db
 347                            .clear_stale_room_participants(room_id, server_id)
 348                            .await
 349                            .trace_err()
 350                        {
 351                            tracing::info!(
 352                                room_id = room_id.0,
 353                                new_participant_count = refreshed_room.room.participants.len(),
 354                                "refreshed room"
 355                            );
 356                            room_updated(&refreshed_room.room, &peer);
 357                            if let Some(channel_id) = refreshed_room.channel_id {
 358                                channel_updated(
 359                                    channel_id,
 360                                    &refreshed_room.room,
 361                                    &refreshed_room.channel_members,
 362                                    &peer,
 363                                    &*pool.lock(),
 364                                );
 365                            }
 366                            contacts_to_update
 367                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 368                            contacts_to_update
 369                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 370                            canceled_calls_to_user_ids =
 371                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 372                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 373                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 374                        }
 375
 376                        {
 377                            let pool = pool.lock();
 378                            for canceled_user_id in canceled_calls_to_user_ids {
 379                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 380                                    peer.send(
 381                                        connection_id,
 382                                        proto::CallCanceled {
 383                                            room_id: room_id.to_proto(),
 384                                        },
 385                                    )
 386                                    .trace_err();
 387                                }
 388                            }
 389                        }
 390
 391                        for user_id in contacts_to_update {
 392                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 393                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 394                            if let Some((busy, contacts)) = busy.zip(contacts) {
 395                                let pool = pool.lock();
 396                                let updated_contact = contact_for_user(user_id, busy, &pool);
 397                                for contact in contacts {
 398                                    if let db::Contact::Accepted {
 399                                        user_id: contact_user_id,
 400                                        ..
 401                                    } = contact
 402                                    {
 403                                        for contact_conn_id in
 404                                            pool.user_connection_ids(contact_user_id)
 405                                        {
 406                                            peer.send(
 407                                                contact_conn_id,
 408                                                proto::UpdateContacts {
 409                                                    contacts: vec![updated_contact.clone()],
 410                                                    remove_contacts: Default::default(),
 411                                                    incoming_requests: Default::default(),
 412                                                    remove_incoming_requests: Default::default(),
 413                                                    outgoing_requests: Default::default(),
 414                                                    remove_outgoing_requests: Default::default(),
 415                                                },
 416                                            )
 417                                            .trace_err();
 418                                        }
 419                                    }
 420                                }
 421                            }
 422                        }
 423
 424                        if let Some(live_kit) = live_kit_client.as_ref() {
 425                            if delete_live_kit_room {
 426                                live_kit.delete_room(live_kit_room).await.trace_err();
 427                            }
 428                        }
 429                    }
 430                }
 431
 432                app_state
 433                    .db
 434                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 435                    .await
 436                    .trace_err();
 437            }
 438            .instrument(span),
 439        );
 440        Ok(())
 441    }
 442
 443    pub fn teardown(&self) {
 444        self.peer.teardown();
 445        self.connection_pool.lock().reset();
 446        let _ = self.teardown.send(());
 447    }
 448
 449    #[cfg(test)]
 450    pub fn reset(&self, id: ServerId) {
 451        self.teardown();
 452        *self.id.lock() = id;
 453        self.peer.reset(id.0 as u32);
 454    }
 455
 456    #[cfg(test)]
 457    pub fn id(&self) -> ServerId {
 458        *self.id.lock()
 459    }
 460
 461    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 462    where
 463        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 464        Fut: 'static + Send + Future<Output = Result<()>>,
 465        M: EnvelopedMessage,
 466    {
 467        let prev_handler = self.handlers.insert(
 468            TypeId::of::<M>(),
 469            Box::new(move |envelope, session| {
 470                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 471                let span = info_span!(
 472                    "handle message",
 473                    payload_type = envelope.payload_type_name()
 474                );
 475                span.in_scope(|| {
 476                    tracing::info!(
 477                        payload_type = envelope.payload_type_name(),
 478                        "message received"
 479                    );
 480                });
 481                let start_time = Instant::now();
 482                let future = (handler)(*envelope, session);
 483                async move {
 484                    let result = future.await;
 485                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 486                    match result {
 487                        Err(error) => {
 488                            tracing::error!(%error, ?duration_ms, "error handling message")
 489                        }
 490                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 491                    }
 492                }
 493                .instrument(span)
 494                .boxed()
 495            }),
 496        );
 497        if prev_handler.is_some() {
 498            panic!("registered a handler for the same message twice");
 499        }
 500        self
 501    }
 502
 503    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 504    where
 505        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 506        Fut: 'static + Send + Future<Output = Result<()>>,
 507        M: EnvelopedMessage,
 508    {
 509        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 510        self
 511    }
 512
 513    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 514    where
 515        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 516        Fut: Send + Future<Output = Result<()>>,
 517        M: RequestMessage,
 518    {
 519        let handler = Arc::new(handler);
 520        self.add_handler(move |envelope, session| {
 521            let receipt = envelope.receipt();
 522            let handler = handler.clone();
 523            async move {
 524                let peer = session.peer.clone();
 525                let responded = Arc::new(AtomicBool::default());
 526                let response = Response {
 527                    peer: peer.clone(),
 528                    responded: responded.clone(),
 529                    receipt,
 530                };
 531                match (handler)(envelope.payload, response, session).await {
 532                    Ok(()) => {
 533                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 534                            Ok(())
 535                        } else {
 536                            Err(anyhow!("handler did not send a response"))?
 537                        }
 538                    }
 539                    Err(error) => {
 540                        peer.respond_with_error(
 541                            receipt,
 542                            proto::Error {
 543                                message: error.to_string(),
 544                            },
 545                        )?;
 546                        Err(error)
 547                    }
 548                }
 549            }
 550        })
 551    }
 552
 553    pub fn handle_connection(
 554        self: &Arc<Self>,
 555        connection: Connection,
 556        address: String,
 557        user: User,
 558        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 559        executor: Executor,
 560    ) -> impl Future<Output = Result<()>> {
 561        let this = self.clone();
 562        let user_id = user.id;
 563        let login = user.github_login;
 564        let span = info_span!("handle connection", %user_id, %login, %address);
 565        let mut teardown = self.teardown.subscribe();
 566        async move {
 567            let (connection_id, handle_io, mut incoming_rx) = this
 568                .peer
 569                .add_connection(connection, {
 570                    let executor = executor.clone();
 571                    move |duration| executor.sleep(duration)
 572                });
 573
 574            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 575            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 576            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 577
 578            if let Some(send_connection_id) = send_connection_id.take() {
 579                let _ = send_connection_id.send(connection_id);
 580            }
 581
 582            if !user.connected_once {
 583                this.peer.send(connection_id, proto::ShowContacts {})?;
 584                this.app_state.db.set_user_connected_once(user_id, true).await?;
 585            }
 586
 587            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 588                this.app_state.db.get_contacts(user_id),
 589                this.app_state.db.get_channels_for_user(user_id),
 590                this.app_state.db.get_channel_invites_for_user(user_id),
 591            ).await?;
 592
 593            {
 594                let mut pool = this.connection_pool.lock();
 595                pool.add_connection(connection_id, user_id, user.admin);
 596                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 597                this.peer.send(connection_id, build_initial_channels_update(
 598                    channels_for_user,
 599                    channel_invites
 600                ))?;
 601            }
 602
 603            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 604                this.peer.send(connection_id, incoming_call)?;
 605            }
 606
 607            let session = Session {
 608                user_id,
 609                connection_id,
 610                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 611                peer: this.peer.clone(),
 612                connection_pool: this.connection_pool.clone(),
 613                live_kit_client: this.app_state.live_kit_client.clone(),
 614                executor: executor.clone(),
 615            };
 616            update_user_contacts(user_id, &session).await?;
 617
 618            let handle_io = handle_io.fuse();
 619            futures::pin_mut!(handle_io);
 620
 621            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 622            // This prevents deadlocks when e.g., client A performs a request to client B and
 623            // client B performs a request to client A. If both clients stop processing further
 624            // messages until their respective request completes, they won't have a chance to
 625            // respond to the other client's request and cause a deadlock.
 626            //
 627            // This arrangement ensures we will attempt to process earlier messages first, but fall
 628            // back to processing messages arrived later in the spirit of making progress.
 629            let mut foreground_message_handlers = FuturesUnordered::new();
 630            let concurrent_handlers = Arc::new(Semaphore::new(256));
 631            loop {
 632                let next_message = async {
 633                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 634                    let message = incoming_rx.next().await;
 635                    (permit, message)
 636                }.fuse();
 637                futures::pin_mut!(next_message);
 638                futures::select_biased! {
 639                    _ = teardown.changed().fuse() => return Ok(()),
 640                    result = handle_io => {
 641                        if let Err(error) = result {
 642                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 643                        }
 644                        break;
 645                    }
 646                    _ = foreground_message_handlers.next() => {}
 647                    next_message = next_message => {
 648                        let (permit, message) = next_message;
 649                        if let Some(message) = message {
 650                            let type_name = message.payload_type_name();
 651                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 652                            let span_enter = span.enter();
 653                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 654                                let is_background = message.is_background();
 655                                let handle_message = (handler)(message, session.clone());
 656                                drop(span_enter);
 657
 658                                let handle_message = async move {
 659                                    handle_message.await;
 660                                    drop(permit);
 661                                }.instrument(span);
 662                                if is_background {
 663                                    executor.spawn_detached(handle_message);
 664                                } else {
 665                                    foreground_message_handlers.push(handle_message);
 666                                }
 667                            } else {
 668                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 669                            }
 670                        } else {
 671                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 672                            break;
 673                        }
 674                    }
 675                }
 676            }
 677
 678            drop(foreground_message_handlers);
 679            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 680            if let Err(error) = connection_lost(session, teardown, executor).await {
 681                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 682            }
 683
 684            Ok(())
 685        }.instrument(span)
 686    }
 687
 688    pub async fn invite_code_redeemed(
 689        self: &Arc<Self>,
 690        inviter_id: UserId,
 691        invitee_id: UserId,
 692    ) -> Result<()> {
 693        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 694            if let Some(code) = &user.invite_code {
 695                let pool = self.connection_pool.lock();
 696                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 697                for connection_id in pool.user_connection_ids(inviter_id) {
 698                    self.peer.send(
 699                        connection_id,
 700                        proto::UpdateContacts {
 701                            contacts: vec![invitee_contact.clone()],
 702                            ..Default::default()
 703                        },
 704                    )?;
 705                    self.peer.send(
 706                        connection_id,
 707                        proto::UpdateInviteInfo {
 708                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 709                            count: user.invite_count as u32,
 710                        },
 711                    )?;
 712                }
 713            }
 714        }
 715        Ok(())
 716    }
 717
 718    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 719        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 720            if let Some(invite_code) = &user.invite_code {
 721                let pool = self.connection_pool.lock();
 722                for connection_id in pool.user_connection_ids(user_id) {
 723                    self.peer.send(
 724                        connection_id,
 725                        proto::UpdateInviteInfo {
 726                            url: format!(
 727                                "{}{}",
 728                                self.app_state.config.invite_link_prefix, invite_code
 729                            ),
 730                            count: user.invite_count as u32,
 731                        },
 732                    )?;
 733                }
 734            }
 735        }
 736        Ok(())
 737    }
 738
 739    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 740        ServerSnapshot {
 741            connection_pool: ConnectionPoolGuard {
 742                guard: self.connection_pool.lock(),
 743                _not_send: PhantomData,
 744            },
 745            peer: &self.peer,
 746        }
 747    }
 748}
 749
 750impl<'a> Deref for ConnectionPoolGuard<'a> {
 751    type Target = ConnectionPool;
 752
 753    fn deref(&self) -> &Self::Target {
 754        &*self.guard
 755    }
 756}
 757
 758impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 759    fn deref_mut(&mut self) -> &mut Self::Target {
 760        &mut *self.guard
 761    }
 762}
 763
 764impl<'a> Drop for ConnectionPoolGuard<'a> {
 765    fn drop(&mut self) {
 766        #[cfg(test)]
 767        self.check_invariants();
 768    }
 769}
 770
 771fn broadcast<F>(
 772    sender_id: Option<ConnectionId>,
 773    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 774    mut f: F,
 775) where
 776    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 777{
 778    for receiver_id in receiver_ids {
 779        if Some(receiver_id) != sender_id {
 780            if let Err(error) = f(receiver_id) {
 781                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 782            }
 783        }
 784    }
 785}
 786
 787lazy_static! {
 788    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 789}
 790
 791pub struct ProtocolVersion(u32);
 792
 793impl Header for ProtocolVersion {
 794    fn name() -> &'static HeaderName {
 795        &ZED_PROTOCOL_VERSION
 796    }
 797
 798    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 799    where
 800        Self: Sized,
 801        I: Iterator<Item = &'i axum::http::HeaderValue>,
 802    {
 803        let version = values
 804            .next()
 805            .ok_or_else(axum::headers::Error::invalid)?
 806            .to_str()
 807            .map_err(|_| axum::headers::Error::invalid())?
 808            .parse()
 809            .map_err(|_| axum::headers::Error::invalid())?;
 810        Ok(Self(version))
 811    }
 812
 813    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 814        values.extend([self.0.to_string().parse().unwrap()]);
 815    }
 816}
 817
 818pub fn routes(server: Arc<Server>) -> Router<Body> {
 819    Router::new()
 820        .route("/rpc", get(handle_websocket_request))
 821        .layer(
 822            ServiceBuilder::new()
 823                .layer(Extension(server.app_state.clone()))
 824                .layer(middleware::from_fn(auth::validate_header)),
 825        )
 826        .route("/metrics", get(handle_metrics))
 827        .layer(Extension(server))
 828}
 829
 830pub async fn handle_websocket_request(
 831    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 832    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 833    Extension(server): Extension<Arc<Server>>,
 834    Extension(user): Extension<User>,
 835    ws: WebSocketUpgrade,
 836) -> axum::response::Response {
 837    if protocol_version != rpc::PROTOCOL_VERSION {
 838        return (
 839            StatusCode::UPGRADE_REQUIRED,
 840            "client must be upgraded".to_string(),
 841        )
 842            .into_response();
 843    }
 844    let socket_address = socket_address.to_string();
 845    ws.on_upgrade(move |socket| {
 846        use util::ResultExt;
 847        let socket = socket
 848            .map_ok(to_tungstenite_message)
 849            .err_into()
 850            .with(|message| async move { Ok(to_axum_message(message)) });
 851        let connection = Connection::new(Box::pin(socket));
 852        async move {
 853            server
 854                .handle_connection(connection, socket_address, user, None, Executor::Production)
 855                .await
 856                .log_err();
 857        }
 858    })
 859}
 860
 861pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 862    let connections = server
 863        .connection_pool
 864        .lock()
 865        .connections()
 866        .filter(|connection| !connection.admin)
 867        .count();
 868
 869    METRIC_CONNECTIONS.set(connections as _);
 870
 871    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 872    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 873
 874    let encoder = prometheus::TextEncoder::new();
 875    let metric_families = prometheus::gather();
 876    let encoded_metrics = encoder
 877        .encode_to_string(&metric_families)
 878        .map_err(|err| anyhow!("{}", err))?;
 879    Ok(encoded_metrics)
 880}
 881
 882#[instrument(err, skip(executor))]
 883async fn connection_lost(
 884    session: Session,
 885    mut teardown: watch::Receiver<()>,
 886    executor: Executor,
 887) -> Result<()> {
 888    session.peer.disconnect(session.connection_id);
 889    session
 890        .connection_pool()
 891        .await
 892        .remove_connection(session.connection_id)?;
 893
 894    session
 895        .db()
 896        .await
 897        .connection_lost(session.connection_id)
 898        .await
 899        .trace_err();
 900
 901    futures::select_biased! {
 902        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 903            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 904            leave_room_for_session(&session).await.trace_err();
 905            leave_channel_buffers_for_session(&session)
 906                .await
 907                .trace_err();
 908
 909            if !session
 910                .connection_pool()
 911                .await
 912                .is_user_online(session.user_id)
 913            {
 914                let db = session.db().await;
 915                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 916                    room_updated(&room, &session.peer);
 917                }
 918            }
 919
 920            update_user_contacts(session.user_id, &session).await?;
 921        }
 922        _ = teardown.changed().fuse() => {}
 923    }
 924
 925    Ok(())
 926}
 927
 928async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 929    response.send(proto::Ack {})?;
 930    Ok(())
 931}
 932
 933async fn create_room(
 934    _request: proto::CreateRoom,
 935    response: Response<proto::CreateRoom>,
 936    session: Session,
 937) -> Result<()> {
 938    let live_kit_room = nanoid::nanoid!(30);
 939
 940    let live_kit_connection_info = {
 941        let live_kit_room = live_kit_room.clone();
 942        let live_kit = session.live_kit_client.as_ref();
 943
 944        util::async_iife!({
 945            let live_kit = live_kit?;
 946
 947            let token = live_kit
 948                .room_token(&live_kit_room, &session.user_id.to_string())
 949                .trace_err()?;
 950
 951            Some(proto::LiveKitConnectionInfo {
 952                server_url: live_kit.url().into(),
 953                token,
 954            })
 955        })
 956    }
 957    .await;
 958
 959    let room = session
 960        .db()
 961        .await
 962        .create_room(
 963            session.user_id,
 964            session.connection_id,
 965            &live_kit_room,
 966            RELEASE_CHANNEL_NAME.as_str(),
 967        )
 968        .await?;
 969
 970    response.send(proto::CreateRoomResponse {
 971        room: Some(room.clone()),
 972        live_kit_connection_info,
 973    })?;
 974
 975    update_user_contacts(session.user_id, &session).await?;
 976    Ok(())
 977}
 978
 979async fn join_room(
 980    request: proto::JoinRoom,
 981    response: Response<proto::JoinRoom>,
 982    session: Session,
 983) -> Result<()> {
 984    let room_id = RoomId::from_proto(request.id);
 985
 986    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 987
 988    if let Some(channel_id) = channel_id {
 989        return join_channel_internal(channel_id, Box::new(response), session).await;
 990    }
 991
 992    let joined_room = {
 993        let room = session
 994            .db()
 995            .await
 996            .join_room(
 997                room_id,
 998                session.user_id,
 999                session.connection_id,
1000                RELEASE_CHANNEL_NAME.as_str(),
1001            )
1002            .await?;
1003        room_updated(&room.room, &session.peer);
1004        room.into_inner()
1005    };
1006
1007    for connection_id in session
1008        .connection_pool()
1009        .await
1010        .user_connection_ids(session.user_id)
1011    {
1012        session
1013            .peer
1014            .send(
1015                connection_id,
1016                proto::CallCanceled {
1017                    room_id: room_id.to_proto(),
1018                },
1019            )
1020            .trace_err();
1021    }
1022
1023    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1024        if let Some(token) = live_kit
1025            .room_token(
1026                &joined_room.room.live_kit_room,
1027                &session.user_id.to_string(),
1028            )
1029            .trace_err()
1030        {
1031            Some(proto::LiveKitConnectionInfo {
1032                server_url: live_kit.url().into(),
1033                token,
1034            })
1035        } else {
1036            None
1037        }
1038    } else {
1039        None
1040    };
1041
1042    response.send(proto::JoinRoomResponse {
1043        room: Some(joined_room.room),
1044        channel_id: None,
1045        live_kit_connection_info,
1046    })?;
1047
1048    update_user_contacts(session.user_id, &session).await?;
1049    Ok(())
1050}
1051
1052async fn rejoin_room(
1053    request: proto::RejoinRoom,
1054    response: Response<proto::RejoinRoom>,
1055    session: Session,
1056) -> Result<()> {
1057    let room;
1058    let channel_id;
1059    let channel_members;
1060    {
1061        let mut rejoined_room = session
1062            .db()
1063            .await
1064            .rejoin_room(request, session.user_id, session.connection_id)
1065            .await?;
1066
1067        response.send(proto::RejoinRoomResponse {
1068            room: Some(rejoined_room.room.clone()),
1069            reshared_projects: rejoined_room
1070                .reshared_projects
1071                .iter()
1072                .map(|project| proto::ResharedProject {
1073                    id: project.id.to_proto(),
1074                    collaborators: project
1075                        .collaborators
1076                        .iter()
1077                        .map(|collaborator| collaborator.to_proto())
1078                        .collect(),
1079                })
1080                .collect(),
1081            rejoined_projects: rejoined_room
1082                .rejoined_projects
1083                .iter()
1084                .map(|rejoined_project| proto::RejoinedProject {
1085                    id: rejoined_project.id.to_proto(),
1086                    worktrees: rejoined_project
1087                        .worktrees
1088                        .iter()
1089                        .map(|worktree| proto::WorktreeMetadata {
1090                            id: worktree.id,
1091                            root_name: worktree.root_name.clone(),
1092                            visible: worktree.visible,
1093                            abs_path: worktree.abs_path.clone(),
1094                        })
1095                        .collect(),
1096                    collaborators: rejoined_project
1097                        .collaborators
1098                        .iter()
1099                        .map(|collaborator| collaborator.to_proto())
1100                        .collect(),
1101                    language_servers: rejoined_project.language_servers.clone(),
1102                })
1103                .collect(),
1104        })?;
1105        room_updated(&rejoined_room.room, &session.peer);
1106
1107        for project in &rejoined_room.reshared_projects {
1108            for collaborator in &project.collaborators {
1109                session
1110                    .peer
1111                    .send(
1112                        collaborator.connection_id,
1113                        proto::UpdateProjectCollaborator {
1114                            project_id: project.id.to_proto(),
1115                            old_peer_id: Some(project.old_connection_id.into()),
1116                            new_peer_id: Some(session.connection_id.into()),
1117                        },
1118                    )
1119                    .trace_err();
1120            }
1121
1122            broadcast(
1123                Some(session.connection_id),
1124                project
1125                    .collaborators
1126                    .iter()
1127                    .map(|collaborator| collaborator.connection_id),
1128                |connection_id| {
1129                    session.peer.forward_send(
1130                        session.connection_id,
1131                        connection_id,
1132                        proto::UpdateProject {
1133                            project_id: project.id.to_proto(),
1134                            worktrees: project.worktrees.clone(),
1135                        },
1136                    )
1137                },
1138            );
1139        }
1140
1141        for project in &rejoined_room.rejoined_projects {
1142            for collaborator in &project.collaborators {
1143                session
1144                    .peer
1145                    .send(
1146                        collaborator.connection_id,
1147                        proto::UpdateProjectCollaborator {
1148                            project_id: project.id.to_proto(),
1149                            old_peer_id: Some(project.old_connection_id.into()),
1150                            new_peer_id: Some(session.connection_id.into()),
1151                        },
1152                    )
1153                    .trace_err();
1154            }
1155        }
1156
1157        for project in &mut rejoined_room.rejoined_projects {
1158            for worktree in mem::take(&mut project.worktrees) {
1159                #[cfg(any(test, feature = "test-support"))]
1160                const MAX_CHUNK_SIZE: usize = 2;
1161                #[cfg(not(any(test, feature = "test-support")))]
1162                const MAX_CHUNK_SIZE: usize = 256;
1163
1164                // Stream this worktree's entries.
1165                let message = proto::UpdateWorktree {
1166                    project_id: project.id.to_proto(),
1167                    worktree_id: worktree.id,
1168                    abs_path: worktree.abs_path.clone(),
1169                    root_name: worktree.root_name,
1170                    updated_entries: worktree.updated_entries,
1171                    removed_entries: worktree.removed_entries,
1172                    scan_id: worktree.scan_id,
1173                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1174                    updated_repositories: worktree.updated_repositories,
1175                    removed_repositories: worktree.removed_repositories,
1176                };
1177                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1178                    session.peer.send(session.connection_id, update.clone())?;
1179                }
1180
1181                // Stream this worktree's diagnostics.
1182                for summary in worktree.diagnostic_summaries {
1183                    session.peer.send(
1184                        session.connection_id,
1185                        proto::UpdateDiagnosticSummary {
1186                            project_id: project.id.to_proto(),
1187                            worktree_id: worktree.id,
1188                            summary: Some(summary),
1189                        },
1190                    )?;
1191                }
1192
1193                for settings_file in worktree.settings_files {
1194                    session.peer.send(
1195                        session.connection_id,
1196                        proto::UpdateWorktreeSettings {
1197                            project_id: project.id.to_proto(),
1198                            worktree_id: worktree.id,
1199                            path: settings_file.path,
1200                            content: Some(settings_file.content),
1201                        },
1202                    )?;
1203                }
1204            }
1205
1206            for language_server in &project.language_servers {
1207                session.peer.send(
1208                    session.connection_id,
1209                    proto::UpdateLanguageServer {
1210                        project_id: project.id.to_proto(),
1211                        language_server_id: language_server.id,
1212                        variant: Some(
1213                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1214                                proto::LspDiskBasedDiagnosticsUpdated {},
1215                            ),
1216                        ),
1217                    },
1218                )?;
1219            }
1220        }
1221
1222        let rejoined_room = rejoined_room.into_inner();
1223
1224        room = rejoined_room.room;
1225        channel_id = rejoined_room.channel_id;
1226        channel_members = rejoined_room.channel_members;
1227    }
1228
1229    if let Some(channel_id) = channel_id {
1230        channel_updated(
1231            channel_id,
1232            &room,
1233            &channel_members,
1234            &session.peer,
1235            &*session.connection_pool().await,
1236        );
1237    }
1238
1239    update_user_contacts(session.user_id, &session).await?;
1240    Ok(())
1241}
1242
1243async fn leave_room(
1244    _: proto::LeaveRoom,
1245    response: Response<proto::LeaveRoom>,
1246    session: Session,
1247) -> Result<()> {
1248    leave_room_for_session(&session).await?;
1249    response.send(proto::Ack {})?;
1250    Ok(())
1251}
1252
1253async fn call(
1254    request: proto::Call,
1255    response: Response<proto::Call>,
1256    session: Session,
1257) -> Result<()> {
1258    let room_id = RoomId::from_proto(request.room_id);
1259    let calling_user_id = session.user_id;
1260    let calling_connection_id = session.connection_id;
1261    let called_user_id = UserId::from_proto(request.called_user_id);
1262    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1263    if !session
1264        .db()
1265        .await
1266        .has_contact(calling_user_id, called_user_id)
1267        .await?
1268    {
1269        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1270    }
1271
1272    let incoming_call = {
1273        let (room, incoming_call) = &mut *session
1274            .db()
1275            .await
1276            .call(
1277                room_id,
1278                calling_user_id,
1279                calling_connection_id,
1280                called_user_id,
1281                initial_project_id,
1282            )
1283            .await?;
1284        room_updated(&room, &session.peer);
1285        mem::take(incoming_call)
1286    };
1287    update_user_contacts(called_user_id, &session).await?;
1288
1289    let mut calls = session
1290        .connection_pool()
1291        .await
1292        .user_connection_ids(called_user_id)
1293        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1294        .collect::<FuturesUnordered<_>>();
1295
1296    while let Some(call_response) = calls.next().await {
1297        match call_response.as_ref() {
1298            Ok(_) => {
1299                response.send(proto::Ack {})?;
1300                return Ok(());
1301            }
1302            Err(_) => {
1303                call_response.trace_err();
1304            }
1305        }
1306    }
1307
1308    {
1309        let room = session
1310            .db()
1311            .await
1312            .call_failed(room_id, called_user_id)
1313            .await?;
1314        room_updated(&room, &session.peer);
1315    }
1316    update_user_contacts(called_user_id, &session).await?;
1317
1318    Err(anyhow!("failed to ring user"))?
1319}
1320
1321async fn cancel_call(
1322    request: proto::CancelCall,
1323    response: Response<proto::CancelCall>,
1324    session: Session,
1325) -> Result<()> {
1326    let called_user_id = UserId::from_proto(request.called_user_id);
1327    let room_id = RoomId::from_proto(request.room_id);
1328    {
1329        let room = session
1330            .db()
1331            .await
1332            .cancel_call(room_id, session.connection_id, called_user_id)
1333            .await?;
1334        room_updated(&room, &session.peer);
1335    }
1336
1337    for connection_id in session
1338        .connection_pool()
1339        .await
1340        .user_connection_ids(called_user_id)
1341    {
1342        session
1343            .peer
1344            .send(
1345                connection_id,
1346                proto::CallCanceled {
1347                    room_id: room_id.to_proto(),
1348                },
1349            )
1350            .trace_err();
1351    }
1352    response.send(proto::Ack {})?;
1353
1354    update_user_contacts(called_user_id, &session).await?;
1355    Ok(())
1356}
1357
1358async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1359    let room_id = RoomId::from_proto(message.room_id);
1360    {
1361        let room = session
1362            .db()
1363            .await
1364            .decline_call(Some(room_id), session.user_id)
1365            .await?
1366            .ok_or_else(|| anyhow!("failed to decline call"))?;
1367        room_updated(&room, &session.peer);
1368    }
1369
1370    for connection_id in session
1371        .connection_pool()
1372        .await
1373        .user_connection_ids(session.user_id)
1374    {
1375        session
1376            .peer
1377            .send(
1378                connection_id,
1379                proto::CallCanceled {
1380                    room_id: room_id.to_proto(),
1381                },
1382            )
1383            .trace_err();
1384    }
1385    update_user_contacts(session.user_id, &session).await?;
1386    Ok(())
1387}
1388
1389async fn update_participant_location(
1390    request: proto::UpdateParticipantLocation,
1391    response: Response<proto::UpdateParticipantLocation>,
1392    session: Session,
1393) -> Result<()> {
1394    let room_id = RoomId::from_proto(request.room_id);
1395    let location = request
1396        .location
1397        .ok_or_else(|| anyhow!("invalid location"))?;
1398
1399    let db = session.db().await;
1400    let room = db
1401        .update_room_participant_location(room_id, session.connection_id, location)
1402        .await?;
1403
1404    room_updated(&room, &session.peer);
1405    response.send(proto::Ack {})?;
1406    Ok(())
1407}
1408
1409async fn share_project(
1410    request: proto::ShareProject,
1411    response: Response<proto::ShareProject>,
1412    session: Session,
1413) -> Result<()> {
1414    let (project_id, room) = &*session
1415        .db()
1416        .await
1417        .share_project(
1418            RoomId::from_proto(request.room_id),
1419            session.connection_id,
1420            &request.worktrees,
1421        )
1422        .await?;
1423    response.send(proto::ShareProjectResponse {
1424        project_id: project_id.to_proto(),
1425    })?;
1426    room_updated(&room, &session.peer);
1427
1428    Ok(())
1429}
1430
1431async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1432    let project_id = ProjectId::from_proto(message.project_id);
1433
1434    let (room, guest_connection_ids) = &*session
1435        .db()
1436        .await
1437        .unshare_project(project_id, session.connection_id)
1438        .await?;
1439
1440    broadcast(
1441        Some(session.connection_id),
1442        guest_connection_ids.iter().copied(),
1443        |conn_id| session.peer.send(conn_id, message.clone()),
1444    );
1445    room_updated(&room, &session.peer);
1446
1447    Ok(())
1448}
1449
1450async fn join_project(
1451    request: proto::JoinProject,
1452    response: Response<proto::JoinProject>,
1453    session: Session,
1454) -> Result<()> {
1455    let project_id = ProjectId::from_proto(request.project_id);
1456    let guest_user_id = session.user_id;
1457
1458    tracing::info!(%project_id, "join project");
1459
1460    let (project, replica_id) = &mut *session
1461        .db()
1462        .await
1463        .join_project(project_id, session.connection_id)
1464        .await?;
1465
1466    let collaborators = project
1467        .collaborators
1468        .iter()
1469        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1470        .map(|collaborator| collaborator.to_proto())
1471        .collect::<Vec<_>>();
1472
1473    let worktrees = project
1474        .worktrees
1475        .iter()
1476        .map(|(id, worktree)| proto::WorktreeMetadata {
1477            id: *id,
1478            root_name: worktree.root_name.clone(),
1479            visible: worktree.visible,
1480            abs_path: worktree.abs_path.clone(),
1481        })
1482        .collect::<Vec<_>>();
1483
1484    for collaborator in &collaborators {
1485        session
1486            .peer
1487            .send(
1488                collaborator.peer_id.unwrap().into(),
1489                proto::AddProjectCollaborator {
1490                    project_id: project_id.to_proto(),
1491                    collaborator: Some(proto::Collaborator {
1492                        peer_id: Some(session.connection_id.into()),
1493                        replica_id: replica_id.0 as u32,
1494                        user_id: guest_user_id.to_proto(),
1495                    }),
1496                },
1497            )
1498            .trace_err();
1499    }
1500
1501    // First, we send the metadata associated with each worktree.
1502    response.send(proto::JoinProjectResponse {
1503        worktrees: worktrees.clone(),
1504        replica_id: replica_id.0 as u32,
1505        collaborators: collaborators.clone(),
1506        language_servers: project.language_servers.clone(),
1507    })?;
1508
1509    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1510        #[cfg(any(test, feature = "test-support"))]
1511        const MAX_CHUNK_SIZE: usize = 2;
1512        #[cfg(not(any(test, feature = "test-support")))]
1513        const MAX_CHUNK_SIZE: usize = 256;
1514
1515        // Stream this worktree's entries.
1516        let message = proto::UpdateWorktree {
1517            project_id: project_id.to_proto(),
1518            worktree_id,
1519            abs_path: worktree.abs_path.clone(),
1520            root_name: worktree.root_name,
1521            updated_entries: worktree.entries,
1522            removed_entries: Default::default(),
1523            scan_id: worktree.scan_id,
1524            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1525            updated_repositories: worktree.repository_entries.into_values().collect(),
1526            removed_repositories: Default::default(),
1527        };
1528        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1529            session.peer.send(session.connection_id, update.clone())?;
1530        }
1531
1532        // Stream this worktree's diagnostics.
1533        for summary in worktree.diagnostic_summaries {
1534            session.peer.send(
1535                session.connection_id,
1536                proto::UpdateDiagnosticSummary {
1537                    project_id: project_id.to_proto(),
1538                    worktree_id: worktree.id,
1539                    summary: Some(summary),
1540                },
1541            )?;
1542        }
1543
1544        for settings_file in worktree.settings_files {
1545            session.peer.send(
1546                session.connection_id,
1547                proto::UpdateWorktreeSettings {
1548                    project_id: project_id.to_proto(),
1549                    worktree_id: worktree.id,
1550                    path: settings_file.path,
1551                    content: Some(settings_file.content),
1552                },
1553            )?;
1554        }
1555    }
1556
1557    for language_server in &project.language_servers {
1558        session.peer.send(
1559            session.connection_id,
1560            proto::UpdateLanguageServer {
1561                project_id: project_id.to_proto(),
1562                language_server_id: language_server.id,
1563                variant: Some(
1564                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1565                        proto::LspDiskBasedDiagnosticsUpdated {},
1566                    ),
1567                ),
1568            },
1569        )?;
1570    }
1571
1572    Ok(())
1573}
1574
1575async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1576    let sender_id = session.connection_id;
1577    let project_id = ProjectId::from_proto(request.project_id);
1578
1579    let (room, project) = &*session
1580        .db()
1581        .await
1582        .leave_project(project_id, sender_id)
1583        .await?;
1584    tracing::info!(
1585        %project_id,
1586        host_user_id = %project.host_user_id,
1587        host_connection_id = %project.host_connection_id,
1588        "leave project"
1589    );
1590
1591    project_left(&project, &session);
1592    room_updated(&room, &session.peer);
1593
1594    Ok(())
1595}
1596
1597async fn update_project(
1598    request: proto::UpdateProject,
1599    response: Response<proto::UpdateProject>,
1600    session: Session,
1601) -> Result<()> {
1602    let project_id = ProjectId::from_proto(request.project_id);
1603    let (room, guest_connection_ids) = &*session
1604        .db()
1605        .await
1606        .update_project(project_id, session.connection_id, &request.worktrees)
1607        .await?;
1608    broadcast(
1609        Some(session.connection_id),
1610        guest_connection_ids.iter().copied(),
1611        |connection_id| {
1612            session
1613                .peer
1614                .forward_send(session.connection_id, connection_id, request.clone())
1615        },
1616    );
1617    room_updated(&room, &session.peer);
1618    response.send(proto::Ack {})?;
1619
1620    Ok(())
1621}
1622
1623async fn update_worktree(
1624    request: proto::UpdateWorktree,
1625    response: Response<proto::UpdateWorktree>,
1626    session: Session,
1627) -> Result<()> {
1628    let guest_connection_ids = session
1629        .db()
1630        .await
1631        .update_worktree(&request, session.connection_id)
1632        .await?;
1633
1634    broadcast(
1635        Some(session.connection_id),
1636        guest_connection_ids.iter().copied(),
1637        |connection_id| {
1638            session
1639                .peer
1640                .forward_send(session.connection_id, connection_id, request.clone())
1641        },
1642    );
1643    response.send(proto::Ack {})?;
1644    Ok(())
1645}
1646
1647async fn update_diagnostic_summary(
1648    message: proto::UpdateDiagnosticSummary,
1649    session: Session,
1650) -> Result<()> {
1651    let guest_connection_ids = session
1652        .db()
1653        .await
1654        .update_diagnostic_summary(&message, session.connection_id)
1655        .await?;
1656
1657    broadcast(
1658        Some(session.connection_id),
1659        guest_connection_ids.iter().copied(),
1660        |connection_id| {
1661            session
1662                .peer
1663                .forward_send(session.connection_id, connection_id, message.clone())
1664        },
1665    );
1666
1667    Ok(())
1668}
1669
1670async fn update_worktree_settings(
1671    message: proto::UpdateWorktreeSettings,
1672    session: Session,
1673) -> Result<()> {
1674    let guest_connection_ids = session
1675        .db()
1676        .await
1677        .update_worktree_settings(&message, session.connection_id)
1678        .await?;
1679
1680    broadcast(
1681        Some(session.connection_id),
1682        guest_connection_ids.iter().copied(),
1683        |connection_id| {
1684            session
1685                .peer
1686                .forward_send(session.connection_id, connection_id, message.clone())
1687        },
1688    );
1689
1690    Ok(())
1691}
1692
1693async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1694    broadcast_project_message(request.project_id, request, session).await
1695}
1696
1697async fn start_language_server(
1698    request: proto::StartLanguageServer,
1699    session: Session,
1700) -> Result<()> {
1701    let guest_connection_ids = session
1702        .db()
1703        .await
1704        .start_language_server(&request, session.connection_id)
1705        .await?;
1706
1707    broadcast(
1708        Some(session.connection_id),
1709        guest_connection_ids.iter().copied(),
1710        |connection_id| {
1711            session
1712                .peer
1713                .forward_send(session.connection_id, connection_id, request.clone())
1714        },
1715    );
1716    Ok(())
1717}
1718
1719async fn update_language_server(
1720    request: proto::UpdateLanguageServer,
1721    session: Session,
1722) -> Result<()> {
1723    session.executor.record_backtrace();
1724    let project_id = ProjectId::from_proto(request.project_id);
1725    let project_connection_ids = session
1726        .db()
1727        .await
1728        .project_connection_ids(project_id, session.connection_id)
1729        .await?;
1730    broadcast(
1731        Some(session.connection_id),
1732        project_connection_ids.iter().copied(),
1733        |connection_id| {
1734            session
1735                .peer
1736                .forward_send(session.connection_id, connection_id, request.clone())
1737        },
1738    );
1739    Ok(())
1740}
1741
1742async fn forward_project_request<T>(
1743    request: T,
1744    response: Response<T>,
1745    session: Session,
1746) -> Result<()>
1747where
1748    T: EntityMessage + RequestMessage,
1749{
1750    session.executor.record_backtrace();
1751    let project_id = ProjectId::from_proto(request.remote_entity_id());
1752    let host_connection_id = {
1753        let collaborators = session
1754            .db()
1755            .await
1756            .project_collaborators(project_id, session.connection_id)
1757            .await?;
1758        collaborators
1759            .iter()
1760            .find(|collaborator| collaborator.is_host)
1761            .ok_or_else(|| anyhow!("host not found"))?
1762            .connection_id
1763    };
1764
1765    let payload = session
1766        .peer
1767        .forward_request(session.connection_id, host_connection_id, request)
1768        .await?;
1769
1770    response.send(payload)?;
1771    Ok(())
1772}
1773
1774async fn create_buffer_for_peer(
1775    request: proto::CreateBufferForPeer,
1776    session: Session,
1777) -> Result<()> {
1778    session.executor.record_backtrace();
1779    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1780    session
1781        .peer
1782        .forward_send(session.connection_id, peer_id.into(), request)?;
1783    Ok(())
1784}
1785
1786async fn update_buffer(
1787    request: proto::UpdateBuffer,
1788    response: Response<proto::UpdateBuffer>,
1789    session: Session,
1790) -> Result<()> {
1791    session.executor.record_backtrace();
1792    let project_id = ProjectId::from_proto(request.project_id);
1793    let mut guest_connection_ids;
1794    let mut host_connection_id = None;
1795    {
1796        let collaborators = session
1797            .db()
1798            .await
1799            .project_collaborators(project_id, session.connection_id)
1800            .await?;
1801        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1802        for collaborator in collaborators.iter() {
1803            if collaborator.is_host {
1804                host_connection_id = Some(collaborator.connection_id);
1805            } else {
1806                guest_connection_ids.push(collaborator.connection_id);
1807            }
1808        }
1809    }
1810    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1811
1812    session.executor.record_backtrace();
1813    broadcast(
1814        Some(session.connection_id),
1815        guest_connection_ids,
1816        |connection_id| {
1817            session
1818                .peer
1819                .forward_send(session.connection_id, connection_id, request.clone())
1820        },
1821    );
1822    if host_connection_id != session.connection_id {
1823        session
1824            .peer
1825            .forward_request(session.connection_id, host_connection_id, request.clone())
1826            .await?;
1827    }
1828
1829    response.send(proto::Ack {})?;
1830    Ok(())
1831}
1832
1833async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1834    let project_id = ProjectId::from_proto(request.project_id);
1835    let project_connection_ids = session
1836        .db()
1837        .await
1838        .project_connection_ids(project_id, session.connection_id)
1839        .await?;
1840
1841    broadcast(
1842        Some(session.connection_id),
1843        project_connection_ids.iter().copied(),
1844        |connection_id| {
1845            session
1846                .peer
1847                .forward_send(session.connection_id, connection_id, request.clone())
1848        },
1849    );
1850    Ok(())
1851}
1852
1853async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1854    let project_id = ProjectId::from_proto(request.project_id);
1855    let project_connection_ids = session
1856        .db()
1857        .await
1858        .project_connection_ids(project_id, session.connection_id)
1859        .await?;
1860    broadcast(
1861        Some(session.connection_id),
1862        project_connection_ids.iter().copied(),
1863        |connection_id| {
1864            session
1865                .peer
1866                .forward_send(session.connection_id, connection_id, request.clone())
1867        },
1868    );
1869    Ok(())
1870}
1871
1872async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1873    broadcast_project_message(request.project_id, request, session).await
1874}
1875
1876async fn broadcast_project_message<T: EnvelopedMessage>(
1877    project_id: u64,
1878    request: T,
1879    session: Session,
1880) -> Result<()> {
1881    let project_id = ProjectId::from_proto(project_id);
1882    let project_connection_ids = session
1883        .db()
1884        .await
1885        .project_connection_ids(project_id, session.connection_id)
1886        .await?;
1887    broadcast(
1888        Some(session.connection_id),
1889        project_connection_ids.iter().copied(),
1890        |connection_id| {
1891            session
1892                .peer
1893                .forward_send(session.connection_id, connection_id, request.clone())
1894        },
1895    );
1896    Ok(())
1897}
1898
1899async fn follow(
1900    request: proto::Follow,
1901    response: Response<proto::Follow>,
1902    session: Session,
1903) -> Result<()> {
1904    let room_id = RoomId::from_proto(request.room_id);
1905    let project_id = request.project_id.map(ProjectId::from_proto);
1906    let leader_id = request
1907        .leader_id
1908        .ok_or_else(|| anyhow!("invalid leader id"))?
1909        .into();
1910    let follower_id = session.connection_id;
1911
1912    session
1913        .db()
1914        .await
1915        .check_room_participants(room_id, leader_id, session.connection_id)
1916        .await?;
1917
1918    let response_payload = session
1919        .peer
1920        .forward_request(session.connection_id, leader_id, request)
1921        .await?;
1922    response.send(response_payload)?;
1923
1924    if let Some(project_id) = project_id {
1925        let room = session
1926            .db()
1927            .await
1928            .follow(room_id, project_id, leader_id, follower_id)
1929            .await?;
1930        room_updated(&room, &session.peer);
1931    }
1932
1933    Ok(())
1934}
1935
1936async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1937    let room_id = RoomId::from_proto(request.room_id);
1938    let project_id = request.project_id.map(ProjectId::from_proto);
1939    let leader_id = request
1940        .leader_id
1941        .ok_or_else(|| anyhow!("invalid leader id"))?
1942        .into();
1943    let follower_id = session.connection_id;
1944
1945    session
1946        .db()
1947        .await
1948        .check_room_participants(room_id, leader_id, session.connection_id)
1949        .await?;
1950
1951    session
1952        .peer
1953        .forward_send(session.connection_id, leader_id, request)?;
1954
1955    if let Some(project_id) = project_id {
1956        let room = session
1957            .db()
1958            .await
1959            .unfollow(room_id, project_id, leader_id, follower_id)
1960            .await?;
1961        room_updated(&room, &session.peer);
1962    }
1963
1964    Ok(())
1965}
1966
1967async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1968    let room_id = RoomId::from_proto(request.room_id);
1969    let database = session.db.lock().await;
1970
1971    let connection_ids = if let Some(project_id) = request.project_id {
1972        let project_id = ProjectId::from_proto(project_id);
1973        database
1974            .project_connection_ids(project_id, session.connection_id)
1975            .await?
1976    } else {
1977        database
1978            .room_connection_ids(room_id, session.connection_id)
1979            .await?
1980    };
1981
1982    // For now, don't send view update messages back to that view's current leader.
1983    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1984        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1985        _ => None,
1986    });
1987
1988    for follower_peer_id in request.follower_ids.iter().copied() {
1989        let follower_connection_id = follower_peer_id.into();
1990        if Some(follower_peer_id) != connection_id_to_omit
1991            && connection_ids.contains(&follower_connection_id)
1992        {
1993            session.peer.forward_send(
1994                session.connection_id,
1995                follower_connection_id,
1996                request.clone(),
1997            )?;
1998        }
1999    }
2000    Ok(())
2001}
2002
2003async fn get_users(
2004    request: proto::GetUsers,
2005    response: Response<proto::GetUsers>,
2006    session: Session,
2007) -> Result<()> {
2008    let user_ids = request
2009        .user_ids
2010        .into_iter()
2011        .map(UserId::from_proto)
2012        .collect();
2013    let users = session
2014        .db()
2015        .await
2016        .get_users_by_ids(user_ids)
2017        .await?
2018        .into_iter()
2019        .map(|user| proto::User {
2020            id: user.id.to_proto(),
2021            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2022            github_login: user.github_login,
2023        })
2024        .collect();
2025    response.send(proto::UsersResponse { users })?;
2026    Ok(())
2027}
2028
2029async fn fuzzy_search_users(
2030    request: proto::FuzzySearchUsers,
2031    response: Response<proto::FuzzySearchUsers>,
2032    session: Session,
2033) -> Result<()> {
2034    let query = request.query;
2035    let users = match query.len() {
2036        0 => vec![],
2037        1 | 2 => session
2038            .db()
2039            .await
2040            .get_user_by_github_login(&query)
2041            .await?
2042            .into_iter()
2043            .collect(),
2044        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2045    };
2046    let users = users
2047        .into_iter()
2048        .filter(|user| user.id != session.user_id)
2049        .map(|user| proto::User {
2050            id: user.id.to_proto(),
2051            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2052            github_login: user.github_login,
2053        })
2054        .collect();
2055    response.send(proto::UsersResponse { users })?;
2056    Ok(())
2057}
2058
2059async fn request_contact(
2060    request: proto::RequestContact,
2061    response: Response<proto::RequestContact>,
2062    session: Session,
2063) -> Result<()> {
2064    let requester_id = session.user_id;
2065    let responder_id = UserId::from_proto(request.responder_id);
2066    if requester_id == responder_id {
2067        return Err(anyhow!("cannot add yourself as a contact"))?;
2068    }
2069
2070    let notifications = session
2071        .db()
2072        .await
2073        .send_contact_request(requester_id, responder_id)
2074        .await?;
2075
2076    // Update outgoing contact requests of requester
2077    let mut update = proto::UpdateContacts::default();
2078    update.outgoing_requests.push(responder_id.to_proto());
2079    for connection_id in session
2080        .connection_pool()
2081        .await
2082        .user_connection_ids(requester_id)
2083    {
2084        session.peer.send(connection_id, update.clone())?;
2085    }
2086
2087    // Update incoming contact requests of responder
2088    let mut update = proto::UpdateContacts::default();
2089    update
2090        .incoming_requests
2091        .push(proto::IncomingContactRequest {
2092            requester_id: requester_id.to_proto(),
2093        });
2094    let connection_pool = session.connection_pool().await;
2095    for connection_id in connection_pool.user_connection_ids(responder_id) {
2096        session.peer.send(connection_id, update.clone())?;
2097    }
2098
2099    send_notifications(&*connection_pool, &session.peer, notifications);
2100
2101    response.send(proto::Ack {})?;
2102    Ok(())
2103}
2104
2105async fn respond_to_contact_request(
2106    request: proto::RespondToContactRequest,
2107    response: Response<proto::RespondToContactRequest>,
2108    session: Session,
2109) -> Result<()> {
2110    let responder_id = session.user_id;
2111    let requester_id = UserId::from_proto(request.requester_id);
2112    let db = session.db().await;
2113    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2114        db.dismiss_contact_notification(responder_id, requester_id)
2115            .await?;
2116    } else {
2117        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2118
2119        let notifications = db
2120            .respond_to_contact_request(responder_id, requester_id, accept)
2121            .await?;
2122        let requester_busy = db.is_user_busy(requester_id).await?;
2123        let responder_busy = db.is_user_busy(responder_id).await?;
2124
2125        let pool = session.connection_pool().await;
2126        // Update responder with new contact
2127        let mut update = proto::UpdateContacts::default();
2128        if accept {
2129            update
2130                .contacts
2131                .push(contact_for_user(requester_id, requester_busy, &pool));
2132        }
2133        update
2134            .remove_incoming_requests
2135            .push(requester_id.to_proto());
2136        for connection_id in pool.user_connection_ids(responder_id) {
2137            session.peer.send(connection_id, update.clone())?;
2138        }
2139
2140        // Update requester with new contact
2141        let mut update = proto::UpdateContacts::default();
2142        if accept {
2143            update
2144                .contacts
2145                .push(contact_for_user(responder_id, responder_busy, &pool));
2146        }
2147        update
2148            .remove_outgoing_requests
2149            .push(responder_id.to_proto());
2150
2151        for connection_id in pool.user_connection_ids(requester_id) {
2152            session.peer.send(connection_id, update.clone())?;
2153        }
2154
2155        send_notifications(&*pool, &session.peer, notifications);
2156    }
2157
2158    response.send(proto::Ack {})?;
2159    Ok(())
2160}
2161
2162async fn remove_contact(
2163    request: proto::RemoveContact,
2164    response: Response<proto::RemoveContact>,
2165    session: Session,
2166) -> Result<()> {
2167    let requester_id = session.user_id;
2168    let responder_id = UserId::from_proto(request.user_id);
2169    let db = session.db().await;
2170    let (contact_accepted, deleted_notification_id) =
2171        db.remove_contact(requester_id, responder_id).await?;
2172
2173    let pool = session.connection_pool().await;
2174    // Update outgoing contact requests of requester
2175    let mut update = proto::UpdateContacts::default();
2176    if contact_accepted {
2177        update.remove_contacts.push(responder_id.to_proto());
2178    } else {
2179        update
2180            .remove_outgoing_requests
2181            .push(responder_id.to_proto());
2182    }
2183    for connection_id in pool.user_connection_ids(requester_id) {
2184        session.peer.send(connection_id, update.clone())?;
2185    }
2186
2187    // Update incoming contact requests of responder
2188    let mut update = proto::UpdateContacts::default();
2189    if contact_accepted {
2190        update.remove_contacts.push(requester_id.to_proto());
2191    } else {
2192        update
2193            .remove_incoming_requests
2194            .push(requester_id.to_proto());
2195    }
2196    for connection_id in pool.user_connection_ids(responder_id) {
2197        session.peer.send(connection_id, update.clone())?;
2198        if let Some(notification_id) = deleted_notification_id {
2199            session.peer.send(
2200                connection_id,
2201                proto::DeleteNotification {
2202                    notification_id: notification_id.to_proto(),
2203                },
2204            )?;
2205        }
2206    }
2207
2208    response.send(proto::Ack {})?;
2209    Ok(())
2210}
2211
2212async fn create_channel(
2213    request: proto::CreateChannel,
2214    response: Response<proto::CreateChannel>,
2215    session: Session,
2216) -> Result<()> {
2217    let db = session.db().await;
2218
2219    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2220    let id = db
2221        .create_channel(&request.name, parent_id, session.user_id)
2222        .await?;
2223
2224    let channel = proto::Channel {
2225        id: id.to_proto(),
2226        name: request.name,
2227        visibility: proto::ChannelVisibility::Members as i32,
2228    };
2229
2230    response.send(proto::CreateChannelResponse {
2231        channel: Some(channel.clone()),
2232        parent_id: request.parent_id,
2233    })?;
2234
2235    let Some(parent_id) = parent_id else {
2236        return Ok(());
2237    };
2238
2239    let update = proto::UpdateChannels {
2240        channels: vec![channel],
2241        insert_edge: vec![ChannelEdge {
2242            parent_id: parent_id.to_proto(),
2243            channel_id: id.to_proto(),
2244        }],
2245        ..Default::default()
2246    };
2247
2248    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2249
2250    let connection_pool = session.connection_pool().await;
2251    for user_id in user_ids_to_notify {
2252        for connection_id in connection_pool.user_connection_ids(user_id) {
2253            if user_id == session.user_id {
2254                continue;
2255            }
2256            session.peer.send(connection_id, update.clone())?;
2257        }
2258    }
2259
2260    Ok(())
2261}
2262
2263async fn delete_channel(
2264    request: proto::DeleteChannel,
2265    response: Response<proto::DeleteChannel>,
2266    session: Session,
2267) -> Result<()> {
2268    let db = session.db().await;
2269
2270    let channel_id = request.channel_id;
2271    let (removed_channels, member_ids) = db
2272        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2273        .await?;
2274    response.send(proto::Ack {})?;
2275
2276    // Notify members of removed channels
2277    let mut update = proto::UpdateChannels::default();
2278    update
2279        .delete_channels
2280        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2281
2282    let connection_pool = session.connection_pool().await;
2283    for member_id in member_ids {
2284        for connection_id in connection_pool.user_connection_ids(member_id) {
2285            session.peer.send(connection_id, update.clone())?;
2286        }
2287    }
2288
2289    Ok(())
2290}
2291
2292async fn invite_channel_member(
2293    request: proto::InviteChannelMember,
2294    response: Response<proto::InviteChannelMember>,
2295    session: Session,
2296) -> Result<()> {
2297    let db = session.db().await;
2298    let channel_id = ChannelId::from_proto(request.channel_id);
2299    let invitee_id = UserId::from_proto(request.user_id);
2300    let notifications = db
2301        .invite_channel_member(
2302            channel_id,
2303            invitee_id,
2304            session.user_id,
2305            request.role().into(),
2306        )
2307        .await?;
2308
2309    let channel = db.get_channel(channel_id, session.user_id).await?;
2310
2311    let mut update = proto::UpdateChannels::default();
2312    update.channel_invitations.push(proto::Channel {
2313        id: channel.id.to_proto(),
2314        visibility: channel.visibility.into(),
2315        name: channel.name,
2316    });
2317
2318    let pool = session.connection_pool().await;
2319    for connection_id in pool.user_connection_ids(invitee_id) {
2320        session.peer.send(connection_id, update.clone())?;
2321    }
2322
2323    send_notifications(&*pool, &session.peer, notifications);
2324
2325    response.send(proto::Ack {})?;
2326    Ok(())
2327}
2328
2329async fn remove_channel_member(
2330    request: proto::RemoveChannelMember,
2331    response: Response<proto::RemoveChannelMember>,
2332    session: Session,
2333) -> Result<()> {
2334    let db = session.db().await;
2335    let channel_id = ChannelId::from_proto(request.channel_id);
2336    let member_id = UserId::from_proto(request.user_id);
2337
2338    let removed_notification_id = db
2339        .remove_channel_member(channel_id, member_id, session.user_id)
2340        .await?;
2341
2342    let mut update = proto::UpdateChannels::default();
2343    update.delete_channels.push(channel_id.to_proto());
2344
2345    for connection_id in session
2346        .connection_pool()
2347        .await
2348        .user_connection_ids(member_id)
2349    {
2350        session.peer.send(connection_id, update.clone()).trace_err();
2351        if let Some(notification_id) = removed_notification_id {
2352            session
2353                .peer
2354                .send(
2355                    connection_id,
2356                    proto::DeleteNotification {
2357                        notification_id: notification_id.to_proto(),
2358                    },
2359                )
2360                .trace_err();
2361        }
2362    }
2363
2364    response.send(proto::Ack {})?;
2365    Ok(())
2366}
2367
2368async fn set_channel_visibility(
2369    request: proto::SetChannelVisibility,
2370    response: Response<proto::SetChannelVisibility>,
2371    session: Session,
2372) -> Result<()> {
2373    let db = session.db().await;
2374    let channel_id = ChannelId::from_proto(request.channel_id);
2375    let visibility = request.visibility().into();
2376
2377    let channel = db
2378        .set_channel_visibility(channel_id, visibility, session.user_id)
2379        .await?;
2380
2381    let mut update = proto::UpdateChannels::default();
2382    update.channels.push(proto::Channel {
2383        id: channel.id.to_proto(),
2384        name: channel.name,
2385        visibility: channel.visibility.into(),
2386    });
2387
2388    let member_ids = db.get_channel_members(channel_id).await?;
2389
2390    let connection_pool = session.connection_pool().await;
2391    for member_id in member_ids {
2392        for connection_id in connection_pool.user_connection_ids(member_id) {
2393            session.peer.send(connection_id, update.clone())?;
2394        }
2395    }
2396
2397    response.send(proto::Ack {})?;
2398    Ok(())
2399}
2400
2401async fn set_channel_member_role(
2402    request: proto::SetChannelMemberRole,
2403    response: Response<proto::SetChannelMemberRole>,
2404    session: Session,
2405) -> Result<()> {
2406    let db = session.db().await;
2407    let channel_id = ChannelId::from_proto(request.channel_id);
2408    let member_id = UserId::from_proto(request.user_id);
2409    let channel_member = db
2410        .set_channel_member_role(
2411            channel_id,
2412            session.user_id,
2413            member_id,
2414            request.role().into(),
2415        )
2416        .await?;
2417
2418    let channel = db.get_channel(channel_id, session.user_id).await?;
2419
2420    let mut update = proto::UpdateChannels::default();
2421    if channel_member.accepted {
2422        update.channel_permissions.push(proto::ChannelPermission {
2423            channel_id: channel.id.to_proto(),
2424            role: request.role,
2425        });
2426    }
2427
2428    for connection_id in session
2429        .connection_pool()
2430        .await
2431        .user_connection_ids(member_id)
2432    {
2433        session.peer.send(connection_id, update.clone())?;
2434    }
2435
2436    response.send(proto::Ack {})?;
2437    Ok(())
2438}
2439
2440async fn rename_channel(
2441    request: proto::RenameChannel,
2442    response: Response<proto::RenameChannel>,
2443    session: Session,
2444) -> Result<()> {
2445    let db = session.db().await;
2446    let channel_id = ChannelId::from_proto(request.channel_id);
2447    let channel = db
2448        .rename_channel(channel_id, session.user_id, &request.name)
2449        .await?;
2450
2451    let channel = proto::Channel {
2452        id: channel.id.to_proto(),
2453        name: channel.name,
2454        visibility: channel.visibility.into(),
2455    };
2456    response.send(proto::RenameChannelResponse {
2457        channel: Some(channel.clone()),
2458    })?;
2459    let mut update = proto::UpdateChannels::default();
2460    update.channels.push(channel);
2461
2462    let member_ids = db.get_channel_members(channel_id).await?;
2463
2464    let connection_pool = session.connection_pool().await;
2465    for member_id in member_ids {
2466        for connection_id in connection_pool.user_connection_ids(member_id) {
2467            session.peer.send(connection_id, update.clone())?;
2468        }
2469    }
2470
2471    Ok(())
2472}
2473
2474async fn link_channel(
2475    request: proto::LinkChannel,
2476    response: Response<proto::LinkChannel>,
2477    session: Session,
2478) -> Result<()> {
2479    let db = session.db().await;
2480    let channel_id = ChannelId::from_proto(request.channel_id);
2481    let to = ChannelId::from_proto(request.to);
2482    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2483
2484    let members = db.get_channel_members(to).await?;
2485    let connection_pool = session.connection_pool().await;
2486    let update = proto::UpdateChannels {
2487        channels: channels_to_send
2488            .channels
2489            .into_iter()
2490            .map(|channel| proto::Channel {
2491                id: channel.id.to_proto(),
2492                visibility: channel.visibility.into(),
2493                name: channel.name,
2494            })
2495            .collect(),
2496        insert_edge: channels_to_send.edges,
2497        ..Default::default()
2498    };
2499    for member_id in members {
2500        for connection_id in connection_pool.user_connection_ids(member_id) {
2501            session.peer.send(connection_id, update.clone())?;
2502        }
2503    }
2504
2505    response.send(Ack {})?;
2506
2507    Ok(())
2508}
2509
2510async fn unlink_channel(
2511    request: proto::UnlinkChannel,
2512    response: Response<proto::UnlinkChannel>,
2513    session: Session,
2514) -> Result<()> {
2515    let db = session.db().await;
2516    let channel_id = ChannelId::from_proto(request.channel_id);
2517    let from = ChannelId::from_proto(request.from);
2518
2519    db.unlink_channel(session.user_id, channel_id, from).await?;
2520
2521    let members = db.get_channel_members(from).await?;
2522
2523    let update = proto::UpdateChannels {
2524        delete_edge: vec![proto::ChannelEdge {
2525            channel_id: channel_id.to_proto(),
2526            parent_id: from.to_proto(),
2527        }],
2528        ..Default::default()
2529    };
2530    let connection_pool = session.connection_pool().await;
2531    for member_id in members {
2532        for connection_id in connection_pool.user_connection_ids(member_id) {
2533            session.peer.send(connection_id, update.clone())?;
2534        }
2535    }
2536
2537    response.send(Ack {})?;
2538
2539    Ok(())
2540}
2541
2542async fn move_channel(
2543    request: proto::MoveChannel,
2544    response: Response<proto::MoveChannel>,
2545    session: Session,
2546) -> Result<()> {
2547    let db = session.db().await;
2548    let channel_id = ChannelId::from_proto(request.channel_id);
2549    let from_parent = ChannelId::from_proto(request.from);
2550    let to = ChannelId::from_proto(request.to);
2551
2552    let channels_to_send = db
2553        .move_channel(session.user_id, channel_id, from_parent, to)
2554        .await?;
2555
2556    if channels_to_send.is_empty() {
2557        response.send(Ack {})?;
2558        return Ok(());
2559    }
2560
2561    let members_from = db.get_channel_members(from_parent).await?;
2562    let members_to = db.get_channel_members(to).await?;
2563
2564    let update = proto::UpdateChannels {
2565        delete_edge: vec![proto::ChannelEdge {
2566            channel_id: channel_id.to_proto(),
2567            parent_id: from_parent.to_proto(),
2568        }],
2569        ..Default::default()
2570    };
2571    let connection_pool = session.connection_pool().await;
2572    for member_id in members_from {
2573        for connection_id in connection_pool.user_connection_ids(member_id) {
2574            session.peer.send(connection_id, update.clone())?;
2575        }
2576    }
2577
2578    let update = proto::UpdateChannels {
2579        channels: channels_to_send
2580            .channels
2581            .into_iter()
2582            .map(|channel| proto::Channel {
2583                id: channel.id.to_proto(),
2584                visibility: channel.visibility.into(),
2585                name: channel.name,
2586            })
2587            .collect(),
2588        insert_edge: channels_to_send.edges,
2589        ..Default::default()
2590    };
2591    for member_id in members_to {
2592        for connection_id in connection_pool.user_connection_ids(member_id) {
2593            session.peer.send(connection_id, update.clone())?;
2594        }
2595    }
2596
2597    response.send(Ack {})?;
2598
2599    Ok(())
2600}
2601
2602async fn get_channel_members(
2603    request: proto::GetChannelMembers,
2604    response: Response<proto::GetChannelMembers>,
2605    session: Session,
2606) -> Result<()> {
2607    let db = session.db().await;
2608    let channel_id = ChannelId::from_proto(request.channel_id);
2609    let members = db
2610        .get_channel_participant_details(channel_id, session.user_id)
2611        .await?;
2612    response.send(proto::GetChannelMembersResponse { members })?;
2613    Ok(())
2614}
2615
2616async fn respond_to_channel_invite(
2617    request: proto::RespondToChannelInvite,
2618    response: Response<proto::RespondToChannelInvite>,
2619    session: Session,
2620) -> Result<()> {
2621    let db = session.db().await;
2622    let channel_id = ChannelId::from_proto(request.channel_id);
2623    let notifications = db
2624        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2625        .await?;
2626
2627    if request.accept {
2628        channel_membership_updated(db, channel_id, &session).await?;
2629    } else {
2630        let mut update = proto::UpdateChannels::default();
2631        update
2632            .remove_channel_invitations
2633            .push(channel_id.to_proto());
2634        session.peer.send(session.connection_id, update)?;
2635    }
2636
2637    send_notifications(
2638        &*session.connection_pool().await,
2639        &session.peer,
2640        notifications,
2641    );
2642    response.send(proto::Ack {})?;
2643
2644    Ok(())
2645}
2646
2647async fn channel_membership_updated(
2648    db: tokio::sync::MutexGuard<'_, DbHandle>,
2649    channel_id: ChannelId,
2650    session: &Session,
2651) -> Result<(), crate::Error> {
2652    let mut update = proto::UpdateChannels::default();
2653    update
2654        .remove_channel_invitations
2655        .push(channel_id.to_proto());
2656
2657    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2658    update.channels.extend(
2659        result
2660            .channels
2661            .channels
2662            .into_iter()
2663            .map(|channel| proto::Channel {
2664                id: channel.id.to_proto(),
2665                visibility: channel.visibility.into(),
2666                name: channel.name,
2667            }),
2668    );
2669    update.unseen_channel_messages = result.channel_messages;
2670    update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2671    update.insert_edge = result.channels.edges;
2672    update
2673        .channel_participants
2674        .extend(
2675            result
2676                .channel_participants
2677                .into_iter()
2678                .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2679                    channel_id: channel_id.to_proto(),
2680                    participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2681                }),
2682        );
2683    update
2684        .channel_permissions
2685        .extend(
2686            result
2687                .channels_with_admin_privileges
2688                .into_iter()
2689                .map(|channel_id| proto::ChannelPermission {
2690                    channel_id: channel_id.to_proto(),
2691                    role: proto::ChannelRole::Admin.into(),
2692                }),
2693        );
2694    session.peer.send(session.connection_id, update)?;
2695    Ok(())
2696}
2697
2698async fn join_channel(
2699    request: proto::JoinChannel,
2700    response: Response<proto::JoinChannel>,
2701    session: Session,
2702) -> Result<()> {
2703    let channel_id = ChannelId::from_proto(request.channel_id);
2704    join_channel_internal(channel_id, Box::new(response), session).await
2705}
2706
2707trait JoinChannelInternalResponse {
2708    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2709}
2710impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2711    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2712        Response::<proto::JoinChannel>::send(self, result)
2713    }
2714}
2715impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2716    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2717        Response::<proto::JoinRoom>::send(self, result)
2718    }
2719}
2720
2721async fn join_channel_internal(
2722    channel_id: ChannelId,
2723    response: Box<impl JoinChannelInternalResponse>,
2724    session: Session,
2725) -> Result<()> {
2726    let joined_room = {
2727        leave_room_for_session(&session).await?;
2728        let db = session.db().await;
2729
2730        let (joined_room, joined_channel) = db
2731            .join_channel(
2732                channel_id,
2733                session.user_id,
2734                session.connection_id,
2735                RELEASE_CHANNEL_NAME.as_str(),
2736            )
2737            .await?;
2738
2739        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2740            let token = live_kit
2741                .room_token(
2742                    &joined_room.room.live_kit_room,
2743                    &session.user_id.to_string(),
2744                )
2745                .trace_err()?;
2746
2747            Some(LiveKitConnectionInfo {
2748                server_url: live_kit.url().into(),
2749                token,
2750            })
2751        });
2752
2753        response.send(proto::JoinRoomResponse {
2754            room: Some(joined_room.room.clone()),
2755            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2756            live_kit_connection_info,
2757        })?;
2758
2759        if let Some(joined_channel) = joined_channel {
2760            channel_membership_updated(db, joined_channel, &session).await?
2761        }
2762
2763        room_updated(&joined_room.room, &session.peer);
2764
2765        joined_room
2766    };
2767
2768    channel_updated(
2769        channel_id,
2770        &joined_room.room,
2771        &joined_room.channel_members,
2772        &session.peer,
2773        &*session.connection_pool().await,
2774    );
2775
2776    update_user_contacts(session.user_id, &session).await?;
2777    Ok(())
2778}
2779
2780async fn join_channel_buffer(
2781    request: proto::JoinChannelBuffer,
2782    response: Response<proto::JoinChannelBuffer>,
2783    session: Session,
2784) -> Result<()> {
2785    let db = session.db().await;
2786    let channel_id = ChannelId::from_proto(request.channel_id);
2787
2788    let open_response = db
2789        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2790        .await?;
2791
2792    let collaborators = open_response.collaborators.clone();
2793    response.send(open_response)?;
2794
2795    let update = UpdateChannelBufferCollaborators {
2796        channel_id: channel_id.to_proto(),
2797        collaborators: collaborators.clone(),
2798    };
2799    channel_buffer_updated(
2800        session.connection_id,
2801        collaborators
2802            .iter()
2803            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2804        &update,
2805        &session.peer,
2806    );
2807
2808    Ok(())
2809}
2810
2811async fn update_channel_buffer(
2812    request: proto::UpdateChannelBuffer,
2813    session: Session,
2814) -> Result<()> {
2815    let db = session.db().await;
2816    let channel_id = ChannelId::from_proto(request.channel_id);
2817
2818    let (collaborators, non_collaborators, epoch, version) = db
2819        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2820        .await?;
2821
2822    channel_buffer_updated(
2823        session.connection_id,
2824        collaborators,
2825        &proto::UpdateChannelBuffer {
2826            channel_id: channel_id.to_proto(),
2827            operations: request.operations,
2828        },
2829        &session.peer,
2830    );
2831
2832    let pool = &*session.connection_pool().await;
2833
2834    broadcast(
2835        None,
2836        non_collaborators
2837            .iter()
2838            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2839        |peer_id| {
2840            session.peer.send(
2841                peer_id.into(),
2842                proto::UpdateChannels {
2843                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2844                        channel_id: channel_id.to_proto(),
2845                        epoch: epoch as u64,
2846                        version: version.clone(),
2847                    }],
2848                    ..Default::default()
2849                },
2850            )
2851        },
2852    );
2853
2854    Ok(())
2855}
2856
2857async fn rejoin_channel_buffers(
2858    request: proto::RejoinChannelBuffers,
2859    response: Response<proto::RejoinChannelBuffers>,
2860    session: Session,
2861) -> Result<()> {
2862    let db = session.db().await;
2863    let buffers = db
2864        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2865        .await?;
2866
2867    for rejoined_buffer in &buffers {
2868        let collaborators_to_notify = rejoined_buffer
2869            .buffer
2870            .collaborators
2871            .iter()
2872            .filter_map(|c| Some(c.peer_id?.into()));
2873        channel_buffer_updated(
2874            session.connection_id,
2875            collaborators_to_notify,
2876            &proto::UpdateChannelBufferCollaborators {
2877                channel_id: rejoined_buffer.buffer.channel_id,
2878                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2879            },
2880            &session.peer,
2881        );
2882    }
2883
2884    response.send(proto::RejoinChannelBuffersResponse {
2885        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2886    })?;
2887
2888    Ok(())
2889}
2890
2891async fn leave_channel_buffer(
2892    request: proto::LeaveChannelBuffer,
2893    response: Response<proto::LeaveChannelBuffer>,
2894    session: Session,
2895) -> Result<()> {
2896    let db = session.db().await;
2897    let channel_id = ChannelId::from_proto(request.channel_id);
2898
2899    let left_buffer = db
2900        .leave_channel_buffer(channel_id, session.connection_id)
2901        .await?;
2902
2903    response.send(Ack {})?;
2904
2905    channel_buffer_updated(
2906        session.connection_id,
2907        left_buffer.connections,
2908        &proto::UpdateChannelBufferCollaborators {
2909            channel_id: channel_id.to_proto(),
2910            collaborators: left_buffer.collaborators,
2911        },
2912        &session.peer,
2913    );
2914
2915    Ok(())
2916}
2917
2918fn channel_buffer_updated<T: EnvelopedMessage>(
2919    sender_id: ConnectionId,
2920    collaborators: impl IntoIterator<Item = ConnectionId>,
2921    message: &T,
2922    peer: &Peer,
2923) {
2924    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2925        peer.send(peer_id.into(), message.clone())
2926    });
2927}
2928
2929fn send_notifications(
2930    connection_pool: &ConnectionPool,
2931    peer: &Peer,
2932    notifications: db::NotificationBatch,
2933) {
2934    for (user_id, notification) in notifications {
2935        for connection_id in connection_pool.user_connection_ids(user_id) {
2936            if let Err(error) = peer.send(
2937                connection_id,
2938                proto::AddNotification {
2939                    notification: Some(notification.clone()),
2940                },
2941            ) {
2942                tracing::error!(
2943                    "failed to send notification to {:?} {}",
2944                    connection_id,
2945                    error
2946                );
2947            }
2948        }
2949    }
2950}
2951
2952async fn send_channel_message(
2953    request: proto::SendChannelMessage,
2954    response: Response<proto::SendChannelMessage>,
2955    session: Session,
2956) -> Result<()> {
2957    // Validate the message body.
2958    let body = request.body.trim().to_string();
2959    if body.len() > MAX_MESSAGE_LEN {
2960        return Err(anyhow!("message is too long"))?;
2961    }
2962    if body.is_empty() {
2963        return Err(anyhow!("message can't be blank"))?;
2964    }
2965
2966    // TODO: adjust mentions if body is trimmed
2967
2968    let timestamp = OffsetDateTime::now_utc();
2969    let nonce = request
2970        .nonce
2971        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2972
2973    let channel_id = ChannelId::from_proto(request.channel_id);
2974    let CreatedChannelMessage {
2975        message_id,
2976        participant_connection_ids,
2977        channel_members,
2978        notifications,
2979    } = session
2980        .db()
2981        .await
2982        .create_channel_message(
2983            channel_id,
2984            session.user_id,
2985            &body,
2986            &request.mentions,
2987            timestamp,
2988            nonce.clone().into(),
2989        )
2990        .await?;
2991    let message = proto::ChannelMessage {
2992        sender_id: session.user_id.to_proto(),
2993        id: message_id.to_proto(),
2994        body,
2995        mentions: request.mentions,
2996        timestamp: timestamp.unix_timestamp() as u64,
2997        nonce: Some(nonce),
2998    };
2999    broadcast(
3000        Some(session.connection_id),
3001        participant_connection_ids,
3002        |connection| {
3003            session.peer.send(
3004                connection,
3005                proto::ChannelMessageSent {
3006                    channel_id: channel_id.to_proto(),
3007                    message: Some(message.clone()),
3008                },
3009            )
3010        },
3011    );
3012    response.send(proto::SendChannelMessageResponse {
3013        message: Some(message),
3014    })?;
3015
3016    let pool = &*session.connection_pool().await;
3017    broadcast(
3018        None,
3019        channel_members
3020            .iter()
3021            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3022        |peer_id| {
3023            session.peer.send(
3024                peer_id.into(),
3025                proto::UpdateChannels {
3026                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
3027                        channel_id: channel_id.to_proto(),
3028                        message_id: message_id.to_proto(),
3029                    }],
3030                    ..Default::default()
3031                },
3032            )
3033        },
3034    );
3035    send_notifications(pool, &session.peer, notifications);
3036
3037    Ok(())
3038}
3039
3040async fn remove_channel_message(
3041    request: proto::RemoveChannelMessage,
3042    response: Response<proto::RemoveChannelMessage>,
3043    session: Session,
3044) -> Result<()> {
3045    let channel_id = ChannelId::from_proto(request.channel_id);
3046    let message_id = MessageId::from_proto(request.message_id);
3047    let connection_ids = session
3048        .db()
3049        .await
3050        .remove_channel_message(channel_id, message_id, session.user_id)
3051        .await?;
3052    broadcast(Some(session.connection_id), connection_ids, |connection| {
3053        session.peer.send(connection, request.clone())
3054    });
3055    response.send(proto::Ack {})?;
3056    Ok(())
3057}
3058
3059async fn acknowledge_channel_message(
3060    request: proto::AckChannelMessage,
3061    session: Session,
3062) -> Result<()> {
3063    let channel_id = ChannelId::from_proto(request.channel_id);
3064    let message_id = MessageId::from_proto(request.message_id);
3065    let notifications = session
3066        .db()
3067        .await
3068        .observe_channel_message(channel_id, session.user_id, message_id)
3069        .await?;
3070    send_notifications(
3071        &*session.connection_pool().await,
3072        &session.peer,
3073        notifications,
3074    );
3075    Ok(())
3076}
3077
3078async fn acknowledge_buffer_version(
3079    request: proto::AckBufferOperation,
3080    session: Session,
3081) -> Result<()> {
3082    let buffer_id = BufferId::from_proto(request.buffer_id);
3083    session
3084        .db()
3085        .await
3086        .observe_buffer_version(
3087            buffer_id,
3088            session.user_id,
3089            request.epoch as i32,
3090            &request.version,
3091        )
3092        .await?;
3093    Ok(())
3094}
3095
3096async fn join_channel_chat(
3097    request: proto::JoinChannelChat,
3098    response: Response<proto::JoinChannelChat>,
3099    session: Session,
3100) -> Result<()> {
3101    let channel_id = ChannelId::from_proto(request.channel_id);
3102
3103    let db = session.db().await;
3104    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3105        .await?;
3106    let messages = db
3107        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3108        .await?;
3109    response.send(proto::JoinChannelChatResponse {
3110        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3111        messages,
3112    })?;
3113    Ok(())
3114}
3115
3116async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3117    let channel_id = ChannelId::from_proto(request.channel_id);
3118    session
3119        .db()
3120        .await
3121        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3122        .await?;
3123    Ok(())
3124}
3125
3126async fn get_channel_messages(
3127    request: proto::GetChannelMessages,
3128    response: Response<proto::GetChannelMessages>,
3129    session: Session,
3130) -> Result<()> {
3131    let channel_id = ChannelId::from_proto(request.channel_id);
3132    let messages = session
3133        .db()
3134        .await
3135        .get_channel_messages(
3136            channel_id,
3137            session.user_id,
3138            MESSAGE_COUNT_PER_PAGE,
3139            Some(MessageId::from_proto(request.before_message_id)),
3140        )
3141        .await?;
3142    response.send(proto::GetChannelMessagesResponse {
3143        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3144        messages,
3145    })?;
3146    Ok(())
3147}
3148
3149async fn get_channel_messages_by_id(
3150    request: proto::GetChannelMessagesById,
3151    response: Response<proto::GetChannelMessagesById>,
3152    session: Session,
3153) -> Result<()> {
3154    let message_ids = request
3155        .message_ids
3156        .iter()
3157        .map(|id| MessageId::from_proto(*id))
3158        .collect::<Vec<_>>();
3159    let messages = session
3160        .db()
3161        .await
3162        .get_channel_messages_by_id(session.user_id, &message_ids)
3163        .await?;
3164    response.send(proto::GetChannelMessagesResponse {
3165        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3166        messages,
3167    })?;
3168    Ok(())
3169}
3170
3171async fn get_notifications(
3172    request: proto::GetNotifications,
3173    response: Response<proto::GetNotifications>,
3174    session: Session,
3175) -> Result<()> {
3176    let notifications = session
3177        .db()
3178        .await
3179        .get_notifications(
3180            session.user_id,
3181            NOTIFICATION_COUNT_PER_PAGE,
3182            request
3183                .before_id
3184                .map(|id| db::NotificationId::from_proto(id)),
3185        )
3186        .await?;
3187    response.send(proto::GetNotificationsResponse {
3188        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3189        notifications,
3190    })?;
3191    Ok(())
3192}
3193
3194async fn mark_notification_as_read(
3195    request: proto::MarkNotificationRead,
3196    response: Response<proto::MarkNotificationRead>,
3197    session: Session,
3198) -> Result<()> {
3199    let database = &session.db().await;
3200    let notifications = database
3201        .mark_notification_as_read_by_id(
3202            session.user_id,
3203            NotificationId::from_proto(request.notification_id),
3204        )
3205        .await?;
3206    send_notifications(
3207        &*session.connection_pool().await,
3208        &session.peer,
3209        notifications,
3210    );
3211    response.send(proto::Ack {})?;
3212    Ok(())
3213}
3214
3215async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3216    let project_id = ProjectId::from_proto(request.project_id);
3217    let project_connection_ids = session
3218        .db()
3219        .await
3220        .project_connection_ids(project_id, session.connection_id)
3221        .await?;
3222    broadcast(
3223        Some(session.connection_id),
3224        project_connection_ids.iter().copied(),
3225        |connection_id| {
3226            session
3227                .peer
3228                .forward_send(session.connection_id, connection_id, request.clone())
3229        },
3230    );
3231    Ok(())
3232}
3233
3234async fn get_private_user_info(
3235    _request: proto::GetPrivateUserInfo,
3236    response: Response<proto::GetPrivateUserInfo>,
3237    session: Session,
3238) -> Result<()> {
3239    let db = session.db().await;
3240
3241    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3242    let user = db
3243        .get_user_by_id(session.user_id)
3244        .await?
3245        .ok_or_else(|| anyhow!("user not found"))?;
3246    let flags = db.get_user_flags(session.user_id).await?;
3247
3248    response.send(proto::GetPrivateUserInfoResponse {
3249        metrics_id,
3250        staff: user.admin,
3251        flags,
3252    })?;
3253    Ok(())
3254}
3255
3256fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3257    match message {
3258        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3259        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3260        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3261        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3262        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3263            code: frame.code.into(),
3264            reason: frame.reason,
3265        })),
3266    }
3267}
3268
3269fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3270    match message {
3271        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3272        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3273        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3274        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3275        AxumMessage::Close(frame) => {
3276            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3277                code: frame.code.into(),
3278                reason: frame.reason,
3279            }))
3280        }
3281    }
3282}
3283
3284fn build_initial_channels_update(
3285    channels: ChannelsForUser,
3286    channel_invites: Vec<db::Channel>,
3287) -> proto::UpdateChannels {
3288    let mut update = proto::UpdateChannels::default();
3289
3290    for channel in channels.channels.channels {
3291        update.channels.push(proto::Channel {
3292            id: channel.id.to_proto(),
3293            name: channel.name,
3294            visibility: channel.visibility.into(),
3295        });
3296    }
3297
3298    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3299    update.unseen_channel_messages = channels.channel_messages;
3300    update.insert_edge = channels.channels.edges;
3301
3302    for (channel_id, participants) in channels.channel_participants {
3303        update
3304            .channel_participants
3305            .push(proto::ChannelParticipants {
3306                channel_id: channel_id.to_proto(),
3307                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3308            });
3309    }
3310
3311    update
3312        .channel_permissions
3313        .extend(
3314            channels
3315                .channels_with_admin_privileges
3316                .into_iter()
3317                .map(|id| proto::ChannelPermission {
3318                    channel_id: id.to_proto(),
3319                    role: proto::ChannelRole::Admin.into(),
3320                }),
3321        );
3322
3323    for channel in channel_invites {
3324        update.channel_invitations.push(proto::Channel {
3325            id: channel.id.to_proto(),
3326            name: channel.name,
3327            // TODO: Visibility
3328            visibility: ChannelVisibility::Public.into(),
3329        });
3330    }
3331
3332    update
3333}
3334
3335fn build_initial_contacts_update(
3336    contacts: Vec<db::Contact>,
3337    pool: &ConnectionPool,
3338) -> proto::UpdateContacts {
3339    let mut update = proto::UpdateContacts::default();
3340
3341    for contact in contacts {
3342        match contact {
3343            db::Contact::Accepted { user_id, busy } => {
3344                update.contacts.push(contact_for_user(user_id, busy, &pool));
3345            }
3346            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3347            db::Contact::Incoming { user_id } => {
3348                update
3349                    .incoming_requests
3350                    .push(proto::IncomingContactRequest {
3351                        requester_id: user_id.to_proto(),
3352                    })
3353            }
3354        }
3355    }
3356
3357    update
3358}
3359
3360fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3361    proto::Contact {
3362        user_id: user_id.to_proto(),
3363        online: pool.is_user_online(user_id),
3364        busy,
3365    }
3366}
3367
3368fn room_updated(room: &proto::Room, peer: &Peer) {
3369    broadcast(
3370        None,
3371        room.participants
3372            .iter()
3373            .filter_map(|participant| Some(participant.peer_id?.into())),
3374        |peer_id| {
3375            peer.send(
3376                peer_id.into(),
3377                proto::RoomUpdated {
3378                    room: Some(room.clone()),
3379                },
3380            )
3381        },
3382    );
3383}
3384
3385fn channel_updated(
3386    channel_id: ChannelId,
3387    room: &proto::Room,
3388    channel_members: &[UserId],
3389    peer: &Peer,
3390    pool: &ConnectionPool,
3391) {
3392    let participants = room
3393        .participants
3394        .iter()
3395        .map(|p| p.user_id)
3396        .collect::<Vec<_>>();
3397
3398    broadcast(
3399        None,
3400        channel_members
3401            .iter()
3402            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3403        |peer_id| {
3404            peer.send(
3405                peer_id.into(),
3406                proto::UpdateChannels {
3407                    channel_participants: vec![proto::ChannelParticipants {
3408                        channel_id: channel_id.to_proto(),
3409                        participant_user_ids: participants.clone(),
3410                    }],
3411                    ..Default::default()
3412                },
3413            )
3414        },
3415    );
3416}
3417
3418async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3419    let db = session.db().await;
3420
3421    let contacts = db.get_contacts(user_id).await?;
3422    let busy = db.is_user_busy(user_id).await?;
3423
3424    let pool = session.connection_pool().await;
3425    let updated_contact = contact_for_user(user_id, busy, &pool);
3426    for contact in contacts {
3427        if let db::Contact::Accepted {
3428            user_id: contact_user_id,
3429            ..
3430        } = contact
3431        {
3432            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3433                session
3434                    .peer
3435                    .send(
3436                        contact_conn_id,
3437                        proto::UpdateContacts {
3438                            contacts: vec![updated_contact.clone()],
3439                            remove_contacts: Default::default(),
3440                            incoming_requests: Default::default(),
3441                            remove_incoming_requests: Default::default(),
3442                            outgoing_requests: Default::default(),
3443                            remove_outgoing_requests: Default::default(),
3444                        },
3445                    )
3446                    .trace_err();
3447            }
3448        }
3449    }
3450    Ok(())
3451}
3452
3453async fn leave_room_for_session(session: &Session) -> Result<()> {
3454    let mut contacts_to_update = HashSet::default();
3455
3456    let room_id;
3457    let canceled_calls_to_user_ids;
3458    let live_kit_room;
3459    let delete_live_kit_room;
3460    let room;
3461    let channel_members;
3462    let channel_id;
3463
3464    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3465        contacts_to_update.insert(session.user_id);
3466
3467        for project in left_room.left_projects.values() {
3468            project_left(project, session);
3469        }
3470
3471        room_id = RoomId::from_proto(left_room.room.id);
3472        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3473        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3474        delete_live_kit_room = left_room.deleted;
3475        room = mem::take(&mut left_room.room);
3476        channel_members = mem::take(&mut left_room.channel_members);
3477        channel_id = left_room.channel_id;
3478
3479        room_updated(&room, &session.peer);
3480    } else {
3481        return Ok(());
3482    }
3483
3484    if let Some(channel_id) = channel_id {
3485        channel_updated(
3486            channel_id,
3487            &room,
3488            &channel_members,
3489            &session.peer,
3490            &*session.connection_pool().await,
3491        );
3492    }
3493
3494    {
3495        let pool = session.connection_pool().await;
3496        for canceled_user_id in canceled_calls_to_user_ids {
3497            for connection_id in pool.user_connection_ids(canceled_user_id) {
3498                session
3499                    .peer
3500                    .send(
3501                        connection_id,
3502                        proto::CallCanceled {
3503                            room_id: room_id.to_proto(),
3504                        },
3505                    )
3506                    .trace_err();
3507            }
3508            contacts_to_update.insert(canceled_user_id);
3509        }
3510    }
3511
3512    for contact_user_id in contacts_to_update {
3513        update_user_contacts(contact_user_id, &session).await?;
3514    }
3515
3516    if let Some(live_kit) = session.live_kit_client.as_ref() {
3517        live_kit
3518            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3519            .await
3520            .trace_err();
3521
3522        if delete_live_kit_room {
3523            live_kit.delete_room(live_kit_room).await.trace_err();
3524        }
3525    }
3526
3527    Ok(())
3528}
3529
3530async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3531    let left_channel_buffers = session
3532        .db()
3533        .await
3534        .leave_channel_buffers(session.connection_id)
3535        .await?;
3536
3537    for left_buffer in left_channel_buffers {
3538        channel_buffer_updated(
3539            session.connection_id,
3540            left_buffer.connections,
3541            &proto::UpdateChannelBufferCollaborators {
3542                channel_id: left_buffer.channel_id.to_proto(),
3543                collaborators: left_buffer.collaborators,
3544            },
3545            &session.peer,
3546        );
3547    }
3548
3549    Ok(())
3550}
3551
3552fn project_left(project: &db::LeftProject, session: &Session) {
3553    for connection_id in &project.connection_ids {
3554        if project.host_user_id == session.user_id {
3555            session
3556                .peer
3557                .send(
3558                    *connection_id,
3559                    proto::UnshareProject {
3560                        project_id: project.id.to_proto(),
3561                    },
3562                )
3563                .trace_err();
3564        } else {
3565            session
3566                .peer
3567                .send(
3568                    *connection_id,
3569                    proto::RemoveProjectCollaborator {
3570                        project_id: project.id.to_proto(),
3571                        peer_id: Some(session.connection_id.into()),
3572                    },
3573                )
3574                .trace_err();
3575        }
3576    }
3577}
3578
3579pub trait ResultExt {
3580    type Ok;
3581
3582    fn trace_err(self) -> Option<Self::Ok>;
3583}
3584
3585impl<T, E> ResultExt for Result<T, E>
3586where
3587    E: std::fmt::Debug,
3588{
3589    type Ok = T;
3590
3591    fn trace_err(self) -> Option<T> {
3592        match self {
3593            Ok(value) => Some(value),
3594            Err(error) => {
3595                tracing::error!("{:?}", error);
3596                None
3597            }
3598        }
3599    }
3600}