rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, Database, MessageId,
   7        ProjectId, RoomId, ServerId, User, UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, JoinRoom,
  42        LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66use util::channel::RELEASE_CHANNEL_NAME;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73
  74lazy_static! {
  75    static ref METRIC_CONNECTIONS: IntGauge =
  76        register_int_gauge!("connections", "number of connections").unwrap();
  77    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  78        "shared_projects",
  79        "number of open projects with one or more guests"
  80    )
  81    .unwrap();
  82}
  83
  84type MessageHandler =
  85    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  86
  87struct Response<R> {
  88    peer: Arc<Peer>,
  89    receipt: Receipt<R>,
  90    responded: Arc<AtomicBool>,
  91}
  92
  93impl<R: RequestMessage> Response<R> {
  94    fn send(self, payload: R::Response) -> Result<()> {
  95        self.responded.store(true, SeqCst);
  96        self.peer.respond(self.receipt, payload)?;
  97        Ok(())
  98    }
  99}
 100
 101#[derive(Clone)]
 102struct Session {
 103    user_id: UserId,
 104    connection_id: ConnectionId,
 105    db: Arc<tokio::sync::Mutex<DbHandle>>,
 106    peer: Arc<Peer>,
 107    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 108    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 109    executor: Executor,
 110}
 111
 112impl Session {
 113    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.db.lock().await;
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        guard
 120    }
 121
 122    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.connection_pool.lock();
 126        ConnectionPoolGuard {
 127            guard,
 128            _not_send: PhantomData,
 129        }
 130    }
 131}
 132
 133impl fmt::Debug for Session {
 134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 135        f.debug_struct("Session")
 136            .field("user_id", &self.user_id)
 137            .field("connection_id", &self.connection_id)
 138            .finish()
 139    }
 140}
 141
 142struct DbHandle(Arc<Database>);
 143
 144impl Deref for DbHandle {
 145    type Target = Database;
 146
 147    fn deref(&self) -> &Self::Target {
 148        self.0.as_ref()
 149    }
 150}
 151
 152pub struct Server {
 153    id: parking_lot::Mutex<ServerId>,
 154    peer: Arc<Peer>,
 155    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 156    app_state: Arc<AppState>,
 157    executor: Executor,
 158    handlers: HashMap<TypeId, MessageHandler>,
 159    teardown: watch::Sender<()>,
 160}
 161
 162pub(crate) struct ConnectionPoolGuard<'a> {
 163    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 164    _not_send: PhantomData<Rc<()>>,
 165}
 166
 167#[derive(Serialize)]
 168pub struct ServerSnapshot<'a> {
 169    peer: &'a Peer,
 170    #[serde(serialize_with = "serialize_deref")]
 171    connection_pool: ConnectionPoolGuard<'a>,
 172}
 173
 174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 175where
 176    S: Serializer,
 177    T: Deref<Target = U>,
 178    U: Serialize,
 179{
 180    Serialize::serialize(value.deref(), serializer)
 181}
 182
 183impl Server {
 184    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 185        let mut server = Self {
 186            id: parking_lot::Mutex::new(id),
 187            peer: Peer::new(id.0 as u32),
 188            app_state,
 189            executor,
 190            connection_pool: Default::default(),
 191            handlers: Default::default(),
 192            teardown: watch::channel(()).0,
 193        };
 194
 195        server
 196            .add_request_handler(ping)
 197            .add_request_handler(create_room)
 198            .add_request_handler(join_room)
 199            .add_request_handler(rejoin_room)
 200            .add_request_handler(leave_room)
 201            .add_request_handler(call)
 202            .add_request_handler(cancel_call)
 203            .add_message_handler(decline_call)
 204            .add_request_handler(update_participant_location)
 205            .add_request_handler(share_project)
 206            .add_message_handler(unshare_project)
 207            .add_request_handler(join_project)
 208            .add_message_handler(leave_project)
 209            .add_request_handler(update_project)
 210            .add_request_handler(update_worktree)
 211            .add_message_handler(start_language_server)
 212            .add_message_handler(update_language_server)
 213            .add_message_handler(update_diagnostic_summary)
 214            .add_message_handler(update_worktree_settings)
 215            .add_message_handler(refresh_inlay_hints)
 216            .add_request_handler(forward_project_request::<proto::GetHover>)
 217            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 219            .add_request_handler(forward_project_request::<proto::GetReferences>)
 220            .add_request_handler(forward_project_request::<proto::SearchProject>)
 221            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 222            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 225            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 226            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 227            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 228            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 229            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 230            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 231            .add_request_handler(forward_project_request::<proto::PerformRename>)
 232            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 233            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 234            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 235            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 240            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 241            .add_request_handler(forward_project_request::<proto::InlayHints>)
 242            .add_message_handler(create_buffer_for_peer)
 243            .add_request_handler(update_buffer)
 244            .add_message_handler(update_buffer_file)
 245            .add_message_handler(buffer_reloaded)
 246            .add_message_handler(buffer_saved)
 247            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 248            .add_request_handler(get_users)
 249            .add_request_handler(fuzzy_search_users)
 250            .add_request_handler(request_contact)
 251            .add_request_handler(remove_contact)
 252            .add_request_handler(respond_to_contact_request)
 253            .add_request_handler(create_channel)
 254            .add_request_handler(delete_channel)
 255            .add_request_handler(invite_channel_member)
 256            .add_request_handler(remove_channel_member)
 257            .add_request_handler(set_channel_member_role)
 258            .add_request_handler(set_channel_visibility)
 259            .add_request_handler(rename_channel)
 260            .add_request_handler(join_channel_buffer)
 261            .add_request_handler(leave_channel_buffer)
 262            .add_message_handler(update_channel_buffer)
 263            .add_request_handler(rejoin_channel_buffers)
 264            .add_request_handler(get_channel_members)
 265            .add_request_handler(respond_to_channel_invite)
 266            .add_request_handler(join_channel)
 267            .add_request_handler(join_channel_chat)
 268            .add_message_handler(leave_channel_chat)
 269            .add_request_handler(send_channel_message)
 270            .add_request_handler(remove_channel_message)
 271            .add_request_handler(get_channel_messages)
 272            .add_request_handler(link_channel)
 273            .add_request_handler(unlink_channel)
 274            .add_request_handler(move_channel)
 275            .add_request_handler(follow)
 276            .add_message_handler(unfollow)
 277            .add_message_handler(update_followers)
 278            .add_message_handler(update_diff_base)
 279            .add_request_handler(get_private_user_info)
 280            .add_message_handler(acknowledge_channel_message)
 281            .add_message_handler(acknowledge_buffer_version);
 282
 283        Arc::new(server)
 284    }
 285
 286    pub async fn start(&self) -> Result<()> {
 287        let server_id = *self.id.lock();
 288        let app_state = self.app_state.clone();
 289        let peer = self.peer.clone();
 290        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 291        let pool = self.connection_pool.clone();
 292        let live_kit_client = self.app_state.live_kit_client.clone();
 293
 294        let span = info_span!("start server");
 295        self.executor.spawn_detached(
 296            async move {
 297                tracing::info!("waiting for cleanup timeout");
 298                timeout.await;
 299                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 300                if let Some((room_ids, channel_ids)) = app_state
 301                    .db
 302                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 303                    .await
 304                    .trace_err()
 305                {
 306                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 307                    tracing::info!(
 308                        stale_channel_buffer_count = channel_ids.len(),
 309                        "retrieved stale channel buffers"
 310                    );
 311
 312                    for channel_id in channel_ids {
 313                        if let Some(refreshed_channel_buffer) = app_state
 314                            .db
 315                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 316                            .await
 317                            .trace_err()
 318                        {
 319                            for connection_id in refreshed_channel_buffer.connection_ids {
 320                                peer.send(
 321                                    connection_id,
 322                                    proto::UpdateChannelBufferCollaborators {
 323                                        channel_id: channel_id.to_proto(),
 324                                        collaborators: refreshed_channel_buffer
 325                                            .collaborators
 326                                            .clone(),
 327                                    },
 328                                )
 329                                .trace_err();
 330                            }
 331                        }
 332                    }
 333
 334                    for room_id in room_ids {
 335                        let mut contacts_to_update = HashSet::default();
 336                        let mut canceled_calls_to_user_ids = Vec::new();
 337                        let mut live_kit_room = String::new();
 338                        let mut delete_live_kit_room = false;
 339
 340                        if let Some(mut refreshed_room) = app_state
 341                            .db
 342                            .clear_stale_room_participants(room_id, server_id)
 343                            .await
 344                            .trace_err()
 345                        {
 346                            tracing::info!(
 347                                room_id = room_id.0,
 348                                new_participant_count = refreshed_room.room.participants.len(),
 349                                "refreshed room"
 350                            );
 351                            room_updated(&refreshed_room.room, &peer);
 352                            if let Some(channel_id) = refreshed_room.channel_id {
 353                                channel_updated(
 354                                    channel_id,
 355                                    &refreshed_room.room,
 356                                    &refreshed_room.channel_members,
 357                                    &peer,
 358                                    &*pool.lock(),
 359                                );
 360                            }
 361                            contacts_to_update
 362                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 363                            contacts_to_update
 364                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 365                            canceled_calls_to_user_ids =
 366                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 367                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 368                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 369                        }
 370
 371                        {
 372                            let pool = pool.lock();
 373                            for canceled_user_id in canceled_calls_to_user_ids {
 374                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 375                                    peer.send(
 376                                        connection_id,
 377                                        proto::CallCanceled {
 378                                            room_id: room_id.to_proto(),
 379                                        },
 380                                    )
 381                                    .trace_err();
 382                                }
 383                            }
 384                        }
 385
 386                        for user_id in contacts_to_update {
 387                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 388                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 389                            if let Some((busy, contacts)) = busy.zip(contacts) {
 390                                let pool = pool.lock();
 391                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 392                                for contact in contacts {
 393                                    if let db::Contact::Accepted {
 394                                        user_id: contact_user_id,
 395                                        ..
 396                                    } = contact
 397                                    {
 398                                        for contact_conn_id in
 399                                            pool.user_connection_ids(contact_user_id)
 400                                        {
 401                                            peer.send(
 402                                                contact_conn_id,
 403                                                proto::UpdateContacts {
 404                                                    contacts: vec![updated_contact.clone()],
 405                                                    remove_contacts: Default::default(),
 406                                                    incoming_requests: Default::default(),
 407                                                    remove_incoming_requests: Default::default(),
 408                                                    outgoing_requests: Default::default(),
 409                                                    remove_outgoing_requests: Default::default(),
 410                                                },
 411                                            )
 412                                            .trace_err();
 413                                        }
 414                                    }
 415                                }
 416                            }
 417                        }
 418
 419                        if let Some(live_kit) = live_kit_client.as_ref() {
 420                            if delete_live_kit_room {
 421                                live_kit.delete_room(live_kit_room).await.trace_err();
 422                            }
 423                        }
 424                    }
 425                }
 426
 427                app_state
 428                    .db
 429                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 430                    .await
 431                    .trace_err();
 432            }
 433            .instrument(span),
 434        );
 435        Ok(())
 436    }
 437
 438    pub fn teardown(&self) {
 439        self.peer.teardown();
 440        self.connection_pool.lock().reset();
 441        let _ = self.teardown.send(());
 442    }
 443
 444    #[cfg(test)]
 445    pub fn reset(&self, id: ServerId) {
 446        self.teardown();
 447        *self.id.lock() = id;
 448        self.peer.reset(id.0 as u32);
 449    }
 450
 451    #[cfg(test)]
 452    pub fn id(&self) -> ServerId {
 453        *self.id.lock()
 454    }
 455
 456    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 457    where
 458        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 459        Fut: 'static + Send + Future<Output = Result<()>>,
 460        M: EnvelopedMessage,
 461    {
 462        let prev_handler = self.handlers.insert(
 463            TypeId::of::<M>(),
 464            Box::new(move |envelope, session| {
 465                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 466                let span = info_span!(
 467                    "handle message",
 468                    payload_type = envelope.payload_type_name()
 469                );
 470                span.in_scope(|| {
 471                    tracing::info!(
 472                        payload_type = envelope.payload_type_name(),
 473                        "message received"
 474                    );
 475                });
 476                let start_time = Instant::now();
 477                let future = (handler)(*envelope, session);
 478                async move {
 479                    let result = future.await;
 480                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 481                    match result {
 482                        Err(error) => {
 483                            tracing::error!(%error, ?duration_ms, "error handling message")
 484                        }
 485                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 486                    }
 487                }
 488                .instrument(span)
 489                .boxed()
 490            }),
 491        );
 492        if prev_handler.is_some() {
 493            panic!("registered a handler for the same message twice");
 494        }
 495        self
 496    }
 497
 498    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 499    where
 500        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 501        Fut: 'static + Send + Future<Output = Result<()>>,
 502        M: EnvelopedMessage,
 503    {
 504        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 505        self
 506    }
 507
 508    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 509    where
 510        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 511        Fut: Send + Future<Output = Result<()>>,
 512        M: RequestMessage,
 513    {
 514        let handler = Arc::new(handler);
 515        self.add_handler(move |envelope, session| {
 516            let receipt = envelope.receipt();
 517            let handler = handler.clone();
 518            async move {
 519                let peer = session.peer.clone();
 520                let responded = Arc::new(AtomicBool::default());
 521                let response = Response {
 522                    peer: peer.clone(),
 523                    responded: responded.clone(),
 524                    receipt,
 525                };
 526                match (handler)(envelope.payload, response, session).await {
 527                    Ok(()) => {
 528                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 529                            Ok(())
 530                        } else {
 531                            Err(anyhow!("handler did not send a response"))?
 532                        }
 533                    }
 534                    Err(error) => {
 535                        peer.respond_with_error(
 536                            receipt,
 537                            proto::Error {
 538                                message: error.to_string(),
 539                            },
 540                        )?;
 541                        Err(error)
 542                    }
 543                }
 544            }
 545        })
 546    }
 547
 548    pub fn handle_connection(
 549        self: &Arc<Self>,
 550        connection: Connection,
 551        address: String,
 552        user: User,
 553        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 554        executor: Executor,
 555    ) -> impl Future<Output = Result<()>> {
 556        let this = self.clone();
 557        let user_id = user.id;
 558        let login = user.github_login;
 559        let span = info_span!("handle connection", %user_id, %login, %address);
 560        let mut teardown = self.teardown.subscribe();
 561        async move {
 562            let (connection_id, handle_io, mut incoming_rx) = this
 563                .peer
 564                .add_connection(connection, {
 565                    let executor = executor.clone();
 566                    move |duration| executor.sleep(duration)
 567                });
 568
 569            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 570            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 571            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 572
 573            if let Some(send_connection_id) = send_connection_id.take() {
 574                let _ = send_connection_id.send(connection_id);
 575            }
 576
 577            if !user.connected_once {
 578                this.peer.send(connection_id, proto::ShowContacts {})?;
 579                this.app_state.db.set_user_connected_once(user_id, true).await?;
 580            }
 581
 582            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 583                this.app_state.db.get_contacts(user_id),
 584                this.app_state.db.get_channels_for_user(user_id),
 585                this.app_state.db.get_channel_invites_for_user(user_id)
 586            ).await?;
 587
 588            {
 589                let mut pool = this.connection_pool.lock();
 590                pool.add_connection(connection_id, user_id, user.admin);
 591                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 592                this.peer.send(connection_id, build_initial_channels_update(
 593                    channels_for_user,
 594                    channel_invites
 595                ))?;
 596            }
 597
 598            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 599                this.peer.send(connection_id, incoming_call)?;
 600            }
 601
 602            let session = Session {
 603                user_id,
 604                connection_id,
 605                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 606                peer: this.peer.clone(),
 607                connection_pool: this.connection_pool.clone(),
 608                live_kit_client: this.app_state.live_kit_client.clone(),
 609                executor: executor.clone(),
 610            };
 611            update_user_contacts(user_id, &session).await?;
 612
 613            let handle_io = handle_io.fuse();
 614            futures::pin_mut!(handle_io);
 615
 616            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 617            // This prevents deadlocks when e.g., client A performs a request to client B and
 618            // client B performs a request to client A. If both clients stop processing further
 619            // messages until their respective request completes, they won't have a chance to
 620            // respond to the other client's request and cause a deadlock.
 621            //
 622            // This arrangement ensures we will attempt to process earlier messages first, but fall
 623            // back to processing messages arrived later in the spirit of making progress.
 624            let mut foreground_message_handlers = FuturesUnordered::new();
 625            let concurrent_handlers = Arc::new(Semaphore::new(256));
 626            loop {
 627                let next_message = async {
 628                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 629                    let message = incoming_rx.next().await;
 630                    (permit, message)
 631                }.fuse();
 632                futures::pin_mut!(next_message);
 633                futures::select_biased! {
 634                    _ = teardown.changed().fuse() => return Ok(()),
 635                    result = handle_io => {
 636                        if let Err(error) = result {
 637                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 638                        }
 639                        break;
 640                    }
 641                    _ = foreground_message_handlers.next() => {}
 642                    next_message = next_message => {
 643                        let (permit, message) = next_message;
 644                        if let Some(message) = message {
 645                            let type_name = message.payload_type_name();
 646                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 647                            let span_enter = span.enter();
 648                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 649                                let is_background = message.is_background();
 650                                let handle_message = (handler)(message, session.clone());
 651                                drop(span_enter);
 652
 653                                let handle_message = async move {
 654                                    handle_message.await;
 655                                    drop(permit);
 656                                }.instrument(span);
 657                                if is_background {
 658                                    executor.spawn_detached(handle_message);
 659                                } else {
 660                                    foreground_message_handlers.push(handle_message);
 661                                }
 662                            } else {
 663                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 664                            }
 665                        } else {
 666                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 667                            break;
 668                        }
 669                    }
 670                }
 671            }
 672
 673            drop(foreground_message_handlers);
 674            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 675            if let Err(error) = connection_lost(session, teardown, executor).await {
 676                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 677            }
 678
 679            Ok(())
 680        }.instrument(span)
 681    }
 682
 683    pub async fn invite_code_redeemed(
 684        self: &Arc<Self>,
 685        inviter_id: UserId,
 686        invitee_id: UserId,
 687    ) -> Result<()> {
 688        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 689            if let Some(code) = &user.invite_code {
 690                let pool = self.connection_pool.lock();
 691                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 692                for connection_id in pool.user_connection_ids(inviter_id) {
 693                    self.peer.send(
 694                        connection_id,
 695                        proto::UpdateContacts {
 696                            contacts: vec![invitee_contact.clone()],
 697                            ..Default::default()
 698                        },
 699                    )?;
 700                    self.peer.send(
 701                        connection_id,
 702                        proto::UpdateInviteInfo {
 703                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 704                            count: user.invite_count as u32,
 705                        },
 706                    )?;
 707                }
 708            }
 709        }
 710        Ok(())
 711    }
 712
 713    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 714        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 715            if let Some(invite_code) = &user.invite_code {
 716                let pool = self.connection_pool.lock();
 717                for connection_id in pool.user_connection_ids(user_id) {
 718                    self.peer.send(
 719                        connection_id,
 720                        proto::UpdateInviteInfo {
 721                            url: format!(
 722                                "{}{}",
 723                                self.app_state.config.invite_link_prefix, invite_code
 724                            ),
 725                            count: user.invite_count as u32,
 726                        },
 727                    )?;
 728                }
 729            }
 730        }
 731        Ok(())
 732    }
 733
 734    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 735        ServerSnapshot {
 736            connection_pool: ConnectionPoolGuard {
 737                guard: self.connection_pool.lock(),
 738                _not_send: PhantomData,
 739            },
 740            peer: &self.peer,
 741        }
 742    }
 743}
 744
 745impl<'a> Deref for ConnectionPoolGuard<'a> {
 746    type Target = ConnectionPool;
 747
 748    fn deref(&self) -> &Self::Target {
 749        &*self.guard
 750    }
 751}
 752
 753impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 754    fn deref_mut(&mut self) -> &mut Self::Target {
 755        &mut *self.guard
 756    }
 757}
 758
 759impl<'a> Drop for ConnectionPoolGuard<'a> {
 760    fn drop(&mut self) {
 761        #[cfg(test)]
 762        self.check_invariants();
 763    }
 764}
 765
 766fn broadcast<F>(
 767    sender_id: Option<ConnectionId>,
 768    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 769    mut f: F,
 770) where
 771    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 772{
 773    for receiver_id in receiver_ids {
 774        if Some(receiver_id) != sender_id {
 775            if let Err(error) = f(receiver_id) {
 776                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 777            }
 778        }
 779    }
 780}
 781
 782lazy_static! {
 783    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 784}
 785
 786pub struct ProtocolVersion(u32);
 787
 788impl Header for ProtocolVersion {
 789    fn name() -> &'static HeaderName {
 790        &ZED_PROTOCOL_VERSION
 791    }
 792
 793    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 794    where
 795        Self: Sized,
 796        I: Iterator<Item = &'i axum::http::HeaderValue>,
 797    {
 798        let version = values
 799            .next()
 800            .ok_or_else(axum::headers::Error::invalid)?
 801            .to_str()
 802            .map_err(|_| axum::headers::Error::invalid())?
 803            .parse()
 804            .map_err(|_| axum::headers::Error::invalid())?;
 805        Ok(Self(version))
 806    }
 807
 808    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 809        values.extend([self.0.to_string().parse().unwrap()]);
 810    }
 811}
 812
 813pub fn routes(server: Arc<Server>) -> Router<Body> {
 814    Router::new()
 815        .route("/rpc", get(handle_websocket_request))
 816        .layer(
 817            ServiceBuilder::new()
 818                .layer(Extension(server.app_state.clone()))
 819                .layer(middleware::from_fn(auth::validate_header)),
 820        )
 821        .route("/metrics", get(handle_metrics))
 822        .layer(Extension(server))
 823}
 824
 825pub async fn handle_websocket_request(
 826    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 827    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 828    Extension(server): Extension<Arc<Server>>,
 829    Extension(user): Extension<User>,
 830    ws: WebSocketUpgrade,
 831) -> axum::response::Response {
 832    if protocol_version != rpc::PROTOCOL_VERSION {
 833        return (
 834            StatusCode::UPGRADE_REQUIRED,
 835            "client must be upgraded".to_string(),
 836        )
 837            .into_response();
 838    }
 839    let socket_address = socket_address.to_string();
 840    ws.on_upgrade(move |socket| {
 841        use util::ResultExt;
 842        let socket = socket
 843            .map_ok(to_tungstenite_message)
 844            .err_into()
 845            .with(|message| async move { Ok(to_axum_message(message)) });
 846        let connection = Connection::new(Box::pin(socket));
 847        async move {
 848            server
 849                .handle_connection(connection, socket_address, user, None, Executor::Production)
 850                .await
 851                .log_err();
 852        }
 853    })
 854}
 855
 856pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 857    let connections = server
 858        .connection_pool
 859        .lock()
 860        .connections()
 861        .filter(|connection| !connection.admin)
 862        .count();
 863
 864    METRIC_CONNECTIONS.set(connections as _);
 865
 866    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 867    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 868
 869    let encoder = prometheus::TextEncoder::new();
 870    let metric_families = prometheus::gather();
 871    let encoded_metrics = encoder
 872        .encode_to_string(&metric_families)
 873        .map_err(|err| anyhow!("{}", err))?;
 874    Ok(encoded_metrics)
 875}
 876
 877#[instrument(err, skip(executor))]
 878async fn connection_lost(
 879    session: Session,
 880    mut teardown: watch::Receiver<()>,
 881    executor: Executor,
 882) -> Result<()> {
 883    session.peer.disconnect(session.connection_id);
 884    session
 885        .connection_pool()
 886        .await
 887        .remove_connection(session.connection_id)?;
 888
 889    session
 890        .db()
 891        .await
 892        .connection_lost(session.connection_id)
 893        .await
 894        .trace_err();
 895
 896    futures::select_biased! {
 897        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 898            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 899            leave_room_for_session(&session).await.trace_err();
 900            leave_channel_buffers_for_session(&session)
 901                .await
 902                .trace_err();
 903
 904            if !session
 905                .connection_pool()
 906                .await
 907                .is_user_online(session.user_id)
 908            {
 909                let db = session.db().await;
 910                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 911                    room_updated(&room, &session.peer);
 912                }
 913            }
 914
 915            update_user_contacts(session.user_id, &session).await?;
 916        }
 917        _ = teardown.changed().fuse() => {}
 918    }
 919
 920    Ok(())
 921}
 922
 923async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 924    response.send(proto::Ack {})?;
 925    Ok(())
 926}
 927
 928async fn create_room(
 929    _request: proto::CreateRoom,
 930    response: Response<proto::CreateRoom>,
 931    session: Session,
 932) -> Result<()> {
 933    let live_kit_room = nanoid::nanoid!(30);
 934
 935    let live_kit_connection_info = {
 936        let live_kit_room = live_kit_room.clone();
 937        let live_kit = session.live_kit_client.as_ref();
 938
 939        util::async_iife!({
 940            let live_kit = live_kit?;
 941
 942            let token = live_kit
 943                .room_token(&live_kit_room, &session.user_id.to_string())
 944                .trace_err()?;
 945
 946            Some(proto::LiveKitConnectionInfo {
 947                server_url: live_kit.url().into(),
 948                token,
 949            })
 950        })
 951    }
 952    .await;
 953
 954    let room = session
 955        .db()
 956        .await
 957        .create_room(
 958            session.user_id,
 959            session.connection_id,
 960            &live_kit_room,
 961            RELEASE_CHANNEL_NAME.as_str(),
 962        )
 963        .await?;
 964
 965    response.send(proto::CreateRoomResponse {
 966        room: Some(room.clone()),
 967        live_kit_connection_info,
 968    })?;
 969
 970    update_user_contacts(session.user_id, &session).await?;
 971    Ok(())
 972}
 973
 974async fn join_room(
 975    request: proto::JoinRoom,
 976    response: Response<proto::JoinRoom>,
 977    session: Session,
 978) -> Result<()> {
 979    let room_id = RoomId::from_proto(request.id);
 980
 981    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 982
 983    if let Some(channel_id) = channel_id {
 984        return join_channel_internal(channel_id, Box::new(response), session).await;
 985    }
 986
 987    let joined_room = {
 988        let room = session
 989            .db()
 990            .await
 991            .join_room(
 992                room_id,
 993                session.user_id,
 994                session.connection_id,
 995                RELEASE_CHANNEL_NAME.as_str(),
 996            )
 997            .await?;
 998        room_updated(&room.room, &session.peer);
 999        room.into_inner()
1000    };
1001
1002    for connection_id in session
1003        .connection_pool()
1004        .await
1005        .user_connection_ids(session.user_id)
1006    {
1007        session
1008            .peer
1009            .send(
1010                connection_id,
1011                proto::CallCanceled {
1012                    room_id: room_id.to_proto(),
1013                },
1014            )
1015            .trace_err();
1016    }
1017
1018    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1019        if let Some(token) = live_kit
1020            .room_token(
1021                &joined_room.room.live_kit_room,
1022                &session.user_id.to_string(),
1023            )
1024            .trace_err()
1025        {
1026            Some(proto::LiveKitConnectionInfo {
1027                server_url: live_kit.url().into(),
1028                token,
1029            })
1030        } else {
1031            None
1032        }
1033    } else {
1034        None
1035    };
1036
1037    response.send(proto::JoinRoomResponse {
1038        room: Some(joined_room.room),
1039        channel_id: None,
1040        live_kit_connection_info,
1041    })?;
1042
1043    update_user_contacts(session.user_id, &session).await?;
1044    Ok(())
1045}
1046
1047async fn rejoin_room(
1048    request: proto::RejoinRoom,
1049    response: Response<proto::RejoinRoom>,
1050    session: Session,
1051) -> Result<()> {
1052    let room;
1053    let channel_id;
1054    let channel_members;
1055    {
1056        let mut rejoined_room = session
1057            .db()
1058            .await
1059            .rejoin_room(request, session.user_id, session.connection_id)
1060            .await?;
1061
1062        response.send(proto::RejoinRoomResponse {
1063            room: Some(rejoined_room.room.clone()),
1064            reshared_projects: rejoined_room
1065                .reshared_projects
1066                .iter()
1067                .map(|project| proto::ResharedProject {
1068                    id: project.id.to_proto(),
1069                    collaborators: project
1070                        .collaborators
1071                        .iter()
1072                        .map(|collaborator| collaborator.to_proto())
1073                        .collect(),
1074                })
1075                .collect(),
1076            rejoined_projects: rejoined_room
1077                .rejoined_projects
1078                .iter()
1079                .map(|rejoined_project| proto::RejoinedProject {
1080                    id: rejoined_project.id.to_proto(),
1081                    worktrees: rejoined_project
1082                        .worktrees
1083                        .iter()
1084                        .map(|worktree| proto::WorktreeMetadata {
1085                            id: worktree.id,
1086                            root_name: worktree.root_name.clone(),
1087                            visible: worktree.visible,
1088                            abs_path: worktree.abs_path.clone(),
1089                        })
1090                        .collect(),
1091                    collaborators: rejoined_project
1092                        .collaborators
1093                        .iter()
1094                        .map(|collaborator| collaborator.to_proto())
1095                        .collect(),
1096                    language_servers: rejoined_project.language_servers.clone(),
1097                })
1098                .collect(),
1099        })?;
1100        room_updated(&rejoined_room.room, &session.peer);
1101
1102        for project in &rejoined_room.reshared_projects {
1103            for collaborator in &project.collaborators {
1104                session
1105                    .peer
1106                    .send(
1107                        collaborator.connection_id,
1108                        proto::UpdateProjectCollaborator {
1109                            project_id: project.id.to_proto(),
1110                            old_peer_id: Some(project.old_connection_id.into()),
1111                            new_peer_id: Some(session.connection_id.into()),
1112                        },
1113                    )
1114                    .trace_err();
1115            }
1116
1117            broadcast(
1118                Some(session.connection_id),
1119                project
1120                    .collaborators
1121                    .iter()
1122                    .map(|collaborator| collaborator.connection_id),
1123                |connection_id| {
1124                    session.peer.forward_send(
1125                        session.connection_id,
1126                        connection_id,
1127                        proto::UpdateProject {
1128                            project_id: project.id.to_proto(),
1129                            worktrees: project.worktrees.clone(),
1130                        },
1131                    )
1132                },
1133            );
1134        }
1135
1136        for project in &rejoined_room.rejoined_projects {
1137            for collaborator in &project.collaborators {
1138                session
1139                    .peer
1140                    .send(
1141                        collaborator.connection_id,
1142                        proto::UpdateProjectCollaborator {
1143                            project_id: project.id.to_proto(),
1144                            old_peer_id: Some(project.old_connection_id.into()),
1145                            new_peer_id: Some(session.connection_id.into()),
1146                        },
1147                    )
1148                    .trace_err();
1149            }
1150        }
1151
1152        for project in &mut rejoined_room.rejoined_projects {
1153            for worktree in mem::take(&mut project.worktrees) {
1154                #[cfg(any(test, feature = "test-support"))]
1155                const MAX_CHUNK_SIZE: usize = 2;
1156                #[cfg(not(any(test, feature = "test-support")))]
1157                const MAX_CHUNK_SIZE: usize = 256;
1158
1159                // Stream this worktree's entries.
1160                let message = proto::UpdateWorktree {
1161                    project_id: project.id.to_proto(),
1162                    worktree_id: worktree.id,
1163                    abs_path: worktree.abs_path.clone(),
1164                    root_name: worktree.root_name,
1165                    updated_entries: worktree.updated_entries,
1166                    removed_entries: worktree.removed_entries,
1167                    scan_id: worktree.scan_id,
1168                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1169                    updated_repositories: worktree.updated_repositories,
1170                    removed_repositories: worktree.removed_repositories,
1171                };
1172                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1173                    session.peer.send(session.connection_id, update.clone())?;
1174                }
1175
1176                // Stream this worktree's diagnostics.
1177                for summary in worktree.diagnostic_summaries {
1178                    session.peer.send(
1179                        session.connection_id,
1180                        proto::UpdateDiagnosticSummary {
1181                            project_id: project.id.to_proto(),
1182                            worktree_id: worktree.id,
1183                            summary: Some(summary),
1184                        },
1185                    )?;
1186                }
1187
1188                for settings_file in worktree.settings_files {
1189                    session.peer.send(
1190                        session.connection_id,
1191                        proto::UpdateWorktreeSettings {
1192                            project_id: project.id.to_proto(),
1193                            worktree_id: worktree.id,
1194                            path: settings_file.path,
1195                            content: Some(settings_file.content),
1196                        },
1197                    )?;
1198                }
1199            }
1200
1201            for language_server in &project.language_servers {
1202                session.peer.send(
1203                    session.connection_id,
1204                    proto::UpdateLanguageServer {
1205                        project_id: project.id.to_proto(),
1206                        language_server_id: language_server.id,
1207                        variant: Some(
1208                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1209                                proto::LspDiskBasedDiagnosticsUpdated {},
1210                            ),
1211                        ),
1212                    },
1213                )?;
1214            }
1215        }
1216
1217        let rejoined_room = rejoined_room.into_inner();
1218
1219        room = rejoined_room.room;
1220        channel_id = rejoined_room.channel_id;
1221        channel_members = rejoined_room.channel_members;
1222    }
1223
1224    if let Some(channel_id) = channel_id {
1225        channel_updated(
1226            channel_id,
1227            &room,
1228            &channel_members,
1229            &session.peer,
1230            &*session.connection_pool().await,
1231        );
1232    }
1233
1234    update_user_contacts(session.user_id, &session).await?;
1235    Ok(())
1236}
1237
1238async fn leave_room(
1239    _: proto::LeaveRoom,
1240    response: Response<proto::LeaveRoom>,
1241    session: Session,
1242) -> Result<()> {
1243    leave_room_for_session(&session).await?;
1244    response.send(proto::Ack {})?;
1245    Ok(())
1246}
1247
1248async fn call(
1249    request: proto::Call,
1250    response: Response<proto::Call>,
1251    session: Session,
1252) -> Result<()> {
1253    let room_id = RoomId::from_proto(request.room_id);
1254    let calling_user_id = session.user_id;
1255    let calling_connection_id = session.connection_id;
1256    let called_user_id = UserId::from_proto(request.called_user_id);
1257    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1258    if !session
1259        .db()
1260        .await
1261        .has_contact(calling_user_id, called_user_id)
1262        .await?
1263    {
1264        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1265    }
1266
1267    let incoming_call = {
1268        let (room, incoming_call) = &mut *session
1269            .db()
1270            .await
1271            .call(
1272                room_id,
1273                calling_user_id,
1274                calling_connection_id,
1275                called_user_id,
1276                initial_project_id,
1277            )
1278            .await?;
1279        room_updated(&room, &session.peer);
1280        mem::take(incoming_call)
1281    };
1282    update_user_contacts(called_user_id, &session).await?;
1283
1284    let mut calls = session
1285        .connection_pool()
1286        .await
1287        .user_connection_ids(called_user_id)
1288        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1289        .collect::<FuturesUnordered<_>>();
1290
1291    while let Some(call_response) = calls.next().await {
1292        match call_response.as_ref() {
1293            Ok(_) => {
1294                response.send(proto::Ack {})?;
1295                return Ok(());
1296            }
1297            Err(_) => {
1298                call_response.trace_err();
1299            }
1300        }
1301    }
1302
1303    {
1304        let room = session
1305            .db()
1306            .await
1307            .call_failed(room_id, called_user_id)
1308            .await?;
1309        room_updated(&room, &session.peer);
1310    }
1311    update_user_contacts(called_user_id, &session).await?;
1312
1313    Err(anyhow!("failed to ring user"))?
1314}
1315
1316async fn cancel_call(
1317    request: proto::CancelCall,
1318    response: Response<proto::CancelCall>,
1319    session: Session,
1320) -> Result<()> {
1321    let called_user_id = UserId::from_proto(request.called_user_id);
1322    let room_id = RoomId::from_proto(request.room_id);
1323    {
1324        let room = session
1325            .db()
1326            .await
1327            .cancel_call(room_id, session.connection_id, called_user_id)
1328            .await?;
1329        room_updated(&room, &session.peer);
1330    }
1331
1332    for connection_id in session
1333        .connection_pool()
1334        .await
1335        .user_connection_ids(called_user_id)
1336    {
1337        session
1338            .peer
1339            .send(
1340                connection_id,
1341                proto::CallCanceled {
1342                    room_id: room_id.to_proto(),
1343                },
1344            )
1345            .trace_err();
1346    }
1347    response.send(proto::Ack {})?;
1348
1349    update_user_contacts(called_user_id, &session).await?;
1350    Ok(())
1351}
1352
1353async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1354    let room_id = RoomId::from_proto(message.room_id);
1355    {
1356        let room = session
1357            .db()
1358            .await
1359            .decline_call(Some(room_id), session.user_id)
1360            .await?
1361            .ok_or_else(|| anyhow!("failed to decline call"))?;
1362        room_updated(&room, &session.peer);
1363    }
1364
1365    for connection_id in session
1366        .connection_pool()
1367        .await
1368        .user_connection_ids(session.user_id)
1369    {
1370        session
1371            .peer
1372            .send(
1373                connection_id,
1374                proto::CallCanceled {
1375                    room_id: room_id.to_proto(),
1376                },
1377            )
1378            .trace_err();
1379    }
1380    update_user_contacts(session.user_id, &session).await?;
1381    Ok(())
1382}
1383
1384async fn update_participant_location(
1385    request: proto::UpdateParticipantLocation,
1386    response: Response<proto::UpdateParticipantLocation>,
1387    session: Session,
1388) -> Result<()> {
1389    let room_id = RoomId::from_proto(request.room_id);
1390    let location = request
1391        .location
1392        .ok_or_else(|| anyhow!("invalid location"))?;
1393
1394    let db = session.db().await;
1395    let room = db
1396        .update_room_participant_location(room_id, session.connection_id, location)
1397        .await?;
1398
1399    room_updated(&room, &session.peer);
1400    response.send(proto::Ack {})?;
1401    Ok(())
1402}
1403
1404async fn share_project(
1405    request: proto::ShareProject,
1406    response: Response<proto::ShareProject>,
1407    session: Session,
1408) -> Result<()> {
1409    let (project_id, room) = &*session
1410        .db()
1411        .await
1412        .share_project(
1413            RoomId::from_proto(request.room_id),
1414            session.connection_id,
1415            &request.worktrees,
1416        )
1417        .await?;
1418    response.send(proto::ShareProjectResponse {
1419        project_id: project_id.to_proto(),
1420    })?;
1421    room_updated(&room, &session.peer);
1422
1423    Ok(())
1424}
1425
1426async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1427    let project_id = ProjectId::from_proto(message.project_id);
1428
1429    let (room, guest_connection_ids) = &*session
1430        .db()
1431        .await
1432        .unshare_project(project_id, session.connection_id)
1433        .await?;
1434
1435    broadcast(
1436        Some(session.connection_id),
1437        guest_connection_ids.iter().copied(),
1438        |conn_id| session.peer.send(conn_id, message.clone()),
1439    );
1440    room_updated(&room, &session.peer);
1441
1442    Ok(())
1443}
1444
1445async fn join_project(
1446    request: proto::JoinProject,
1447    response: Response<proto::JoinProject>,
1448    session: Session,
1449) -> Result<()> {
1450    let project_id = ProjectId::from_proto(request.project_id);
1451    let guest_user_id = session.user_id;
1452
1453    tracing::info!(%project_id, "join project");
1454
1455    let (project, replica_id) = &mut *session
1456        .db()
1457        .await
1458        .join_project(project_id, session.connection_id)
1459        .await?;
1460
1461    let collaborators = project
1462        .collaborators
1463        .iter()
1464        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1465        .map(|collaborator| collaborator.to_proto())
1466        .collect::<Vec<_>>();
1467
1468    let worktrees = project
1469        .worktrees
1470        .iter()
1471        .map(|(id, worktree)| proto::WorktreeMetadata {
1472            id: *id,
1473            root_name: worktree.root_name.clone(),
1474            visible: worktree.visible,
1475            abs_path: worktree.abs_path.clone(),
1476        })
1477        .collect::<Vec<_>>();
1478
1479    for collaborator in &collaborators {
1480        session
1481            .peer
1482            .send(
1483                collaborator.peer_id.unwrap().into(),
1484                proto::AddProjectCollaborator {
1485                    project_id: project_id.to_proto(),
1486                    collaborator: Some(proto::Collaborator {
1487                        peer_id: Some(session.connection_id.into()),
1488                        replica_id: replica_id.0 as u32,
1489                        user_id: guest_user_id.to_proto(),
1490                    }),
1491                },
1492            )
1493            .trace_err();
1494    }
1495
1496    // First, we send the metadata associated with each worktree.
1497    response.send(proto::JoinProjectResponse {
1498        worktrees: worktrees.clone(),
1499        replica_id: replica_id.0 as u32,
1500        collaborators: collaborators.clone(),
1501        language_servers: project.language_servers.clone(),
1502    })?;
1503
1504    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1505        #[cfg(any(test, feature = "test-support"))]
1506        const MAX_CHUNK_SIZE: usize = 2;
1507        #[cfg(not(any(test, feature = "test-support")))]
1508        const MAX_CHUNK_SIZE: usize = 256;
1509
1510        // Stream this worktree's entries.
1511        let message = proto::UpdateWorktree {
1512            project_id: project_id.to_proto(),
1513            worktree_id,
1514            abs_path: worktree.abs_path.clone(),
1515            root_name: worktree.root_name,
1516            updated_entries: worktree.entries,
1517            removed_entries: Default::default(),
1518            scan_id: worktree.scan_id,
1519            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1520            updated_repositories: worktree.repository_entries.into_values().collect(),
1521            removed_repositories: Default::default(),
1522        };
1523        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1524            session.peer.send(session.connection_id, update.clone())?;
1525        }
1526
1527        // Stream this worktree's diagnostics.
1528        for summary in worktree.diagnostic_summaries {
1529            session.peer.send(
1530                session.connection_id,
1531                proto::UpdateDiagnosticSummary {
1532                    project_id: project_id.to_proto(),
1533                    worktree_id: worktree.id,
1534                    summary: Some(summary),
1535                },
1536            )?;
1537        }
1538
1539        for settings_file in worktree.settings_files {
1540            session.peer.send(
1541                session.connection_id,
1542                proto::UpdateWorktreeSettings {
1543                    project_id: project_id.to_proto(),
1544                    worktree_id: worktree.id,
1545                    path: settings_file.path,
1546                    content: Some(settings_file.content),
1547                },
1548            )?;
1549        }
1550    }
1551
1552    for language_server in &project.language_servers {
1553        session.peer.send(
1554            session.connection_id,
1555            proto::UpdateLanguageServer {
1556                project_id: project_id.to_proto(),
1557                language_server_id: language_server.id,
1558                variant: Some(
1559                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1560                        proto::LspDiskBasedDiagnosticsUpdated {},
1561                    ),
1562                ),
1563            },
1564        )?;
1565    }
1566
1567    Ok(())
1568}
1569
1570async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1571    let sender_id = session.connection_id;
1572    let project_id = ProjectId::from_proto(request.project_id);
1573
1574    let (room, project) = &*session
1575        .db()
1576        .await
1577        .leave_project(project_id, sender_id)
1578        .await?;
1579    tracing::info!(
1580        %project_id,
1581        host_user_id = %project.host_user_id,
1582        host_connection_id = %project.host_connection_id,
1583        "leave project"
1584    );
1585
1586    project_left(&project, &session);
1587    room_updated(&room, &session.peer);
1588
1589    Ok(())
1590}
1591
1592async fn update_project(
1593    request: proto::UpdateProject,
1594    response: Response<proto::UpdateProject>,
1595    session: Session,
1596) -> Result<()> {
1597    let project_id = ProjectId::from_proto(request.project_id);
1598    let (room, guest_connection_ids) = &*session
1599        .db()
1600        .await
1601        .update_project(project_id, session.connection_id, &request.worktrees)
1602        .await?;
1603    broadcast(
1604        Some(session.connection_id),
1605        guest_connection_ids.iter().copied(),
1606        |connection_id| {
1607            session
1608                .peer
1609                .forward_send(session.connection_id, connection_id, request.clone())
1610        },
1611    );
1612    room_updated(&room, &session.peer);
1613    response.send(proto::Ack {})?;
1614
1615    Ok(())
1616}
1617
1618async fn update_worktree(
1619    request: proto::UpdateWorktree,
1620    response: Response<proto::UpdateWorktree>,
1621    session: Session,
1622) -> Result<()> {
1623    let guest_connection_ids = session
1624        .db()
1625        .await
1626        .update_worktree(&request, session.connection_id)
1627        .await?;
1628
1629    broadcast(
1630        Some(session.connection_id),
1631        guest_connection_ids.iter().copied(),
1632        |connection_id| {
1633            session
1634                .peer
1635                .forward_send(session.connection_id, connection_id, request.clone())
1636        },
1637    );
1638    response.send(proto::Ack {})?;
1639    Ok(())
1640}
1641
1642async fn update_diagnostic_summary(
1643    message: proto::UpdateDiagnosticSummary,
1644    session: Session,
1645) -> Result<()> {
1646    let guest_connection_ids = session
1647        .db()
1648        .await
1649        .update_diagnostic_summary(&message, session.connection_id)
1650        .await?;
1651
1652    broadcast(
1653        Some(session.connection_id),
1654        guest_connection_ids.iter().copied(),
1655        |connection_id| {
1656            session
1657                .peer
1658                .forward_send(session.connection_id, connection_id, message.clone())
1659        },
1660    );
1661
1662    Ok(())
1663}
1664
1665async fn update_worktree_settings(
1666    message: proto::UpdateWorktreeSettings,
1667    session: Session,
1668) -> Result<()> {
1669    let guest_connection_ids = session
1670        .db()
1671        .await
1672        .update_worktree_settings(&message, session.connection_id)
1673        .await?;
1674
1675    broadcast(
1676        Some(session.connection_id),
1677        guest_connection_ids.iter().copied(),
1678        |connection_id| {
1679            session
1680                .peer
1681                .forward_send(session.connection_id, connection_id, message.clone())
1682        },
1683    );
1684
1685    Ok(())
1686}
1687
1688async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1689    broadcast_project_message(request.project_id, request, session).await
1690}
1691
1692async fn start_language_server(
1693    request: proto::StartLanguageServer,
1694    session: Session,
1695) -> Result<()> {
1696    let guest_connection_ids = session
1697        .db()
1698        .await
1699        .start_language_server(&request, session.connection_id)
1700        .await?;
1701
1702    broadcast(
1703        Some(session.connection_id),
1704        guest_connection_ids.iter().copied(),
1705        |connection_id| {
1706            session
1707                .peer
1708                .forward_send(session.connection_id, connection_id, request.clone())
1709        },
1710    );
1711    Ok(())
1712}
1713
1714async fn update_language_server(
1715    request: proto::UpdateLanguageServer,
1716    session: Session,
1717) -> Result<()> {
1718    session.executor.record_backtrace();
1719    let project_id = ProjectId::from_proto(request.project_id);
1720    let project_connection_ids = session
1721        .db()
1722        .await
1723        .project_connection_ids(project_id, session.connection_id)
1724        .await?;
1725    broadcast(
1726        Some(session.connection_id),
1727        project_connection_ids.iter().copied(),
1728        |connection_id| {
1729            session
1730                .peer
1731                .forward_send(session.connection_id, connection_id, request.clone())
1732        },
1733    );
1734    Ok(())
1735}
1736
1737async fn forward_project_request<T>(
1738    request: T,
1739    response: Response<T>,
1740    session: Session,
1741) -> Result<()>
1742where
1743    T: EntityMessage + RequestMessage,
1744{
1745    session.executor.record_backtrace();
1746    let project_id = ProjectId::from_proto(request.remote_entity_id());
1747    let host_connection_id = {
1748        let collaborators = session
1749            .db()
1750            .await
1751            .project_collaborators(project_id, session.connection_id)
1752            .await?;
1753        collaborators
1754            .iter()
1755            .find(|collaborator| collaborator.is_host)
1756            .ok_or_else(|| anyhow!("host not found"))?
1757            .connection_id
1758    };
1759
1760    let payload = session
1761        .peer
1762        .forward_request(session.connection_id, host_connection_id, request)
1763        .await?;
1764
1765    response.send(payload)?;
1766    Ok(())
1767}
1768
1769async fn create_buffer_for_peer(
1770    request: proto::CreateBufferForPeer,
1771    session: Session,
1772) -> Result<()> {
1773    session.executor.record_backtrace();
1774    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1775    session
1776        .peer
1777        .forward_send(session.connection_id, peer_id.into(), request)?;
1778    Ok(())
1779}
1780
1781async fn update_buffer(
1782    request: proto::UpdateBuffer,
1783    response: Response<proto::UpdateBuffer>,
1784    session: Session,
1785) -> Result<()> {
1786    session.executor.record_backtrace();
1787    let project_id = ProjectId::from_proto(request.project_id);
1788    let mut guest_connection_ids;
1789    let mut host_connection_id = None;
1790    {
1791        let collaborators = session
1792            .db()
1793            .await
1794            .project_collaborators(project_id, session.connection_id)
1795            .await?;
1796        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1797        for collaborator in collaborators.iter() {
1798            if collaborator.is_host {
1799                host_connection_id = Some(collaborator.connection_id);
1800            } else {
1801                guest_connection_ids.push(collaborator.connection_id);
1802            }
1803        }
1804    }
1805    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1806
1807    session.executor.record_backtrace();
1808    broadcast(
1809        Some(session.connection_id),
1810        guest_connection_ids,
1811        |connection_id| {
1812            session
1813                .peer
1814                .forward_send(session.connection_id, connection_id, request.clone())
1815        },
1816    );
1817    if host_connection_id != session.connection_id {
1818        session
1819            .peer
1820            .forward_request(session.connection_id, host_connection_id, request.clone())
1821            .await?;
1822    }
1823
1824    response.send(proto::Ack {})?;
1825    Ok(())
1826}
1827
1828async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1829    let project_id = ProjectId::from_proto(request.project_id);
1830    let project_connection_ids = session
1831        .db()
1832        .await
1833        .project_connection_ids(project_id, session.connection_id)
1834        .await?;
1835
1836    broadcast(
1837        Some(session.connection_id),
1838        project_connection_ids.iter().copied(),
1839        |connection_id| {
1840            session
1841                .peer
1842                .forward_send(session.connection_id, connection_id, request.clone())
1843        },
1844    );
1845    Ok(())
1846}
1847
1848async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1849    let project_id = ProjectId::from_proto(request.project_id);
1850    let project_connection_ids = session
1851        .db()
1852        .await
1853        .project_connection_ids(project_id, session.connection_id)
1854        .await?;
1855    broadcast(
1856        Some(session.connection_id),
1857        project_connection_ids.iter().copied(),
1858        |connection_id| {
1859            session
1860                .peer
1861                .forward_send(session.connection_id, connection_id, request.clone())
1862        },
1863    );
1864    Ok(())
1865}
1866
1867async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1868    broadcast_project_message(request.project_id, request, session).await
1869}
1870
1871async fn broadcast_project_message<T: EnvelopedMessage>(
1872    project_id: u64,
1873    request: T,
1874    session: Session,
1875) -> Result<()> {
1876    let project_id = ProjectId::from_proto(project_id);
1877    let project_connection_ids = session
1878        .db()
1879        .await
1880        .project_connection_ids(project_id, session.connection_id)
1881        .await?;
1882    broadcast(
1883        Some(session.connection_id),
1884        project_connection_ids.iter().copied(),
1885        |connection_id| {
1886            session
1887                .peer
1888                .forward_send(session.connection_id, connection_id, request.clone())
1889        },
1890    );
1891    Ok(())
1892}
1893
1894async fn follow(
1895    request: proto::Follow,
1896    response: Response<proto::Follow>,
1897    session: Session,
1898) -> Result<()> {
1899    let room_id = RoomId::from_proto(request.room_id);
1900    let project_id = request.project_id.map(ProjectId::from_proto);
1901    let leader_id = request
1902        .leader_id
1903        .ok_or_else(|| anyhow!("invalid leader id"))?
1904        .into();
1905    let follower_id = session.connection_id;
1906
1907    session
1908        .db()
1909        .await
1910        .check_room_participants(room_id, leader_id, session.connection_id)
1911        .await?;
1912
1913    let response_payload = session
1914        .peer
1915        .forward_request(session.connection_id, leader_id, request)
1916        .await?;
1917    response.send(response_payload)?;
1918
1919    if let Some(project_id) = project_id {
1920        let room = session
1921            .db()
1922            .await
1923            .follow(room_id, project_id, leader_id, follower_id)
1924            .await?;
1925        room_updated(&room, &session.peer);
1926    }
1927
1928    Ok(())
1929}
1930
1931async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1932    let room_id = RoomId::from_proto(request.room_id);
1933    let project_id = request.project_id.map(ProjectId::from_proto);
1934    let leader_id = request
1935        .leader_id
1936        .ok_or_else(|| anyhow!("invalid leader id"))?
1937        .into();
1938    let follower_id = session.connection_id;
1939
1940    session
1941        .db()
1942        .await
1943        .check_room_participants(room_id, leader_id, session.connection_id)
1944        .await?;
1945
1946    session
1947        .peer
1948        .forward_send(session.connection_id, leader_id, request)?;
1949
1950    if let Some(project_id) = project_id {
1951        let room = session
1952            .db()
1953            .await
1954            .unfollow(room_id, project_id, leader_id, follower_id)
1955            .await?;
1956        room_updated(&room, &session.peer);
1957    }
1958
1959    Ok(())
1960}
1961
1962async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1963    let room_id = RoomId::from_proto(request.room_id);
1964    let database = session.db.lock().await;
1965
1966    let connection_ids = if let Some(project_id) = request.project_id {
1967        let project_id = ProjectId::from_proto(project_id);
1968        database
1969            .project_connection_ids(project_id, session.connection_id)
1970            .await?
1971    } else {
1972        database
1973            .room_connection_ids(room_id, session.connection_id)
1974            .await?
1975    };
1976
1977    // For now, don't send view update messages back to that view's current leader.
1978    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1979        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1980        _ => None,
1981    });
1982
1983    for follower_peer_id in request.follower_ids.iter().copied() {
1984        let follower_connection_id = follower_peer_id.into();
1985        if Some(follower_peer_id) != connection_id_to_omit
1986            && connection_ids.contains(&follower_connection_id)
1987        {
1988            session.peer.forward_send(
1989                session.connection_id,
1990                follower_connection_id,
1991                request.clone(),
1992            )?;
1993        }
1994    }
1995    Ok(())
1996}
1997
1998async fn get_users(
1999    request: proto::GetUsers,
2000    response: Response<proto::GetUsers>,
2001    session: Session,
2002) -> Result<()> {
2003    let user_ids = request
2004        .user_ids
2005        .into_iter()
2006        .map(UserId::from_proto)
2007        .collect();
2008    let users = session
2009        .db()
2010        .await
2011        .get_users_by_ids(user_ids)
2012        .await?
2013        .into_iter()
2014        .map(|user| proto::User {
2015            id: user.id.to_proto(),
2016            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2017            github_login: user.github_login,
2018        })
2019        .collect();
2020    response.send(proto::UsersResponse { users })?;
2021    Ok(())
2022}
2023
2024async fn fuzzy_search_users(
2025    request: proto::FuzzySearchUsers,
2026    response: Response<proto::FuzzySearchUsers>,
2027    session: Session,
2028) -> Result<()> {
2029    let query = request.query;
2030    let users = match query.len() {
2031        0 => vec![],
2032        1 | 2 => session
2033            .db()
2034            .await
2035            .get_user_by_github_login(&query)
2036            .await?
2037            .into_iter()
2038            .collect(),
2039        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2040    };
2041    let users = users
2042        .into_iter()
2043        .filter(|user| user.id != session.user_id)
2044        .map(|user| proto::User {
2045            id: user.id.to_proto(),
2046            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2047            github_login: user.github_login,
2048        })
2049        .collect();
2050    response.send(proto::UsersResponse { users })?;
2051    Ok(())
2052}
2053
2054async fn request_contact(
2055    request: proto::RequestContact,
2056    response: Response<proto::RequestContact>,
2057    session: Session,
2058) -> Result<()> {
2059    let requester_id = session.user_id;
2060    let responder_id = UserId::from_proto(request.responder_id);
2061    if requester_id == responder_id {
2062        return Err(anyhow!("cannot add yourself as a contact"))?;
2063    }
2064
2065    session
2066        .db()
2067        .await
2068        .send_contact_request(requester_id, responder_id)
2069        .await?;
2070
2071    // Update outgoing contact requests of requester
2072    let mut update = proto::UpdateContacts::default();
2073    update.outgoing_requests.push(responder_id.to_proto());
2074    for connection_id in session
2075        .connection_pool()
2076        .await
2077        .user_connection_ids(requester_id)
2078    {
2079        session.peer.send(connection_id, update.clone())?;
2080    }
2081
2082    // Update incoming contact requests of responder
2083    let mut update = proto::UpdateContacts::default();
2084    update
2085        .incoming_requests
2086        .push(proto::IncomingContactRequest {
2087            requester_id: requester_id.to_proto(),
2088            should_notify: true,
2089        });
2090    for connection_id in session
2091        .connection_pool()
2092        .await
2093        .user_connection_ids(responder_id)
2094    {
2095        session.peer.send(connection_id, update.clone())?;
2096    }
2097
2098    response.send(proto::Ack {})?;
2099    Ok(())
2100}
2101
2102async fn respond_to_contact_request(
2103    request: proto::RespondToContactRequest,
2104    response: Response<proto::RespondToContactRequest>,
2105    session: Session,
2106) -> Result<()> {
2107    let responder_id = session.user_id;
2108    let requester_id = UserId::from_proto(request.requester_id);
2109    let db = session.db().await;
2110    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2111        db.dismiss_contact_notification(responder_id, requester_id)
2112            .await?;
2113    } else {
2114        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2115
2116        db.respond_to_contact_request(responder_id, requester_id, accept)
2117            .await?;
2118        let requester_busy = db.is_user_busy(requester_id).await?;
2119        let responder_busy = db.is_user_busy(responder_id).await?;
2120
2121        let pool = session.connection_pool().await;
2122        // Update responder with new contact
2123        let mut update = proto::UpdateContacts::default();
2124        if accept {
2125            update
2126                .contacts
2127                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2128        }
2129        update
2130            .remove_incoming_requests
2131            .push(requester_id.to_proto());
2132        for connection_id in pool.user_connection_ids(responder_id) {
2133            session.peer.send(connection_id, update.clone())?;
2134        }
2135
2136        // Update requester with new contact
2137        let mut update = proto::UpdateContacts::default();
2138        if accept {
2139            update
2140                .contacts
2141                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2142        }
2143        update
2144            .remove_outgoing_requests
2145            .push(responder_id.to_proto());
2146        for connection_id in pool.user_connection_ids(requester_id) {
2147            session.peer.send(connection_id, update.clone())?;
2148        }
2149    }
2150
2151    response.send(proto::Ack {})?;
2152    Ok(())
2153}
2154
2155async fn remove_contact(
2156    request: proto::RemoveContact,
2157    response: Response<proto::RemoveContact>,
2158    session: Session,
2159) -> Result<()> {
2160    let requester_id = session.user_id;
2161    let responder_id = UserId::from_proto(request.user_id);
2162    let db = session.db().await;
2163    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2164
2165    let pool = session.connection_pool().await;
2166    // Update outgoing contact requests of requester
2167    let mut update = proto::UpdateContacts::default();
2168    if contact_accepted {
2169        update.remove_contacts.push(responder_id.to_proto());
2170    } else {
2171        update
2172            .remove_outgoing_requests
2173            .push(responder_id.to_proto());
2174    }
2175    for connection_id in pool.user_connection_ids(requester_id) {
2176        session.peer.send(connection_id, update.clone())?;
2177    }
2178
2179    // Update incoming contact requests of responder
2180    let mut update = proto::UpdateContacts::default();
2181    if contact_accepted {
2182        update.remove_contacts.push(requester_id.to_proto());
2183    } else {
2184        update
2185            .remove_incoming_requests
2186            .push(requester_id.to_proto());
2187    }
2188    for connection_id in pool.user_connection_ids(responder_id) {
2189        session.peer.send(connection_id, update.clone())?;
2190    }
2191
2192    response.send(proto::Ack {})?;
2193    Ok(())
2194}
2195
2196async fn create_channel(
2197    request: proto::CreateChannel,
2198    response: Response<proto::CreateChannel>,
2199    session: Session,
2200) -> Result<()> {
2201    let db = session.db().await;
2202
2203    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2204    let id = db
2205        .create_channel(&request.name, parent_id, session.user_id)
2206        .await?;
2207
2208    let channel = proto::Channel {
2209        id: id.to_proto(),
2210        name: request.name,
2211        visibility: proto::ChannelVisibility::Members as i32,
2212    };
2213
2214    response.send(proto::CreateChannelResponse {
2215        channel: Some(channel.clone()),
2216        parent_id: request.parent_id,
2217    })?;
2218
2219    let Some(parent_id) = parent_id else {
2220        return Ok(());
2221    };
2222
2223    let update = proto::UpdateChannels {
2224        channels: vec![channel],
2225        insert_edge: vec![ChannelEdge {
2226            parent_id: parent_id.to_proto(),
2227            channel_id: id.to_proto(),
2228        }],
2229        ..Default::default()
2230    };
2231
2232    let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2233
2234    let connection_pool = session.connection_pool().await;
2235    for user_id in user_ids_to_notify {
2236        for connection_id in connection_pool.user_connection_ids(user_id) {
2237            if user_id == session.user_id {
2238                continue;
2239            }
2240            session.peer.send(connection_id, update.clone())?;
2241        }
2242    }
2243
2244    Ok(())
2245}
2246
2247async fn delete_channel(
2248    request: proto::DeleteChannel,
2249    response: Response<proto::DeleteChannel>,
2250    session: Session,
2251) -> Result<()> {
2252    let db = session.db().await;
2253
2254    let channel_id = request.channel_id;
2255    let (removed_channels, member_ids) = db
2256        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2257        .await?;
2258    response.send(proto::Ack {})?;
2259
2260    // Notify members of removed channels
2261    let mut update = proto::UpdateChannels::default();
2262    update
2263        .delete_channels
2264        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2265
2266    let connection_pool = session.connection_pool().await;
2267    for member_id in member_ids {
2268        for connection_id in connection_pool.user_connection_ids(member_id) {
2269            session.peer.send(connection_id, update.clone())?;
2270        }
2271    }
2272
2273    Ok(())
2274}
2275
2276async fn invite_channel_member(
2277    request: proto::InviteChannelMember,
2278    response: Response<proto::InviteChannelMember>,
2279    session: Session,
2280) -> Result<()> {
2281    let db = session.db().await;
2282    let channel_id = ChannelId::from_proto(request.channel_id);
2283    let invitee_id = UserId::from_proto(request.user_id);
2284    db.invite_channel_member(
2285        channel_id,
2286        invitee_id,
2287        session.user_id,
2288        request.role().into(),
2289    )
2290    .await?;
2291
2292    let (channel, _) = db
2293        .get_channel(channel_id, session.user_id)
2294        .await?
2295        .ok_or_else(|| anyhow!("channel not found"))?;
2296
2297    let mut update = proto::UpdateChannels::default();
2298    update.channel_invitations.push(proto::Channel {
2299        id: channel.id.to_proto(),
2300        visibility: channel.visibility.into(),
2301        name: channel.name,
2302    });
2303    for connection_id in session
2304        .connection_pool()
2305        .await
2306        .user_connection_ids(invitee_id)
2307    {
2308        session.peer.send(connection_id, update.clone())?;
2309    }
2310
2311    response.send(proto::Ack {})?;
2312    Ok(())
2313}
2314
2315async fn remove_channel_member(
2316    request: proto::RemoveChannelMember,
2317    response: Response<proto::RemoveChannelMember>,
2318    session: Session,
2319) -> Result<()> {
2320    let db = session.db().await;
2321    let channel_id = ChannelId::from_proto(request.channel_id);
2322    let member_id = UserId::from_proto(request.user_id);
2323
2324    db.remove_channel_member(channel_id, member_id, session.user_id)
2325        .await?;
2326
2327    let mut update = proto::UpdateChannels::default();
2328    update.delete_channels.push(channel_id.to_proto());
2329
2330    for connection_id in session
2331        .connection_pool()
2332        .await
2333        .user_connection_ids(member_id)
2334    {
2335        session.peer.send(connection_id, update.clone())?;
2336    }
2337
2338    response.send(proto::Ack {})?;
2339    Ok(())
2340}
2341
2342async fn set_channel_visibility(
2343    request: proto::SetChannelVisibility,
2344    response: Response<proto::SetChannelVisibility>,
2345    session: Session,
2346) -> Result<()> {
2347    let db = session.db().await;
2348    let channel_id = ChannelId::from_proto(request.channel_id);
2349    let visibility = request.visibility().into();
2350
2351    let channel = db
2352        .set_channel_visibility(channel_id, visibility, session.user_id)
2353        .await?;
2354
2355    let mut update = proto::UpdateChannels::default();
2356    update.channels.push(proto::Channel {
2357        id: channel.id.to_proto(),
2358        name: channel.name,
2359        visibility: channel.visibility.into(),
2360    });
2361
2362    let member_ids = db.get_channel_members(channel_id).await?;
2363
2364    let connection_pool = session.connection_pool().await;
2365    for member_id in member_ids {
2366        for connection_id in connection_pool.user_connection_ids(member_id) {
2367            session.peer.send(connection_id, update.clone())?;
2368        }
2369    }
2370
2371    response.send(proto::Ack {})?;
2372    Ok(())
2373}
2374
2375async fn set_channel_member_role(
2376    request: proto::SetChannelMemberRole,
2377    response: Response<proto::SetChannelMemberRole>,
2378    session: Session,
2379) -> Result<()> {
2380    let db = session.db().await;
2381    let channel_id = ChannelId::from_proto(request.channel_id);
2382    let member_id = UserId::from_proto(request.user_id);
2383    db.set_channel_member_role(
2384        channel_id,
2385        session.user_id,
2386        member_id,
2387        request.role().into(),
2388    )
2389    .await?;
2390
2391    let (channel, has_accepted) = db
2392        .get_channel(channel_id, member_id)
2393        .await?
2394        .ok_or_else(|| anyhow!("channel not found"))?;
2395
2396    let mut update = proto::UpdateChannels::default();
2397    if has_accepted {
2398        update.channel_permissions.push(proto::ChannelPermission {
2399            channel_id: channel.id.to_proto(),
2400            role: request.role,
2401        });
2402    }
2403
2404    for connection_id in session
2405        .connection_pool()
2406        .await
2407        .user_connection_ids(member_id)
2408    {
2409        session.peer.send(connection_id, update.clone())?;
2410    }
2411
2412    response.send(proto::Ack {})?;
2413    Ok(())
2414}
2415
2416async fn rename_channel(
2417    request: proto::RenameChannel,
2418    response: Response<proto::RenameChannel>,
2419    session: Session,
2420) -> Result<()> {
2421    let db = session.db().await;
2422    let channel_id = ChannelId::from_proto(request.channel_id);
2423    let channel = db
2424        .rename_channel(channel_id, session.user_id, &request.name)
2425        .await?;
2426
2427    let channel = proto::Channel {
2428        id: channel.id.to_proto(),
2429        name: channel.name,
2430        visibility: channel.visibility.into(),
2431    };
2432    response.send(proto::RenameChannelResponse {
2433        channel: Some(channel.clone()),
2434    })?;
2435    let mut update = proto::UpdateChannels::default();
2436    update.channels.push(channel);
2437
2438    let member_ids = db.get_channel_members(channel_id).await?;
2439
2440    let connection_pool = session.connection_pool().await;
2441    for member_id in member_ids {
2442        for connection_id in connection_pool.user_connection_ids(member_id) {
2443            session.peer.send(connection_id, update.clone())?;
2444        }
2445    }
2446
2447    Ok(())
2448}
2449
2450async fn link_channel(
2451    request: proto::LinkChannel,
2452    response: Response<proto::LinkChannel>,
2453    session: Session,
2454) -> Result<()> {
2455    let db = session.db().await;
2456    let channel_id = ChannelId::from_proto(request.channel_id);
2457    let to = ChannelId::from_proto(request.to);
2458    let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2459
2460    let members = db.get_channel_members(to).await?;
2461    let connection_pool = session.connection_pool().await;
2462    let update = proto::UpdateChannels {
2463        channels: channels_to_send
2464            .channels
2465            .into_iter()
2466            .map(|channel| proto::Channel {
2467                id: channel.id.to_proto(),
2468                visibility: channel.visibility.into(),
2469                name: channel.name,
2470            })
2471            .collect(),
2472        insert_edge: channels_to_send.edges,
2473        ..Default::default()
2474    };
2475    for member_id in members {
2476        for connection_id in connection_pool.user_connection_ids(member_id) {
2477            session.peer.send(connection_id, update.clone())?;
2478        }
2479    }
2480
2481    response.send(Ack {})?;
2482
2483    Ok(())
2484}
2485
2486async fn unlink_channel(
2487    request: proto::UnlinkChannel,
2488    response: Response<proto::UnlinkChannel>,
2489    session: Session,
2490) -> Result<()> {
2491    let db = session.db().await;
2492    let channel_id = ChannelId::from_proto(request.channel_id);
2493    let from = ChannelId::from_proto(request.from);
2494
2495    db.unlink_channel(session.user_id, channel_id, from).await?;
2496
2497    let members = db.get_channel_members(from).await?;
2498
2499    let update = proto::UpdateChannels {
2500        delete_edge: vec![proto::ChannelEdge {
2501            channel_id: channel_id.to_proto(),
2502            parent_id: from.to_proto(),
2503        }],
2504        ..Default::default()
2505    };
2506    let connection_pool = session.connection_pool().await;
2507    for member_id in members {
2508        for connection_id in connection_pool.user_connection_ids(member_id) {
2509            session.peer.send(connection_id, update.clone())?;
2510        }
2511    }
2512
2513    response.send(Ack {})?;
2514
2515    Ok(())
2516}
2517
2518async fn move_channel(
2519    request: proto::MoveChannel,
2520    response: Response<proto::MoveChannel>,
2521    session: Session,
2522) -> Result<()> {
2523    let db = session.db().await;
2524    let channel_id = ChannelId::from_proto(request.channel_id);
2525    let from_parent = ChannelId::from_proto(request.from);
2526    let to = ChannelId::from_proto(request.to);
2527
2528    let channels_to_send = db
2529        .move_channel(session.user_id, channel_id, from_parent, to)
2530        .await?;
2531
2532    if channels_to_send.is_empty() {
2533        response.send(Ack {})?;
2534        return Ok(());
2535    }
2536
2537    let members_from = db.get_channel_members(from_parent).await?;
2538    let members_to = db.get_channel_members(to).await?;
2539
2540    let update = proto::UpdateChannels {
2541        delete_edge: vec![proto::ChannelEdge {
2542            channel_id: channel_id.to_proto(),
2543            parent_id: from_parent.to_proto(),
2544        }],
2545        ..Default::default()
2546    };
2547    let connection_pool = session.connection_pool().await;
2548    for member_id in members_from {
2549        for connection_id in connection_pool.user_connection_ids(member_id) {
2550            session.peer.send(connection_id, update.clone())?;
2551        }
2552    }
2553
2554    let update = proto::UpdateChannels {
2555        channels: channels_to_send
2556            .channels
2557            .into_iter()
2558            .map(|channel| proto::Channel {
2559                id: channel.id.to_proto(),
2560                visibility: channel.visibility.into(),
2561                name: channel.name,
2562            })
2563            .collect(),
2564        insert_edge: channels_to_send.edges,
2565        ..Default::default()
2566    };
2567    for member_id in members_to {
2568        for connection_id in connection_pool.user_connection_ids(member_id) {
2569            session.peer.send(connection_id, update.clone())?;
2570        }
2571    }
2572
2573    response.send(Ack {})?;
2574
2575    Ok(())
2576}
2577
2578async fn get_channel_members(
2579    request: proto::GetChannelMembers,
2580    response: Response<proto::GetChannelMembers>,
2581    session: Session,
2582) -> Result<()> {
2583    let db = session.db().await;
2584    let channel_id = ChannelId::from_proto(request.channel_id);
2585    let members = db
2586        .get_channel_participant_details(channel_id, session.user_id)
2587        .await?;
2588    response.send(proto::GetChannelMembersResponse { members })?;
2589    Ok(())
2590}
2591
2592async fn respond_to_channel_invite(
2593    request: proto::RespondToChannelInvite,
2594    response: Response<proto::RespondToChannelInvite>,
2595    session: Session,
2596) -> Result<()> {
2597    let db = session.db().await;
2598    let channel_id = ChannelId::from_proto(request.channel_id);
2599    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2600        .await?;
2601
2602    if request.accept {
2603        channel_membership_updated(db, channel_id, &session).await?;
2604    } else {
2605        let mut update = proto::UpdateChannels::default();
2606        update
2607            .remove_channel_invitations
2608            .push(channel_id.to_proto());
2609        session.peer.send(session.connection_id, update)?;
2610    }
2611    response.send(proto::Ack {})?;
2612
2613    Ok(())
2614}
2615
2616async fn channel_membership_updated(
2617    db: tokio::sync::MutexGuard<'_, DbHandle>,
2618    channel_id: ChannelId,
2619    session: &Session,
2620) -> Result<(), crate::Error> {
2621    let mut update = proto::UpdateChannels::default();
2622    update
2623        .remove_channel_invitations
2624        .push(channel_id.to_proto());
2625
2626    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2627    update.channels.extend(
2628        result
2629            .channels
2630            .channels
2631            .into_iter()
2632            .map(|channel| proto::Channel {
2633                id: channel.id.to_proto(),
2634                visibility: channel.visibility.into(),
2635                name: channel.name,
2636            }),
2637    );
2638    update.unseen_channel_messages = result.channel_messages;
2639    update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2640    update.insert_edge = result.channels.edges;
2641    update
2642        .channel_participants
2643        .extend(
2644            result
2645                .channel_participants
2646                .into_iter()
2647                .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2648                    channel_id: channel_id.to_proto(),
2649                    participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2650                }),
2651        );
2652    update
2653        .channel_permissions
2654        .extend(
2655            result
2656                .channels_with_admin_privileges
2657                .into_iter()
2658                .map(|channel_id| proto::ChannelPermission {
2659                    channel_id: channel_id.to_proto(),
2660                    role: proto::ChannelRole::Admin.into(),
2661                }),
2662        );
2663    session.peer.send(session.connection_id, update)?;
2664    Ok(())
2665}
2666
2667async fn join_channel(
2668    request: proto::JoinChannel,
2669    response: Response<proto::JoinChannel>,
2670    session: Session,
2671) -> Result<()> {
2672    let channel_id = ChannelId::from_proto(request.channel_id);
2673    join_channel_internal(channel_id, Box::new(response), session).await
2674}
2675
2676trait JoinChannelInternalResponse {
2677    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2678}
2679impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2680    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2681        Response::<proto::JoinChannel>::send(self, result)
2682    }
2683}
2684impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2685    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2686        Response::<proto::JoinRoom>::send(self, result)
2687    }
2688}
2689
2690async fn join_channel_internal(
2691    channel_id: ChannelId,
2692    response: Box<impl JoinChannelInternalResponse>,
2693    session: Session,
2694) -> Result<()> {
2695    let joined_room = {
2696        leave_room_for_session(&session).await?;
2697        let db = session.db().await;
2698
2699        let (joined_room, joined_channel) = db
2700            .join_channel(
2701                channel_id,
2702                session.user_id,
2703                session.connection_id,
2704                RELEASE_CHANNEL_NAME.as_str(),
2705            )
2706            .await?;
2707
2708        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2709            let token = live_kit
2710                .room_token(
2711                    &joined_room.room.live_kit_room,
2712                    &session.user_id.to_string(),
2713                )
2714                .trace_err()?;
2715
2716            Some(LiveKitConnectionInfo {
2717                server_url: live_kit.url().into(),
2718                token,
2719            })
2720        });
2721
2722        response.send(proto::JoinRoomResponse {
2723            room: Some(joined_room.room.clone()),
2724            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2725            live_kit_connection_info,
2726        })?;
2727
2728        if joined_channel {
2729            channel_membership_updated(db, channel_id, &session).await?
2730        }
2731
2732        room_updated(&joined_room.room, &session.peer);
2733
2734        joined_room
2735    };
2736
2737    channel_updated(
2738        channel_id,
2739        &joined_room.room,
2740        &joined_room.channel_members,
2741        &session.peer,
2742        &*session.connection_pool().await,
2743    );
2744
2745    update_user_contacts(session.user_id, &session).await?;
2746    Ok(())
2747}
2748
2749async fn join_channel_buffer(
2750    request: proto::JoinChannelBuffer,
2751    response: Response<proto::JoinChannelBuffer>,
2752    session: Session,
2753) -> Result<()> {
2754    let db = session.db().await;
2755    let channel_id = ChannelId::from_proto(request.channel_id);
2756
2757    let open_response = db
2758        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2759        .await?;
2760
2761    let collaborators = open_response.collaborators.clone();
2762    response.send(open_response)?;
2763
2764    let update = UpdateChannelBufferCollaborators {
2765        channel_id: channel_id.to_proto(),
2766        collaborators: collaborators.clone(),
2767    };
2768    channel_buffer_updated(
2769        session.connection_id,
2770        collaborators
2771            .iter()
2772            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2773        &update,
2774        &session.peer,
2775    );
2776
2777    Ok(())
2778}
2779
2780async fn update_channel_buffer(
2781    request: proto::UpdateChannelBuffer,
2782    session: Session,
2783) -> Result<()> {
2784    let db = session.db().await;
2785    let channel_id = ChannelId::from_proto(request.channel_id);
2786
2787    let (collaborators, non_collaborators, epoch, version) = db
2788        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2789        .await?;
2790
2791    channel_buffer_updated(
2792        session.connection_id,
2793        collaborators,
2794        &proto::UpdateChannelBuffer {
2795            channel_id: channel_id.to_proto(),
2796            operations: request.operations,
2797        },
2798        &session.peer,
2799    );
2800
2801    let pool = &*session.connection_pool().await;
2802
2803    broadcast(
2804        None,
2805        non_collaborators
2806            .iter()
2807            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2808        |peer_id| {
2809            session.peer.send(
2810                peer_id.into(),
2811                proto::UpdateChannels {
2812                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2813                        channel_id: channel_id.to_proto(),
2814                        epoch: epoch as u64,
2815                        version: version.clone(),
2816                    }],
2817                    ..Default::default()
2818                },
2819            )
2820        },
2821    );
2822
2823    Ok(())
2824}
2825
2826async fn rejoin_channel_buffers(
2827    request: proto::RejoinChannelBuffers,
2828    response: Response<proto::RejoinChannelBuffers>,
2829    session: Session,
2830) -> Result<()> {
2831    let db = session.db().await;
2832    let buffers = db
2833        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2834        .await?;
2835
2836    for rejoined_buffer in &buffers {
2837        let collaborators_to_notify = rejoined_buffer
2838            .buffer
2839            .collaborators
2840            .iter()
2841            .filter_map(|c| Some(c.peer_id?.into()));
2842        channel_buffer_updated(
2843            session.connection_id,
2844            collaborators_to_notify,
2845            &proto::UpdateChannelBufferCollaborators {
2846                channel_id: rejoined_buffer.buffer.channel_id,
2847                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2848            },
2849            &session.peer,
2850        );
2851    }
2852
2853    response.send(proto::RejoinChannelBuffersResponse {
2854        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2855    })?;
2856
2857    Ok(())
2858}
2859
2860async fn leave_channel_buffer(
2861    request: proto::LeaveChannelBuffer,
2862    response: Response<proto::LeaveChannelBuffer>,
2863    session: Session,
2864) -> Result<()> {
2865    let db = session.db().await;
2866    let channel_id = ChannelId::from_proto(request.channel_id);
2867
2868    let left_buffer = db
2869        .leave_channel_buffer(channel_id, session.connection_id)
2870        .await?;
2871
2872    response.send(Ack {})?;
2873
2874    channel_buffer_updated(
2875        session.connection_id,
2876        left_buffer.connections,
2877        &proto::UpdateChannelBufferCollaborators {
2878            channel_id: channel_id.to_proto(),
2879            collaborators: left_buffer.collaborators,
2880        },
2881        &session.peer,
2882    );
2883
2884    Ok(())
2885}
2886
2887fn channel_buffer_updated<T: EnvelopedMessage>(
2888    sender_id: ConnectionId,
2889    collaborators: impl IntoIterator<Item = ConnectionId>,
2890    message: &T,
2891    peer: &Peer,
2892) {
2893    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2894        peer.send(peer_id.into(), message.clone())
2895    });
2896}
2897
2898async fn send_channel_message(
2899    request: proto::SendChannelMessage,
2900    response: Response<proto::SendChannelMessage>,
2901    session: Session,
2902) -> Result<()> {
2903    // Validate the message body.
2904    let body = request.body.trim().to_string();
2905    if body.len() > MAX_MESSAGE_LEN {
2906        return Err(anyhow!("message is too long"))?;
2907    }
2908    if body.is_empty() {
2909        return Err(anyhow!("message can't be blank"))?;
2910    }
2911
2912    let timestamp = OffsetDateTime::now_utc();
2913    let nonce = request
2914        .nonce
2915        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2916
2917    let channel_id = ChannelId::from_proto(request.channel_id);
2918    let (message_id, connection_ids, non_participants) = session
2919        .db()
2920        .await
2921        .create_channel_message(
2922            channel_id,
2923            session.user_id,
2924            &body,
2925            timestamp,
2926            nonce.clone().into(),
2927        )
2928        .await?;
2929    let message = proto::ChannelMessage {
2930        sender_id: session.user_id.to_proto(),
2931        id: message_id.to_proto(),
2932        body,
2933        timestamp: timestamp.unix_timestamp() as u64,
2934        nonce: Some(nonce),
2935    };
2936    broadcast(Some(session.connection_id), connection_ids, |connection| {
2937        session.peer.send(
2938            connection,
2939            proto::ChannelMessageSent {
2940                channel_id: channel_id.to_proto(),
2941                message: Some(message.clone()),
2942            },
2943        )
2944    });
2945    response.send(proto::SendChannelMessageResponse {
2946        message: Some(message),
2947    })?;
2948
2949    let pool = &*session.connection_pool().await;
2950    broadcast(
2951        None,
2952        non_participants
2953            .iter()
2954            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2955        |peer_id| {
2956            session.peer.send(
2957                peer_id.into(),
2958                proto::UpdateChannels {
2959                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2960                        channel_id: channel_id.to_proto(),
2961                        message_id: message_id.to_proto(),
2962                    }],
2963                    ..Default::default()
2964                },
2965            )
2966        },
2967    );
2968
2969    Ok(())
2970}
2971
2972async fn remove_channel_message(
2973    request: proto::RemoveChannelMessage,
2974    response: Response<proto::RemoveChannelMessage>,
2975    session: Session,
2976) -> Result<()> {
2977    let channel_id = ChannelId::from_proto(request.channel_id);
2978    let message_id = MessageId::from_proto(request.message_id);
2979    let connection_ids = session
2980        .db()
2981        .await
2982        .remove_channel_message(channel_id, message_id, session.user_id)
2983        .await?;
2984    broadcast(Some(session.connection_id), connection_ids, |connection| {
2985        session.peer.send(connection, request.clone())
2986    });
2987    response.send(proto::Ack {})?;
2988    Ok(())
2989}
2990
2991async fn acknowledge_channel_message(
2992    request: proto::AckChannelMessage,
2993    session: Session,
2994) -> Result<()> {
2995    let channel_id = ChannelId::from_proto(request.channel_id);
2996    let message_id = MessageId::from_proto(request.message_id);
2997    session
2998        .db()
2999        .await
3000        .observe_channel_message(channel_id, session.user_id, message_id)
3001        .await?;
3002    Ok(())
3003}
3004
3005async fn acknowledge_buffer_version(
3006    request: proto::AckBufferOperation,
3007    session: Session,
3008) -> Result<()> {
3009    let buffer_id = BufferId::from_proto(request.buffer_id);
3010    session
3011        .db()
3012        .await
3013        .observe_buffer_version(
3014            buffer_id,
3015            session.user_id,
3016            request.epoch as i32,
3017            &request.version,
3018        )
3019        .await?;
3020    Ok(())
3021}
3022
3023async fn join_channel_chat(
3024    request: proto::JoinChannelChat,
3025    response: Response<proto::JoinChannelChat>,
3026    session: Session,
3027) -> Result<()> {
3028    let channel_id = ChannelId::from_proto(request.channel_id);
3029
3030    let db = session.db().await;
3031    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3032        .await?;
3033    let messages = db
3034        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3035        .await?;
3036    response.send(proto::JoinChannelChatResponse {
3037        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3038        messages,
3039    })?;
3040    Ok(())
3041}
3042
3043async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3044    let channel_id = ChannelId::from_proto(request.channel_id);
3045    session
3046        .db()
3047        .await
3048        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3049        .await?;
3050    Ok(())
3051}
3052
3053async fn get_channel_messages(
3054    request: proto::GetChannelMessages,
3055    response: Response<proto::GetChannelMessages>,
3056    session: Session,
3057) -> Result<()> {
3058    let channel_id = ChannelId::from_proto(request.channel_id);
3059    let messages = session
3060        .db()
3061        .await
3062        .get_channel_messages(
3063            channel_id,
3064            session.user_id,
3065            MESSAGE_COUNT_PER_PAGE,
3066            Some(MessageId::from_proto(request.before_message_id)),
3067        )
3068        .await?;
3069    response.send(proto::GetChannelMessagesResponse {
3070        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3071        messages,
3072    })?;
3073    Ok(())
3074}
3075
3076async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3077    let project_id = ProjectId::from_proto(request.project_id);
3078    let project_connection_ids = session
3079        .db()
3080        .await
3081        .project_connection_ids(project_id, session.connection_id)
3082        .await?;
3083    broadcast(
3084        Some(session.connection_id),
3085        project_connection_ids.iter().copied(),
3086        |connection_id| {
3087            session
3088                .peer
3089                .forward_send(session.connection_id, connection_id, request.clone())
3090        },
3091    );
3092    Ok(())
3093}
3094
3095async fn get_private_user_info(
3096    _request: proto::GetPrivateUserInfo,
3097    response: Response<proto::GetPrivateUserInfo>,
3098    session: Session,
3099) -> Result<()> {
3100    let db = session.db().await;
3101
3102    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3103    let user = db
3104        .get_user_by_id(session.user_id)
3105        .await?
3106        .ok_or_else(|| anyhow!("user not found"))?;
3107    let flags = db.get_user_flags(session.user_id).await?;
3108
3109    response.send(proto::GetPrivateUserInfoResponse {
3110        metrics_id,
3111        staff: user.admin,
3112        flags,
3113    })?;
3114    Ok(())
3115}
3116
3117fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3118    match message {
3119        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3120        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3121        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3122        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3123        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3124            code: frame.code.into(),
3125            reason: frame.reason,
3126        })),
3127    }
3128}
3129
3130fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3131    match message {
3132        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3133        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3134        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3135        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3136        AxumMessage::Close(frame) => {
3137            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3138                code: frame.code.into(),
3139                reason: frame.reason,
3140            }))
3141        }
3142    }
3143}
3144
3145fn build_initial_channels_update(
3146    channels: ChannelsForUser,
3147    channel_invites: Vec<db::Channel>,
3148) -> proto::UpdateChannels {
3149    let mut update = proto::UpdateChannels::default();
3150
3151    for channel in channels.channels.channels {
3152        update.channels.push(proto::Channel {
3153            id: channel.id.to_proto(),
3154            name: channel.name,
3155            visibility: channel.visibility.into(),
3156        });
3157    }
3158
3159    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3160    update.unseen_channel_messages = channels.channel_messages;
3161    update.insert_edge = channels.channels.edges;
3162
3163    for (channel_id, participants) in channels.channel_participants {
3164        update
3165            .channel_participants
3166            .push(proto::ChannelParticipants {
3167                channel_id: channel_id.to_proto(),
3168                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3169            });
3170    }
3171
3172    update
3173        .channel_permissions
3174        .extend(
3175            channels
3176                .channels_with_admin_privileges
3177                .into_iter()
3178                .map(|id| proto::ChannelPermission {
3179                    channel_id: id.to_proto(),
3180                    role: proto::ChannelRole::Admin.into(),
3181                }),
3182        );
3183
3184    for channel in channel_invites {
3185        update.channel_invitations.push(proto::Channel {
3186            id: channel.id.to_proto(),
3187            name: channel.name,
3188            // TODO: Visibility
3189            visibility: ChannelVisibility::Public.into(),
3190        });
3191    }
3192
3193    update
3194}
3195
3196fn build_initial_contacts_update(
3197    contacts: Vec<db::Contact>,
3198    pool: &ConnectionPool,
3199) -> proto::UpdateContacts {
3200    let mut update = proto::UpdateContacts::default();
3201
3202    for contact in contacts {
3203        match contact {
3204            db::Contact::Accepted {
3205                user_id,
3206                should_notify,
3207                busy,
3208            } => {
3209                update
3210                    .contacts
3211                    .push(contact_for_user(user_id, should_notify, busy, &pool));
3212            }
3213            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3214            db::Contact::Incoming {
3215                user_id,
3216                should_notify,
3217            } => update
3218                .incoming_requests
3219                .push(proto::IncomingContactRequest {
3220                    requester_id: user_id.to_proto(),
3221                    should_notify,
3222                }),
3223        }
3224    }
3225
3226    update
3227}
3228
3229fn contact_for_user(
3230    user_id: UserId,
3231    should_notify: bool,
3232    busy: bool,
3233    pool: &ConnectionPool,
3234) -> proto::Contact {
3235    proto::Contact {
3236        user_id: user_id.to_proto(),
3237        online: pool.is_user_online(user_id),
3238        busy,
3239        should_notify,
3240    }
3241}
3242
3243fn room_updated(room: &proto::Room, peer: &Peer) {
3244    broadcast(
3245        None,
3246        room.participants
3247            .iter()
3248            .filter_map(|participant| Some(participant.peer_id?.into())),
3249        |peer_id| {
3250            peer.send(
3251                peer_id.into(),
3252                proto::RoomUpdated {
3253                    room: Some(room.clone()),
3254                },
3255            )
3256        },
3257    );
3258}
3259
3260fn channel_updated(
3261    channel_id: ChannelId,
3262    room: &proto::Room,
3263    channel_members: &[UserId],
3264    peer: &Peer,
3265    pool: &ConnectionPool,
3266) {
3267    let participants = room
3268        .participants
3269        .iter()
3270        .map(|p| p.user_id)
3271        .collect::<Vec<_>>();
3272
3273    broadcast(
3274        None,
3275        channel_members
3276            .iter()
3277            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3278        |peer_id| {
3279            peer.send(
3280                peer_id.into(),
3281                proto::UpdateChannels {
3282                    channel_participants: vec![proto::ChannelParticipants {
3283                        channel_id: channel_id.to_proto(),
3284                        participant_user_ids: participants.clone(),
3285                    }],
3286                    ..Default::default()
3287                },
3288            )
3289        },
3290    );
3291}
3292
3293async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3294    let db = session.db().await;
3295
3296    let contacts = db.get_contacts(user_id).await?;
3297    let busy = db.is_user_busy(user_id).await?;
3298
3299    let pool = session.connection_pool().await;
3300    let updated_contact = contact_for_user(user_id, false, busy, &pool);
3301    for contact in contacts {
3302        if let db::Contact::Accepted {
3303            user_id: contact_user_id,
3304            ..
3305        } = contact
3306        {
3307            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3308                session
3309                    .peer
3310                    .send(
3311                        contact_conn_id,
3312                        proto::UpdateContacts {
3313                            contacts: vec![updated_contact.clone()],
3314                            remove_contacts: Default::default(),
3315                            incoming_requests: Default::default(),
3316                            remove_incoming_requests: Default::default(),
3317                            outgoing_requests: Default::default(),
3318                            remove_outgoing_requests: Default::default(),
3319                        },
3320                    )
3321                    .trace_err();
3322            }
3323        }
3324    }
3325    Ok(())
3326}
3327
3328async fn leave_room_for_session(session: &Session) -> Result<()> {
3329    let mut contacts_to_update = HashSet::default();
3330
3331    let room_id;
3332    let canceled_calls_to_user_ids;
3333    let live_kit_room;
3334    let delete_live_kit_room;
3335    let room;
3336    let channel_members;
3337    let channel_id;
3338
3339    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3340        contacts_to_update.insert(session.user_id);
3341
3342        for project in left_room.left_projects.values() {
3343            project_left(project, session);
3344        }
3345
3346        room_id = RoomId::from_proto(left_room.room.id);
3347        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3348        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3349        delete_live_kit_room = left_room.deleted;
3350        room = mem::take(&mut left_room.room);
3351        channel_members = mem::take(&mut left_room.channel_members);
3352        channel_id = left_room.channel_id;
3353
3354        room_updated(&room, &session.peer);
3355    } else {
3356        return Ok(());
3357    }
3358
3359    if let Some(channel_id) = channel_id {
3360        channel_updated(
3361            channel_id,
3362            &room,
3363            &channel_members,
3364            &session.peer,
3365            &*session.connection_pool().await,
3366        );
3367    }
3368
3369    {
3370        let pool = session.connection_pool().await;
3371        for canceled_user_id in canceled_calls_to_user_ids {
3372            for connection_id in pool.user_connection_ids(canceled_user_id) {
3373                session
3374                    .peer
3375                    .send(
3376                        connection_id,
3377                        proto::CallCanceled {
3378                            room_id: room_id.to_proto(),
3379                        },
3380                    )
3381                    .trace_err();
3382            }
3383            contacts_to_update.insert(canceled_user_id);
3384        }
3385    }
3386
3387    for contact_user_id in contacts_to_update {
3388        update_user_contacts(contact_user_id, &session).await?;
3389    }
3390
3391    if let Some(live_kit) = session.live_kit_client.as_ref() {
3392        live_kit
3393            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3394            .await
3395            .trace_err();
3396
3397        if delete_live_kit_room {
3398            live_kit.delete_room(live_kit_room).await.trace_err();
3399        }
3400    }
3401
3402    Ok(())
3403}
3404
3405async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3406    let left_channel_buffers = session
3407        .db()
3408        .await
3409        .leave_channel_buffers(session.connection_id)
3410        .await?;
3411
3412    for left_buffer in left_channel_buffers {
3413        channel_buffer_updated(
3414            session.connection_id,
3415            left_buffer.connections,
3416            &proto::UpdateChannelBufferCollaborators {
3417                channel_id: left_buffer.channel_id.to_proto(),
3418                collaborators: left_buffer.collaborators,
3419            },
3420            &session.peer,
3421        );
3422    }
3423
3424    Ok(())
3425}
3426
3427fn project_left(project: &db::LeftProject, session: &Session) {
3428    for connection_id in &project.connection_ids {
3429        if project.host_user_id == session.user_id {
3430            session
3431                .peer
3432                .send(
3433                    *connection_id,
3434                    proto::UnshareProject {
3435                        project_id: project.id.to_proto(),
3436                    },
3437                )
3438                .trace_err();
3439        } else {
3440            session
3441                .peer
3442                .send(
3443                    *connection_id,
3444                    proto::RemoveProjectCollaborator {
3445                        project_id: project.id.to_proto(),
3446                        peer_id: Some(session.connection_id.into()),
3447                    },
3448                )
3449                .trace_err();
3450        }
3451    }
3452}
3453
3454pub trait ResultExt {
3455    type Ok;
3456
3457    fn trace_err(self) -> Option<Self::Ok>;
3458}
3459
3460impl<T, E> ResultExt for Result<T, E>
3461where
3462    E: std::fmt::Debug,
3463{
3464    type Ok = T;
3465
3466    fn trace_err(self) -> Option<T> {
3467        match self {
3468            Ok(value) => Some(value),
3469            Err(error) => {
3470                tracing::error!("{:?}", error);
3471                None
3472            }
3473        }
3474    }
3475}