rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{info_span, instrument, Instrument};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  72
  73const MESSAGE_COUNT_PER_PAGE: usize = 100;
  74const MAX_MESSAGE_LEN: usize = 1024;
  75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  76
  77lazy_static! {
  78    static ref METRIC_CONNECTIONS: IntGauge =
  79        register_int_gauge!("connections", "number of connections").unwrap();
  80    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  81        "shared_projects",
  82        "number of open projects with one or more guests"
  83    )
  84    .unwrap();
  85}
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct Session {
 106    zed_environment: Arc<str>,
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(set_room_participant_role)
 206            .add_request_handler(call)
 207            .add_request_handler(cancel_call)
 208            .add_message_handler(decline_call)
 209            .add_request_handler(update_participant_location)
 210            .add_request_handler(share_project)
 211            .add_message_handler(unshare_project)
 212            .add_request_handler(join_project)
 213            .add_message_handler(leave_project)
 214            .add_request_handler(update_project)
 215            .add_request_handler(update_worktree)
 216            .add_message_handler(start_language_server)
 217            .add_message_handler(update_language_server)
 218            .add_message_handler(update_diagnostic_summary)
 219            .add_message_handler(update_worktree_settings)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 230            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 231            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 232            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 233            .add_request_handler(
 234                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 235            )
 236            .add_request_handler(
 237                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 238            )
 239            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 240            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 243            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 251            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 252            .add_message_handler(create_buffer_for_peer)
 253            .add_request_handler(update_buffer)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 258            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 259            .add_request_handler(get_users)
 260            .add_request_handler(fuzzy_search_users)
 261            .add_request_handler(request_contact)
 262            .add_request_handler(remove_contact)
 263            .add_request_handler(respond_to_contact_request)
 264            .add_request_handler(create_channel)
 265            .add_request_handler(delete_channel)
 266            .add_request_handler(invite_channel_member)
 267            .add_request_handler(remove_channel_member)
 268            .add_request_handler(set_channel_member_role)
 269            .add_request_handler(set_channel_visibility)
 270            .add_request_handler(rename_channel)
 271            .add_request_handler(join_channel_buffer)
 272            .add_request_handler(leave_channel_buffer)
 273            .add_message_handler(update_channel_buffer)
 274            .add_request_handler(rejoin_channel_buffers)
 275            .add_request_handler(get_channel_members)
 276            .add_request_handler(respond_to_channel_invite)
 277            .add_request_handler(join_channel)
 278            .add_request_handler(join_channel_chat)
 279            .add_message_handler(leave_channel_chat)
 280            .add_request_handler(send_channel_message)
 281            .add_request_handler(remove_channel_message)
 282            .add_request_handler(get_channel_messages)
 283            .add_request_handler(get_channel_messages_by_id)
 284            .add_request_handler(get_notifications)
 285            .add_request_handler(mark_notification_as_read)
 286            .add_request_handler(move_channel)
 287            .add_request_handler(follow)
 288            .add_message_handler(unfollow)
 289            .add_message_handler(update_followers)
 290            .add_request_handler(get_private_user_info)
 291            .add_message_handler(acknowledge_channel_message)
 292            .add_message_handler(acknowledge_buffer_version);
 293
 294        Arc::new(server)
 295    }
 296
 297    pub async fn start(&self) -> Result<()> {
 298        let server_id = *self.id.lock();
 299        let app_state = self.app_state.clone();
 300        let peer = self.peer.clone();
 301        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 302        let pool = self.connection_pool.clone();
 303        let live_kit_client = self.app_state.live_kit_client.clone();
 304
 305        let span = info_span!("start server");
 306        self.executor.spawn_detached(
 307            async move {
 308                tracing::info!("waiting for cleanup timeout");
 309                timeout.await;
 310                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 311                if let Some((room_ids, channel_ids)) = app_state
 312                    .db
 313                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 314                    .await
 315                    .trace_err()
 316                {
 317                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 318                    tracing::info!(
 319                        stale_channel_buffer_count = channel_ids.len(),
 320                        "retrieved stale channel buffers"
 321                    );
 322
 323                    for channel_id in channel_ids {
 324                        if let Some(refreshed_channel_buffer) = app_state
 325                            .db
 326                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 327                            .await
 328                            .trace_err()
 329                        {
 330                            for connection_id in refreshed_channel_buffer.connection_ids {
 331                                peer.send(
 332                                    connection_id,
 333                                    proto::UpdateChannelBufferCollaborators {
 334                                        channel_id: channel_id.to_proto(),
 335                                        collaborators: refreshed_channel_buffer
 336                                            .collaborators
 337                                            .clone(),
 338                                    },
 339                                )
 340                                .trace_err();
 341                            }
 342                        }
 343                    }
 344
 345                    for room_id in room_ids {
 346                        let mut contacts_to_update = HashSet::default();
 347                        let mut canceled_calls_to_user_ids = Vec::new();
 348                        let mut live_kit_room = String::new();
 349                        let mut delete_live_kit_room = false;
 350
 351                        if let Some(mut refreshed_room) = app_state
 352                            .db
 353                            .clear_stale_room_participants(room_id, server_id)
 354                            .await
 355                            .trace_err()
 356                        {
 357                            tracing::info!(
 358                                room_id = room_id.0,
 359                                new_participant_count = refreshed_room.room.participants.len(),
 360                                "refreshed room"
 361                            );
 362                            room_updated(&refreshed_room.room, &peer);
 363                            if let Some(channel_id) = refreshed_room.channel_id {
 364                                channel_updated(
 365                                    channel_id,
 366                                    &refreshed_room.room,
 367                                    &refreshed_room.channel_members,
 368                                    &peer,
 369                                    &*pool.lock(),
 370                                );
 371                            }
 372                            contacts_to_update
 373                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 374                            contacts_to_update
 375                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 376                            canceled_calls_to_user_ids =
 377                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 378                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 379                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 380                        }
 381
 382                        {
 383                            let pool = pool.lock();
 384                            for canceled_user_id in canceled_calls_to_user_ids {
 385                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 386                                    peer.send(
 387                                        connection_id,
 388                                        proto::CallCanceled {
 389                                            room_id: room_id.to_proto(),
 390                                        },
 391                                    )
 392                                    .trace_err();
 393                                }
 394                            }
 395                        }
 396
 397                        for user_id in contacts_to_update {
 398                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 399                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 400                            if let Some((busy, contacts)) = busy.zip(contacts) {
 401                                let pool = pool.lock();
 402                                let updated_contact = contact_for_user(user_id, busy, &pool);
 403                                for contact in contacts {
 404                                    if let db::Contact::Accepted {
 405                                        user_id: contact_user_id,
 406                                        ..
 407                                    } = contact
 408                                    {
 409                                        for contact_conn_id in
 410                                            pool.user_connection_ids(contact_user_id)
 411                                        {
 412                                            peer.send(
 413                                                contact_conn_id,
 414                                                proto::UpdateContacts {
 415                                                    contacts: vec![updated_contact.clone()],
 416                                                    remove_contacts: Default::default(),
 417                                                    incoming_requests: Default::default(),
 418                                                    remove_incoming_requests: Default::default(),
 419                                                    outgoing_requests: Default::default(),
 420                                                    remove_outgoing_requests: Default::default(),
 421                                                },
 422                                            )
 423                                            .trace_err();
 424                                        }
 425                                    }
 426                                }
 427                            }
 428                        }
 429
 430                        if let Some(live_kit) = live_kit_client.as_ref() {
 431                            if delete_live_kit_room {
 432                                live_kit.delete_room(live_kit_room).await.trace_err();
 433                            }
 434                        }
 435                    }
 436                }
 437
 438                app_state
 439                    .db
 440                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 441                    .await
 442                    .trace_err();
 443            }
 444            .instrument(span),
 445        );
 446        Ok(())
 447    }
 448
 449    pub fn teardown(&self) {
 450        self.peer.teardown();
 451        self.connection_pool.lock().reset();
 452        let _ = self.teardown.send(());
 453    }
 454
 455    #[cfg(test)]
 456    pub fn reset(&self, id: ServerId) {
 457        self.teardown();
 458        *self.id.lock() = id;
 459        self.peer.reset(id.0 as u32);
 460    }
 461
 462    #[cfg(test)]
 463    pub fn id(&self) -> ServerId {
 464        *self.id.lock()
 465    }
 466
 467    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 468    where
 469        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 470        Fut: 'static + Send + Future<Output = Result<()>>,
 471        M: EnvelopedMessage,
 472    {
 473        let prev_handler = self.handlers.insert(
 474            TypeId::of::<M>(),
 475            Box::new(move |envelope, session| {
 476                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 477                let span = info_span!(
 478                    "handle message",
 479                    payload_type = envelope.payload_type_name()
 480                );
 481                span.in_scope(|| {
 482                    tracing::info!(
 483                        payload_type = envelope.payload_type_name(),
 484                        "message received"
 485                    );
 486                });
 487                let start_time = Instant::now();
 488                let future = (handler)(*envelope, session);
 489                async move {
 490                    let result = future.await;
 491                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 492                    match result {
 493                        Err(error) => {
 494                            tracing::error!(%error, ?duration_ms, "error handling message")
 495                        }
 496                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 497                    }
 498                }
 499                .instrument(span)
 500                .boxed()
 501            }),
 502        );
 503        if prev_handler.is_some() {
 504            panic!("registered a handler for the same message twice");
 505        }
 506        self
 507    }
 508
 509    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 512        Fut: 'static + Send + Future<Output = Result<()>>,
 513        M: EnvelopedMessage,
 514    {
 515        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 516        self
 517    }
 518
 519    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 520    where
 521        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 522        Fut: Send + Future<Output = Result<()>>,
 523        M: RequestMessage,
 524    {
 525        let handler = Arc::new(handler);
 526        self.add_handler(move |envelope, session| {
 527            let receipt = envelope.receipt();
 528            let handler = handler.clone();
 529            async move {
 530                let peer = session.peer.clone();
 531                let responded = Arc::new(AtomicBool::default());
 532                let response = Response {
 533                    peer: peer.clone(),
 534                    responded: responded.clone(),
 535                    receipt,
 536                };
 537                match (handler)(envelope.payload, response, session).await {
 538                    Ok(()) => {
 539                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 540                            Ok(())
 541                        } else {
 542                            Err(anyhow!("handler did not send a response"))?
 543                        }
 544                    }
 545                    Err(error) => {
 546                        peer.respond_with_error(
 547                            receipt,
 548                            proto::Error {
 549                                message: error.to_string(),
 550                            },
 551                        )?;
 552                        Err(error)
 553                    }
 554                }
 555            }
 556        })
 557    }
 558
 559    pub fn handle_connection(
 560        self: &Arc<Self>,
 561        connection: Connection,
 562        address: String,
 563        user: User,
 564        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 565        executor: Executor,
 566    ) -> impl Future<Output = Result<()>> {
 567        let this = self.clone();
 568        let user_id = user.id;
 569        let login = user.github_login;
 570        let span = info_span!("handle connection", %user_id, %login, %address);
 571        let mut teardown = self.teardown.subscribe();
 572        async move {
 573            let (connection_id, handle_io, mut incoming_rx) = this
 574                .peer
 575                .add_connection(connection, {
 576                    let executor = executor.clone();
 577                    move |duration| executor.sleep(duration)
 578                });
 579
 580            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 581            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 582            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 583
 584            if let Some(send_connection_id) = send_connection_id.take() {
 585                let _ = send_connection_id.send(connection_id);
 586            }
 587
 588            if !user.connected_once {
 589                this.peer.send(connection_id, proto::ShowContacts {})?;
 590                this.app_state.db.set_user_connected_once(user_id, true).await?;
 591            }
 592
 593            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 594                this.app_state.db.get_contacts(user_id),
 595                this.app_state.db.get_channels_for_user(user_id),
 596                this.app_state.db.get_channel_invites_for_user(user_id),
 597            ).await?;
 598
 599            {
 600                let mut pool = this.connection_pool.lock();
 601                pool.add_connection(connection_id, user_id, user.admin);
 602                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 603                this.peer.send(connection_id, build_channels_update(
 604                    channels_for_user,
 605                    channel_invites
 606                ))?;
 607            }
 608
 609            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 610                this.peer.send(connection_id, incoming_call)?;
 611            }
 612
 613            let session = Session {
 614                user_id,
 615                connection_id,
 616                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 617                zed_environment: this.app_state.config.zed_environment.clone(),
 618                peer: this.peer.clone(),
 619                connection_pool: this.connection_pool.clone(),
 620                live_kit_client: this.app_state.live_kit_client.clone(),
 621                _executor: executor.clone()
 622            };
 623            update_user_contacts(user_id, &session).await?;
 624
 625            let handle_io = handle_io.fuse();
 626            futures::pin_mut!(handle_io);
 627
 628            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 629            // This prevents deadlocks when e.g., client A performs a request to client B and
 630            // client B performs a request to client A. If both clients stop processing further
 631            // messages until their respective request completes, they won't have a chance to
 632            // respond to the other client's request and cause a deadlock.
 633            //
 634            // This arrangement ensures we will attempt to process earlier messages first, but fall
 635            // back to processing messages arrived later in the spirit of making progress.
 636            let mut foreground_message_handlers = FuturesUnordered::new();
 637            let concurrent_handlers = Arc::new(Semaphore::new(256));
 638            loop {
 639                let next_message = async {
 640                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 641                    let message = incoming_rx.next().await;
 642                    (permit, message)
 643                }.fuse();
 644                futures::pin_mut!(next_message);
 645                futures::select_biased! {
 646                    _ = teardown.changed().fuse() => return Ok(()),
 647                    result = handle_io => {
 648                        if let Err(error) = result {
 649                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 650                        }
 651                        break;
 652                    }
 653                    _ = foreground_message_handlers.next() => {}
 654                    next_message = next_message => {
 655                        let (permit, message) = next_message;
 656                        if let Some(message) = message {
 657                            let type_name = message.payload_type_name();
 658                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 659                            let span_enter = span.enter();
 660                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 661                                let is_background = message.is_background();
 662                                let handle_message = (handler)(message, session.clone());
 663                                drop(span_enter);
 664
 665                                let handle_message = async move {
 666                                    handle_message.await;
 667                                    drop(permit);
 668                                }.instrument(span);
 669                                if is_background {
 670                                    executor.spawn_detached(handle_message);
 671                                } else {
 672                                    foreground_message_handlers.push(handle_message);
 673                                }
 674                            } else {
 675                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 676                            }
 677                        } else {
 678                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 679                            break;
 680                        }
 681                    }
 682                }
 683            }
 684
 685            drop(foreground_message_handlers);
 686            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 687            if let Err(error) = connection_lost(session, teardown, executor).await {
 688                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 689            }
 690
 691            Ok(())
 692        }.instrument(span)
 693    }
 694
 695    pub async fn invite_code_redeemed(
 696        self: &Arc<Self>,
 697        inviter_id: UserId,
 698        invitee_id: UserId,
 699    ) -> Result<()> {
 700        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 701            if let Some(code) = &user.invite_code {
 702                let pool = self.connection_pool.lock();
 703                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 704                for connection_id in pool.user_connection_ids(inviter_id) {
 705                    self.peer.send(
 706                        connection_id,
 707                        proto::UpdateContacts {
 708                            contacts: vec![invitee_contact.clone()],
 709                            ..Default::default()
 710                        },
 711                    )?;
 712                    self.peer.send(
 713                        connection_id,
 714                        proto::UpdateInviteInfo {
 715                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 716                            count: user.invite_count as u32,
 717                        },
 718                    )?;
 719                }
 720            }
 721        }
 722        Ok(())
 723    }
 724
 725    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 726        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 727            if let Some(invite_code) = &user.invite_code {
 728                let pool = self.connection_pool.lock();
 729                for connection_id in pool.user_connection_ids(user_id) {
 730                    self.peer.send(
 731                        connection_id,
 732                        proto::UpdateInviteInfo {
 733                            url: format!(
 734                                "{}{}",
 735                                self.app_state.config.invite_link_prefix, invite_code
 736                            ),
 737                            count: user.invite_count as u32,
 738                        },
 739                    )?;
 740                }
 741            }
 742        }
 743        Ok(())
 744    }
 745
 746    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 747        ServerSnapshot {
 748            connection_pool: ConnectionPoolGuard {
 749                guard: self.connection_pool.lock(),
 750                _not_send: PhantomData,
 751            },
 752            peer: &self.peer,
 753        }
 754    }
 755}
 756
 757impl<'a> Deref for ConnectionPoolGuard<'a> {
 758    type Target = ConnectionPool;
 759
 760    fn deref(&self) -> &Self::Target {
 761        &*self.guard
 762    }
 763}
 764
 765impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 766    fn deref_mut(&mut self) -> &mut Self::Target {
 767        &mut *self.guard
 768    }
 769}
 770
 771impl<'a> Drop for ConnectionPoolGuard<'a> {
 772    fn drop(&mut self) {
 773        #[cfg(test)]
 774        self.check_invariants();
 775    }
 776}
 777
 778fn broadcast<F>(
 779    sender_id: Option<ConnectionId>,
 780    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 781    mut f: F,
 782) where
 783    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 784{
 785    for receiver_id in receiver_ids {
 786        if Some(receiver_id) != sender_id {
 787            if let Err(error) = f(receiver_id) {
 788                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 789            }
 790        }
 791    }
 792}
 793
 794lazy_static! {
 795    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 796}
 797
 798pub struct ProtocolVersion(u32);
 799
 800impl Header for ProtocolVersion {
 801    fn name() -> &'static HeaderName {
 802        &ZED_PROTOCOL_VERSION
 803    }
 804
 805    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 806    where
 807        Self: Sized,
 808        I: Iterator<Item = &'i axum::http::HeaderValue>,
 809    {
 810        let version = values
 811            .next()
 812            .ok_or_else(axum::headers::Error::invalid)?
 813            .to_str()
 814            .map_err(|_| axum::headers::Error::invalid())?
 815            .parse()
 816            .map_err(|_| axum::headers::Error::invalid())?;
 817        Ok(Self(version))
 818    }
 819
 820    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 821        values.extend([self.0.to_string().parse().unwrap()]);
 822    }
 823}
 824
 825pub fn routes(server: Arc<Server>) -> Router<Body> {
 826    Router::new()
 827        .route("/rpc", get(handle_websocket_request))
 828        .layer(
 829            ServiceBuilder::new()
 830                .layer(Extension(server.app_state.clone()))
 831                .layer(middleware::from_fn(auth::validate_header)),
 832        )
 833        .route("/metrics", get(handle_metrics))
 834        .layer(Extension(server))
 835}
 836
 837pub async fn handle_websocket_request(
 838    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 839    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 840    Extension(server): Extension<Arc<Server>>,
 841    Extension(user): Extension<User>,
 842    ws: WebSocketUpgrade,
 843) -> axum::response::Response {
 844    if protocol_version != rpc::PROTOCOL_VERSION {
 845        return (
 846            StatusCode::UPGRADE_REQUIRED,
 847            "client must be upgraded".to_string(),
 848        )
 849            .into_response();
 850    }
 851    let socket_address = socket_address.to_string();
 852    ws.on_upgrade(move |socket| {
 853        use util::ResultExt;
 854        let socket = socket
 855            .map_ok(to_tungstenite_message)
 856            .err_into()
 857            .with(|message| async move { Ok(to_axum_message(message)) });
 858        let connection = Connection::new(Box::pin(socket));
 859        async move {
 860            server
 861                .handle_connection(connection, socket_address, user, None, Executor::Production)
 862                .await
 863                .log_err();
 864        }
 865    })
 866}
 867
 868pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 869    let connections = server
 870        .connection_pool
 871        .lock()
 872        .connections()
 873        .filter(|connection| !connection.admin)
 874        .count();
 875
 876    METRIC_CONNECTIONS.set(connections as _);
 877
 878    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 879    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 880
 881    let encoder = prometheus::TextEncoder::new();
 882    let metric_families = prometheus::gather();
 883    let encoded_metrics = encoder
 884        .encode_to_string(&metric_families)
 885        .map_err(|err| anyhow!("{}", err))?;
 886    Ok(encoded_metrics)
 887}
 888
 889#[instrument(err, skip(executor))]
 890async fn connection_lost(
 891    session: Session,
 892    mut teardown: watch::Receiver<()>,
 893    executor: Executor,
 894) -> Result<()> {
 895    session.peer.disconnect(session.connection_id);
 896    session
 897        .connection_pool()
 898        .await
 899        .remove_connection(session.connection_id)?;
 900
 901    session
 902        .db()
 903        .await
 904        .connection_lost(session.connection_id)
 905        .await
 906        .trace_err();
 907
 908    futures::select_biased! {
 909        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 910            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 911            leave_room_for_session(&session).await.trace_err();
 912            leave_channel_buffers_for_session(&session)
 913                .await
 914                .trace_err();
 915
 916            if !session
 917                .connection_pool()
 918                .await
 919                .is_user_online(session.user_id)
 920            {
 921                let db = session.db().await;
 922                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 923                    room_updated(&room, &session.peer);
 924                }
 925            }
 926
 927            update_user_contacts(session.user_id, &session).await?;
 928        }
 929        _ = teardown.changed().fuse() => {}
 930    }
 931
 932    Ok(())
 933}
 934
 935async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 936    response.send(proto::Ack {})?;
 937    Ok(())
 938}
 939
 940async fn create_room(
 941    _request: proto::CreateRoom,
 942    response: Response<proto::CreateRoom>,
 943    session: Session,
 944) -> Result<()> {
 945    let live_kit_room = nanoid::nanoid!(30);
 946
 947    let live_kit_connection_info = {
 948        let live_kit_room = live_kit_room.clone();
 949        let live_kit = session.live_kit_client.as_ref();
 950
 951        util::async_maybe!({
 952            let live_kit = live_kit?;
 953
 954            let token = live_kit
 955                .room_token(&live_kit_room, &session.user_id.to_string())
 956                .trace_err()?;
 957
 958            Some(proto::LiveKitConnectionInfo {
 959                server_url: live_kit.url().into(),
 960                token,
 961                can_publish: true,
 962            })
 963        })
 964    }
 965    .await;
 966
 967    let room = session
 968        .db()
 969        .await
 970        .create_room(
 971            session.user_id,
 972            session.connection_id,
 973            &live_kit_room,
 974            &session.zed_environment,
 975        )
 976        .await?;
 977
 978    response.send(proto::CreateRoomResponse {
 979        room: Some(room.clone()),
 980        live_kit_connection_info,
 981    })?;
 982
 983    update_user_contacts(session.user_id, &session).await?;
 984    Ok(())
 985}
 986
 987async fn join_room(
 988    request: proto::JoinRoom,
 989    response: Response<proto::JoinRoom>,
 990    session: Session,
 991) -> Result<()> {
 992    let room_id = RoomId::from_proto(request.id);
 993
 994    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
 995
 996    if let Some(channel_id) = channel_id {
 997        return join_channel_internal(channel_id, Box::new(response), session).await;
 998    }
 999
1000    let joined_room = {
1001        let room = session
1002            .db()
1003            .await
1004            .join_room(
1005                room_id,
1006                session.user_id,
1007                session.connection_id,
1008                session.zed_environment.as_ref(),
1009            )
1010            .await?;
1011        room_updated(&room.room, &session.peer);
1012        room.into_inner()
1013    };
1014
1015    for connection_id in session
1016        .connection_pool()
1017        .await
1018        .user_connection_ids(session.user_id)
1019    {
1020        session
1021            .peer
1022            .send(
1023                connection_id,
1024                proto::CallCanceled {
1025                    room_id: room_id.to_proto(),
1026                },
1027            )
1028            .trace_err();
1029    }
1030
1031    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1032        if let Some(token) = live_kit
1033            .room_token(
1034                &joined_room.room.live_kit_room,
1035                &session.user_id.to_string(),
1036            )
1037            .trace_err()
1038        {
1039            Some(proto::LiveKitConnectionInfo {
1040                server_url: live_kit.url().into(),
1041                token,
1042                can_publish: true,
1043            })
1044        } else {
1045            None
1046        }
1047    } else {
1048        None
1049    };
1050
1051    response.send(proto::JoinRoomResponse {
1052        room: Some(joined_room.room),
1053        channel_id: None,
1054        live_kit_connection_info,
1055    })?;
1056
1057    update_user_contacts(session.user_id, &session).await?;
1058    Ok(())
1059}
1060
1061async fn rejoin_room(
1062    request: proto::RejoinRoom,
1063    response: Response<proto::RejoinRoom>,
1064    session: Session,
1065) -> Result<()> {
1066    let room;
1067    let channel_id;
1068    let channel_members;
1069    {
1070        let mut rejoined_room = session
1071            .db()
1072            .await
1073            .rejoin_room(request, session.user_id, session.connection_id)
1074            .await?;
1075
1076        response.send(proto::RejoinRoomResponse {
1077            room: Some(rejoined_room.room.clone()),
1078            reshared_projects: rejoined_room
1079                .reshared_projects
1080                .iter()
1081                .map(|project| proto::ResharedProject {
1082                    id: project.id.to_proto(),
1083                    collaborators: project
1084                        .collaborators
1085                        .iter()
1086                        .map(|collaborator| collaborator.to_proto())
1087                        .collect(),
1088                })
1089                .collect(),
1090            rejoined_projects: rejoined_room
1091                .rejoined_projects
1092                .iter()
1093                .map(|rejoined_project| proto::RejoinedProject {
1094                    id: rejoined_project.id.to_proto(),
1095                    worktrees: rejoined_project
1096                        .worktrees
1097                        .iter()
1098                        .map(|worktree| proto::WorktreeMetadata {
1099                            id: worktree.id,
1100                            root_name: worktree.root_name.clone(),
1101                            visible: worktree.visible,
1102                            abs_path: worktree.abs_path.clone(),
1103                        })
1104                        .collect(),
1105                    collaborators: rejoined_project
1106                        .collaborators
1107                        .iter()
1108                        .map(|collaborator| collaborator.to_proto())
1109                        .collect(),
1110                    language_servers: rejoined_project.language_servers.clone(),
1111                })
1112                .collect(),
1113        })?;
1114        room_updated(&rejoined_room.room, &session.peer);
1115
1116        for project in &rejoined_room.reshared_projects {
1117            for collaborator in &project.collaborators {
1118                session
1119                    .peer
1120                    .send(
1121                        collaborator.connection_id,
1122                        proto::UpdateProjectCollaborator {
1123                            project_id: project.id.to_proto(),
1124                            old_peer_id: Some(project.old_connection_id.into()),
1125                            new_peer_id: Some(session.connection_id.into()),
1126                        },
1127                    )
1128                    .trace_err();
1129            }
1130
1131            broadcast(
1132                Some(session.connection_id),
1133                project
1134                    .collaborators
1135                    .iter()
1136                    .map(|collaborator| collaborator.connection_id),
1137                |connection_id| {
1138                    session.peer.forward_send(
1139                        session.connection_id,
1140                        connection_id,
1141                        proto::UpdateProject {
1142                            project_id: project.id.to_proto(),
1143                            worktrees: project.worktrees.clone(),
1144                        },
1145                    )
1146                },
1147            );
1148        }
1149
1150        for project in &rejoined_room.rejoined_projects {
1151            for collaborator in &project.collaborators {
1152                session
1153                    .peer
1154                    .send(
1155                        collaborator.connection_id,
1156                        proto::UpdateProjectCollaborator {
1157                            project_id: project.id.to_proto(),
1158                            old_peer_id: Some(project.old_connection_id.into()),
1159                            new_peer_id: Some(session.connection_id.into()),
1160                        },
1161                    )
1162                    .trace_err();
1163            }
1164        }
1165
1166        for project in &mut rejoined_room.rejoined_projects {
1167            for worktree in mem::take(&mut project.worktrees) {
1168                #[cfg(any(test, feature = "test-support"))]
1169                const MAX_CHUNK_SIZE: usize = 2;
1170                #[cfg(not(any(test, feature = "test-support")))]
1171                const MAX_CHUNK_SIZE: usize = 256;
1172
1173                // Stream this worktree's entries.
1174                let message = proto::UpdateWorktree {
1175                    project_id: project.id.to_proto(),
1176                    worktree_id: worktree.id,
1177                    abs_path: worktree.abs_path.clone(),
1178                    root_name: worktree.root_name,
1179                    updated_entries: worktree.updated_entries,
1180                    removed_entries: worktree.removed_entries,
1181                    scan_id: worktree.scan_id,
1182                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1183                    updated_repositories: worktree.updated_repositories,
1184                    removed_repositories: worktree.removed_repositories,
1185                };
1186                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1187                    session.peer.send(session.connection_id, update.clone())?;
1188                }
1189
1190                // Stream this worktree's diagnostics.
1191                for summary in worktree.diagnostic_summaries {
1192                    session.peer.send(
1193                        session.connection_id,
1194                        proto::UpdateDiagnosticSummary {
1195                            project_id: project.id.to_proto(),
1196                            worktree_id: worktree.id,
1197                            summary: Some(summary),
1198                        },
1199                    )?;
1200                }
1201
1202                for settings_file in worktree.settings_files {
1203                    session.peer.send(
1204                        session.connection_id,
1205                        proto::UpdateWorktreeSettings {
1206                            project_id: project.id.to_proto(),
1207                            worktree_id: worktree.id,
1208                            path: settings_file.path,
1209                            content: Some(settings_file.content),
1210                        },
1211                    )?;
1212                }
1213            }
1214
1215            for language_server in &project.language_servers {
1216                session.peer.send(
1217                    session.connection_id,
1218                    proto::UpdateLanguageServer {
1219                        project_id: project.id.to_proto(),
1220                        language_server_id: language_server.id,
1221                        variant: Some(
1222                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1223                                proto::LspDiskBasedDiagnosticsUpdated {},
1224                            ),
1225                        ),
1226                    },
1227                )?;
1228            }
1229        }
1230
1231        let rejoined_room = rejoined_room.into_inner();
1232
1233        room = rejoined_room.room;
1234        channel_id = rejoined_room.channel_id;
1235        channel_members = rejoined_room.channel_members;
1236    }
1237
1238    if let Some(channel_id) = channel_id {
1239        channel_updated(
1240            channel_id,
1241            &room,
1242            &channel_members,
1243            &session.peer,
1244            &*session.connection_pool().await,
1245        );
1246    }
1247
1248    update_user_contacts(session.user_id, &session).await?;
1249    Ok(())
1250}
1251
1252async fn leave_room(
1253    _: proto::LeaveRoom,
1254    response: Response<proto::LeaveRoom>,
1255    session: Session,
1256) -> Result<()> {
1257    leave_room_for_session(&session).await?;
1258    response.send(proto::Ack {})?;
1259    Ok(())
1260}
1261
1262async fn set_room_participant_role(
1263    request: proto::SetRoomParticipantRole,
1264    response: Response<proto::SetRoomParticipantRole>,
1265    session: Session,
1266) -> Result<()> {
1267    let (live_kit_room, can_publish) = {
1268        let room = session
1269            .db()
1270            .await
1271            .set_room_participant_role(
1272                session.user_id,
1273                RoomId::from_proto(request.room_id),
1274                UserId::from_proto(request.user_id),
1275                ChannelRole::from(request.role()),
1276            )
1277            .await?;
1278
1279        let live_kit_room = room.live_kit_room.clone();
1280        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1281        room_updated(&room, &session.peer);
1282        (live_kit_room, can_publish)
1283    };
1284
1285    if let Some(live_kit) = session.live_kit_client.as_ref() {
1286        live_kit
1287            .update_participant(
1288                live_kit_room.clone(),
1289                request.user_id.to_string(),
1290                live_kit_server::proto::ParticipantPermission {
1291                    can_subscribe: true,
1292                    can_publish,
1293                    can_publish_data: can_publish,
1294                    hidden: false,
1295                    recorder: false,
1296                },
1297            )
1298            .await
1299            .trace_err();
1300    }
1301
1302    response.send(proto::Ack {})?;
1303    Ok(())
1304}
1305
1306async fn call(
1307    request: proto::Call,
1308    response: Response<proto::Call>,
1309    session: Session,
1310) -> Result<()> {
1311    let room_id = RoomId::from_proto(request.room_id);
1312    let calling_user_id = session.user_id;
1313    let calling_connection_id = session.connection_id;
1314    let called_user_id = UserId::from_proto(request.called_user_id);
1315    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1316    if !session
1317        .db()
1318        .await
1319        .has_contact(calling_user_id, called_user_id)
1320        .await?
1321    {
1322        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1323    }
1324
1325    let incoming_call = {
1326        let (room, incoming_call) = &mut *session
1327            .db()
1328            .await
1329            .call(
1330                room_id,
1331                calling_user_id,
1332                calling_connection_id,
1333                called_user_id,
1334                initial_project_id,
1335            )
1336            .await?;
1337        room_updated(&room, &session.peer);
1338        mem::take(incoming_call)
1339    };
1340    update_user_contacts(called_user_id, &session).await?;
1341
1342    let mut calls = session
1343        .connection_pool()
1344        .await
1345        .user_connection_ids(called_user_id)
1346        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1347        .collect::<FuturesUnordered<_>>();
1348
1349    while let Some(call_response) = calls.next().await {
1350        match call_response.as_ref() {
1351            Ok(_) => {
1352                response.send(proto::Ack {})?;
1353                return Ok(());
1354            }
1355            Err(_) => {
1356                call_response.trace_err();
1357            }
1358        }
1359    }
1360
1361    {
1362        let room = session
1363            .db()
1364            .await
1365            .call_failed(room_id, called_user_id)
1366            .await?;
1367        room_updated(&room, &session.peer);
1368    }
1369    update_user_contacts(called_user_id, &session).await?;
1370
1371    Err(anyhow!("failed to ring user"))?
1372}
1373
1374async fn cancel_call(
1375    request: proto::CancelCall,
1376    response: Response<proto::CancelCall>,
1377    session: Session,
1378) -> Result<()> {
1379    let called_user_id = UserId::from_proto(request.called_user_id);
1380    let room_id = RoomId::from_proto(request.room_id);
1381    {
1382        let room = session
1383            .db()
1384            .await
1385            .cancel_call(room_id, session.connection_id, called_user_id)
1386            .await?;
1387        room_updated(&room, &session.peer);
1388    }
1389
1390    for connection_id in session
1391        .connection_pool()
1392        .await
1393        .user_connection_ids(called_user_id)
1394    {
1395        session
1396            .peer
1397            .send(
1398                connection_id,
1399                proto::CallCanceled {
1400                    room_id: room_id.to_proto(),
1401                },
1402            )
1403            .trace_err();
1404    }
1405    response.send(proto::Ack {})?;
1406
1407    update_user_contacts(called_user_id, &session).await?;
1408    Ok(())
1409}
1410
1411async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1412    let room_id = RoomId::from_proto(message.room_id);
1413    {
1414        let room = session
1415            .db()
1416            .await
1417            .decline_call(Some(room_id), session.user_id)
1418            .await?
1419            .ok_or_else(|| anyhow!("failed to decline call"))?;
1420        room_updated(&room, &session.peer);
1421    }
1422
1423    for connection_id in session
1424        .connection_pool()
1425        .await
1426        .user_connection_ids(session.user_id)
1427    {
1428        session
1429            .peer
1430            .send(
1431                connection_id,
1432                proto::CallCanceled {
1433                    room_id: room_id.to_proto(),
1434                },
1435            )
1436            .trace_err();
1437    }
1438    update_user_contacts(session.user_id, &session).await?;
1439    Ok(())
1440}
1441
1442async fn update_participant_location(
1443    request: proto::UpdateParticipantLocation,
1444    response: Response<proto::UpdateParticipantLocation>,
1445    session: Session,
1446) -> Result<()> {
1447    let room_id = RoomId::from_proto(request.room_id);
1448    let location = request
1449        .location
1450        .ok_or_else(|| anyhow!("invalid location"))?;
1451
1452    let db = session.db().await;
1453    let room = db
1454        .update_room_participant_location(room_id, session.connection_id, location)
1455        .await?;
1456
1457    room_updated(&room, &session.peer);
1458    response.send(proto::Ack {})?;
1459    Ok(())
1460}
1461
1462async fn share_project(
1463    request: proto::ShareProject,
1464    response: Response<proto::ShareProject>,
1465    session: Session,
1466) -> Result<()> {
1467    let (project_id, room) = &*session
1468        .db()
1469        .await
1470        .share_project(
1471            RoomId::from_proto(request.room_id),
1472            session.connection_id,
1473            &request.worktrees,
1474        )
1475        .await?;
1476    response.send(proto::ShareProjectResponse {
1477        project_id: project_id.to_proto(),
1478    })?;
1479    room_updated(&room, &session.peer);
1480
1481    Ok(())
1482}
1483
1484async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1485    let project_id = ProjectId::from_proto(message.project_id);
1486
1487    let (room, guest_connection_ids) = &*session
1488        .db()
1489        .await
1490        .unshare_project(project_id, session.connection_id)
1491        .await?;
1492
1493    broadcast(
1494        Some(session.connection_id),
1495        guest_connection_ids.iter().copied(),
1496        |conn_id| session.peer.send(conn_id, message.clone()),
1497    );
1498    room_updated(&room, &session.peer);
1499
1500    Ok(())
1501}
1502
1503async fn join_project(
1504    request: proto::JoinProject,
1505    response: Response<proto::JoinProject>,
1506    session: Session,
1507) -> Result<()> {
1508    let project_id = ProjectId::from_proto(request.project_id);
1509    let guest_user_id = session.user_id;
1510
1511    tracing::info!(%project_id, "join project");
1512
1513    let (project, replica_id) = &mut *session
1514        .db()
1515        .await
1516        .join_project(project_id, session.connection_id)
1517        .await?;
1518
1519    let collaborators = project
1520        .collaborators
1521        .iter()
1522        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1523        .map(|collaborator| collaborator.to_proto())
1524        .collect::<Vec<_>>();
1525
1526    let worktrees = project
1527        .worktrees
1528        .iter()
1529        .map(|(id, worktree)| proto::WorktreeMetadata {
1530            id: *id,
1531            root_name: worktree.root_name.clone(),
1532            visible: worktree.visible,
1533            abs_path: worktree.abs_path.clone(),
1534        })
1535        .collect::<Vec<_>>();
1536
1537    for collaborator in &collaborators {
1538        session
1539            .peer
1540            .send(
1541                collaborator.peer_id.unwrap().into(),
1542                proto::AddProjectCollaborator {
1543                    project_id: project_id.to_proto(),
1544                    collaborator: Some(proto::Collaborator {
1545                        peer_id: Some(session.connection_id.into()),
1546                        replica_id: replica_id.0 as u32,
1547                        user_id: guest_user_id.to_proto(),
1548                    }),
1549                },
1550            )
1551            .trace_err();
1552    }
1553
1554    // First, we send the metadata associated with each worktree.
1555    response.send(proto::JoinProjectResponse {
1556        worktrees: worktrees.clone(),
1557        replica_id: replica_id.0 as u32,
1558        collaborators: collaborators.clone(),
1559        language_servers: project.language_servers.clone(),
1560    })?;
1561
1562    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1563        #[cfg(any(test, feature = "test-support"))]
1564        const MAX_CHUNK_SIZE: usize = 2;
1565        #[cfg(not(any(test, feature = "test-support")))]
1566        const MAX_CHUNK_SIZE: usize = 256;
1567
1568        // Stream this worktree's entries.
1569        let message = proto::UpdateWorktree {
1570            project_id: project_id.to_proto(),
1571            worktree_id,
1572            abs_path: worktree.abs_path.clone(),
1573            root_name: worktree.root_name,
1574            updated_entries: worktree.entries,
1575            removed_entries: Default::default(),
1576            scan_id: worktree.scan_id,
1577            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1578            updated_repositories: worktree.repository_entries.into_values().collect(),
1579            removed_repositories: Default::default(),
1580        };
1581        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1582            session.peer.send(session.connection_id, update.clone())?;
1583        }
1584
1585        // Stream this worktree's diagnostics.
1586        for summary in worktree.diagnostic_summaries {
1587            session.peer.send(
1588                session.connection_id,
1589                proto::UpdateDiagnosticSummary {
1590                    project_id: project_id.to_proto(),
1591                    worktree_id: worktree.id,
1592                    summary: Some(summary),
1593                },
1594            )?;
1595        }
1596
1597        for settings_file in worktree.settings_files {
1598            session.peer.send(
1599                session.connection_id,
1600                proto::UpdateWorktreeSettings {
1601                    project_id: project_id.to_proto(),
1602                    worktree_id: worktree.id,
1603                    path: settings_file.path,
1604                    content: Some(settings_file.content),
1605                },
1606            )?;
1607        }
1608    }
1609
1610    for language_server in &project.language_servers {
1611        session.peer.send(
1612            session.connection_id,
1613            proto::UpdateLanguageServer {
1614                project_id: project_id.to_proto(),
1615                language_server_id: language_server.id,
1616                variant: Some(
1617                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1618                        proto::LspDiskBasedDiagnosticsUpdated {},
1619                    ),
1620                ),
1621            },
1622        )?;
1623    }
1624
1625    Ok(())
1626}
1627
1628async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1629    let sender_id = session.connection_id;
1630    let project_id = ProjectId::from_proto(request.project_id);
1631
1632    let (room, project) = &*session
1633        .db()
1634        .await
1635        .leave_project(project_id, sender_id)
1636        .await?;
1637    tracing::info!(
1638        %project_id,
1639        host_user_id = %project.host_user_id,
1640        host_connection_id = %project.host_connection_id,
1641        "leave project"
1642    );
1643
1644    project_left(&project, &session);
1645    room_updated(&room, &session.peer);
1646
1647    Ok(())
1648}
1649
1650async fn update_project(
1651    request: proto::UpdateProject,
1652    response: Response<proto::UpdateProject>,
1653    session: Session,
1654) -> Result<()> {
1655    let project_id = ProjectId::from_proto(request.project_id);
1656    let (room, guest_connection_ids) = &*session
1657        .db()
1658        .await
1659        .update_project(project_id, session.connection_id, &request.worktrees)
1660        .await?;
1661    broadcast(
1662        Some(session.connection_id),
1663        guest_connection_ids.iter().copied(),
1664        |connection_id| {
1665            session
1666                .peer
1667                .forward_send(session.connection_id, connection_id, request.clone())
1668        },
1669    );
1670    room_updated(&room, &session.peer);
1671    response.send(proto::Ack {})?;
1672
1673    Ok(())
1674}
1675
1676async fn update_worktree(
1677    request: proto::UpdateWorktree,
1678    response: Response<proto::UpdateWorktree>,
1679    session: Session,
1680) -> Result<()> {
1681    let guest_connection_ids = session
1682        .db()
1683        .await
1684        .update_worktree(&request, session.connection_id)
1685        .await?;
1686
1687    broadcast(
1688        Some(session.connection_id),
1689        guest_connection_ids.iter().copied(),
1690        |connection_id| {
1691            session
1692                .peer
1693                .forward_send(session.connection_id, connection_id, request.clone())
1694        },
1695    );
1696    response.send(proto::Ack {})?;
1697    Ok(())
1698}
1699
1700async fn update_diagnostic_summary(
1701    message: proto::UpdateDiagnosticSummary,
1702    session: Session,
1703) -> Result<()> {
1704    let guest_connection_ids = session
1705        .db()
1706        .await
1707        .update_diagnostic_summary(&message, session.connection_id)
1708        .await?;
1709
1710    broadcast(
1711        Some(session.connection_id),
1712        guest_connection_ids.iter().copied(),
1713        |connection_id| {
1714            session
1715                .peer
1716                .forward_send(session.connection_id, connection_id, message.clone())
1717        },
1718    );
1719
1720    Ok(())
1721}
1722
1723async fn update_worktree_settings(
1724    message: proto::UpdateWorktreeSettings,
1725    session: Session,
1726) -> Result<()> {
1727    let guest_connection_ids = session
1728        .db()
1729        .await
1730        .update_worktree_settings(&message, session.connection_id)
1731        .await?;
1732
1733    broadcast(
1734        Some(session.connection_id),
1735        guest_connection_ids.iter().copied(),
1736        |connection_id| {
1737            session
1738                .peer
1739                .forward_send(session.connection_id, connection_id, message.clone())
1740        },
1741    );
1742
1743    Ok(())
1744}
1745
1746async fn start_language_server(
1747    request: proto::StartLanguageServer,
1748    session: Session,
1749) -> Result<()> {
1750    let guest_connection_ids = session
1751        .db()
1752        .await
1753        .start_language_server(&request, session.connection_id)
1754        .await?;
1755
1756    broadcast(
1757        Some(session.connection_id),
1758        guest_connection_ids.iter().copied(),
1759        |connection_id| {
1760            session
1761                .peer
1762                .forward_send(session.connection_id, connection_id, request.clone())
1763        },
1764    );
1765    Ok(())
1766}
1767
1768async fn update_language_server(
1769    request: proto::UpdateLanguageServer,
1770    session: Session,
1771) -> Result<()> {
1772    let project_id = ProjectId::from_proto(request.project_id);
1773    let project_connection_ids = session
1774        .db()
1775        .await
1776        .project_connection_ids(project_id, session.connection_id)
1777        .await?;
1778    broadcast(
1779        Some(session.connection_id),
1780        project_connection_ids.iter().copied(),
1781        |connection_id| {
1782            session
1783                .peer
1784                .forward_send(session.connection_id, connection_id, request.clone())
1785        },
1786    );
1787    Ok(())
1788}
1789
1790async fn forward_read_only_project_request<T>(
1791    request: T,
1792    response: Response<T>,
1793    session: Session,
1794) -> Result<()>
1795where
1796    T: EntityMessage + RequestMessage,
1797{
1798    let project_id = ProjectId::from_proto(request.remote_entity_id());
1799    let host_connection_id = session
1800        .db()
1801        .await
1802        .host_for_read_only_project_request(project_id, session.connection_id)
1803        .await?;
1804    let payload = session
1805        .peer
1806        .forward_request(session.connection_id, host_connection_id, request)
1807        .await?;
1808    response.send(payload)?;
1809    Ok(())
1810}
1811
1812async fn forward_mutating_project_request<T>(
1813    request: T,
1814    response: Response<T>,
1815    session: Session,
1816) -> Result<()>
1817where
1818    T: EntityMessage + RequestMessage,
1819{
1820    let project_id = ProjectId::from_proto(request.remote_entity_id());
1821    let host_connection_id = session
1822        .db()
1823        .await
1824        .host_for_mutating_project_request(project_id, session.connection_id)
1825        .await?;
1826    let payload = session
1827        .peer
1828        .forward_request(session.connection_id, host_connection_id, request)
1829        .await?;
1830    response.send(payload)?;
1831    Ok(())
1832}
1833
1834async fn create_buffer_for_peer(
1835    request: proto::CreateBufferForPeer,
1836    session: Session,
1837) -> Result<()> {
1838    session
1839        .db()
1840        .await
1841        .check_user_is_project_host(
1842            ProjectId::from_proto(request.project_id),
1843            session.connection_id,
1844        )
1845        .await?;
1846    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1847    session
1848        .peer
1849        .forward_send(session.connection_id, peer_id.into(), request)?;
1850    Ok(())
1851}
1852
1853async fn update_buffer(
1854    request: proto::UpdateBuffer,
1855    response: Response<proto::UpdateBuffer>,
1856    session: Session,
1857) -> Result<()> {
1858    let project_id = ProjectId::from_proto(request.project_id);
1859    let mut guest_connection_ids;
1860    let mut host_connection_id = None;
1861
1862    let mut requires_write_permission = false;
1863
1864    for op in request.operations.iter() {
1865        match op.variant {
1866            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1867            Some(_) => requires_write_permission = true,
1868        }
1869    }
1870
1871    {
1872        let collaborators = session
1873            .db()
1874            .await
1875            .project_collaborators_for_buffer_update(
1876                project_id,
1877                session.connection_id,
1878                requires_write_permission,
1879            )
1880            .await?;
1881        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1882        for collaborator in collaborators.iter() {
1883            if collaborator.is_host {
1884                host_connection_id = Some(collaborator.connection_id);
1885            } else {
1886                guest_connection_ids.push(collaborator.connection_id);
1887            }
1888        }
1889    }
1890    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1891
1892    broadcast(
1893        Some(session.connection_id),
1894        guest_connection_ids,
1895        |connection_id| {
1896            session
1897                .peer
1898                .forward_send(session.connection_id, connection_id, request.clone())
1899        },
1900    );
1901    if host_connection_id != session.connection_id {
1902        session
1903            .peer
1904            .forward_request(session.connection_id, host_connection_id, request.clone())
1905            .await?;
1906    }
1907
1908    response.send(proto::Ack {})?;
1909    Ok(())
1910}
1911
1912async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1913    request: T,
1914    session: Session,
1915) -> Result<()> {
1916    let project_id = ProjectId::from_proto(request.remote_entity_id());
1917    let project_connection_ids = session
1918        .db()
1919        .await
1920        .project_connection_ids(project_id, session.connection_id)
1921        .await?;
1922
1923    broadcast(
1924        Some(session.connection_id),
1925        project_connection_ids.iter().copied(),
1926        |connection_id| {
1927            session
1928                .peer
1929                .forward_send(session.connection_id, connection_id, request.clone())
1930        },
1931    );
1932    Ok(())
1933}
1934
1935async fn follow(
1936    request: proto::Follow,
1937    response: Response<proto::Follow>,
1938    session: Session,
1939) -> Result<()> {
1940    let room_id = RoomId::from_proto(request.room_id);
1941    let project_id = request.project_id.map(ProjectId::from_proto);
1942    let leader_id = request
1943        .leader_id
1944        .ok_or_else(|| anyhow!("invalid leader id"))?
1945        .into();
1946    let follower_id = session.connection_id;
1947
1948    session
1949        .db()
1950        .await
1951        .check_room_participants(room_id, leader_id, session.connection_id)
1952        .await?;
1953
1954    let response_payload = session
1955        .peer
1956        .forward_request(session.connection_id, leader_id, request)
1957        .await?;
1958    response.send(response_payload)?;
1959
1960    if let Some(project_id) = project_id {
1961        let room = session
1962            .db()
1963            .await
1964            .follow(room_id, project_id, leader_id, follower_id)
1965            .await?;
1966        room_updated(&room, &session.peer);
1967    }
1968
1969    Ok(())
1970}
1971
1972async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1973    let room_id = RoomId::from_proto(request.room_id);
1974    let project_id = request.project_id.map(ProjectId::from_proto);
1975    let leader_id = request
1976        .leader_id
1977        .ok_or_else(|| anyhow!("invalid leader id"))?
1978        .into();
1979    let follower_id = session.connection_id;
1980
1981    session
1982        .db()
1983        .await
1984        .check_room_participants(room_id, leader_id, session.connection_id)
1985        .await?;
1986
1987    session
1988        .peer
1989        .forward_send(session.connection_id, leader_id, request)?;
1990
1991    if let Some(project_id) = project_id {
1992        let room = session
1993            .db()
1994            .await
1995            .unfollow(room_id, project_id, leader_id, follower_id)
1996            .await?;
1997        room_updated(&room, &session.peer);
1998    }
1999
2000    Ok(())
2001}
2002
2003async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2004    let room_id = RoomId::from_proto(request.room_id);
2005    let database = session.db.lock().await;
2006
2007    let connection_ids = if let Some(project_id) = request.project_id {
2008        let project_id = ProjectId::from_proto(project_id);
2009        database
2010            .project_connection_ids(project_id, session.connection_id)
2011            .await?
2012    } else {
2013        database
2014            .room_connection_ids(room_id, session.connection_id)
2015            .await?
2016    };
2017
2018    // For now, don't send view update messages back to that view's current leader.
2019    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2020        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2021        _ => None,
2022    });
2023
2024    for follower_peer_id in request.follower_ids.iter().copied() {
2025        let follower_connection_id = follower_peer_id.into();
2026        if Some(follower_peer_id) != connection_id_to_omit
2027            && connection_ids.contains(&follower_connection_id)
2028        {
2029            session.peer.forward_send(
2030                session.connection_id,
2031                follower_connection_id,
2032                request.clone(),
2033            )?;
2034        }
2035    }
2036    Ok(())
2037}
2038
2039async fn get_users(
2040    request: proto::GetUsers,
2041    response: Response<proto::GetUsers>,
2042    session: Session,
2043) -> Result<()> {
2044    let user_ids = request
2045        .user_ids
2046        .into_iter()
2047        .map(UserId::from_proto)
2048        .collect();
2049    let users = session
2050        .db()
2051        .await
2052        .get_users_by_ids(user_ids)
2053        .await?
2054        .into_iter()
2055        .map(|user| proto::User {
2056            id: user.id.to_proto(),
2057            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2058            github_login: user.github_login,
2059        })
2060        .collect();
2061    response.send(proto::UsersResponse { users })?;
2062    Ok(())
2063}
2064
2065async fn fuzzy_search_users(
2066    request: proto::FuzzySearchUsers,
2067    response: Response<proto::FuzzySearchUsers>,
2068    session: Session,
2069) -> Result<()> {
2070    let query = request.query;
2071    let users = match query.len() {
2072        0 => vec![],
2073        1 | 2 => session
2074            .db()
2075            .await
2076            .get_user_by_github_login(&query)
2077            .await?
2078            .into_iter()
2079            .collect(),
2080        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2081    };
2082    let users = users
2083        .into_iter()
2084        .filter(|user| user.id != session.user_id)
2085        .map(|user| proto::User {
2086            id: user.id.to_proto(),
2087            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2088            github_login: user.github_login,
2089        })
2090        .collect();
2091    response.send(proto::UsersResponse { users })?;
2092    Ok(())
2093}
2094
2095async fn request_contact(
2096    request: proto::RequestContact,
2097    response: Response<proto::RequestContact>,
2098    session: Session,
2099) -> Result<()> {
2100    let requester_id = session.user_id;
2101    let responder_id = UserId::from_proto(request.responder_id);
2102    if requester_id == responder_id {
2103        return Err(anyhow!("cannot add yourself as a contact"))?;
2104    }
2105
2106    let notifications = session
2107        .db()
2108        .await
2109        .send_contact_request(requester_id, responder_id)
2110        .await?;
2111
2112    // Update outgoing contact requests of requester
2113    let mut update = proto::UpdateContacts::default();
2114    update.outgoing_requests.push(responder_id.to_proto());
2115    for connection_id in session
2116        .connection_pool()
2117        .await
2118        .user_connection_ids(requester_id)
2119    {
2120        session.peer.send(connection_id, update.clone())?;
2121    }
2122
2123    // Update incoming contact requests of responder
2124    let mut update = proto::UpdateContacts::default();
2125    update
2126        .incoming_requests
2127        .push(proto::IncomingContactRequest {
2128            requester_id: requester_id.to_proto(),
2129        });
2130    let connection_pool = session.connection_pool().await;
2131    for connection_id in connection_pool.user_connection_ids(responder_id) {
2132        session.peer.send(connection_id, update.clone())?;
2133    }
2134
2135    send_notifications(&*connection_pool, &session.peer, notifications);
2136
2137    response.send(proto::Ack {})?;
2138    Ok(())
2139}
2140
2141async fn respond_to_contact_request(
2142    request: proto::RespondToContactRequest,
2143    response: Response<proto::RespondToContactRequest>,
2144    session: Session,
2145) -> Result<()> {
2146    let responder_id = session.user_id;
2147    let requester_id = UserId::from_proto(request.requester_id);
2148    let db = session.db().await;
2149    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2150        db.dismiss_contact_notification(responder_id, requester_id)
2151            .await?;
2152    } else {
2153        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2154
2155        let notifications = db
2156            .respond_to_contact_request(responder_id, requester_id, accept)
2157            .await?;
2158        let requester_busy = db.is_user_busy(requester_id).await?;
2159        let responder_busy = db.is_user_busy(responder_id).await?;
2160
2161        let pool = session.connection_pool().await;
2162        // Update responder with new contact
2163        let mut update = proto::UpdateContacts::default();
2164        if accept {
2165            update
2166                .contacts
2167                .push(contact_for_user(requester_id, requester_busy, &pool));
2168        }
2169        update
2170            .remove_incoming_requests
2171            .push(requester_id.to_proto());
2172        for connection_id in pool.user_connection_ids(responder_id) {
2173            session.peer.send(connection_id, update.clone())?;
2174        }
2175
2176        // Update requester with new contact
2177        let mut update = proto::UpdateContacts::default();
2178        if accept {
2179            update
2180                .contacts
2181                .push(contact_for_user(responder_id, responder_busy, &pool));
2182        }
2183        update
2184            .remove_outgoing_requests
2185            .push(responder_id.to_proto());
2186
2187        for connection_id in pool.user_connection_ids(requester_id) {
2188            session.peer.send(connection_id, update.clone())?;
2189        }
2190
2191        send_notifications(&*pool, &session.peer, notifications);
2192    }
2193
2194    response.send(proto::Ack {})?;
2195    Ok(())
2196}
2197
2198async fn remove_contact(
2199    request: proto::RemoveContact,
2200    response: Response<proto::RemoveContact>,
2201    session: Session,
2202) -> Result<()> {
2203    let requester_id = session.user_id;
2204    let responder_id = UserId::from_proto(request.user_id);
2205    let db = session.db().await;
2206    let (contact_accepted, deleted_notification_id) =
2207        db.remove_contact(requester_id, responder_id).await?;
2208
2209    let pool = session.connection_pool().await;
2210    // Update outgoing contact requests of requester
2211    let mut update = proto::UpdateContacts::default();
2212    if contact_accepted {
2213        update.remove_contacts.push(responder_id.to_proto());
2214    } else {
2215        update
2216            .remove_outgoing_requests
2217            .push(responder_id.to_proto());
2218    }
2219    for connection_id in pool.user_connection_ids(requester_id) {
2220        session.peer.send(connection_id, update.clone())?;
2221    }
2222
2223    // Update incoming contact requests of responder
2224    let mut update = proto::UpdateContacts::default();
2225    if contact_accepted {
2226        update.remove_contacts.push(requester_id.to_proto());
2227    } else {
2228        update
2229            .remove_incoming_requests
2230            .push(requester_id.to_proto());
2231    }
2232    for connection_id in pool.user_connection_ids(responder_id) {
2233        session.peer.send(connection_id, update.clone())?;
2234        if let Some(notification_id) = deleted_notification_id {
2235            session.peer.send(
2236                connection_id,
2237                proto::DeleteNotification {
2238                    notification_id: notification_id.to_proto(),
2239                },
2240            )?;
2241        }
2242    }
2243
2244    response.send(proto::Ack {})?;
2245    Ok(())
2246}
2247
2248async fn create_channel(
2249    request: proto::CreateChannel,
2250    response: Response<proto::CreateChannel>,
2251    session: Session,
2252) -> Result<()> {
2253    let db = session.db().await;
2254
2255    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2256    let CreateChannelResult {
2257        channel,
2258        participants_to_update,
2259    } = db
2260        .create_channel(&request.name, parent_id, session.user_id)
2261        .await?;
2262
2263    response.send(proto::CreateChannelResponse {
2264        channel: Some(channel.to_proto()),
2265        parent_id: request.parent_id,
2266    })?;
2267
2268    let connection_pool = session.connection_pool().await;
2269    for (user_id, channels) in participants_to_update {
2270        let update = build_channels_update(channels, vec![]);
2271        for connection_id in connection_pool.user_connection_ids(user_id) {
2272            if user_id == session.user_id {
2273                continue;
2274            }
2275            session.peer.send(connection_id, update.clone())?;
2276        }
2277    }
2278
2279    Ok(())
2280}
2281
2282async fn delete_channel(
2283    request: proto::DeleteChannel,
2284    response: Response<proto::DeleteChannel>,
2285    session: Session,
2286) -> Result<()> {
2287    let db = session.db().await;
2288
2289    let channel_id = request.channel_id;
2290    let (removed_channels, member_ids) = db
2291        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2292        .await?;
2293    response.send(proto::Ack {})?;
2294
2295    // Notify members of removed channels
2296    let mut update = proto::UpdateChannels::default();
2297    update
2298        .delete_channels
2299        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2300
2301    let connection_pool = session.connection_pool().await;
2302    for member_id in member_ids {
2303        for connection_id in connection_pool.user_connection_ids(member_id) {
2304            session.peer.send(connection_id, update.clone())?;
2305        }
2306    }
2307
2308    Ok(())
2309}
2310
2311async fn invite_channel_member(
2312    request: proto::InviteChannelMember,
2313    response: Response<proto::InviteChannelMember>,
2314    session: Session,
2315) -> Result<()> {
2316    let db = session.db().await;
2317    let channel_id = ChannelId::from_proto(request.channel_id);
2318    let invitee_id = UserId::from_proto(request.user_id);
2319    let InviteMemberResult {
2320        channel,
2321        notifications,
2322    } = db
2323        .invite_channel_member(
2324            channel_id,
2325            invitee_id,
2326            session.user_id,
2327            request.role().into(),
2328        )
2329        .await?;
2330
2331    let update = proto::UpdateChannels {
2332        channel_invitations: vec![channel.to_proto()],
2333        ..Default::default()
2334    };
2335
2336    let connection_pool = session.connection_pool().await;
2337    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2338        session.peer.send(connection_id, update.clone())?;
2339    }
2340
2341    send_notifications(&*connection_pool, &session.peer, notifications);
2342
2343    response.send(proto::Ack {})?;
2344    Ok(())
2345}
2346
2347async fn remove_channel_member(
2348    request: proto::RemoveChannelMember,
2349    response: Response<proto::RemoveChannelMember>,
2350    session: Session,
2351) -> Result<()> {
2352    let db = session.db().await;
2353    let channel_id = ChannelId::from_proto(request.channel_id);
2354    let member_id = UserId::from_proto(request.user_id);
2355
2356    let RemoveChannelMemberResult {
2357        membership_update,
2358        notification_id,
2359    } = db
2360        .remove_channel_member(channel_id, member_id, session.user_id)
2361        .await?;
2362
2363    let connection_pool = &session.connection_pool().await;
2364    notify_membership_updated(
2365        &connection_pool,
2366        membership_update,
2367        member_id,
2368        &session.peer,
2369    );
2370    for connection_id in connection_pool.user_connection_ids(member_id) {
2371        if let Some(notification_id) = notification_id {
2372            session
2373                .peer
2374                .send(
2375                    connection_id,
2376                    proto::DeleteNotification {
2377                        notification_id: notification_id.to_proto(),
2378                    },
2379                )
2380                .trace_err();
2381        }
2382    }
2383
2384    response.send(proto::Ack {})?;
2385    Ok(())
2386}
2387
2388async fn set_channel_visibility(
2389    request: proto::SetChannelVisibility,
2390    response: Response<proto::SetChannelVisibility>,
2391    session: Session,
2392) -> Result<()> {
2393    let db = session.db().await;
2394    let channel_id = ChannelId::from_proto(request.channel_id);
2395    let visibility = request.visibility().into();
2396
2397    let SetChannelVisibilityResult {
2398        participants_to_update,
2399        participants_to_remove,
2400        channels_to_remove,
2401    } = db
2402        .set_channel_visibility(channel_id, visibility, session.user_id)
2403        .await?;
2404
2405    let connection_pool = session.connection_pool().await;
2406    for (user_id, channels) in participants_to_update {
2407        let update = build_channels_update(channels, vec![]);
2408        for connection_id in connection_pool.user_connection_ids(user_id) {
2409            session.peer.send(connection_id, update.clone())?;
2410        }
2411    }
2412    for user_id in participants_to_remove {
2413        let update = proto::UpdateChannels {
2414            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2415            ..Default::default()
2416        };
2417        for connection_id in connection_pool.user_connection_ids(user_id) {
2418            session.peer.send(connection_id, update.clone())?;
2419        }
2420    }
2421
2422    response.send(proto::Ack {})?;
2423    Ok(())
2424}
2425
2426async fn set_channel_member_role(
2427    request: proto::SetChannelMemberRole,
2428    response: Response<proto::SetChannelMemberRole>,
2429    session: Session,
2430) -> Result<()> {
2431    let db = session.db().await;
2432    let channel_id = ChannelId::from_proto(request.channel_id);
2433    let member_id = UserId::from_proto(request.user_id);
2434    let result = db
2435        .set_channel_member_role(
2436            channel_id,
2437            session.user_id,
2438            member_id,
2439            request.role().into(),
2440        )
2441        .await?;
2442
2443    match result {
2444        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2445            let connection_pool = session.connection_pool().await;
2446            notify_membership_updated(
2447                &connection_pool,
2448                membership_update,
2449                member_id,
2450                &session.peer,
2451            )
2452        }
2453        db::SetMemberRoleResult::InviteUpdated(channel) => {
2454            let update = proto::UpdateChannels {
2455                channel_invitations: vec![channel.to_proto()],
2456                ..Default::default()
2457            };
2458
2459            for connection_id in session
2460                .connection_pool()
2461                .await
2462                .user_connection_ids(member_id)
2463            {
2464                session.peer.send(connection_id, update.clone())?;
2465            }
2466        }
2467    }
2468
2469    response.send(proto::Ack {})?;
2470    Ok(())
2471}
2472
2473async fn rename_channel(
2474    request: proto::RenameChannel,
2475    response: Response<proto::RenameChannel>,
2476    session: Session,
2477) -> Result<()> {
2478    let db = session.db().await;
2479    let channel_id = ChannelId::from_proto(request.channel_id);
2480    let RenameChannelResult {
2481        channel,
2482        participants_to_update,
2483    } = db
2484        .rename_channel(channel_id, session.user_id, &request.name)
2485        .await?;
2486
2487    response.send(proto::RenameChannelResponse {
2488        channel: Some(channel.to_proto()),
2489    })?;
2490
2491    let connection_pool = session.connection_pool().await;
2492    for (user_id, channel) in participants_to_update {
2493        for connection_id in connection_pool.user_connection_ids(user_id) {
2494            let update = proto::UpdateChannels {
2495                channels: vec![channel.to_proto()],
2496                ..Default::default()
2497            };
2498
2499            session.peer.send(connection_id, update.clone())?;
2500        }
2501    }
2502
2503    Ok(())
2504}
2505
2506async fn move_channel(
2507    request: proto::MoveChannel,
2508    response: Response<proto::MoveChannel>,
2509    session: Session,
2510) -> Result<()> {
2511    let channel_id = ChannelId::from_proto(request.channel_id);
2512    let to = request.to.map(ChannelId::from_proto);
2513
2514    let result = session
2515        .db()
2516        .await
2517        .move_channel(channel_id, to, session.user_id)
2518        .await?;
2519
2520    notify_channel_moved(result, session).await?;
2521
2522    response.send(Ack {})?;
2523    Ok(())
2524}
2525
2526async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2527    let Some(MoveChannelResult {
2528        participants_to_remove,
2529        participants_to_update,
2530        moved_channels,
2531    }) = result
2532    else {
2533        return Ok(());
2534    };
2535    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2536
2537    let connection_pool = session.connection_pool().await;
2538    for (user_id, channels) in participants_to_update {
2539        let mut update = build_channels_update(channels, vec![]);
2540        update.delete_channels = moved_channels.clone();
2541        for connection_id in connection_pool.user_connection_ids(user_id) {
2542            session.peer.send(connection_id, update.clone())?;
2543        }
2544    }
2545
2546    for user_id in participants_to_remove {
2547        let update = proto::UpdateChannels {
2548            delete_channels: moved_channels.clone(),
2549            ..Default::default()
2550        };
2551        for connection_id in connection_pool.user_connection_ids(user_id) {
2552            session.peer.send(connection_id, update.clone())?;
2553        }
2554    }
2555    Ok(())
2556}
2557
2558async fn get_channel_members(
2559    request: proto::GetChannelMembers,
2560    response: Response<proto::GetChannelMembers>,
2561    session: Session,
2562) -> Result<()> {
2563    let db = session.db().await;
2564    let channel_id = ChannelId::from_proto(request.channel_id);
2565    let members = db
2566        .get_channel_participant_details(channel_id, session.user_id)
2567        .await?;
2568    response.send(proto::GetChannelMembersResponse { members })?;
2569    Ok(())
2570}
2571
2572async fn respond_to_channel_invite(
2573    request: proto::RespondToChannelInvite,
2574    response: Response<proto::RespondToChannelInvite>,
2575    session: Session,
2576) -> Result<()> {
2577    let db = session.db().await;
2578    let channel_id = ChannelId::from_proto(request.channel_id);
2579    let RespondToChannelInvite {
2580        membership_update,
2581        notifications,
2582    } = db
2583        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2584        .await?;
2585
2586    let connection_pool = session.connection_pool().await;
2587    if let Some(membership_update) = membership_update {
2588        notify_membership_updated(
2589            &connection_pool,
2590            membership_update,
2591            session.user_id,
2592            &session.peer,
2593        );
2594    } else {
2595        let update = proto::UpdateChannels {
2596            remove_channel_invitations: vec![channel_id.to_proto()],
2597            ..Default::default()
2598        };
2599
2600        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2601            session.peer.send(connection_id, update.clone())?;
2602        }
2603    };
2604
2605    send_notifications(&*connection_pool, &session.peer, notifications);
2606
2607    response.send(proto::Ack {})?;
2608
2609    Ok(())
2610}
2611
2612async fn join_channel(
2613    request: proto::JoinChannel,
2614    response: Response<proto::JoinChannel>,
2615    session: Session,
2616) -> Result<()> {
2617    let channel_id = ChannelId::from_proto(request.channel_id);
2618    join_channel_internal(channel_id, Box::new(response), session).await
2619}
2620
2621trait JoinChannelInternalResponse {
2622    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2623}
2624impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2625    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2626        Response::<proto::JoinChannel>::send(self, result)
2627    }
2628}
2629impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2630    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2631        Response::<proto::JoinRoom>::send(self, result)
2632    }
2633}
2634
2635async fn join_channel_internal(
2636    channel_id: ChannelId,
2637    response: Box<impl JoinChannelInternalResponse>,
2638    session: Session,
2639) -> Result<()> {
2640    let joined_room = {
2641        leave_room_for_session(&session).await?;
2642        let db = session.db().await;
2643
2644        let (joined_room, membership_updated, role) = db
2645            .join_channel(
2646                channel_id,
2647                session.user_id,
2648                session.connection_id,
2649                session.zed_environment.as_ref(),
2650            )
2651            .await?;
2652
2653        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2654            let (can_publish, token) = if role == ChannelRole::Guest {
2655                (
2656                    false,
2657                    live_kit
2658                        .guest_token(
2659                            &joined_room.room.live_kit_room,
2660                            &session.user_id.to_string(),
2661                        )
2662                        .trace_err()?,
2663                )
2664            } else {
2665                (
2666                    true,
2667                    live_kit
2668                        .room_token(
2669                            &joined_room.room.live_kit_room,
2670                            &session.user_id.to_string(),
2671                        )
2672                        .trace_err()?,
2673                )
2674            };
2675
2676            Some(LiveKitConnectionInfo {
2677                server_url: live_kit.url().into(),
2678                token,
2679                can_publish,
2680            })
2681        });
2682
2683        response.send(proto::JoinRoomResponse {
2684            room: Some(joined_room.room.clone()),
2685            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2686            live_kit_connection_info,
2687        })?;
2688
2689        let connection_pool = session.connection_pool().await;
2690        if let Some(membership_updated) = membership_updated {
2691            notify_membership_updated(
2692                &connection_pool,
2693                membership_updated,
2694                session.user_id,
2695                &session.peer,
2696            );
2697        }
2698
2699        room_updated(&joined_room.room, &session.peer);
2700
2701        joined_room
2702    };
2703
2704    channel_updated(
2705        channel_id,
2706        &joined_room.room,
2707        &joined_room.channel_members,
2708        &session.peer,
2709        &*session.connection_pool().await,
2710    );
2711
2712    update_user_contacts(session.user_id, &session).await?;
2713    Ok(())
2714}
2715
2716async fn join_channel_buffer(
2717    request: proto::JoinChannelBuffer,
2718    response: Response<proto::JoinChannelBuffer>,
2719    session: Session,
2720) -> Result<()> {
2721    let db = session.db().await;
2722    let channel_id = ChannelId::from_proto(request.channel_id);
2723
2724    let open_response = db
2725        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2726        .await?;
2727
2728    let collaborators = open_response.collaborators.clone();
2729    response.send(open_response)?;
2730
2731    let update = UpdateChannelBufferCollaborators {
2732        channel_id: channel_id.to_proto(),
2733        collaborators: collaborators.clone(),
2734    };
2735    channel_buffer_updated(
2736        session.connection_id,
2737        collaborators
2738            .iter()
2739            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2740        &update,
2741        &session.peer,
2742    );
2743
2744    Ok(())
2745}
2746
2747async fn update_channel_buffer(
2748    request: proto::UpdateChannelBuffer,
2749    session: Session,
2750) -> Result<()> {
2751    let db = session.db().await;
2752    let channel_id = ChannelId::from_proto(request.channel_id);
2753
2754    let (collaborators, non_collaborators, epoch, version) = db
2755        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2756        .await?;
2757
2758    channel_buffer_updated(
2759        session.connection_id,
2760        collaborators,
2761        &proto::UpdateChannelBuffer {
2762            channel_id: channel_id.to_proto(),
2763            operations: request.operations,
2764        },
2765        &session.peer,
2766    );
2767
2768    let pool = &*session.connection_pool().await;
2769
2770    broadcast(
2771        None,
2772        non_collaborators
2773            .iter()
2774            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2775        |peer_id| {
2776            session.peer.send(
2777                peer_id.into(),
2778                proto::UpdateChannels {
2779                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2780                        channel_id: channel_id.to_proto(),
2781                        epoch: epoch as u64,
2782                        version: version.clone(),
2783                    }],
2784                    ..Default::default()
2785                },
2786            )
2787        },
2788    );
2789
2790    Ok(())
2791}
2792
2793async fn rejoin_channel_buffers(
2794    request: proto::RejoinChannelBuffers,
2795    response: Response<proto::RejoinChannelBuffers>,
2796    session: Session,
2797) -> Result<()> {
2798    let db = session.db().await;
2799    let buffers = db
2800        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2801        .await?;
2802
2803    for rejoined_buffer in &buffers {
2804        let collaborators_to_notify = rejoined_buffer
2805            .buffer
2806            .collaborators
2807            .iter()
2808            .filter_map(|c| Some(c.peer_id?.into()));
2809        channel_buffer_updated(
2810            session.connection_id,
2811            collaborators_to_notify,
2812            &proto::UpdateChannelBufferCollaborators {
2813                channel_id: rejoined_buffer.buffer.channel_id,
2814                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2815            },
2816            &session.peer,
2817        );
2818    }
2819
2820    response.send(proto::RejoinChannelBuffersResponse {
2821        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2822    })?;
2823
2824    Ok(())
2825}
2826
2827async fn leave_channel_buffer(
2828    request: proto::LeaveChannelBuffer,
2829    response: Response<proto::LeaveChannelBuffer>,
2830    session: Session,
2831) -> Result<()> {
2832    let db = session.db().await;
2833    let channel_id = ChannelId::from_proto(request.channel_id);
2834
2835    let left_buffer = db
2836        .leave_channel_buffer(channel_id, session.connection_id)
2837        .await?;
2838
2839    response.send(Ack {})?;
2840
2841    channel_buffer_updated(
2842        session.connection_id,
2843        left_buffer.connections,
2844        &proto::UpdateChannelBufferCollaborators {
2845            channel_id: channel_id.to_proto(),
2846            collaborators: left_buffer.collaborators,
2847        },
2848        &session.peer,
2849    );
2850
2851    Ok(())
2852}
2853
2854fn channel_buffer_updated<T: EnvelopedMessage>(
2855    sender_id: ConnectionId,
2856    collaborators: impl IntoIterator<Item = ConnectionId>,
2857    message: &T,
2858    peer: &Peer,
2859) {
2860    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2861        peer.send(peer_id.into(), message.clone())
2862    });
2863}
2864
2865fn send_notifications(
2866    connection_pool: &ConnectionPool,
2867    peer: &Peer,
2868    notifications: db::NotificationBatch,
2869) {
2870    for (user_id, notification) in notifications {
2871        for connection_id in connection_pool.user_connection_ids(user_id) {
2872            if let Err(error) = peer.send(
2873                connection_id,
2874                proto::AddNotification {
2875                    notification: Some(notification.clone()),
2876                },
2877            ) {
2878                tracing::error!(
2879                    "failed to send notification to {:?} {}",
2880                    connection_id,
2881                    error
2882                );
2883            }
2884        }
2885    }
2886}
2887
2888async fn send_channel_message(
2889    request: proto::SendChannelMessage,
2890    response: Response<proto::SendChannelMessage>,
2891    session: Session,
2892) -> Result<()> {
2893    // Validate the message body.
2894    let body = request.body.trim().to_string();
2895    if body.len() > MAX_MESSAGE_LEN {
2896        return Err(anyhow!("message is too long"))?;
2897    }
2898    if body.is_empty() {
2899        return Err(anyhow!("message can't be blank"))?;
2900    }
2901
2902    // TODO: adjust mentions if body is trimmed
2903
2904    let timestamp = OffsetDateTime::now_utc();
2905    let nonce = request
2906        .nonce
2907        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2908
2909    let channel_id = ChannelId::from_proto(request.channel_id);
2910    let CreatedChannelMessage {
2911        message_id,
2912        participant_connection_ids,
2913        channel_members,
2914        notifications,
2915    } = session
2916        .db()
2917        .await
2918        .create_channel_message(
2919            channel_id,
2920            session.user_id,
2921            &body,
2922            &request.mentions,
2923            timestamp,
2924            nonce.clone().into(),
2925        )
2926        .await?;
2927    let message = proto::ChannelMessage {
2928        sender_id: session.user_id.to_proto(),
2929        id: message_id.to_proto(),
2930        body,
2931        mentions: request.mentions,
2932        timestamp: timestamp.unix_timestamp() as u64,
2933        nonce: Some(nonce),
2934    };
2935    broadcast(
2936        Some(session.connection_id),
2937        participant_connection_ids,
2938        |connection| {
2939            session.peer.send(
2940                connection,
2941                proto::ChannelMessageSent {
2942                    channel_id: channel_id.to_proto(),
2943                    message: Some(message.clone()),
2944                },
2945            )
2946        },
2947    );
2948    response.send(proto::SendChannelMessageResponse {
2949        message: Some(message),
2950    })?;
2951
2952    let pool = &*session.connection_pool().await;
2953    broadcast(
2954        None,
2955        channel_members
2956            .iter()
2957            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2958        |peer_id| {
2959            session.peer.send(
2960                peer_id.into(),
2961                proto::UpdateChannels {
2962                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
2963                        channel_id: channel_id.to_proto(),
2964                        message_id: message_id.to_proto(),
2965                    }],
2966                    ..Default::default()
2967                },
2968            )
2969        },
2970    );
2971    send_notifications(pool, &session.peer, notifications);
2972
2973    Ok(())
2974}
2975
2976async fn remove_channel_message(
2977    request: proto::RemoveChannelMessage,
2978    response: Response<proto::RemoveChannelMessage>,
2979    session: Session,
2980) -> Result<()> {
2981    let channel_id = ChannelId::from_proto(request.channel_id);
2982    let message_id = MessageId::from_proto(request.message_id);
2983    let connection_ids = session
2984        .db()
2985        .await
2986        .remove_channel_message(channel_id, message_id, session.user_id)
2987        .await?;
2988    broadcast(Some(session.connection_id), connection_ids, |connection| {
2989        session.peer.send(connection, request.clone())
2990    });
2991    response.send(proto::Ack {})?;
2992    Ok(())
2993}
2994
2995async fn acknowledge_channel_message(
2996    request: proto::AckChannelMessage,
2997    session: Session,
2998) -> Result<()> {
2999    let channel_id = ChannelId::from_proto(request.channel_id);
3000    let message_id = MessageId::from_proto(request.message_id);
3001    let notifications = session
3002        .db()
3003        .await
3004        .observe_channel_message(channel_id, session.user_id, message_id)
3005        .await?;
3006    send_notifications(
3007        &*session.connection_pool().await,
3008        &session.peer,
3009        notifications,
3010    );
3011    Ok(())
3012}
3013
3014async fn acknowledge_buffer_version(
3015    request: proto::AckBufferOperation,
3016    session: Session,
3017) -> Result<()> {
3018    let buffer_id = BufferId::from_proto(request.buffer_id);
3019    session
3020        .db()
3021        .await
3022        .observe_buffer_version(
3023            buffer_id,
3024            session.user_id,
3025            request.epoch as i32,
3026            &request.version,
3027        )
3028        .await?;
3029    Ok(())
3030}
3031
3032async fn join_channel_chat(
3033    request: proto::JoinChannelChat,
3034    response: Response<proto::JoinChannelChat>,
3035    session: Session,
3036) -> Result<()> {
3037    let channel_id = ChannelId::from_proto(request.channel_id);
3038
3039    let db = session.db().await;
3040    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3041        .await?;
3042    let messages = db
3043        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3044        .await?;
3045    response.send(proto::JoinChannelChatResponse {
3046        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3047        messages,
3048    })?;
3049    Ok(())
3050}
3051
3052async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3053    let channel_id = ChannelId::from_proto(request.channel_id);
3054    session
3055        .db()
3056        .await
3057        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3058        .await?;
3059    Ok(())
3060}
3061
3062async fn get_channel_messages(
3063    request: proto::GetChannelMessages,
3064    response: Response<proto::GetChannelMessages>,
3065    session: Session,
3066) -> Result<()> {
3067    let channel_id = ChannelId::from_proto(request.channel_id);
3068    let messages = session
3069        .db()
3070        .await
3071        .get_channel_messages(
3072            channel_id,
3073            session.user_id,
3074            MESSAGE_COUNT_PER_PAGE,
3075            Some(MessageId::from_proto(request.before_message_id)),
3076        )
3077        .await?;
3078    response.send(proto::GetChannelMessagesResponse {
3079        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3080        messages,
3081    })?;
3082    Ok(())
3083}
3084
3085async fn get_channel_messages_by_id(
3086    request: proto::GetChannelMessagesById,
3087    response: Response<proto::GetChannelMessagesById>,
3088    session: Session,
3089) -> Result<()> {
3090    let message_ids = request
3091        .message_ids
3092        .iter()
3093        .map(|id| MessageId::from_proto(*id))
3094        .collect::<Vec<_>>();
3095    let messages = session
3096        .db()
3097        .await
3098        .get_channel_messages_by_id(session.user_id, &message_ids)
3099        .await?;
3100    response.send(proto::GetChannelMessagesResponse {
3101        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3102        messages,
3103    })?;
3104    Ok(())
3105}
3106
3107async fn get_notifications(
3108    request: proto::GetNotifications,
3109    response: Response<proto::GetNotifications>,
3110    session: Session,
3111) -> Result<()> {
3112    let notifications = session
3113        .db()
3114        .await
3115        .get_notifications(
3116            session.user_id,
3117            NOTIFICATION_COUNT_PER_PAGE,
3118            request
3119                .before_id
3120                .map(|id| db::NotificationId::from_proto(id)),
3121        )
3122        .await?;
3123    response.send(proto::GetNotificationsResponse {
3124        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3125        notifications,
3126    })?;
3127    Ok(())
3128}
3129
3130async fn mark_notification_as_read(
3131    request: proto::MarkNotificationRead,
3132    response: Response<proto::MarkNotificationRead>,
3133    session: Session,
3134) -> Result<()> {
3135    let database = &session.db().await;
3136    let notifications = database
3137        .mark_notification_as_read_by_id(
3138            session.user_id,
3139            NotificationId::from_proto(request.notification_id),
3140        )
3141        .await?;
3142    send_notifications(
3143        &*session.connection_pool().await,
3144        &session.peer,
3145        notifications,
3146    );
3147    response.send(proto::Ack {})?;
3148    Ok(())
3149}
3150
3151async fn get_private_user_info(
3152    _request: proto::GetPrivateUserInfo,
3153    response: Response<proto::GetPrivateUserInfo>,
3154    session: Session,
3155) -> Result<()> {
3156    let db = session.db().await;
3157
3158    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3159    let user = db
3160        .get_user_by_id(session.user_id)
3161        .await?
3162        .ok_or_else(|| anyhow!("user not found"))?;
3163    let flags = db.get_user_flags(session.user_id).await?;
3164
3165    response.send(proto::GetPrivateUserInfoResponse {
3166        metrics_id,
3167        staff: user.admin,
3168        flags,
3169    })?;
3170    Ok(())
3171}
3172
3173fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3174    match message {
3175        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3176        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3177        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3178        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3179        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3180            code: frame.code.into(),
3181            reason: frame.reason,
3182        })),
3183    }
3184}
3185
3186fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3187    match message {
3188        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3189        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3190        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3191        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3192        AxumMessage::Close(frame) => {
3193            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3194                code: frame.code.into(),
3195                reason: frame.reason,
3196            }))
3197        }
3198    }
3199}
3200
3201fn notify_membership_updated(
3202    connection_pool: &ConnectionPool,
3203    result: MembershipUpdated,
3204    user_id: UserId,
3205    peer: &Peer,
3206) {
3207    let mut update = build_channels_update(result.new_channels, vec![]);
3208    update.delete_channels = result
3209        .removed_channels
3210        .into_iter()
3211        .map(|id| id.to_proto())
3212        .collect();
3213    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3214
3215    for connection_id in connection_pool.user_connection_ids(user_id) {
3216        peer.send(connection_id, update.clone()).trace_err();
3217    }
3218}
3219
3220fn build_channels_update(
3221    channels: ChannelsForUser,
3222    channel_invites: Vec<db::Channel>,
3223) -> proto::UpdateChannels {
3224    let mut update = proto::UpdateChannels::default();
3225
3226    for channel in channels.channels {
3227        update.channels.push(channel.to_proto());
3228    }
3229
3230    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3231    update.unseen_channel_messages = channels.channel_messages;
3232
3233    for (channel_id, participants) in channels.channel_participants {
3234        update
3235            .channel_participants
3236            .push(proto::ChannelParticipants {
3237                channel_id: channel_id.to_proto(),
3238                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3239            });
3240    }
3241
3242    for channel in channel_invites {
3243        update.channel_invitations.push(channel.to_proto());
3244    }
3245
3246    update
3247}
3248
3249fn build_initial_contacts_update(
3250    contacts: Vec<db::Contact>,
3251    pool: &ConnectionPool,
3252) -> proto::UpdateContacts {
3253    let mut update = proto::UpdateContacts::default();
3254
3255    for contact in contacts {
3256        match contact {
3257            db::Contact::Accepted { user_id, busy } => {
3258                update.contacts.push(contact_for_user(user_id, busy, &pool));
3259            }
3260            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3261            db::Contact::Incoming { user_id } => {
3262                update
3263                    .incoming_requests
3264                    .push(proto::IncomingContactRequest {
3265                        requester_id: user_id.to_proto(),
3266                    })
3267            }
3268        }
3269    }
3270
3271    update
3272}
3273
3274fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3275    proto::Contact {
3276        user_id: user_id.to_proto(),
3277        online: pool.is_user_online(user_id),
3278        busy,
3279    }
3280}
3281
3282fn room_updated(room: &proto::Room, peer: &Peer) {
3283    broadcast(
3284        None,
3285        room.participants
3286            .iter()
3287            .filter_map(|participant| Some(participant.peer_id?.into())),
3288        |peer_id| {
3289            peer.send(
3290                peer_id.into(),
3291                proto::RoomUpdated {
3292                    room: Some(room.clone()),
3293                },
3294            )
3295        },
3296    );
3297}
3298
3299fn channel_updated(
3300    channel_id: ChannelId,
3301    room: &proto::Room,
3302    channel_members: &[UserId],
3303    peer: &Peer,
3304    pool: &ConnectionPool,
3305) {
3306    let participants = room
3307        .participants
3308        .iter()
3309        .map(|p| p.user_id)
3310        .collect::<Vec<_>>();
3311
3312    broadcast(
3313        None,
3314        channel_members
3315            .iter()
3316            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3317        |peer_id| {
3318            peer.send(
3319                peer_id.into(),
3320                proto::UpdateChannels {
3321                    channel_participants: vec![proto::ChannelParticipants {
3322                        channel_id: channel_id.to_proto(),
3323                        participant_user_ids: participants.clone(),
3324                    }],
3325                    ..Default::default()
3326                },
3327            )
3328        },
3329    );
3330}
3331
3332async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3333    let db = session.db().await;
3334
3335    let contacts = db.get_contacts(user_id).await?;
3336    let busy = db.is_user_busy(user_id).await?;
3337
3338    let pool = session.connection_pool().await;
3339    let updated_contact = contact_for_user(user_id, busy, &pool);
3340    for contact in contacts {
3341        if let db::Contact::Accepted {
3342            user_id: contact_user_id,
3343            ..
3344        } = contact
3345        {
3346            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3347                session
3348                    .peer
3349                    .send(
3350                        contact_conn_id,
3351                        proto::UpdateContacts {
3352                            contacts: vec![updated_contact.clone()],
3353                            remove_contacts: Default::default(),
3354                            incoming_requests: Default::default(),
3355                            remove_incoming_requests: Default::default(),
3356                            outgoing_requests: Default::default(),
3357                            remove_outgoing_requests: Default::default(),
3358                        },
3359                    )
3360                    .trace_err();
3361            }
3362        }
3363    }
3364    Ok(())
3365}
3366
3367async fn leave_room_for_session(session: &Session) -> Result<()> {
3368    let mut contacts_to_update = HashSet::default();
3369
3370    let room_id;
3371    let canceled_calls_to_user_ids;
3372    let live_kit_room;
3373    let delete_live_kit_room;
3374    let room;
3375    let channel_members;
3376    let channel_id;
3377
3378    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3379        contacts_to_update.insert(session.user_id);
3380
3381        for project in left_room.left_projects.values() {
3382            project_left(project, session);
3383        }
3384
3385        room_id = RoomId::from_proto(left_room.room.id);
3386        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3387        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3388        delete_live_kit_room = left_room.deleted;
3389        room = mem::take(&mut left_room.room);
3390        channel_members = mem::take(&mut left_room.channel_members);
3391        channel_id = left_room.channel_id;
3392
3393        room_updated(&room, &session.peer);
3394    } else {
3395        return Ok(());
3396    }
3397
3398    if let Some(channel_id) = channel_id {
3399        channel_updated(
3400            channel_id,
3401            &room,
3402            &channel_members,
3403            &session.peer,
3404            &*session.connection_pool().await,
3405        );
3406    }
3407
3408    {
3409        let pool = session.connection_pool().await;
3410        for canceled_user_id in canceled_calls_to_user_ids {
3411            for connection_id in pool.user_connection_ids(canceled_user_id) {
3412                session
3413                    .peer
3414                    .send(
3415                        connection_id,
3416                        proto::CallCanceled {
3417                            room_id: room_id.to_proto(),
3418                        },
3419                    )
3420                    .trace_err();
3421            }
3422            contacts_to_update.insert(canceled_user_id);
3423        }
3424    }
3425
3426    for contact_user_id in contacts_to_update {
3427        update_user_contacts(contact_user_id, &session).await?;
3428    }
3429
3430    if let Some(live_kit) = session.live_kit_client.as_ref() {
3431        live_kit
3432            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3433            .await
3434            .trace_err();
3435
3436        if delete_live_kit_room {
3437            live_kit.delete_room(live_kit_room).await.trace_err();
3438        }
3439    }
3440
3441    Ok(())
3442}
3443
3444async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3445    let left_channel_buffers = session
3446        .db()
3447        .await
3448        .leave_channel_buffers(session.connection_id)
3449        .await?;
3450
3451    for left_buffer in left_channel_buffers {
3452        channel_buffer_updated(
3453            session.connection_id,
3454            left_buffer.connections,
3455            &proto::UpdateChannelBufferCollaborators {
3456                channel_id: left_buffer.channel_id.to_proto(),
3457                collaborators: left_buffer.collaborators,
3458            },
3459            &session.peer,
3460        );
3461    }
3462
3463    Ok(())
3464}
3465
3466fn project_left(project: &db::LeftProject, session: &Session) {
3467    for connection_id in &project.connection_ids {
3468        if project.host_user_id == session.user_id {
3469            session
3470                .peer
3471                .send(
3472                    *connection_id,
3473                    proto::UnshareProject {
3474                        project_id: project.id.to_proto(),
3475                    },
3476                )
3477                .trace_err();
3478        } else {
3479            session
3480                .peer
3481                .send(
3482                    *connection_id,
3483                    proto::RemoveProjectCollaborator {
3484                        project_id: project.id.to_proto(),
3485                        peer_id: Some(session.connection_id.into()),
3486                    },
3487                )
3488                .trace_err();
3489        }
3490    }
3491}
3492
3493pub trait ResultExt {
3494    type Ok;
3495
3496    fn trace_err(self) -> Option<Self::Ok>;
3497}
3498
3499impl<T, E> ResultExt for Result<T, E>
3500where
3501    E: std::fmt::Debug,
3502{
3503    type Ok = T;
3504
3505    fn trace_err(self) -> Option<T> {
3506        match self {
3507            Ok(value) => Some(value),
3508            Err(error) => {
3509                tracing::error!("{:?}", error);
3510                None
3511            }
3512        }
3513    }
3514}