rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, CreatedChannelMessage,
   7        Database, MessageId, ProjectId, RoomId, ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  74
  75lazy_static! {
  76    static ref METRIC_CONNECTIONS: IntGauge =
  77        register_int_gauge!("connections", "number of connections").unwrap();
  78    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  79        "shared_projects",
  80        "number of open projects with one or more guests"
  81    )
  82    .unwrap();
  83}
  84
  85type MessageHandler =
  86    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  87
  88struct Response<R> {
  89    peer: Arc<Peer>,
  90    receipt: Receipt<R>,
  91    responded: Arc<AtomicBool>,
  92}
  93
  94impl<R: RequestMessage> Response<R> {
  95    fn send(self, payload: R::Response) -> Result<()> {
  96        self.responded.store(true, SeqCst);
  97        self.peer.respond(self.receipt, payload)?;
  98        Ok(())
  99    }
 100}
 101
 102#[derive(Clone)]
 103struct Session {
 104    user_id: UserId,
 105    connection_id: ConnectionId,
 106    db: Arc<tokio::sync::Mutex<DbHandle>>,
 107    peer: Arc<Peer>,
 108    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 109    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 110    executor: Executor,
 111}
 112
 113impl Session {
 114    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 115        #[cfg(test)]
 116        tokio::task::yield_now().await;
 117        let guard = self.db.lock().await;
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        guard
 121    }
 122
 123    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 124        #[cfg(test)]
 125        tokio::task::yield_now().await;
 126        let guard = self.connection_pool.lock();
 127        ConnectionPoolGuard {
 128            guard,
 129            _not_send: PhantomData,
 130        }
 131    }
 132}
 133
 134impl fmt::Debug for Session {
 135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 136        f.debug_struct("Session")
 137            .field("user_id", &self.user_id)
 138            .field("connection_id", &self.connection_id)
 139            .finish()
 140    }
 141}
 142
 143struct DbHandle(Arc<Database>);
 144
 145impl Deref for DbHandle {
 146    type Target = Database;
 147
 148    fn deref(&self) -> &Self::Target {
 149        self.0.as_ref()
 150    }
 151}
 152
 153pub struct Server {
 154    id: parking_lot::Mutex<ServerId>,
 155    peer: Arc<Peer>,
 156    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 157    app_state: Arc<AppState>,
 158    executor: Executor,
 159    handlers: HashMap<TypeId, MessageHandler>,
 160    teardown: watch::Sender<()>,
 161}
 162
 163pub(crate) struct ConnectionPoolGuard<'a> {
 164    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 165    _not_send: PhantomData<Rc<()>>,
 166}
 167
 168#[derive(Serialize)]
 169pub struct ServerSnapshot<'a> {
 170    peer: &'a Peer,
 171    #[serde(serialize_with = "serialize_deref")]
 172    connection_pool: ConnectionPoolGuard<'a>,
 173}
 174
 175pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 176where
 177    S: Serializer,
 178    T: Deref<Target = U>,
 179    U: Serialize,
 180{
 181    Serialize::serialize(value.deref(), serializer)
 182}
 183
 184impl Server {
 185    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 186        let mut server = Self {
 187            id: parking_lot::Mutex::new(id),
 188            peer: Peer::new(id.0 as u32),
 189            app_state,
 190            executor,
 191            connection_pool: Default::default(),
 192            handlers: Default::default(),
 193            teardown: watch::channel(()).0,
 194        };
 195
 196        server
 197            .add_request_handler(ping)
 198            .add_request_handler(create_room)
 199            .add_request_handler(join_room)
 200            .add_request_handler(rejoin_room)
 201            .add_request_handler(leave_room)
 202            .add_request_handler(call)
 203            .add_request_handler(cancel_call)
 204            .add_message_handler(decline_call)
 205            .add_request_handler(update_participant_location)
 206            .add_request_handler(share_project)
 207            .add_message_handler(unshare_project)
 208            .add_request_handler(join_project)
 209            .add_message_handler(leave_project)
 210            .add_request_handler(update_project)
 211            .add_request_handler(update_worktree)
 212            .add_message_handler(start_language_server)
 213            .add_message_handler(update_language_server)
 214            .add_message_handler(update_diagnostic_summary)
 215            .add_message_handler(update_worktree_settings)
 216            .add_message_handler(refresh_inlay_hints)
 217            .add_request_handler(forward_project_request::<proto::GetHover>)
 218            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 220            .add_request_handler(forward_project_request::<proto::GetReferences>)
 221            .add_request_handler(forward_project_request::<proto::SearchProject>)
 222            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 223            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 226            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 227            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 229            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
 230            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 231            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 232            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 233            .add_request_handler(forward_project_request::<proto::PerformRename>)
 234            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 235            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 236            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 237            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 241            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 242            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 243            .add_request_handler(forward_project_request::<proto::InlayHints>)
 244            .add_message_handler(create_buffer_for_peer)
 245            .add_request_handler(update_buffer)
 246            .add_message_handler(update_buffer_file)
 247            .add_message_handler(buffer_reloaded)
 248            .add_message_handler(buffer_saved)
 249            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 250            .add_request_handler(get_users)
 251            .add_request_handler(fuzzy_search_users)
 252            .add_request_handler(request_contact)
 253            .add_request_handler(remove_contact)
 254            .add_request_handler(respond_to_contact_request)
 255            .add_request_handler(create_channel)
 256            .add_request_handler(delete_channel)
 257            .add_request_handler(invite_channel_member)
 258            .add_request_handler(remove_channel_member)
 259            .add_request_handler(set_channel_member_role)
 260            .add_request_handler(set_channel_visibility)
 261            .add_request_handler(rename_channel)
 262            .add_request_handler(join_channel_buffer)
 263            .add_request_handler(leave_channel_buffer)
 264            .add_message_handler(update_channel_buffer)
 265            .add_request_handler(rejoin_channel_buffers)
 266            .add_request_handler(get_channel_members)
 267            .add_request_handler(respond_to_channel_invite)
 268            .add_request_handler(join_channel)
 269            .add_request_handler(join_channel_chat)
 270            .add_message_handler(leave_channel_chat)
 271            .add_request_handler(send_channel_message)
 272            .add_request_handler(remove_channel_message)
 273            .add_request_handler(get_channel_messages)
 274            .add_request_handler(get_channel_messages_by_id)
 275            .add_request_handler(get_notifications)
 276            .add_request_handler(link_channel)
 277            .add_request_handler(unlink_channel)
 278            .add_request_handler(move_channel)
 279            .add_request_handler(follow)
 280            .add_message_handler(unfollow)
 281            .add_message_handler(update_followers)
 282            .add_message_handler(update_diff_base)
 283            .add_request_handler(get_private_user_info)
 284            .add_message_handler(acknowledge_channel_message)
 285            .add_message_handler(acknowledge_buffer_version);
 286
 287        Arc::new(server)
 288    }
 289
 290    pub async fn start(&self) -> Result<()> {
 291        let server_id = *self.id.lock();
 292        let app_state = self.app_state.clone();
 293        let peer = self.peer.clone();
 294        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 295        let pool = self.connection_pool.clone();
 296        let live_kit_client = self.app_state.live_kit_client.clone();
 297
 298        let span = info_span!("start server");
 299        self.executor.spawn_detached(
 300            async move {
 301                tracing::info!("waiting for cleanup timeout");
 302                timeout.await;
 303                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 304                if let Some((room_ids, channel_ids)) = app_state
 305                    .db
 306                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 307                    .await
 308                    .trace_err()
 309                {
 310                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 311                    tracing::info!(
 312                        stale_channel_buffer_count = channel_ids.len(),
 313                        "retrieved stale channel buffers"
 314                    );
 315
 316                    for channel_id in channel_ids {
 317                        if let Some(refreshed_channel_buffer) = app_state
 318                            .db
 319                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 320                            .await
 321                            .trace_err()
 322                        {
 323                            for connection_id in refreshed_channel_buffer.connection_ids {
 324                                peer.send(
 325                                    connection_id,
 326                                    proto::UpdateChannelBufferCollaborators {
 327                                        channel_id: channel_id.to_proto(),
 328                                        collaborators: refreshed_channel_buffer
 329                                            .collaborators
 330                                            .clone(),
 331                                    },
 332                                )
 333                                .trace_err();
 334                            }
 335                        }
 336                    }
 337
 338                    for room_id in room_ids {
 339                        let mut contacts_to_update = HashSet::default();
 340                        let mut canceled_calls_to_user_ids = Vec::new();
 341                        let mut live_kit_room = String::new();
 342                        let mut delete_live_kit_room = false;
 343
 344                        if let Some(mut refreshed_room) = app_state
 345                            .db
 346                            .clear_stale_room_participants(room_id, server_id)
 347                            .await
 348                            .trace_err()
 349                        {
 350                            tracing::info!(
 351                                room_id = room_id.0,
 352                                new_participant_count = refreshed_room.room.participants.len(),
 353                                "refreshed room"
 354                            );
 355                            room_updated(&refreshed_room.room, &peer);
 356                            if let Some(channel_id) = refreshed_room.channel_id {
 357                                channel_updated(
 358                                    channel_id,
 359                                    &refreshed_room.room,
 360                                    &refreshed_room.channel_members,
 361                                    &peer,
 362                                    &*pool.lock(),
 363                                );
 364                            }
 365                            contacts_to_update
 366                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 367                            contacts_to_update
 368                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 369                            canceled_calls_to_user_ids =
 370                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 371                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 372                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 373                        }
 374
 375                        {
 376                            let pool = pool.lock();
 377                            for canceled_user_id in canceled_calls_to_user_ids {
 378                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 379                                    peer.send(
 380                                        connection_id,
 381                                        proto::CallCanceled {
 382                                            room_id: room_id.to_proto(),
 383                                        },
 384                                    )
 385                                    .trace_err();
 386                                }
 387                            }
 388                        }
 389
 390                        for user_id in contacts_to_update {
 391                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 392                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 393                            if let Some((busy, contacts)) = busy.zip(contacts) {
 394                                let pool = pool.lock();
 395                                let updated_contact = contact_for_user(user_id, busy, &pool);
 396                                for contact in contacts {
 397                                    if let db::Contact::Accepted {
 398                                        user_id: contact_user_id,
 399                                        ..
 400                                    } = contact
 401                                    {
 402                                        for contact_conn_id in
 403                                            pool.user_connection_ids(contact_user_id)
 404                                        {
 405                                            peer.send(
 406                                                contact_conn_id,
 407                                                proto::UpdateContacts {
 408                                                    contacts: vec![updated_contact.clone()],
 409                                                    remove_contacts: Default::default(),
 410                                                    incoming_requests: Default::default(),
 411                                                    remove_incoming_requests: Default::default(),
 412                                                    outgoing_requests: Default::default(),
 413                                                    remove_outgoing_requests: Default::default(),
 414                                                },
 415                                            )
 416                                            .trace_err();
 417                                        }
 418                                    }
 419                                }
 420                            }
 421                        }
 422
 423                        if let Some(live_kit) = live_kit_client.as_ref() {
 424                            if delete_live_kit_room {
 425                                live_kit.delete_room(live_kit_room).await.trace_err();
 426                            }
 427                        }
 428                    }
 429                }
 430
 431                app_state
 432                    .db
 433                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 434                    .await
 435                    .trace_err();
 436            }
 437            .instrument(span),
 438        );
 439        Ok(())
 440    }
 441
 442    pub fn teardown(&self) {
 443        self.peer.teardown();
 444        self.connection_pool.lock().reset();
 445        let _ = self.teardown.send(());
 446    }
 447
 448    #[cfg(test)]
 449    pub fn reset(&self, id: ServerId) {
 450        self.teardown();
 451        *self.id.lock() = id;
 452        self.peer.reset(id.0 as u32);
 453    }
 454
 455    #[cfg(test)]
 456    pub fn id(&self) -> ServerId {
 457        *self.id.lock()
 458    }
 459
 460    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 461    where
 462        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 463        Fut: 'static + Send + Future<Output = Result<()>>,
 464        M: EnvelopedMessage,
 465    {
 466        let prev_handler = self.handlers.insert(
 467            TypeId::of::<M>(),
 468            Box::new(move |envelope, session| {
 469                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 470                let span = info_span!(
 471                    "handle message",
 472                    payload_type = envelope.payload_type_name()
 473                );
 474                span.in_scope(|| {
 475                    tracing::info!(
 476                        payload_type = envelope.payload_type_name(),
 477                        "message received"
 478                    );
 479                });
 480                let start_time = Instant::now();
 481                let future = (handler)(*envelope, session);
 482                async move {
 483                    let result = future.await;
 484                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 485                    match result {
 486                        Err(error) => {
 487                            tracing::error!(%error, ?duration_ms, "error handling message")
 488                        }
 489                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 490                    }
 491                }
 492                .instrument(span)
 493                .boxed()
 494            }),
 495        );
 496        if prev_handler.is_some() {
 497            panic!("registered a handler for the same message twice");
 498        }
 499        self
 500    }
 501
 502    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 503    where
 504        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 505        Fut: 'static + Send + Future<Output = Result<()>>,
 506        M: EnvelopedMessage,
 507    {
 508        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 509        self
 510    }
 511
 512    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 513    where
 514        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 515        Fut: Send + Future<Output = Result<()>>,
 516        M: RequestMessage,
 517    {
 518        let handler = Arc::new(handler);
 519        self.add_handler(move |envelope, session| {
 520            let receipt = envelope.receipt();
 521            let handler = handler.clone();
 522            async move {
 523                let peer = session.peer.clone();
 524                let responded = Arc::new(AtomicBool::default());
 525                let response = Response {
 526                    peer: peer.clone(),
 527                    responded: responded.clone(),
 528                    receipt,
 529                };
 530                match (handler)(envelope.payload, response, session).await {
 531                    Ok(()) => {
 532                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 533                            Ok(())
 534                        } else {
 535                            Err(anyhow!("handler did not send a response"))?
 536                        }
 537                    }
 538                    Err(error) => {
 539                        peer.respond_with_error(
 540                            receipt,
 541                            proto::Error {
 542                                message: error.to_string(),
 543                            },
 544                        )?;
 545                        Err(error)
 546                    }
 547                }
 548            }
 549        })
 550    }
 551
 552    pub fn handle_connection(
 553        self: &Arc<Self>,
 554        connection: Connection,
 555        address: String,
 556        user: User,
 557        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 558        executor: Executor,
 559    ) -> impl Future<Output = Result<()>> {
 560        let this = self.clone();
 561        let user_id = user.id;
 562        let login = user.github_login;
 563        let span = info_span!("handle connection", %user_id, %login, %address);
 564        let mut teardown = self.teardown.subscribe();
 565        async move {
 566            let (connection_id, handle_io, mut incoming_rx) = this
 567                .peer
 568                .add_connection(connection, {
 569                    let executor = executor.clone();
 570                    move |duration| executor.sleep(duration)
 571                });
 572
 573            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 574            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 575            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 576
 577            if let Some(send_connection_id) = send_connection_id.take() {
 578                let _ = send_connection_id.send(connection_id);
 579            }
 580
 581            if !user.connected_once {
 582                this.peer.send(connection_id, proto::ShowContacts {})?;
 583                this.app_state.db.set_user_connected_once(user_id, true).await?;
 584            }
 585
 586            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 587                this.app_state.db.get_contacts(user_id),
 588                this.app_state.db.get_channels_for_user(user_id),
 589                this.app_state.db.get_channel_invites_for_user(user_id),
 590            ).await?;
 591
 592            {
 593                let mut pool = this.connection_pool.lock();
 594                pool.add_connection(connection_id, user_id, user.admin);
 595                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 596                this.peer.send(connection_id, build_initial_channels_update(
 597                    channels_for_user,
 598                    channel_invites
 599                ))?;
 600            }
 601
 602            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 603                this.peer.send(connection_id, incoming_call)?;
 604            }
 605
 606            let session = Session {
 607                user_id,
 608                connection_id,
 609                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 610                peer: this.peer.clone(),
 611                connection_pool: this.connection_pool.clone(),
 612                live_kit_client: this.app_state.live_kit_client.clone(),
 613                executor: executor.clone(),
 614            };
 615            update_user_contacts(user_id, &session).await?;
 616
 617            let handle_io = handle_io.fuse();
 618            futures::pin_mut!(handle_io);
 619
 620            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 621            // This prevents deadlocks when e.g., client A performs a request to client B and
 622            // client B performs a request to client A. If both clients stop processing further
 623            // messages until their respective request completes, they won't have a chance to
 624            // respond to the other client's request and cause a deadlock.
 625            //
 626            // This arrangement ensures we will attempt to process earlier messages first, but fall
 627            // back to processing messages arrived later in the spirit of making progress.
 628            let mut foreground_message_handlers = FuturesUnordered::new();
 629            let concurrent_handlers = Arc::new(Semaphore::new(256));
 630            loop {
 631                let next_message = async {
 632                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 633                    let message = incoming_rx.next().await;
 634                    (permit, message)
 635                }.fuse();
 636                futures::pin_mut!(next_message);
 637                futures::select_biased! {
 638                    _ = teardown.changed().fuse() => return Ok(()),
 639                    result = handle_io => {
 640                        if let Err(error) = result {
 641                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 642                        }
 643                        break;
 644                    }
 645                    _ = foreground_message_handlers.next() => {}
 646                    next_message = next_message => {
 647                        let (permit, message) = next_message;
 648                        if let Some(message) = message {
 649                            let type_name = message.payload_type_name();
 650                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 651                            let span_enter = span.enter();
 652                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 653                                let is_background = message.is_background();
 654                                let handle_message = (handler)(message, session.clone());
 655                                drop(span_enter);
 656
 657                                let handle_message = async move {
 658                                    handle_message.await;
 659                                    drop(permit);
 660                                }.instrument(span);
 661                                if is_background {
 662                                    executor.spawn_detached(handle_message);
 663                                } else {
 664                                    foreground_message_handlers.push(handle_message);
 665                                }
 666                            } else {
 667                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 668                            }
 669                        } else {
 670                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 671                            break;
 672                        }
 673                    }
 674                }
 675            }
 676
 677            drop(foreground_message_handlers);
 678            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 679            if let Err(error) = connection_lost(session, teardown, executor).await {
 680                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 681            }
 682
 683            Ok(())
 684        }.instrument(span)
 685    }
 686
 687    pub async fn invite_code_redeemed(
 688        self: &Arc<Self>,
 689        inviter_id: UserId,
 690        invitee_id: UserId,
 691    ) -> Result<()> {
 692        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 693            if let Some(code) = &user.invite_code {
 694                let pool = self.connection_pool.lock();
 695                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 696                for connection_id in pool.user_connection_ids(inviter_id) {
 697                    self.peer.send(
 698                        connection_id,
 699                        proto::UpdateContacts {
 700                            contacts: vec![invitee_contact.clone()],
 701                            ..Default::default()
 702                        },
 703                    )?;
 704                    self.peer.send(
 705                        connection_id,
 706                        proto::UpdateInviteInfo {
 707                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 708                            count: user.invite_count as u32,
 709                        },
 710                    )?;
 711                }
 712            }
 713        }
 714        Ok(())
 715    }
 716
 717    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 718        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 719            if let Some(invite_code) = &user.invite_code {
 720                let pool = self.connection_pool.lock();
 721                for connection_id in pool.user_connection_ids(user_id) {
 722                    self.peer.send(
 723                        connection_id,
 724                        proto::UpdateInviteInfo {
 725                            url: format!(
 726                                "{}{}",
 727                                self.app_state.config.invite_link_prefix, invite_code
 728                            ),
 729                            count: user.invite_count as u32,
 730                        },
 731                    )?;
 732                }
 733            }
 734        }
 735        Ok(())
 736    }
 737
 738    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 739        ServerSnapshot {
 740            connection_pool: ConnectionPoolGuard {
 741                guard: self.connection_pool.lock(),
 742                _not_send: PhantomData,
 743            },
 744            peer: &self.peer,
 745        }
 746    }
 747}
 748
 749impl<'a> Deref for ConnectionPoolGuard<'a> {
 750    type Target = ConnectionPool;
 751
 752    fn deref(&self) -> &Self::Target {
 753        &*self.guard
 754    }
 755}
 756
 757impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 758    fn deref_mut(&mut self) -> &mut Self::Target {
 759        &mut *self.guard
 760    }
 761}
 762
 763impl<'a> Drop for ConnectionPoolGuard<'a> {
 764    fn drop(&mut self) {
 765        #[cfg(test)]
 766        self.check_invariants();
 767    }
 768}
 769
 770fn broadcast<F>(
 771    sender_id: Option<ConnectionId>,
 772    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 773    mut f: F,
 774) where
 775    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 776{
 777    for receiver_id in receiver_ids {
 778        if Some(receiver_id) != sender_id {
 779            if let Err(error) = f(receiver_id) {
 780                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 781            }
 782        }
 783    }
 784}
 785
 786lazy_static! {
 787    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 788}
 789
 790pub struct ProtocolVersion(u32);
 791
 792impl Header for ProtocolVersion {
 793    fn name() -> &'static HeaderName {
 794        &ZED_PROTOCOL_VERSION
 795    }
 796
 797    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 798    where
 799        Self: Sized,
 800        I: Iterator<Item = &'i axum::http::HeaderValue>,
 801    {
 802        let version = values
 803            .next()
 804            .ok_or_else(axum::headers::Error::invalid)?
 805            .to_str()
 806            .map_err(|_| axum::headers::Error::invalid())?
 807            .parse()
 808            .map_err(|_| axum::headers::Error::invalid())?;
 809        Ok(Self(version))
 810    }
 811
 812    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 813        values.extend([self.0.to_string().parse().unwrap()]);
 814    }
 815}
 816
 817pub fn routes(server: Arc<Server>) -> Router<Body> {
 818    Router::new()
 819        .route("/rpc", get(handle_websocket_request))
 820        .layer(
 821            ServiceBuilder::new()
 822                .layer(Extension(server.app_state.clone()))
 823                .layer(middleware::from_fn(auth::validate_header)),
 824        )
 825        .route("/metrics", get(handle_metrics))
 826        .layer(Extension(server))
 827}
 828
 829pub async fn handle_websocket_request(
 830    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 831    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 832    Extension(server): Extension<Arc<Server>>,
 833    Extension(user): Extension<User>,
 834    ws: WebSocketUpgrade,
 835) -> axum::response::Response {
 836    if protocol_version != rpc::PROTOCOL_VERSION {
 837        return (
 838            StatusCode::UPGRADE_REQUIRED,
 839            "client must be upgraded".to_string(),
 840        )
 841            .into_response();
 842    }
 843    let socket_address = socket_address.to_string();
 844    ws.on_upgrade(move |socket| {
 845        use util::ResultExt;
 846        let socket = socket
 847            .map_ok(to_tungstenite_message)
 848            .err_into()
 849            .with(|message| async move { Ok(to_axum_message(message)) });
 850        let connection = Connection::new(Box::pin(socket));
 851        async move {
 852            server
 853                .handle_connection(connection, socket_address, user, None, Executor::Production)
 854                .await
 855                .log_err();
 856        }
 857    })
 858}
 859
 860pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 861    let connections = server
 862        .connection_pool
 863        .lock()
 864        .connections()
 865        .filter(|connection| !connection.admin)
 866        .count();
 867
 868    METRIC_CONNECTIONS.set(connections as _);
 869
 870    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 871    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 872
 873    let encoder = prometheus::TextEncoder::new();
 874    let metric_families = prometheus::gather();
 875    let encoded_metrics = encoder
 876        .encode_to_string(&metric_families)
 877        .map_err(|err| anyhow!("{}", err))?;
 878    Ok(encoded_metrics)
 879}
 880
 881#[instrument(err, skip(executor))]
 882async fn connection_lost(
 883    session: Session,
 884    mut teardown: watch::Receiver<()>,
 885    executor: Executor,
 886) -> Result<()> {
 887    session.peer.disconnect(session.connection_id);
 888    session
 889        .connection_pool()
 890        .await
 891        .remove_connection(session.connection_id)?;
 892
 893    session
 894        .db()
 895        .await
 896        .connection_lost(session.connection_id)
 897        .await
 898        .trace_err();
 899
 900    futures::select_biased! {
 901        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 902            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 903            leave_room_for_session(&session).await.trace_err();
 904            leave_channel_buffers_for_session(&session)
 905                .await
 906                .trace_err();
 907
 908            if !session
 909                .connection_pool()
 910                .await
 911                .is_user_online(session.user_id)
 912            {
 913                let db = session.db().await;
 914                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 915                    room_updated(&room, &session.peer);
 916                }
 917            }
 918
 919            update_user_contacts(session.user_id, &session).await?;
 920        }
 921        _ = teardown.changed().fuse() => {}
 922    }
 923
 924    Ok(())
 925}
 926
 927async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 928    response.send(proto::Ack {})?;
 929    Ok(())
 930}
 931
 932async fn create_room(
 933    _request: proto::CreateRoom,
 934    response: Response<proto::CreateRoom>,
 935    session: Session,
 936) -> Result<()> {
 937    let live_kit_room = nanoid::nanoid!(30);
 938
 939    let live_kit_connection_info = {
 940        let live_kit_room = live_kit_room.clone();
 941        let live_kit = session.live_kit_client.as_ref();
 942
 943        util::async_iife!({
 944            let live_kit = live_kit?;
 945
 946            let token = live_kit
 947                .room_token(&live_kit_room, &session.user_id.to_string())
 948                .trace_err()?;
 949
 950            Some(proto::LiveKitConnectionInfo {
 951                server_url: live_kit.url().into(),
 952                token,
 953            })
 954        })
 955    }
 956    .await;
 957
 958    let room = session
 959        .db()
 960        .await
 961        .create_room(
 962            session.user_id,
 963            session.connection_id,
 964            &live_kit_room,
 965            RELEASE_CHANNEL_NAME.as_str(),
 966        )
 967        .await?;
 968
 969    response.send(proto::CreateRoomResponse {
 970        room: Some(room.clone()),
 971        live_kit_connection_info,
 972    })?;
 973
 974    update_user_contacts(session.user_id, &session).await?;
 975    Ok(())
 976}
 977
 978async fn join_room(
 979    request: proto::JoinRoom,
 980    response: Response<proto::JoinRoom>,
 981    session: Session,
 982) -> Result<()> {
 983    let room_id = RoomId::from_proto(request.id);
 984
 985    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 986
 987    if let Some(channel_id) = channel_id {
 988        return join_channel_internal(channel_id, Box::new(response), session).await;
 989    }
 990
 991    let joined_room = {
 992        let room = session
 993            .db()
 994            .await
 995            .join_room(
 996                room_id,
 997                session.user_id,
 998                session.connection_id,
 999                RELEASE_CHANNEL_NAME.as_str(),
1000            )
1001            .await?;
1002        room_updated(&room.room, &session.peer);
1003        room.into_inner()
1004    };
1005
1006    for connection_id in session
1007        .connection_pool()
1008        .await
1009        .user_connection_ids(session.user_id)
1010    {
1011        session
1012            .peer
1013            .send(
1014                connection_id,
1015                proto::CallCanceled {
1016                    room_id: room_id.to_proto(),
1017                },
1018            )
1019            .trace_err();
1020    }
1021
1022    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1023        if let Some(token) = live_kit
1024            .room_token(
1025                &joined_room.room.live_kit_room,
1026                &session.user_id.to_string(),
1027            )
1028            .trace_err()
1029        {
1030            Some(proto::LiveKitConnectionInfo {
1031                server_url: live_kit.url().into(),
1032                token,
1033            })
1034        } else {
1035            None
1036        }
1037    } else {
1038        None
1039    };
1040
1041    response.send(proto::JoinRoomResponse {
1042        room: Some(joined_room.room),
1043        channel_id: None,
1044        live_kit_connection_info,
1045    })?;
1046
1047    update_user_contacts(session.user_id, &session).await?;
1048    Ok(())
1049}
1050
1051async fn rejoin_room(
1052    request: proto::RejoinRoom,
1053    response: Response<proto::RejoinRoom>,
1054    session: Session,
1055) -> Result<()> {
1056    let room;
1057    let channel_id;
1058    let channel_members;
1059    {
1060        let mut rejoined_room = session
1061            .db()
1062            .await
1063            .rejoin_room(request, session.user_id, session.connection_id)
1064            .await?;
1065
1066        response.send(proto::RejoinRoomResponse {
1067            room: Some(rejoined_room.room.clone()),
1068            reshared_projects: rejoined_room
1069                .reshared_projects
1070                .iter()
1071                .map(|project| proto::ResharedProject {
1072                    id: project.id.to_proto(),
1073                    collaborators: project
1074                        .collaborators
1075                        .iter()
1076                        .map(|collaborator| collaborator.to_proto())
1077                        .collect(),
1078                })
1079                .collect(),
1080            rejoined_projects: rejoined_room
1081                .rejoined_projects
1082                .iter()
1083                .map(|rejoined_project| proto::RejoinedProject {
1084                    id: rejoined_project.id.to_proto(),
1085                    worktrees: rejoined_project
1086                        .worktrees
1087                        .iter()
1088                        .map(|worktree| proto::WorktreeMetadata {
1089                            id: worktree.id,
1090                            root_name: worktree.root_name.clone(),
1091                            visible: worktree.visible,
1092                            abs_path: worktree.abs_path.clone(),
1093                        })
1094                        .collect(),
1095                    collaborators: rejoined_project
1096                        .collaborators
1097                        .iter()
1098                        .map(|collaborator| collaborator.to_proto())
1099                        .collect(),
1100                    language_servers: rejoined_project.language_servers.clone(),
1101                })
1102                .collect(),
1103        })?;
1104        room_updated(&rejoined_room.room, &session.peer);
1105
1106        for project in &rejoined_room.reshared_projects {
1107            for collaborator in &project.collaborators {
1108                session
1109                    .peer
1110                    .send(
1111                        collaborator.connection_id,
1112                        proto::UpdateProjectCollaborator {
1113                            project_id: project.id.to_proto(),
1114                            old_peer_id: Some(project.old_connection_id.into()),
1115                            new_peer_id: Some(session.connection_id.into()),
1116                        },
1117                    )
1118                    .trace_err();
1119            }
1120
1121            broadcast(
1122                Some(session.connection_id),
1123                project
1124                    .collaborators
1125                    .iter()
1126                    .map(|collaborator| collaborator.connection_id),
1127                |connection_id| {
1128                    session.peer.forward_send(
1129                        session.connection_id,
1130                        connection_id,
1131                        proto::UpdateProject {
1132                            project_id: project.id.to_proto(),
1133                            worktrees: project.worktrees.clone(),
1134                        },
1135                    )
1136                },
1137            );
1138        }
1139
1140        for project in &rejoined_room.rejoined_projects {
1141            for collaborator in &project.collaborators {
1142                session
1143                    .peer
1144                    .send(
1145                        collaborator.connection_id,
1146                        proto::UpdateProjectCollaborator {
1147                            project_id: project.id.to_proto(),
1148                            old_peer_id: Some(project.old_connection_id.into()),
1149                            new_peer_id: Some(session.connection_id.into()),
1150                        },
1151                    )
1152                    .trace_err();
1153            }
1154        }
1155
1156        for project in &mut rejoined_room.rejoined_projects {
1157            for worktree in mem::take(&mut project.worktrees) {
1158                #[cfg(any(test, feature = "test-support"))]
1159                const MAX_CHUNK_SIZE: usize = 2;
1160                #[cfg(not(any(test, feature = "test-support")))]
1161                const MAX_CHUNK_SIZE: usize = 256;
1162
1163                // Stream this worktree's entries.
1164                let message = proto::UpdateWorktree {
1165                    project_id: project.id.to_proto(),
1166                    worktree_id: worktree.id,
1167                    abs_path: worktree.abs_path.clone(),
1168                    root_name: worktree.root_name,
1169                    updated_entries: worktree.updated_entries,
1170                    removed_entries: worktree.removed_entries,
1171                    scan_id: worktree.scan_id,
1172                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1173                    updated_repositories: worktree.updated_repositories,
1174                    removed_repositories: worktree.removed_repositories,
1175                };
1176                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1177                    session.peer.send(session.connection_id, update.clone())?;
1178                }
1179
1180                // Stream this worktree's diagnostics.
1181                for summary in worktree.diagnostic_summaries {
1182                    session.peer.send(
1183                        session.connection_id,
1184                        proto::UpdateDiagnosticSummary {
1185                            project_id: project.id.to_proto(),
1186                            worktree_id: worktree.id,
1187                            summary: Some(summary),
1188                        },
1189                    )?;
1190                }
1191
1192                for settings_file in worktree.settings_files {
1193                    session.peer.send(
1194                        session.connection_id,
1195                        proto::UpdateWorktreeSettings {
1196                            project_id: project.id.to_proto(),
1197                            worktree_id: worktree.id,
1198                            path: settings_file.path,
1199                            content: Some(settings_file.content),
1200                        },
1201                    )?;
1202                }
1203            }
1204
1205            for language_server in &project.language_servers {
1206                session.peer.send(
1207                    session.connection_id,
1208                    proto::UpdateLanguageServer {
1209                        project_id: project.id.to_proto(),
1210                        language_server_id: language_server.id,
1211                        variant: Some(
1212                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1213                                proto::LspDiskBasedDiagnosticsUpdated {},
1214                            ),
1215                        ),
1216                    },
1217                )?;
1218            }
1219        }
1220
1221        let rejoined_room = rejoined_room.into_inner();
1222
1223        room = rejoined_room.room;
1224        channel_id = rejoined_room.channel_id;
1225        channel_members = rejoined_room.channel_members;
1226    }
1227
1228    if let Some(channel_id) = channel_id {
1229        channel_updated(
1230            channel_id,
1231            &room,
1232            &channel_members,
1233            &session.peer,
1234            &*session.connection_pool().await,
1235        );
1236    }
1237
1238    update_user_contacts(session.user_id, &session).await?;
1239    Ok(())
1240}
1241
1242async fn leave_room(
1243    _: proto::LeaveRoom,
1244    response: Response<proto::LeaveRoom>,
1245    session: Session,
1246) -> Result<()> {
1247    leave_room_for_session(&session).await?;
1248    response.send(proto::Ack {})?;
1249    Ok(())
1250}
1251
1252async fn call(
1253    request: proto::Call,
1254    response: Response<proto::Call>,
1255    session: Session,
1256) -> Result<()> {
1257    let room_id = RoomId::from_proto(request.room_id);
1258    let calling_user_id = session.user_id;
1259    let calling_connection_id = session.connection_id;
1260    let called_user_id = UserId::from_proto(request.called_user_id);
1261    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1262    if !session
1263        .db()
1264        .await
1265        .has_contact(calling_user_id, called_user_id)
1266        .await?
1267    {
1268        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1269    }
1270
1271    let incoming_call = {
1272        let (room, incoming_call) = &mut *session
1273            .db()
1274            .await
1275            .call(
1276                room_id,
1277                calling_user_id,
1278                calling_connection_id,
1279                called_user_id,
1280                initial_project_id,
1281            )
1282            .await?;
1283        room_updated(&room, &session.peer);
1284        mem::take(incoming_call)
1285    };
1286    update_user_contacts(called_user_id, &session).await?;
1287
1288    let mut calls = session
1289        .connection_pool()
1290        .await
1291        .user_connection_ids(called_user_id)
1292        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1293        .collect::<FuturesUnordered<_>>();
1294
1295    while let Some(call_response) = calls.next().await {
1296        match call_response.as_ref() {
1297            Ok(_) => {
1298                response.send(proto::Ack {})?;
1299                return Ok(());
1300            }
1301            Err(_) => {
1302                call_response.trace_err();
1303            }
1304        }
1305    }
1306
1307    {
1308        let room = session
1309            .db()
1310            .await
1311            .call_failed(room_id, called_user_id)
1312            .await?;
1313        room_updated(&room, &session.peer);
1314    }
1315    update_user_contacts(called_user_id, &session).await?;
1316
1317    Err(anyhow!("failed to ring user"))?
1318}
1319
1320async fn cancel_call(
1321    request: proto::CancelCall,
1322    response: Response<proto::CancelCall>,
1323    session: Session,
1324) -> Result<()> {
1325    let called_user_id = UserId::from_proto(request.called_user_id);
1326    let room_id = RoomId::from_proto(request.room_id);
1327    {
1328        let room = session
1329            .db()
1330            .await
1331            .cancel_call(room_id, session.connection_id, called_user_id)
1332            .await?;
1333        room_updated(&room, &session.peer);
1334    }
1335
1336    for connection_id in session
1337        .connection_pool()
1338        .await
1339        .user_connection_ids(called_user_id)
1340    {
1341        session
1342            .peer
1343            .send(
1344                connection_id,
1345                proto::CallCanceled {
1346                    room_id: room_id.to_proto(),
1347                },
1348            )
1349            .trace_err();
1350    }
1351    response.send(proto::Ack {})?;
1352
1353    update_user_contacts(called_user_id, &session).await?;
1354    Ok(())
1355}
1356
1357async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1358    let room_id = RoomId::from_proto(message.room_id);
1359    {
1360        let room = session
1361            .db()
1362            .await
1363            .decline_call(Some(room_id), session.user_id)
1364            .await?
1365            .ok_or_else(|| anyhow!("failed to decline call"))?;
1366        room_updated(&room, &session.peer);
1367    }
1368
1369    for connection_id in session
1370        .connection_pool()
1371        .await
1372        .user_connection_ids(session.user_id)
1373    {
1374        session
1375            .peer
1376            .send(
1377                connection_id,
1378                proto::CallCanceled {
1379                    room_id: room_id.to_proto(),
1380                },
1381            )
1382            .trace_err();
1383    }
1384    update_user_contacts(session.user_id, &session).await?;
1385    Ok(())
1386}
1387
1388async fn update_participant_location(
1389    request: proto::UpdateParticipantLocation,
1390    response: Response<proto::UpdateParticipantLocation>,
1391    session: Session,
1392) -> Result<()> {
1393    let room_id = RoomId::from_proto(request.room_id);
1394    let location = request
1395        .location
1396        .ok_or_else(|| anyhow!("invalid location"))?;
1397
1398    let db = session.db().await;
1399    let room = db
1400        .update_room_participant_location(room_id, session.connection_id, location)
1401        .await?;
1402
1403    room_updated(&room, &session.peer);
1404    response.send(proto::Ack {})?;
1405    Ok(())
1406}
1407
1408async fn share_project(
1409    request: proto::ShareProject,
1410    response: Response<proto::ShareProject>,
1411    session: Session,
1412) -> Result<()> {
1413    let (project_id, room) = &*session
1414        .db()
1415        .await
1416        .share_project(
1417            RoomId::from_proto(request.room_id),
1418            session.connection_id,
1419            &request.worktrees,
1420        )
1421        .await?;
1422    response.send(proto::ShareProjectResponse {
1423        project_id: project_id.to_proto(),
1424    })?;
1425    room_updated(&room, &session.peer);
1426
1427    Ok(())
1428}
1429
1430async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1431    let project_id = ProjectId::from_proto(message.project_id);
1432
1433    let (room, guest_connection_ids) = &*session
1434        .db()
1435        .await
1436        .unshare_project(project_id, session.connection_id)
1437        .await?;
1438
1439    broadcast(
1440        Some(session.connection_id),
1441        guest_connection_ids.iter().copied(),
1442        |conn_id| session.peer.send(conn_id, message.clone()),
1443    );
1444    room_updated(&room, &session.peer);
1445
1446    Ok(())
1447}
1448
1449async fn join_project(
1450    request: proto::JoinProject,
1451    response: Response<proto::JoinProject>,
1452    session: Session,
1453) -> Result<()> {
1454    let project_id = ProjectId::from_proto(request.project_id);
1455    let guest_user_id = session.user_id;
1456
1457    tracing::info!(%project_id, "join project");
1458
1459    let (project, replica_id) = &mut *session
1460        .db()
1461        .await
1462        .join_project(project_id, session.connection_id)
1463        .await?;
1464
1465    let collaborators = project
1466        .collaborators
1467        .iter()
1468        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1469        .map(|collaborator| collaborator.to_proto())
1470        .collect::<Vec<_>>();
1471
1472    let worktrees = project
1473        .worktrees
1474        .iter()
1475        .map(|(id, worktree)| proto::WorktreeMetadata {
1476            id: *id,
1477            root_name: worktree.root_name.clone(),
1478            visible: worktree.visible,
1479            abs_path: worktree.abs_path.clone(),
1480        })
1481        .collect::<Vec<_>>();
1482
1483    for collaborator in &collaborators {
1484        session
1485            .peer
1486            .send(
1487                collaborator.peer_id.unwrap().into(),
1488                proto::AddProjectCollaborator {
1489                    project_id: project_id.to_proto(),
1490                    collaborator: Some(proto::Collaborator {
1491                        peer_id: Some(session.connection_id.into()),
1492                        replica_id: replica_id.0 as u32,
1493                        user_id: guest_user_id.to_proto(),
1494                    }),
1495                },
1496            )
1497            .trace_err();
1498    }
1499
1500    // First, we send the metadata associated with each worktree.
1501    response.send(proto::JoinProjectResponse {
1502        worktrees: worktrees.clone(),
1503        replica_id: replica_id.0 as u32,
1504        collaborators: collaborators.clone(),
1505        language_servers: project.language_servers.clone(),
1506    })?;
1507
1508    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1509        #[cfg(any(test, feature = "test-support"))]
1510        const MAX_CHUNK_SIZE: usize = 2;
1511        #[cfg(not(any(test, feature = "test-support")))]
1512        const MAX_CHUNK_SIZE: usize = 256;
1513
1514        // Stream this worktree's entries.
1515        let message = proto::UpdateWorktree {
1516            project_id: project_id.to_proto(),
1517            worktree_id,
1518            abs_path: worktree.abs_path.clone(),
1519            root_name: worktree.root_name,
1520            updated_entries: worktree.entries,
1521            removed_entries: Default::default(),
1522            scan_id: worktree.scan_id,
1523            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1524            updated_repositories: worktree.repository_entries.into_values().collect(),
1525            removed_repositories: Default::default(),
1526        };
1527        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1528            session.peer.send(session.connection_id, update.clone())?;
1529        }
1530
1531        // Stream this worktree's diagnostics.
1532        for summary in worktree.diagnostic_summaries {
1533            session.peer.send(
1534                session.connection_id,
1535                proto::UpdateDiagnosticSummary {
1536                    project_id: project_id.to_proto(),
1537                    worktree_id: worktree.id,
1538                    summary: Some(summary),
1539                },
1540            )?;
1541        }
1542
1543        for settings_file in worktree.settings_files {
1544            session.peer.send(
1545                session.connection_id,
1546                proto::UpdateWorktreeSettings {
1547                    project_id: project_id.to_proto(),
1548                    worktree_id: worktree.id,
1549                    path: settings_file.path,
1550                    content: Some(settings_file.content),
1551                },
1552            )?;
1553        }
1554    }
1555
1556    for language_server in &project.language_servers {
1557        session.peer.send(
1558            session.connection_id,
1559            proto::UpdateLanguageServer {
1560                project_id: project_id.to_proto(),
1561                language_server_id: language_server.id,
1562                variant: Some(
1563                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1564                        proto::LspDiskBasedDiagnosticsUpdated {},
1565                    ),
1566                ),
1567            },
1568        )?;
1569    }
1570
1571    Ok(())
1572}
1573
1574async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1575    let sender_id = session.connection_id;
1576    let project_id = ProjectId::from_proto(request.project_id);
1577
1578    let (room, project) = &*session
1579        .db()
1580        .await
1581        .leave_project(project_id, sender_id)
1582        .await?;
1583    tracing::info!(
1584        %project_id,
1585        host_user_id = %project.host_user_id,
1586        host_connection_id = %project.host_connection_id,
1587        "leave project"
1588    );
1589
1590    project_left(&project, &session);
1591    room_updated(&room, &session.peer);
1592
1593    Ok(())
1594}
1595
1596async fn update_project(
1597    request: proto::UpdateProject,
1598    response: Response<proto::UpdateProject>,
1599    session: Session,
1600) -> Result<()> {
1601    let project_id = ProjectId::from_proto(request.project_id);
1602    let (room, guest_connection_ids) = &*session
1603        .db()
1604        .await
1605        .update_project(project_id, session.connection_id, &request.worktrees)
1606        .await?;
1607    broadcast(
1608        Some(session.connection_id),
1609        guest_connection_ids.iter().copied(),
1610        |connection_id| {
1611            session
1612                .peer
1613                .forward_send(session.connection_id, connection_id, request.clone())
1614        },
1615    );
1616    room_updated(&room, &session.peer);
1617    response.send(proto::Ack {})?;
1618
1619    Ok(())
1620}
1621
1622async fn update_worktree(
1623    request: proto::UpdateWorktree,
1624    response: Response<proto::UpdateWorktree>,
1625    session: Session,
1626) -> Result<()> {
1627    let guest_connection_ids = session
1628        .db()
1629        .await
1630        .update_worktree(&request, session.connection_id)
1631        .await?;
1632
1633    broadcast(
1634        Some(session.connection_id),
1635        guest_connection_ids.iter().copied(),
1636        |connection_id| {
1637            session
1638                .peer
1639                .forward_send(session.connection_id, connection_id, request.clone())
1640        },
1641    );
1642    response.send(proto::Ack {})?;
1643    Ok(())
1644}
1645
1646async fn update_diagnostic_summary(
1647    message: proto::UpdateDiagnosticSummary,
1648    session: Session,
1649) -> Result<()> {
1650    let guest_connection_ids = session
1651        .db()
1652        .await
1653        .update_diagnostic_summary(&message, session.connection_id)
1654        .await?;
1655
1656    broadcast(
1657        Some(session.connection_id),
1658        guest_connection_ids.iter().copied(),
1659        |connection_id| {
1660            session
1661                .peer
1662                .forward_send(session.connection_id, connection_id, message.clone())
1663        },
1664    );
1665
1666    Ok(())
1667}
1668
1669async fn update_worktree_settings(
1670    message: proto::UpdateWorktreeSettings,
1671    session: Session,
1672) -> Result<()> {
1673    let guest_connection_ids = session
1674        .db()
1675        .await
1676        .update_worktree_settings(&message, session.connection_id)
1677        .await?;
1678
1679    broadcast(
1680        Some(session.connection_id),
1681        guest_connection_ids.iter().copied(),
1682        |connection_id| {
1683            session
1684                .peer
1685                .forward_send(session.connection_id, connection_id, message.clone())
1686        },
1687    );
1688
1689    Ok(())
1690}
1691
1692async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1693    broadcast_project_message(request.project_id, request, session).await
1694}
1695
1696async fn start_language_server(
1697    request: proto::StartLanguageServer,
1698    session: Session,
1699) -> Result<()> {
1700    let guest_connection_ids = session
1701        .db()
1702        .await
1703        .start_language_server(&request, session.connection_id)
1704        .await?;
1705
1706    broadcast(
1707        Some(session.connection_id),
1708        guest_connection_ids.iter().copied(),
1709        |connection_id| {
1710            session
1711                .peer
1712                .forward_send(session.connection_id, connection_id, request.clone())
1713        },
1714    );
1715    Ok(())
1716}
1717
1718async fn update_language_server(
1719    request: proto::UpdateLanguageServer,
1720    session: Session,
1721) -> Result<()> {
1722    session.executor.record_backtrace();
1723    let project_id = ProjectId::from_proto(request.project_id);
1724    let project_connection_ids = session
1725        .db()
1726        .await
1727        .project_connection_ids(project_id, session.connection_id)
1728        .await?;
1729    broadcast(
1730        Some(session.connection_id),
1731        project_connection_ids.iter().copied(),
1732        |connection_id| {
1733            session
1734                .peer
1735                .forward_send(session.connection_id, connection_id, request.clone())
1736        },
1737    );
1738    Ok(())
1739}
1740
1741async fn forward_project_request<T>(
1742    request: T,
1743    response: Response<T>,
1744    session: Session,
1745) -> Result<()>
1746where
1747    T: EntityMessage + RequestMessage,
1748{
1749    session.executor.record_backtrace();
1750    let project_id = ProjectId::from_proto(request.remote_entity_id());
1751    let host_connection_id = {
1752        let collaborators = session
1753            .db()
1754            .await
1755            .project_collaborators(project_id, session.connection_id)
1756            .await?;
1757        collaborators
1758            .iter()
1759            .find(|collaborator| collaborator.is_host)
1760            .ok_or_else(|| anyhow!("host not found"))?
1761            .connection_id
1762    };
1763
1764    let payload = session
1765        .peer
1766        .forward_request(session.connection_id, host_connection_id, request)
1767        .await?;
1768
1769    response.send(payload)?;
1770    Ok(())
1771}
1772
1773async fn create_buffer_for_peer(
1774    request: proto::CreateBufferForPeer,
1775    session: Session,
1776) -> Result<()> {
1777    session.executor.record_backtrace();
1778    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1779    session
1780        .peer
1781        .forward_send(session.connection_id, peer_id.into(), request)?;
1782    Ok(())
1783}
1784
1785async fn update_buffer(
1786    request: proto::UpdateBuffer,
1787    response: Response<proto::UpdateBuffer>,
1788    session: Session,
1789) -> Result<()> {
1790    session.executor.record_backtrace();
1791    let project_id = ProjectId::from_proto(request.project_id);
1792    let mut guest_connection_ids;
1793    let mut host_connection_id = None;
1794    {
1795        let collaborators = session
1796            .db()
1797            .await
1798            .project_collaborators(project_id, session.connection_id)
1799            .await?;
1800        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1801        for collaborator in collaborators.iter() {
1802            if collaborator.is_host {
1803                host_connection_id = Some(collaborator.connection_id);
1804            } else {
1805                guest_connection_ids.push(collaborator.connection_id);
1806            }
1807        }
1808    }
1809    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1810
1811    session.executor.record_backtrace();
1812    broadcast(
1813        Some(session.connection_id),
1814        guest_connection_ids,
1815        |connection_id| {
1816            session
1817                .peer
1818                .forward_send(session.connection_id, connection_id, request.clone())
1819        },
1820    );
1821    if host_connection_id != session.connection_id {
1822        session
1823            .peer
1824            .forward_request(session.connection_id, host_connection_id, request.clone())
1825            .await?;
1826    }
1827
1828    response.send(proto::Ack {})?;
1829    Ok(())
1830}
1831
1832async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1833    let project_id = ProjectId::from_proto(request.project_id);
1834    let project_connection_ids = session
1835        .db()
1836        .await
1837        .project_connection_ids(project_id, session.connection_id)
1838        .await?;
1839
1840    broadcast(
1841        Some(session.connection_id),
1842        project_connection_ids.iter().copied(),
1843        |connection_id| {
1844            session
1845                .peer
1846                .forward_send(session.connection_id, connection_id, request.clone())
1847        },
1848    );
1849    Ok(())
1850}
1851
1852async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1853    let project_id = ProjectId::from_proto(request.project_id);
1854    let project_connection_ids = session
1855        .db()
1856        .await
1857        .project_connection_ids(project_id, session.connection_id)
1858        .await?;
1859    broadcast(
1860        Some(session.connection_id),
1861        project_connection_ids.iter().copied(),
1862        |connection_id| {
1863            session
1864                .peer
1865                .forward_send(session.connection_id, connection_id, request.clone())
1866        },
1867    );
1868    Ok(())
1869}
1870
1871async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1872    broadcast_project_message(request.project_id, request, session).await
1873}
1874
1875async fn broadcast_project_message<T: EnvelopedMessage>(
1876    project_id: u64,
1877    request: T,
1878    session: Session,
1879) -> Result<()> {
1880    let project_id = ProjectId::from_proto(project_id);
1881    let project_connection_ids = session
1882        .db()
1883        .await
1884        .project_connection_ids(project_id, session.connection_id)
1885        .await?;
1886    broadcast(
1887        Some(session.connection_id),
1888        project_connection_ids.iter().copied(),
1889        |connection_id| {
1890            session
1891                .peer
1892                .forward_send(session.connection_id, connection_id, request.clone())
1893        },
1894    );
1895    Ok(())
1896}
1897
1898async fn follow(
1899    request: proto::Follow,
1900    response: Response<proto::Follow>,
1901    session: Session,
1902) -> Result<()> {
1903    let room_id = RoomId::from_proto(request.room_id);
1904    let project_id = request.project_id.map(ProjectId::from_proto);
1905    let leader_id = request
1906        .leader_id
1907        .ok_or_else(|| anyhow!("invalid leader id"))?
1908        .into();
1909    let follower_id = session.connection_id;
1910
1911    session
1912        .db()
1913        .await
1914        .check_room_participants(room_id, leader_id, session.connection_id)
1915        .await?;
1916
1917    let response_payload = session
1918        .peer
1919        .forward_request(session.connection_id, leader_id, request)
1920        .await?;
1921    response.send(response_payload)?;
1922
1923    if let Some(project_id) = project_id {
1924        let room = session
1925            .db()
1926            .await
1927            .follow(room_id, project_id, leader_id, follower_id)
1928            .await?;
1929        room_updated(&room, &session.peer);
1930    }
1931
1932    Ok(())
1933}
1934
1935async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1936    let room_id = RoomId::from_proto(request.room_id);
1937    let project_id = request.project_id.map(ProjectId::from_proto);
1938    let leader_id = request
1939        .leader_id
1940        .ok_or_else(|| anyhow!("invalid leader id"))?
1941        .into();
1942    let follower_id = session.connection_id;
1943
1944    session
1945        .db()
1946        .await
1947        .check_room_participants(room_id, leader_id, session.connection_id)
1948        .await?;
1949
1950    session
1951        .peer
1952        .forward_send(session.connection_id, leader_id, request)?;
1953
1954    if let Some(project_id) = project_id {
1955        let room = session
1956            .db()
1957            .await
1958            .unfollow(room_id, project_id, leader_id, follower_id)
1959            .await?;
1960        room_updated(&room, &session.peer);
1961    }
1962
1963    Ok(())
1964}
1965
1966async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1967    let room_id = RoomId::from_proto(request.room_id);
1968    let database = session.db.lock().await;
1969
1970    let connection_ids = if let Some(project_id) = request.project_id {
1971        let project_id = ProjectId::from_proto(project_id);
1972        database
1973            .project_connection_ids(project_id, session.connection_id)
1974            .await?
1975    } else {
1976        database
1977            .room_connection_ids(room_id, session.connection_id)
1978            .await?
1979    };
1980
1981    // For now, don't send view update messages back to that view's current leader.
1982    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1983        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1984        _ => None,
1985    });
1986
1987    for follower_peer_id in request.follower_ids.iter().copied() {
1988        let follower_connection_id = follower_peer_id.into();
1989        if Some(follower_peer_id) != connection_id_to_omit
1990            && connection_ids.contains(&follower_connection_id)
1991        {
1992            session.peer.forward_send(
1993                session.connection_id,
1994                follower_connection_id,
1995                request.clone(),
1996            )?;
1997        }
1998    }
1999    Ok(())
2000}
2001
2002async fn get_users(
2003    request: proto::GetUsers,
2004    response: Response<proto::GetUsers>,
2005    session: Session,
2006) -> Result<()> {
2007    let user_ids = request
2008        .user_ids
2009        .into_iter()
2010        .map(UserId::from_proto)
2011        .collect();
2012    let users = session
2013        .db()
2014        .await
2015        .get_users_by_ids(user_ids)
2016        .await?
2017        .into_iter()
2018        .map(|user| proto::User {
2019            id: user.id.to_proto(),
2020            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2021            github_login: user.github_login,
2022        })
2023        .collect();
2024    response.send(proto::UsersResponse { users })?;
2025    Ok(())
2026}
2027
2028async fn fuzzy_search_users(
2029    request: proto::FuzzySearchUsers,
2030    response: Response<proto::FuzzySearchUsers>,
2031    session: Session,
2032) -> Result<()> {
2033    let query = request.query;
2034    let users = match query.len() {
2035        0 => vec![],
2036        1 | 2 => session
2037            .db()
2038            .await
2039            .get_user_by_github_login(&query)
2040            .await?
2041            .into_iter()
2042            .collect(),
2043        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2044    };
2045    let users = users
2046        .into_iter()
2047        .filter(|user| user.id != session.user_id)
2048        .map(|user| proto::User {
2049            id: user.id.to_proto(),
2050            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2051            github_login: user.github_login,
2052        })
2053        .collect();
2054    response.send(proto::UsersResponse { users })?;
2055    Ok(())
2056}
2057
2058async fn request_contact(
2059    request: proto::RequestContact,
2060    response: Response<proto::RequestContact>,
2061    session: Session,
2062) -> Result<()> {
2063    let requester_id = session.user_id;
2064    let responder_id = UserId::from_proto(request.responder_id);
2065    if requester_id == responder_id {
2066        return Err(anyhow!("cannot add yourself as a contact"))?;
2067    }
2068
2069    let notifications = session
2070        .db()
2071        .await
2072        .send_contact_request(requester_id, responder_id)
2073        .await?;
2074
2075    // Update outgoing contact requests of requester
2076    let mut update = proto::UpdateContacts::default();
2077    update.outgoing_requests.push(responder_id.to_proto());
2078    for connection_id in session
2079        .connection_pool()
2080        .await
2081        .user_connection_ids(requester_id)
2082    {
2083        session.peer.send(connection_id, update.clone())?;
2084    }
2085
2086    // Update incoming contact requests of responder
2087    let mut update = proto::UpdateContacts::default();
2088    update
2089        .incoming_requests
2090        .push(proto::IncomingContactRequest {
2091            requester_id: requester_id.to_proto(),
2092        });
2093    let connection_pool = session.connection_pool().await;
2094    for connection_id in connection_pool.user_connection_ids(responder_id) {
2095        session.peer.send(connection_id, update.clone())?;
2096    }
2097
2098    send_notifications(&*connection_pool, &session.peer, notifications);
2099
2100    response.send(proto::Ack {})?;
2101    Ok(())
2102}
2103
2104async fn respond_to_contact_request(
2105    request: proto::RespondToContactRequest,
2106    response: Response<proto::RespondToContactRequest>,
2107    session: Session,
2108) -> Result<()> {
2109    let responder_id = session.user_id;
2110    let requester_id = UserId::from_proto(request.requester_id);
2111    let db = session.db().await;
2112    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2113        db.dismiss_contact_notification(responder_id, requester_id)
2114            .await?;
2115    } else {
2116        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2117
2118        let notifications = db
2119            .respond_to_contact_request(responder_id, requester_id, accept)
2120            .await?;
2121        let requester_busy = db.is_user_busy(requester_id).await?;
2122        let responder_busy = db.is_user_busy(responder_id).await?;
2123
2124        let pool = session.connection_pool().await;
2125        // Update responder with new contact
2126        let mut update = proto::UpdateContacts::default();
2127        if accept {
2128            update
2129                .contacts
2130                .push(contact_for_user(requester_id, requester_busy, &pool));
2131        }
2132        update
2133            .remove_incoming_requests
2134            .push(requester_id.to_proto());
2135        for connection_id in pool.user_connection_ids(responder_id) {
2136            session.peer.send(connection_id, update.clone())?;
2137        }
2138
2139        // Update requester with new contact
2140        let mut update = proto::UpdateContacts::default();
2141        if accept {
2142            update
2143                .contacts
2144                .push(contact_for_user(responder_id, responder_busy, &pool));
2145        }
2146        update
2147            .remove_outgoing_requests
2148            .push(responder_id.to_proto());
2149
2150        for connection_id in pool.user_connection_ids(requester_id) {
2151            session.peer.send(connection_id, update.clone())?;
2152        }
2153
2154        send_notifications(&*pool, &session.peer, notifications);
2155    }
2156
2157    response.send(proto::Ack {})?;
2158    Ok(())
2159}
2160
2161async fn remove_contact(
2162    request: proto::RemoveContact,
2163    response: Response<proto::RemoveContact>,
2164    session: Session,
2165) -> Result<()> {
2166    let requester_id = session.user_id;
2167    let responder_id = UserId::from_proto(request.user_id);
2168    let db = session.db().await;
2169    let (contact_accepted, deleted_notification_id) =
2170        db.remove_contact(requester_id, responder_id).await?;
2171
2172    let pool = session.connection_pool().await;
2173    // Update outgoing contact requests of requester
2174    let mut update = proto::UpdateContacts::default();
2175    if contact_accepted {
2176        update.remove_contacts.push(responder_id.to_proto());
2177    } else {
2178        update
2179            .remove_outgoing_requests
2180            .push(responder_id.to_proto());
2181    }
2182    for connection_id in pool.user_connection_ids(requester_id) {
2183        session.peer.send(connection_id, update.clone())?;
2184    }
2185
2186    // Update incoming contact requests of responder
2187    let mut update = proto::UpdateContacts::default();
2188    if contact_accepted {
2189        update.remove_contacts.push(requester_id.to_proto());
2190    } else {
2191        update
2192            .remove_incoming_requests
2193            .push(requester_id.to_proto());
2194    }
2195    for connection_id in pool.user_connection_ids(responder_id) {
2196        session.peer.send(connection_id, update.clone())?;
2197        if let Some(notification_id) = deleted_notification_id {
2198            session.peer.send(
2199                connection_id,
2200                proto::DeleteNotification {
2201                    notification_id: notification_id.to_proto(),
2202                },
2203            )?;
2204        }
2205    }
2206
2207    response.send(proto::Ack {})?;
2208    Ok(())
2209}
2210
2211async fn create_channel(
2212    request: proto::CreateChannel,
2213    response: Response<proto::CreateChannel>,
2214    session: Session,
2215) -> Result<()> {
2216    let db = session.db().await;
2217
2218    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2219    let id = db
2220        .create_channel(&request.name, parent_id, session.user_id)
2221        .await?;
2222
2223    let channel = proto::Channel {
2224        id: id.to_proto(),
2225        name: request.name,
2226        visibility: proto::ChannelVisibility::Members as i32,
2227    };
2228
2229    response.send(proto::CreateChannelResponse {
2230        channel: Some(channel.clone()),
2231        parent_id: request.parent_id,
2232    })?;
2233
2234    let Some(parent_id) = parent_id else {
2235        return Ok(());
2236    };
2237
2238    let update = proto::UpdateChannels {
2239        channels: vec![channel],
2240        insert_edge: vec![ChannelEdge {
2241            parent_id: parent_id.to_proto(),
2242            channel_id: id.to_proto(),
2243        }],
2244        ..Default::default()
2245    };
2246
2247    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2248
2249    let connection_pool = session.connection_pool().await;
2250    for user_id in user_ids_to_notify {
2251        for connection_id in connection_pool.user_connection_ids(user_id) {
2252            if user_id == session.user_id {
2253                continue;
2254            }
2255            session.peer.send(connection_id, update.clone())?;
2256        }
2257    }
2258
2259    Ok(())
2260}
2261
2262async fn delete_channel(
2263    request: proto::DeleteChannel,
2264    response: Response<proto::DeleteChannel>,
2265    session: Session,
2266) -> Result<()> {
2267    let db = session.db().await;
2268
2269    let channel_id = request.channel_id;
2270    let (removed_channels, member_ids) = db
2271        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2272        .await?;
2273    response.send(proto::Ack {})?;
2274
2275    // Notify members of removed channels
2276    let mut update = proto::UpdateChannels::default();
2277    update
2278        .delete_channels
2279        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2280
2281    let connection_pool = session.connection_pool().await;
2282    for member_id in member_ids {
2283        for connection_id in connection_pool.user_connection_ids(member_id) {
2284            session.peer.send(connection_id, update.clone())?;
2285        }
2286    }
2287
2288    Ok(())
2289}
2290
2291async fn invite_channel_member(
2292    request: proto::InviteChannelMember,
2293    response: Response<proto::InviteChannelMember>,
2294    session: Session,
2295) -> Result<()> {
2296    let db = session.db().await;
2297    let channel_id = ChannelId::from_proto(request.channel_id);
2298    let invitee_id = UserId::from_proto(request.user_id);
2299    let notifications = db
2300        .invite_channel_member(
2301            channel_id,
2302            invitee_id,
2303            session.user_id,
2304            request.role().into(),
2305        )
2306        .await?;
2307
2308    let channel = db.get_channel(channel_id, session.user_id).await?;
2309
2310    let mut update = proto::UpdateChannels::default();
2311    update.channel_invitations.push(proto::Channel {
2312        id: channel.id.to_proto(),
2313        visibility: channel.visibility.into(),
2314        name: channel.name,
2315    });
2316
2317    let pool = session.connection_pool().await;
2318    for connection_id in pool.user_connection_ids(invitee_id) {
2319        session.peer.send(connection_id, update.clone())?;
2320    }
2321
2322    send_notifications(&*pool, &session.peer, notifications);
2323
2324    response.send(proto::Ack {})?;
2325    Ok(())
2326}
2327
2328async fn remove_channel_member(
2329    request: proto::RemoveChannelMember,
2330    response: Response<proto::RemoveChannelMember>,
2331    session: Session,
2332) -> Result<()> {
2333    let db = session.db().await;
2334    let channel_id = ChannelId::from_proto(request.channel_id);
2335    let member_id = UserId::from_proto(request.user_id);
2336
2337    let removed_notification_id = db
2338        .remove_channel_member(channel_id, member_id, session.user_id)
2339        .await?;
2340
2341    let mut update = proto::UpdateChannels::default();
2342    update.delete_channels.push(channel_id.to_proto());
2343
2344    for connection_id in session
2345        .connection_pool()
2346        .await
2347        .user_connection_ids(member_id)
2348    {
2349        session.peer.send(connection_id, update.clone()).trace_err();
2350        if let Some(notification_id) = removed_notification_id {
2351            session
2352                .peer
2353                .send(
2354                    connection_id,
2355                    proto::DeleteNotification {
2356                        notification_id: notification_id.to_proto(),
2357                    },
2358                )
2359                .trace_err();
2360        }
2361    }
2362
2363    response.send(proto::Ack {})?;
2364    Ok(())
2365}
2366
2367async fn set_channel_visibility(
2368    request: proto::SetChannelVisibility,
2369    response: Response<proto::SetChannelVisibility>,
2370    session: Session,
2371) -> Result<()> {
2372    let db = session.db().await;
2373    let channel_id = ChannelId::from_proto(request.channel_id);
2374    let visibility = request.visibility().into();
2375
2376    let channel = db
2377        .set_channel_visibility(channel_id, visibility, session.user_id)
2378        .await?;
2379
2380    let mut update = proto::UpdateChannels::default();
2381    update.channels.push(proto::Channel {
2382        id: channel.id.to_proto(),
2383        name: channel.name,
2384        visibility: channel.visibility.into(),
2385    });
2386
2387    let member_ids = db.get_channel_members(channel_id).await?;
2388
2389    let connection_pool = session.connection_pool().await;
2390    for member_id in member_ids {
2391        for connection_id in connection_pool.user_connection_ids(member_id) {
2392            session.peer.send(connection_id, update.clone())?;
2393        }
2394    }
2395
2396    response.send(proto::Ack {})?;
2397    Ok(())
2398}
2399
2400async fn set_channel_member_role(
2401    request: proto::SetChannelMemberRole,
2402    response: Response<proto::SetChannelMemberRole>,
2403    session: Session,
2404) -> Result<()> {
2405    let db = session.db().await;
2406    let channel_id = ChannelId::from_proto(request.channel_id);
2407    let member_id = UserId::from_proto(request.user_id);
2408    let channel_member = db
2409        .set_channel_member_role(
2410            channel_id,
2411            session.user_id,
2412            member_id,
2413            request.role().into(),
2414        )
2415        .await?;
2416
2417    let channel = db.get_channel(channel_id, session.user_id).await?;
2418
2419    let mut update = proto::UpdateChannels::default();
2420    if channel_member.accepted {
2421        update.channel_permissions.push(proto::ChannelPermission {
2422            channel_id: channel.id.to_proto(),
2423            role: request.role,
2424        });
2425    }
2426
2427    for connection_id in session
2428        .connection_pool()
2429        .await
2430        .user_connection_ids(member_id)
2431    {
2432        session.peer.send(connection_id, update.clone())?;
2433    }
2434
2435    response.send(proto::Ack {})?;
2436    Ok(())
2437}
2438
2439async fn rename_channel(
2440    request: proto::RenameChannel,
2441    response: Response<proto::RenameChannel>,
2442    session: Session,
2443) -> Result<()> {
2444    let db = session.db().await;
2445    let channel_id = ChannelId::from_proto(request.channel_id);
2446    let channel = db
2447        .rename_channel(channel_id, session.user_id, &request.name)
2448        .await?;
2449
2450    let channel = proto::Channel {
2451        id: channel.id.to_proto(),
2452        name: channel.name,
2453        visibility: channel.visibility.into(),
2454    };
2455    response.send(proto::RenameChannelResponse {
2456        channel: Some(channel.clone()),
2457    })?;
2458    let mut update = proto::UpdateChannels::default();
2459    update.channels.push(channel);
2460
2461    let member_ids = db.get_channel_members(channel_id).await?;
2462
2463    let connection_pool = session.connection_pool().await;
2464    for member_id in member_ids {
2465        for connection_id in connection_pool.user_connection_ids(member_id) {
2466            session.peer.send(connection_id, update.clone())?;
2467        }
2468    }
2469
2470    Ok(())
2471}
2472
2473async fn link_channel(
2474    request: proto::LinkChannel,
2475    response: Response<proto::LinkChannel>,
2476    session: Session,
2477) -> Result<()> {
2478    let db = session.db().await;
2479    let channel_id = ChannelId::from_proto(request.channel_id);
2480    let to = ChannelId::from_proto(request.to);
2481    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2482
2483    let members = db.get_channel_members(to).await?;
2484    let connection_pool = session.connection_pool().await;
2485    let update = proto::UpdateChannels {
2486        channels: channels_to_send
2487            .channels
2488            .into_iter()
2489            .map(|channel| proto::Channel {
2490                id: channel.id.to_proto(),
2491                visibility: channel.visibility.into(),
2492                name: channel.name,
2493            })
2494            .collect(),
2495        insert_edge: channels_to_send.edges,
2496        ..Default::default()
2497    };
2498    for member_id in members {
2499        for connection_id in connection_pool.user_connection_ids(member_id) {
2500            session.peer.send(connection_id, update.clone())?;
2501        }
2502    }
2503
2504    response.send(Ack {})?;
2505
2506    Ok(())
2507}
2508
2509async fn unlink_channel(
2510    request: proto::UnlinkChannel,
2511    response: Response<proto::UnlinkChannel>,
2512    session: Session,
2513) -> Result<()> {
2514    let db = session.db().await;
2515    let channel_id = ChannelId::from_proto(request.channel_id);
2516    let from = ChannelId::from_proto(request.from);
2517
2518    db.unlink_channel(session.user_id, channel_id, from).await?;
2519
2520    let members = db.get_channel_members(from).await?;
2521
2522    let update = proto::UpdateChannels {
2523        delete_edge: vec![proto::ChannelEdge {
2524            channel_id: channel_id.to_proto(),
2525            parent_id: from.to_proto(),
2526        }],
2527        ..Default::default()
2528    };
2529    let connection_pool = session.connection_pool().await;
2530    for member_id in members {
2531        for connection_id in connection_pool.user_connection_ids(member_id) {
2532            session.peer.send(connection_id, update.clone())?;
2533        }
2534    }
2535
2536    response.send(Ack {})?;
2537
2538    Ok(())
2539}
2540
2541async fn move_channel(
2542    request: proto::MoveChannel,
2543    response: Response<proto::MoveChannel>,
2544    session: Session,
2545) -> Result<()> {
2546    let db = session.db().await;
2547    let channel_id = ChannelId::from_proto(request.channel_id);
2548    let from_parent = ChannelId::from_proto(request.from);
2549    let to = ChannelId::from_proto(request.to);
2550
2551    let channels_to_send = db
2552        .move_channel(session.user_id, channel_id, from_parent, to)
2553        .await?;
2554
2555    if channels_to_send.is_empty() {
2556        response.send(Ack {})?;
2557        return Ok(());
2558    }
2559
2560    let members_from = db.get_channel_members(from_parent).await?;
2561    let members_to = db.get_channel_members(to).await?;
2562
2563    let update = proto::UpdateChannels {
2564        delete_edge: vec![proto::ChannelEdge {
2565            channel_id: channel_id.to_proto(),
2566            parent_id: from_parent.to_proto(),
2567        }],
2568        ..Default::default()
2569    };
2570    let connection_pool = session.connection_pool().await;
2571    for member_id in members_from {
2572        for connection_id in connection_pool.user_connection_ids(member_id) {
2573            session.peer.send(connection_id, update.clone())?;
2574        }
2575    }
2576
2577    let update = proto::UpdateChannels {
2578        channels: channels_to_send
2579            .channels
2580            .into_iter()
2581            .map(|channel| proto::Channel {
2582                id: channel.id.to_proto(),
2583                visibility: channel.visibility.into(),
2584                name: channel.name,
2585            })
2586            .collect(),
2587        insert_edge: channels_to_send.edges,
2588        ..Default::default()
2589    };
2590    for member_id in members_to {
2591        for connection_id in connection_pool.user_connection_ids(member_id) {
2592            session.peer.send(connection_id, update.clone())?;
2593        }
2594    }
2595
2596    response.send(Ack {})?;
2597
2598    Ok(())
2599}
2600
2601async fn get_channel_members(
2602    request: proto::GetChannelMembers,
2603    response: Response<proto::GetChannelMembers>,
2604    session: Session,
2605) -> Result<()> {
2606    let db = session.db().await;
2607    let channel_id = ChannelId::from_proto(request.channel_id);
2608    let members = db
2609        .get_channel_participant_details(channel_id, session.user_id)
2610        .await?;
2611    response.send(proto::GetChannelMembersResponse { members })?;
2612    Ok(())
2613}
2614
2615async fn respond_to_channel_invite(
2616    request: proto::RespondToChannelInvite,
2617    response: Response<proto::RespondToChannelInvite>,
2618    session: Session,
2619) -> Result<()> {
2620    let db = session.db().await;
2621    let channel_id = ChannelId::from_proto(request.channel_id);
2622    let notifications = db
2623        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2624        .await?;
2625
2626    if request.accept {
2627        channel_membership_updated(db, channel_id, &session).await?;
2628    } else {
2629        let mut update = proto::UpdateChannels::default();
2630        update
2631            .remove_channel_invitations
2632            .push(channel_id.to_proto());
2633        session.peer.send(session.connection_id, update)?;
2634    }
2635
2636    send_notifications(
2637        &*session.connection_pool().await,
2638        &session.peer,
2639        notifications,
2640    );
2641    response.send(proto::Ack {})?;
2642
2643    Ok(())
2644}
2645
2646async fn channel_membership_updated(
2647    db: tokio::sync::MutexGuard<'_, DbHandle>,
2648    channel_id: ChannelId,
2649    session: &Session,
2650) -> Result<(), crate::Error> {
2651    let mut update = proto::UpdateChannels::default();
2652    update
2653        .remove_channel_invitations
2654        .push(channel_id.to_proto());
2655
2656    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2657    update.channels.extend(
2658        result
2659            .channels
2660            .channels
2661            .into_iter()
2662            .map(|channel| proto::Channel {
2663                id: channel.id.to_proto(),
2664                visibility: channel.visibility.into(),
2665                name: channel.name,
2666            }),
2667    );
2668    update.unseen_channel_messages = result.channel_messages;
2669    update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2670    update.insert_edge = result.channels.edges;
2671    update
2672        .channel_participants
2673        .extend(
2674            result
2675                .channel_participants
2676                .into_iter()
2677                .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2678                    channel_id: channel_id.to_proto(),
2679                    participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2680                }),
2681        );
2682    update
2683        .channel_permissions
2684        .extend(
2685            result
2686                .channels_with_admin_privileges
2687                .into_iter()
2688                .map(|channel_id| proto::ChannelPermission {
2689                    channel_id: channel_id.to_proto(),
2690                    role: proto::ChannelRole::Admin.into(),
2691                }),
2692        );
2693    session.peer.send(session.connection_id, update)?;
2694    Ok(())
2695}
2696
2697async fn join_channel(
2698    request: proto::JoinChannel,
2699    response: Response<proto::JoinChannel>,
2700    session: Session,
2701) -> Result<()> {
2702    let channel_id = ChannelId::from_proto(request.channel_id);
2703    join_channel_internal(channel_id, Box::new(response), session).await
2704}
2705
2706trait JoinChannelInternalResponse {
2707    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2708}
2709impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2710    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2711        Response::<proto::JoinChannel>::send(self, result)
2712    }
2713}
2714impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2715    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2716        Response::<proto::JoinRoom>::send(self, result)
2717    }
2718}
2719
2720async fn join_channel_internal(
2721    channel_id: ChannelId,
2722    response: Box<impl JoinChannelInternalResponse>,
2723    session: Session,
2724) -> Result<()> {
2725    let joined_room = {
2726        leave_room_for_session(&session).await?;
2727        let db = session.db().await;
2728
2729        let (joined_room, joined_channel) = db
2730            .join_channel(
2731                channel_id,
2732                session.user_id,
2733                session.connection_id,
2734                RELEASE_CHANNEL_NAME.as_str(),
2735            )
2736            .await?;
2737
2738        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2739            let token = live_kit
2740                .room_token(
2741                    &joined_room.room.live_kit_room,
2742                    &session.user_id.to_string(),
2743                )
2744                .trace_err()?;
2745
2746            Some(LiveKitConnectionInfo {
2747                server_url: live_kit.url().into(),
2748                token,
2749            })
2750        });
2751
2752        response.send(proto::JoinRoomResponse {
2753            room: Some(joined_room.room.clone()),
2754            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2755            live_kit_connection_info,
2756        })?;
2757
2758        if let Some(joined_channel) = joined_channel {
2759            channel_membership_updated(db, joined_channel, &session).await?
2760        }
2761
2762        room_updated(&joined_room.room, &session.peer);
2763
2764        joined_room
2765    };
2766
2767    channel_updated(
2768        channel_id,
2769        &joined_room.room,
2770        &joined_room.channel_members,
2771        &session.peer,
2772        &*session.connection_pool().await,
2773    );
2774
2775    update_user_contacts(session.user_id, &session).await?;
2776    Ok(())
2777}
2778
2779async fn join_channel_buffer(
2780    request: proto::JoinChannelBuffer,
2781    response: Response<proto::JoinChannelBuffer>,
2782    session: Session,
2783) -> Result<()> {
2784    let db = session.db().await;
2785    let channel_id = ChannelId::from_proto(request.channel_id);
2786
2787    let open_response = db
2788        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2789        .await?;
2790
2791    let collaborators = open_response.collaborators.clone();
2792    response.send(open_response)?;
2793
2794    let update = UpdateChannelBufferCollaborators {
2795        channel_id: channel_id.to_proto(),
2796        collaborators: collaborators.clone(),
2797    };
2798    channel_buffer_updated(
2799        session.connection_id,
2800        collaborators
2801            .iter()
2802            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2803        &update,
2804        &session.peer,
2805    );
2806
2807    Ok(())
2808}
2809
2810async fn update_channel_buffer(
2811    request: proto::UpdateChannelBuffer,
2812    session: Session,
2813) -> Result<()> {
2814    let db = session.db().await;
2815    let channel_id = ChannelId::from_proto(request.channel_id);
2816
2817    let (collaborators, non_collaborators, epoch, version) = db
2818        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2819        .await?;
2820
2821    channel_buffer_updated(
2822        session.connection_id,
2823        collaborators,
2824        &proto::UpdateChannelBuffer {
2825            channel_id: channel_id.to_proto(),
2826            operations: request.operations,
2827        },
2828        &session.peer,
2829    );
2830
2831    let pool = &*session.connection_pool().await;
2832
2833    broadcast(
2834        None,
2835        non_collaborators
2836            .iter()
2837            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2838        |peer_id| {
2839            session.peer.send(
2840                peer_id.into(),
2841                proto::UpdateChannels {
2842                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2843                        channel_id: channel_id.to_proto(),
2844                        epoch: epoch as u64,
2845                        version: version.clone(),
2846                    }],
2847                    ..Default::default()
2848                },
2849            )
2850        },
2851    );
2852
2853    Ok(())
2854}
2855
2856async fn rejoin_channel_buffers(
2857    request: proto::RejoinChannelBuffers,
2858    response: Response<proto::RejoinChannelBuffers>,
2859    session: Session,
2860) -> Result<()> {
2861    let db = session.db().await;
2862    let buffers = db
2863        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2864        .await?;
2865
2866    for rejoined_buffer in &buffers {
2867        let collaborators_to_notify = rejoined_buffer
2868            .buffer
2869            .collaborators
2870            .iter()
2871            .filter_map(|c| Some(c.peer_id?.into()));
2872        channel_buffer_updated(
2873            session.connection_id,
2874            collaborators_to_notify,
2875            &proto::UpdateChannelBufferCollaborators {
2876                channel_id: rejoined_buffer.buffer.channel_id,
2877                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2878            },
2879            &session.peer,
2880        );
2881    }
2882
2883    response.send(proto::RejoinChannelBuffersResponse {
2884        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2885    })?;
2886
2887    Ok(())
2888}
2889
2890async fn leave_channel_buffer(
2891    request: proto::LeaveChannelBuffer,
2892    response: Response<proto::LeaveChannelBuffer>,
2893    session: Session,
2894) -> Result<()> {
2895    let db = session.db().await;
2896    let channel_id = ChannelId::from_proto(request.channel_id);
2897
2898    let left_buffer = db
2899        .leave_channel_buffer(channel_id, session.connection_id)
2900        .await?;
2901
2902    response.send(Ack {})?;
2903
2904    channel_buffer_updated(
2905        session.connection_id,
2906        left_buffer.connections,
2907        &proto::UpdateChannelBufferCollaborators {
2908            channel_id: channel_id.to_proto(),
2909            collaborators: left_buffer.collaborators,
2910        },
2911        &session.peer,
2912    );
2913
2914    Ok(())
2915}
2916
2917fn channel_buffer_updated<T: EnvelopedMessage>(
2918    sender_id: ConnectionId,
2919    collaborators: impl IntoIterator<Item = ConnectionId>,
2920    message: &T,
2921    peer: &Peer,
2922) {
2923    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2924        peer.send(peer_id.into(), message.clone())
2925    });
2926}
2927
2928fn send_notifications(
2929    connection_pool: &ConnectionPool,
2930    peer: &Peer,
2931    notifications: db::NotificationBatch,
2932) {
2933    for (user_id, notification) in notifications {
2934        for connection_id in connection_pool.user_connection_ids(user_id) {
2935            if let Err(error) = peer.send(
2936                connection_id,
2937                proto::NewNotification {
2938                    notification: Some(notification.clone()),
2939                },
2940            ) {
2941                tracing::error!(
2942                    "failed to send notification to {:?} {}",
2943                    connection_id,
2944                    error
2945                );
2946            }
2947        }
2948    }
2949}
2950
2951async fn send_channel_message(
2952    request: proto::SendChannelMessage,
2953    response: Response<proto::SendChannelMessage>,
2954    session: Session,
2955) -> Result<()> {
2956    // Validate the message body.
2957    let body = request.body.trim().to_string();
2958    if body.len() > MAX_MESSAGE_LEN {
2959        return Err(anyhow!("message is too long"))?;
2960    }
2961    if body.is_empty() {
2962        return Err(anyhow!("message can't be blank"))?;
2963    }
2964
2965    // TODO: adjust mentions if body is trimmed
2966
2967    let timestamp = OffsetDateTime::now_utc();
2968    let nonce = request
2969        .nonce
2970        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2971
2972    let channel_id = ChannelId::from_proto(request.channel_id);
2973    let CreatedChannelMessage {
2974        message_id,
2975        participant_connection_ids,
2976        channel_members,
2977        notifications,
2978    } = session
2979        .db()
2980        .await
2981        .create_channel_message(
2982            channel_id,
2983            session.user_id,
2984            &body,
2985            &request.mentions,
2986            timestamp,
2987            nonce.clone().into(),
2988        )
2989        .await?;
2990    let message = proto::ChannelMessage {
2991        sender_id: session.user_id.to_proto(),
2992        id: message_id.to_proto(),
2993        body,
2994        mentions: request.mentions,
2995        timestamp: timestamp.unix_timestamp() as u64,
2996        nonce: Some(nonce),
2997    };
2998    broadcast(
2999        Some(session.connection_id),
3000        participant_connection_ids,
3001        |connection| {
3002            session.peer.send(
3003                connection,
3004                proto::ChannelMessageSent {
3005                    channel_id: channel_id.to_proto(),
3006                    message: Some(message.clone()),
3007                },
3008            )
3009        },
3010    );
3011    response.send(proto::SendChannelMessageResponse {
3012        message: Some(message),
3013    })?;
3014
3015    let pool = &*session.connection_pool().await;
3016    broadcast(
3017        None,
3018        channel_members
3019            .iter()
3020            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3021        |peer_id| {
3022            session.peer.send(
3023                peer_id.into(),
3024                proto::UpdateChannels {
3025                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
3026                        channel_id: channel_id.to_proto(),
3027                        message_id: message_id.to_proto(),
3028                    }],
3029                    ..Default::default()
3030                },
3031            )
3032        },
3033    );
3034    send_notifications(pool, &session.peer, notifications);
3035
3036    Ok(())
3037}
3038
3039async fn remove_channel_message(
3040    request: proto::RemoveChannelMessage,
3041    response: Response<proto::RemoveChannelMessage>,
3042    session: Session,
3043) -> Result<()> {
3044    let channel_id = ChannelId::from_proto(request.channel_id);
3045    let message_id = MessageId::from_proto(request.message_id);
3046    let connection_ids = session
3047        .db()
3048        .await
3049        .remove_channel_message(channel_id, message_id, session.user_id)
3050        .await?;
3051    broadcast(Some(session.connection_id), connection_ids, |connection| {
3052        session.peer.send(connection, request.clone())
3053    });
3054    response.send(proto::Ack {})?;
3055    Ok(())
3056}
3057
3058async fn acknowledge_channel_message(
3059    request: proto::AckChannelMessage,
3060    session: Session,
3061) -> Result<()> {
3062    let channel_id = ChannelId::from_proto(request.channel_id);
3063    let message_id = MessageId::from_proto(request.message_id);
3064    session
3065        .db()
3066        .await
3067        .observe_channel_message(channel_id, session.user_id, message_id)
3068        .await?;
3069    Ok(())
3070}
3071
3072async fn acknowledge_buffer_version(
3073    request: proto::AckBufferOperation,
3074    session: Session,
3075) -> Result<()> {
3076    let buffer_id = BufferId::from_proto(request.buffer_id);
3077    session
3078        .db()
3079        .await
3080        .observe_buffer_version(
3081            buffer_id,
3082            session.user_id,
3083            request.epoch as i32,
3084            &request.version,
3085        )
3086        .await?;
3087    Ok(())
3088}
3089
3090async fn join_channel_chat(
3091    request: proto::JoinChannelChat,
3092    response: Response<proto::JoinChannelChat>,
3093    session: Session,
3094) -> Result<()> {
3095    let channel_id = ChannelId::from_proto(request.channel_id);
3096
3097    let db = session.db().await;
3098    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3099        .await?;
3100    let messages = db
3101        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3102        .await?;
3103    response.send(proto::JoinChannelChatResponse {
3104        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3105        messages,
3106    })?;
3107    Ok(())
3108}
3109
3110async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3111    let channel_id = ChannelId::from_proto(request.channel_id);
3112    session
3113        .db()
3114        .await
3115        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3116        .await?;
3117    Ok(())
3118}
3119
3120async fn get_channel_messages(
3121    request: proto::GetChannelMessages,
3122    response: Response<proto::GetChannelMessages>,
3123    session: Session,
3124) -> Result<()> {
3125    let channel_id = ChannelId::from_proto(request.channel_id);
3126    let messages = session
3127        .db()
3128        .await
3129        .get_channel_messages(
3130            channel_id,
3131            session.user_id,
3132            MESSAGE_COUNT_PER_PAGE,
3133            Some(MessageId::from_proto(request.before_message_id)),
3134        )
3135        .await?;
3136    response.send(proto::GetChannelMessagesResponse {
3137        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3138        messages,
3139    })?;
3140    Ok(())
3141}
3142
3143async fn get_channel_messages_by_id(
3144    request: proto::GetChannelMessagesById,
3145    response: Response<proto::GetChannelMessagesById>,
3146    session: Session,
3147) -> Result<()> {
3148    let message_ids = request
3149        .message_ids
3150        .iter()
3151        .map(|id| MessageId::from_proto(*id))
3152        .collect::<Vec<_>>();
3153    let messages = session
3154        .db()
3155        .await
3156        .get_channel_messages_by_id(session.user_id, &message_ids)
3157        .await?;
3158    response.send(proto::GetChannelMessagesResponse {
3159        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3160        messages,
3161    })?;
3162    Ok(())
3163}
3164
3165async fn get_notifications(
3166    request: proto::GetNotifications,
3167    response: Response<proto::GetNotifications>,
3168    session: Session,
3169) -> Result<()> {
3170    let notifications = session
3171        .db()
3172        .await
3173        .get_notifications(
3174            session.user_id,
3175            NOTIFICATION_COUNT_PER_PAGE,
3176            request
3177                .before_id
3178                .map(|id| db::NotificationId::from_proto(id)),
3179        )
3180        .await?;
3181    response.send(proto::GetNotificationsResponse { notifications })?;
3182    Ok(())
3183}
3184
3185async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3186    let project_id = ProjectId::from_proto(request.project_id);
3187    let project_connection_ids = session
3188        .db()
3189        .await
3190        .project_connection_ids(project_id, session.connection_id)
3191        .await?;
3192    broadcast(
3193        Some(session.connection_id),
3194        project_connection_ids.iter().copied(),
3195        |connection_id| {
3196            session
3197                .peer
3198                .forward_send(session.connection_id, connection_id, request.clone())
3199        },
3200    );
3201    Ok(())
3202}
3203
3204async fn get_private_user_info(
3205    _request: proto::GetPrivateUserInfo,
3206    response: Response<proto::GetPrivateUserInfo>,
3207    session: Session,
3208) -> Result<()> {
3209    let db = session.db().await;
3210
3211    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3212    let user = db
3213        .get_user_by_id(session.user_id)
3214        .await?
3215        .ok_or_else(|| anyhow!("user not found"))?;
3216    let flags = db.get_user_flags(session.user_id).await?;
3217
3218    response.send(proto::GetPrivateUserInfoResponse {
3219        metrics_id,
3220        staff: user.admin,
3221        flags,
3222    })?;
3223    Ok(())
3224}
3225
3226fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3227    match message {
3228        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3229        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3230        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3231        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3232        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3233            code: frame.code.into(),
3234            reason: frame.reason,
3235        })),
3236    }
3237}
3238
3239fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3240    match message {
3241        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3242        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3243        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3244        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3245        AxumMessage::Close(frame) => {
3246            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3247                code: frame.code.into(),
3248                reason: frame.reason,
3249            }))
3250        }
3251    }
3252}
3253
3254fn build_initial_channels_update(
3255    channels: ChannelsForUser,
3256    channel_invites: Vec<db::Channel>,
3257) -> proto::UpdateChannels {
3258    let mut update = proto::UpdateChannels::default();
3259
3260    for channel in channels.channels.channels {
3261        update.channels.push(proto::Channel {
3262            id: channel.id.to_proto(),
3263            name: channel.name,
3264            visibility: channel.visibility.into(),
3265        });
3266    }
3267
3268    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3269    update.unseen_channel_messages = channels.channel_messages;
3270    update.insert_edge = channels.channels.edges;
3271
3272    for (channel_id, participants) in channels.channel_participants {
3273        update
3274            .channel_participants
3275            .push(proto::ChannelParticipants {
3276                channel_id: channel_id.to_proto(),
3277                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3278            });
3279    }
3280
3281    update
3282        .channel_permissions
3283        .extend(
3284            channels
3285                .channels_with_admin_privileges
3286                .into_iter()
3287                .map(|id| proto::ChannelPermission {
3288                    channel_id: id.to_proto(),
3289                    role: proto::ChannelRole::Admin.into(),
3290                }),
3291        );
3292
3293    for channel in channel_invites {
3294        update.channel_invitations.push(proto::Channel {
3295            id: channel.id.to_proto(),
3296            name: channel.name,
3297            // TODO: Visibility
3298            visibility: ChannelVisibility::Public.into(),
3299        });
3300    }
3301
3302    update
3303}
3304
3305fn build_initial_contacts_update(
3306    contacts: Vec<db::Contact>,
3307    pool: &ConnectionPool,
3308) -> proto::UpdateContacts {
3309    let mut update = proto::UpdateContacts::default();
3310
3311    for contact in contacts {
3312        match contact {
3313            db::Contact::Accepted { user_id, busy } => {
3314                update.contacts.push(contact_for_user(user_id, busy, &pool));
3315            }
3316            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3317            db::Contact::Incoming { user_id } => {
3318                update
3319                    .incoming_requests
3320                    .push(proto::IncomingContactRequest {
3321                        requester_id: user_id.to_proto(),
3322                    })
3323            }
3324        }
3325    }
3326
3327    update
3328}
3329
3330fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3331    proto::Contact {
3332        user_id: user_id.to_proto(),
3333        online: pool.is_user_online(user_id),
3334        busy,
3335    }
3336}
3337
3338fn room_updated(room: &proto::Room, peer: &Peer) {
3339    broadcast(
3340        None,
3341        room.participants
3342            .iter()
3343            .filter_map(|participant| Some(participant.peer_id?.into())),
3344        |peer_id| {
3345            peer.send(
3346                peer_id.into(),
3347                proto::RoomUpdated {
3348                    room: Some(room.clone()),
3349                },
3350            )
3351        },
3352    );
3353}
3354
3355fn channel_updated(
3356    channel_id: ChannelId,
3357    room: &proto::Room,
3358    channel_members: &[UserId],
3359    peer: &Peer,
3360    pool: &ConnectionPool,
3361) {
3362    let participants = room
3363        .participants
3364        .iter()
3365        .map(|p| p.user_id)
3366        .collect::<Vec<_>>();
3367
3368    broadcast(
3369        None,
3370        channel_members
3371            .iter()
3372            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3373        |peer_id| {
3374            peer.send(
3375                peer_id.into(),
3376                proto::UpdateChannels {
3377                    channel_participants: vec![proto::ChannelParticipants {
3378                        channel_id: channel_id.to_proto(),
3379                        participant_user_ids: participants.clone(),
3380                    }],
3381                    ..Default::default()
3382                },
3383            )
3384        },
3385    );
3386}
3387
3388async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3389    let db = session.db().await;
3390
3391    let contacts = db.get_contacts(user_id).await?;
3392    let busy = db.is_user_busy(user_id).await?;
3393
3394    let pool = session.connection_pool().await;
3395    let updated_contact = contact_for_user(user_id, busy, &pool);
3396    for contact in contacts {
3397        if let db::Contact::Accepted {
3398            user_id: contact_user_id,
3399            ..
3400        } = contact
3401        {
3402            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3403                session
3404                    .peer
3405                    .send(
3406                        contact_conn_id,
3407                        proto::UpdateContacts {
3408                            contacts: vec![updated_contact.clone()],
3409                            remove_contacts: Default::default(),
3410                            incoming_requests: Default::default(),
3411                            remove_incoming_requests: Default::default(),
3412                            outgoing_requests: Default::default(),
3413                            remove_outgoing_requests: Default::default(),
3414                        },
3415                    )
3416                    .trace_err();
3417            }
3418        }
3419    }
3420    Ok(())
3421}
3422
3423async fn leave_room_for_session(session: &Session) -> Result<()> {
3424    let mut contacts_to_update = HashSet::default();
3425
3426    let room_id;
3427    let canceled_calls_to_user_ids;
3428    let live_kit_room;
3429    let delete_live_kit_room;
3430    let room;
3431    let channel_members;
3432    let channel_id;
3433
3434    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3435        contacts_to_update.insert(session.user_id);
3436
3437        for project in left_room.left_projects.values() {
3438            project_left(project, session);
3439        }
3440
3441        room_id = RoomId::from_proto(left_room.room.id);
3442        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3443        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3444        delete_live_kit_room = left_room.deleted;
3445        room = mem::take(&mut left_room.room);
3446        channel_members = mem::take(&mut left_room.channel_members);
3447        channel_id = left_room.channel_id;
3448
3449        room_updated(&room, &session.peer);
3450    } else {
3451        return Ok(());
3452    }
3453
3454    if let Some(channel_id) = channel_id {
3455        channel_updated(
3456            channel_id,
3457            &room,
3458            &channel_members,
3459            &session.peer,
3460            &*session.connection_pool().await,
3461        );
3462    }
3463
3464    {
3465        let pool = session.connection_pool().await;
3466        for canceled_user_id in canceled_calls_to_user_ids {
3467            for connection_id in pool.user_connection_ids(canceled_user_id) {
3468                session
3469                    .peer
3470                    .send(
3471                        connection_id,
3472                        proto::CallCanceled {
3473                            room_id: room_id.to_proto(),
3474                        },
3475                    )
3476                    .trace_err();
3477            }
3478            contacts_to_update.insert(canceled_user_id);
3479        }
3480    }
3481
3482    for contact_user_id in contacts_to_update {
3483        update_user_contacts(contact_user_id, &session).await?;
3484    }
3485
3486    if let Some(live_kit) = session.live_kit_client.as_ref() {
3487        live_kit
3488            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3489            .await
3490            .trace_err();
3491
3492        if delete_live_kit_room {
3493            live_kit.delete_room(live_kit_room).await.trace_err();
3494        }
3495    }
3496
3497    Ok(())
3498}
3499
3500async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3501    let left_channel_buffers = session
3502        .db()
3503        .await
3504        .leave_channel_buffers(session.connection_id)
3505        .await?;
3506
3507    for left_buffer in left_channel_buffers {
3508        channel_buffer_updated(
3509            session.connection_id,
3510            left_buffer.connections,
3511            &proto::UpdateChannelBufferCollaborators {
3512                channel_id: left_buffer.channel_id.to_proto(),
3513                collaborators: left_buffer.collaborators,
3514            },
3515            &session.peer,
3516        );
3517    }
3518
3519    Ok(())
3520}
3521
3522fn project_left(project: &db::LeftProject, session: &Session) {
3523    for connection_id in &project.connection_ids {
3524        if project.host_user_id == session.user_id {
3525            session
3526                .peer
3527                .send(
3528                    *connection_id,
3529                    proto::UnshareProject {
3530                        project_id: project.id.to_proto(),
3531                    },
3532                )
3533                .trace_err();
3534        } else {
3535            session
3536                .peer
3537                .send(
3538                    *connection_id,
3539                    proto::RemoveProjectCollaborator {
3540                        project_id: project.id.to_proto(),
3541                        peer_id: Some(session.connection_id.into()),
3542                    },
3543                )
3544                .trace_err();
3545        }
3546    }
3547}
3548
3549pub trait ResultExt {
3550    type Ok;
3551
3552    fn trace_err(self) -> Option<Self::Ok>;
3553}
3554
3555impl<T, E> ResultExt for Result<T, E>
3556where
3557    E: std::fmt::Debug,
3558{
3559    type Ok = T;
3560
3561    fn trace_err(self) -> Option<T> {
3562        match self {
3563            Ok(value) => Some(value),
3564            Err(error) => {
3565                tracing::error!("{:?}", error);
3566                None
3567            }
3568        }
3569    }
3570}