rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{field, info_span, instrument, Instrument};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  72
  73const MESSAGE_COUNT_PER_PAGE: usize = 100;
  74const MAX_MESSAGE_LEN: usize = 1024;
  75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  76
  77lazy_static! {
  78    static ref METRIC_CONNECTIONS: IntGauge =
  79        register_int_gauge!("connections", "number of connections").unwrap();
  80    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  81        "shared_projects",
  82        "number of open projects with one or more guests"
  83    )
  84    .unwrap();
  85}
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct Session {
 106    zed_environment: Arc<str>,
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(set_room_participant_role)
 206            .add_request_handler(call)
 207            .add_request_handler(cancel_call)
 208            .add_message_handler(decline_call)
 209            .add_request_handler(update_participant_location)
 210            .add_request_handler(share_project)
 211            .add_message_handler(unshare_project)
 212            .add_request_handler(join_project)
 213            .add_message_handler(leave_project)
 214            .add_request_handler(update_project)
 215            .add_request_handler(update_worktree)
 216            .add_message_handler(start_language_server)
 217            .add_message_handler(update_language_server)
 218            .add_message_handler(update_diagnostic_summary)
 219            .add_message_handler(update_worktree_settings)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 230            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 231            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 232            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 233            .add_request_handler(
 234                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 235            )
 236            .add_request_handler(
 237                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 238            )
 239            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 240            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 243            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 251            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 252            .add_message_handler(create_buffer_for_peer)
 253            .add_request_handler(update_buffer)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 258            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 259            .add_request_handler(get_users)
 260            .add_request_handler(fuzzy_search_users)
 261            .add_request_handler(request_contact)
 262            .add_request_handler(remove_contact)
 263            .add_request_handler(respond_to_contact_request)
 264            .add_request_handler(create_channel)
 265            .add_request_handler(delete_channel)
 266            .add_request_handler(invite_channel_member)
 267            .add_request_handler(remove_channel_member)
 268            .add_request_handler(set_channel_member_role)
 269            .add_request_handler(set_channel_visibility)
 270            .add_request_handler(rename_channel)
 271            .add_request_handler(join_channel_buffer)
 272            .add_request_handler(leave_channel_buffer)
 273            .add_message_handler(update_channel_buffer)
 274            .add_request_handler(rejoin_channel_buffers)
 275            .add_request_handler(get_channel_members)
 276            .add_request_handler(respond_to_channel_invite)
 277            .add_request_handler(join_channel)
 278            .add_request_handler(join_channel_chat)
 279            .add_message_handler(leave_channel_chat)
 280            .add_request_handler(send_channel_message)
 281            .add_request_handler(remove_channel_message)
 282            .add_request_handler(get_channel_messages)
 283            .add_request_handler(get_channel_messages_by_id)
 284            .add_request_handler(get_notifications)
 285            .add_request_handler(mark_notification_as_read)
 286            .add_request_handler(move_channel)
 287            .add_request_handler(follow)
 288            .add_message_handler(unfollow)
 289            .add_message_handler(update_followers)
 290            .add_request_handler(get_private_user_info)
 291            .add_message_handler(acknowledge_channel_message)
 292            .add_message_handler(acknowledge_buffer_version);
 293
 294        Arc::new(server)
 295    }
 296
 297    pub async fn start(&self) -> Result<()> {
 298        let server_id = *self.id.lock();
 299        let app_state = self.app_state.clone();
 300        let peer = self.peer.clone();
 301        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 302        let pool = self.connection_pool.clone();
 303        let live_kit_client = self.app_state.live_kit_client.clone();
 304
 305        let span = info_span!("start server");
 306        self.executor.spawn_detached(
 307            async move {
 308                tracing::info!("waiting for cleanup timeout");
 309                timeout.await;
 310                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 311                if let Some((room_ids, channel_ids)) = app_state
 312                    .db
 313                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 314                    .await
 315                    .trace_err()
 316                {
 317                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 318                    tracing::info!(
 319                        stale_channel_buffer_count = channel_ids.len(),
 320                        "retrieved stale channel buffers"
 321                    );
 322
 323                    for channel_id in channel_ids {
 324                        if let Some(refreshed_channel_buffer) = app_state
 325                            .db
 326                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 327                            .await
 328                            .trace_err()
 329                        {
 330                            for connection_id in refreshed_channel_buffer.connection_ids {
 331                                peer.send(
 332                                    connection_id,
 333                                    proto::UpdateChannelBufferCollaborators {
 334                                        channel_id: channel_id.to_proto(),
 335                                        collaborators: refreshed_channel_buffer
 336                                            .collaborators
 337                                            .clone(),
 338                                    },
 339                                )
 340                                .trace_err();
 341                            }
 342                        }
 343                    }
 344
 345                    for room_id in room_ids {
 346                        let mut contacts_to_update = HashSet::default();
 347                        let mut canceled_calls_to_user_ids = Vec::new();
 348                        let mut live_kit_room = String::new();
 349                        let mut delete_live_kit_room = false;
 350
 351                        if let Some(mut refreshed_room) = app_state
 352                            .db
 353                            .clear_stale_room_participants(room_id, server_id)
 354                            .await
 355                            .trace_err()
 356                        {
 357                            tracing::info!(
 358                                room_id = room_id.0,
 359                                new_participant_count = refreshed_room.room.participants.len(),
 360                                "refreshed room"
 361                            );
 362                            room_updated(&refreshed_room.room, &peer);
 363                            if let Some(channel_id) = refreshed_room.channel_id {
 364                                channel_updated(
 365                                    channel_id,
 366                                    &refreshed_room.room,
 367                                    &refreshed_room.channel_members,
 368                                    &peer,
 369                                    &*pool.lock(),
 370                                );
 371                            }
 372                            contacts_to_update
 373                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 374                            contacts_to_update
 375                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 376                            canceled_calls_to_user_ids =
 377                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 378                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 379                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 380                        }
 381
 382                        {
 383                            let pool = pool.lock();
 384                            for canceled_user_id in canceled_calls_to_user_ids {
 385                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 386                                    peer.send(
 387                                        connection_id,
 388                                        proto::CallCanceled {
 389                                            room_id: room_id.to_proto(),
 390                                        },
 391                                    )
 392                                    .trace_err();
 393                                }
 394                            }
 395                        }
 396
 397                        for user_id in contacts_to_update {
 398                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 399                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 400                            if let Some((busy, contacts)) = busy.zip(contacts) {
 401                                let pool = pool.lock();
 402                                let updated_contact = contact_for_user(user_id, busy, &pool);
 403                                for contact in contacts {
 404                                    if let db::Contact::Accepted {
 405                                        user_id: contact_user_id,
 406                                        ..
 407                                    } = contact
 408                                    {
 409                                        for contact_conn_id in
 410                                            pool.user_connection_ids(contact_user_id)
 411                                        {
 412                                            peer.send(
 413                                                contact_conn_id,
 414                                                proto::UpdateContacts {
 415                                                    contacts: vec![updated_contact.clone()],
 416                                                    remove_contacts: Default::default(),
 417                                                    incoming_requests: Default::default(),
 418                                                    remove_incoming_requests: Default::default(),
 419                                                    outgoing_requests: Default::default(),
 420                                                    remove_outgoing_requests: Default::default(),
 421                                                },
 422                                            )
 423                                            .trace_err();
 424                                        }
 425                                    }
 426                                }
 427                            }
 428                        }
 429
 430                        if let Some(live_kit) = live_kit_client.as_ref() {
 431                            if delete_live_kit_room {
 432                                live_kit.delete_room(live_kit_room).await.trace_err();
 433                            }
 434                        }
 435                    }
 436                }
 437
 438                app_state
 439                    .db
 440                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 441                    .await
 442                    .trace_err();
 443            }
 444            .instrument(span),
 445        );
 446        Ok(())
 447    }
 448
 449    pub fn teardown(&self) {
 450        self.peer.teardown();
 451        self.connection_pool.lock().reset();
 452        let _ = self.teardown.send(());
 453    }
 454
 455    #[cfg(test)]
 456    pub fn reset(&self, id: ServerId) {
 457        self.teardown();
 458        *self.id.lock() = id;
 459        self.peer.reset(id.0 as u32);
 460    }
 461
 462    #[cfg(test)]
 463    pub fn id(&self) -> ServerId {
 464        *self.id.lock()
 465    }
 466
 467    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 468    where
 469        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 470        Fut: 'static + Send + Future<Output = Result<()>>,
 471        M: EnvelopedMessage,
 472    {
 473        let prev_handler = self.handlers.insert(
 474            TypeId::of::<M>(),
 475            Box::new(move |envelope, session| {
 476                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 477                let span = info_span!(
 478                    "handle message",
 479                    payload_type = envelope.payload_type_name()
 480                );
 481                span.in_scope(|| {
 482                    tracing::info!(
 483                        payload_type = envelope.payload_type_name(),
 484                        "message received"
 485                    );
 486                });
 487                let start_time = Instant::now();
 488                let future = (handler)(*envelope, session);
 489                async move {
 490                    let result = future.await;
 491                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 492                    match result {
 493                        Err(error) => {
 494                            tracing::error!(%error, ?duration_ms, "error handling message")
 495                        }
 496                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 497                    }
 498                }
 499                .instrument(span)
 500                .boxed()
 501            }),
 502        );
 503        if prev_handler.is_some() {
 504            panic!("registered a handler for the same message twice");
 505        }
 506        self
 507    }
 508
 509    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 512        Fut: 'static + Send + Future<Output = Result<()>>,
 513        M: EnvelopedMessage,
 514    {
 515        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 516        self
 517    }
 518
 519    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 520    where
 521        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 522        Fut: Send + Future<Output = Result<()>>,
 523        M: RequestMessage,
 524    {
 525        let handler = Arc::new(handler);
 526        self.add_handler(move |envelope, session| {
 527            let receipt = envelope.receipt();
 528            let handler = handler.clone();
 529            async move {
 530                let peer = session.peer.clone();
 531                let responded = Arc::new(AtomicBool::default());
 532                let response = Response {
 533                    peer: peer.clone(),
 534                    responded: responded.clone(),
 535                    receipt,
 536                };
 537                match (handler)(envelope.payload, response, session).await {
 538                    Ok(()) => {
 539                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 540                            Ok(())
 541                        } else {
 542                            Err(anyhow!("handler did not send a response"))?
 543                        }
 544                    }
 545                    Err(error) => {
 546                        peer.respond_with_error(
 547                            receipt,
 548                            proto::Error {
 549                                message: error.to_string(),
 550                            },
 551                        )?;
 552                        Err(error)
 553                    }
 554                }
 555            }
 556        })
 557    }
 558
 559    pub fn handle_connection(
 560        self: &Arc<Self>,
 561        connection: Connection,
 562        address: String,
 563        user: User,
 564        impersonator: Option<User>,
 565        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 566        executor: Executor,
 567    ) -> impl Future<Output = Result<()>> {
 568        let this = self.clone();
 569        let user_id = user.id;
 570        let login = user.github_login;
 571        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 572        if let Some(impersonator) = impersonator {
 573            span.record("impersonator", &impersonator.github_login);
 574        }
 575        let mut teardown = self.teardown.subscribe();
 576        async move {
 577            let (connection_id, handle_io, mut incoming_rx) = this
 578                .peer
 579                .add_connection(connection, {
 580                    let executor = executor.clone();
 581                    move |duration| executor.sleep(duration)
 582                });
 583
 584            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 585            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 586            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 587
 588            if let Some(send_connection_id) = send_connection_id.take() {
 589                let _ = send_connection_id.send(connection_id);
 590            }
 591
 592            if !user.connected_once {
 593                this.peer.send(connection_id, proto::ShowContacts {})?;
 594                this.app_state.db.set_user_connected_once(user_id, true).await?;
 595            }
 596
 597            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 598                this.app_state.db.get_contacts(user_id),
 599                this.app_state.db.get_channels_for_user(user_id),
 600                this.app_state.db.get_channel_invites_for_user(user_id),
 601            ).await?;
 602
 603            {
 604                let mut pool = this.connection_pool.lock();
 605                pool.add_connection(connection_id, user_id, user.admin);
 606                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 607                this.peer.send(connection_id, build_channels_update(
 608                    channels_for_user,
 609                    channel_invites
 610                ))?;
 611            }
 612
 613            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 614                this.peer.send(connection_id, incoming_call)?;
 615            }
 616
 617            let session = Session {
 618                user_id,
 619                connection_id,
 620                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 621                zed_environment: this.app_state.config.zed_environment.clone(),
 622                peer: this.peer.clone(),
 623                connection_pool: this.connection_pool.clone(),
 624                live_kit_client: this.app_state.live_kit_client.clone(),
 625                _executor: executor.clone()
 626            };
 627            update_user_contacts(user_id, &session).await?;
 628
 629            let handle_io = handle_io.fuse();
 630            futures::pin_mut!(handle_io);
 631
 632            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 633            // This prevents deadlocks when e.g., client A performs a request to client B and
 634            // client B performs a request to client A. If both clients stop processing further
 635            // messages until their respective request completes, they won't have a chance to
 636            // respond to the other client's request and cause a deadlock.
 637            //
 638            // This arrangement ensures we will attempt to process earlier messages first, but fall
 639            // back to processing messages arrived later in the spirit of making progress.
 640            let mut foreground_message_handlers = FuturesUnordered::new();
 641            let concurrent_handlers = Arc::new(Semaphore::new(256));
 642            loop {
 643                let next_message = async {
 644                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 645                    let message = incoming_rx.next().await;
 646                    (permit, message)
 647                }.fuse();
 648                futures::pin_mut!(next_message);
 649                futures::select_biased! {
 650                    _ = teardown.changed().fuse() => return Ok(()),
 651                    result = handle_io => {
 652                        if let Err(error) = result {
 653                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 654                        }
 655                        break;
 656                    }
 657                    _ = foreground_message_handlers.next() => {}
 658                    next_message = next_message => {
 659                        let (permit, message) = next_message;
 660                        if let Some(message) = message {
 661                            let type_name = message.payload_type_name();
 662                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 663                            let span_enter = span.enter();
 664                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 665                                let is_background = message.is_background();
 666                                let handle_message = (handler)(message, session.clone());
 667                                drop(span_enter);
 668
 669                                let handle_message = async move {
 670                                    handle_message.await;
 671                                    drop(permit);
 672                                }.instrument(span);
 673                                if is_background {
 674                                    executor.spawn_detached(handle_message);
 675                                } else {
 676                                    foreground_message_handlers.push(handle_message);
 677                                }
 678                            } else {
 679                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 680                            }
 681                        } else {
 682                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 683                            break;
 684                        }
 685                    }
 686                }
 687            }
 688
 689            drop(foreground_message_handlers);
 690            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 691            if let Err(error) = connection_lost(session, teardown, executor).await {
 692                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 693            }
 694
 695            Ok(())
 696        }.instrument(span)
 697    }
 698
 699    pub async fn invite_code_redeemed(
 700        self: &Arc<Self>,
 701        inviter_id: UserId,
 702        invitee_id: UserId,
 703    ) -> Result<()> {
 704        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 705            if let Some(code) = &user.invite_code {
 706                let pool = self.connection_pool.lock();
 707                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 708                for connection_id in pool.user_connection_ids(inviter_id) {
 709                    self.peer.send(
 710                        connection_id,
 711                        proto::UpdateContacts {
 712                            contacts: vec![invitee_contact.clone()],
 713                            ..Default::default()
 714                        },
 715                    )?;
 716                    self.peer.send(
 717                        connection_id,
 718                        proto::UpdateInviteInfo {
 719                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 720                            count: user.invite_count as u32,
 721                        },
 722                    )?;
 723                }
 724            }
 725        }
 726        Ok(())
 727    }
 728
 729    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 730        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 731            if let Some(invite_code) = &user.invite_code {
 732                let pool = self.connection_pool.lock();
 733                for connection_id in pool.user_connection_ids(user_id) {
 734                    self.peer.send(
 735                        connection_id,
 736                        proto::UpdateInviteInfo {
 737                            url: format!(
 738                                "{}{}",
 739                                self.app_state.config.invite_link_prefix, invite_code
 740                            ),
 741                            count: user.invite_count as u32,
 742                        },
 743                    )?;
 744                }
 745            }
 746        }
 747        Ok(())
 748    }
 749
 750    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 751        ServerSnapshot {
 752            connection_pool: ConnectionPoolGuard {
 753                guard: self.connection_pool.lock(),
 754                _not_send: PhantomData,
 755            },
 756            peer: &self.peer,
 757        }
 758    }
 759}
 760
 761impl<'a> Deref for ConnectionPoolGuard<'a> {
 762    type Target = ConnectionPool;
 763
 764    fn deref(&self) -> &Self::Target {
 765        &*self.guard
 766    }
 767}
 768
 769impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 770    fn deref_mut(&mut self) -> &mut Self::Target {
 771        &mut *self.guard
 772    }
 773}
 774
 775impl<'a> Drop for ConnectionPoolGuard<'a> {
 776    fn drop(&mut self) {
 777        #[cfg(test)]
 778        self.check_invariants();
 779    }
 780}
 781
 782fn broadcast<F>(
 783    sender_id: Option<ConnectionId>,
 784    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 785    mut f: F,
 786) where
 787    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 788{
 789    for receiver_id in receiver_ids {
 790        if Some(receiver_id) != sender_id {
 791            if let Err(error) = f(receiver_id) {
 792                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 793            }
 794        }
 795    }
 796}
 797
 798lazy_static! {
 799    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 800}
 801
 802pub struct ProtocolVersion(u32);
 803
 804impl Header for ProtocolVersion {
 805    fn name() -> &'static HeaderName {
 806        &ZED_PROTOCOL_VERSION
 807    }
 808
 809    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 810    where
 811        Self: Sized,
 812        I: Iterator<Item = &'i axum::http::HeaderValue>,
 813    {
 814        let version = values
 815            .next()
 816            .ok_or_else(axum::headers::Error::invalid)?
 817            .to_str()
 818            .map_err(|_| axum::headers::Error::invalid())?
 819            .parse()
 820            .map_err(|_| axum::headers::Error::invalid())?;
 821        Ok(Self(version))
 822    }
 823
 824    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 825        values.extend([self.0.to_string().parse().unwrap()]);
 826    }
 827}
 828
 829pub fn routes(server: Arc<Server>) -> Router<Body> {
 830    Router::new()
 831        .route("/rpc", get(handle_websocket_request))
 832        .layer(
 833            ServiceBuilder::new()
 834                .layer(Extension(server.app_state.clone()))
 835                .layer(middleware::from_fn(auth::validate_header)),
 836        )
 837        .route("/metrics", get(handle_metrics))
 838        .layer(Extension(server))
 839}
 840
 841pub async fn handle_websocket_request(
 842    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 843    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 844    Extension(server): Extension<Arc<Server>>,
 845    Extension(user): Extension<User>,
 846    Extension(impersonator): Extension<Impersonator>,
 847    ws: WebSocketUpgrade,
 848) -> axum::response::Response {
 849    if protocol_version != rpc::PROTOCOL_VERSION {
 850        return (
 851            StatusCode::UPGRADE_REQUIRED,
 852            "client must be upgraded".to_string(),
 853        )
 854            .into_response();
 855    }
 856    let socket_address = socket_address.to_string();
 857    ws.on_upgrade(move |socket| {
 858        use util::ResultExt;
 859        let socket = socket
 860            .map_ok(to_tungstenite_message)
 861            .err_into()
 862            .with(|message| async move { Ok(to_axum_message(message)) });
 863        let connection = Connection::new(Box::pin(socket));
 864        async move {
 865            server
 866                .handle_connection(
 867                    connection,
 868                    socket_address,
 869                    user,
 870                    impersonator.0,
 871                    None,
 872                    Executor::Production,
 873                )
 874                .await
 875                .log_err();
 876        }
 877    })
 878}
 879
 880pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 881    let connections = server
 882        .connection_pool
 883        .lock()
 884        .connections()
 885        .filter(|connection| !connection.admin)
 886        .count();
 887
 888    METRIC_CONNECTIONS.set(connections as _);
 889
 890    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 891    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 892
 893    let encoder = prometheus::TextEncoder::new();
 894    let metric_families = prometheus::gather();
 895    let encoded_metrics = encoder
 896        .encode_to_string(&metric_families)
 897        .map_err(|err| anyhow!("{}", err))?;
 898    Ok(encoded_metrics)
 899}
 900
 901#[instrument(err, skip(executor))]
 902async fn connection_lost(
 903    session: Session,
 904    mut teardown: watch::Receiver<()>,
 905    executor: Executor,
 906) -> Result<()> {
 907    session.peer.disconnect(session.connection_id);
 908    session
 909        .connection_pool()
 910        .await
 911        .remove_connection(session.connection_id)?;
 912
 913    session
 914        .db()
 915        .await
 916        .connection_lost(session.connection_id)
 917        .await
 918        .trace_err();
 919
 920    futures::select_biased! {
 921        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 922            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 923            leave_room_for_session(&session).await.trace_err();
 924            leave_channel_buffers_for_session(&session)
 925                .await
 926                .trace_err();
 927
 928            if !session
 929                .connection_pool()
 930                .await
 931                .is_user_online(session.user_id)
 932            {
 933                let db = session.db().await;
 934                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 935                    room_updated(&room, &session.peer);
 936                }
 937            }
 938
 939            update_user_contacts(session.user_id, &session).await?;
 940        }
 941        _ = teardown.changed().fuse() => {}
 942    }
 943
 944    Ok(())
 945}
 946
 947/// Acknowledges a ping from a client, used to keep the connection alive.
 948async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 949    response.send(proto::Ack {})?;
 950    Ok(())
 951}
 952
 953/// Create a new room for calling (outside of channels)
 954async fn create_room(
 955    _request: proto::CreateRoom,
 956    response: Response<proto::CreateRoom>,
 957    session: Session,
 958) -> Result<()> {
 959    let live_kit_room = nanoid::nanoid!(30);
 960
 961    let live_kit_connection_info = {
 962        let live_kit_room = live_kit_room.clone();
 963        let live_kit = session.live_kit_client.as_ref();
 964
 965        util::async_maybe!({
 966            let live_kit = live_kit?;
 967
 968            let token = live_kit
 969                .room_token(&live_kit_room, &session.user_id.to_string())
 970                .trace_err()?;
 971
 972            Some(proto::LiveKitConnectionInfo {
 973                server_url: live_kit.url().into(),
 974                token,
 975                can_publish: true,
 976            })
 977        })
 978    }
 979    .await;
 980
 981    let room = session
 982        .db()
 983        .await
 984        .create_room(
 985            session.user_id,
 986            session.connection_id,
 987            &live_kit_room,
 988            &session.zed_environment,
 989        )
 990        .await?;
 991
 992    response.send(proto::CreateRoomResponse {
 993        room: Some(room.clone()),
 994        live_kit_connection_info,
 995    })?;
 996
 997    update_user_contacts(session.user_id, &session).await?;
 998    Ok(())
 999}
1000
1001/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1002async fn join_room(
1003    request: proto::JoinRoom,
1004    response: Response<proto::JoinRoom>,
1005    session: Session,
1006) -> Result<()> {
1007    let room_id = RoomId::from_proto(request.id);
1008
1009    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1010
1011    if let Some(channel_id) = channel_id {
1012        return join_channel_internal(channel_id, Box::new(response), session).await;
1013    }
1014
1015    let joined_room = {
1016        let room = session
1017            .db()
1018            .await
1019            .join_room(
1020                room_id,
1021                session.user_id,
1022                session.connection_id,
1023                session.zed_environment.as_ref(),
1024            )
1025            .await?;
1026        room_updated(&room.room, &session.peer);
1027        room.into_inner()
1028    };
1029
1030    for connection_id in session
1031        .connection_pool()
1032        .await
1033        .user_connection_ids(session.user_id)
1034    {
1035        session
1036            .peer
1037            .send(
1038                connection_id,
1039                proto::CallCanceled {
1040                    room_id: room_id.to_proto(),
1041                },
1042            )
1043            .trace_err();
1044    }
1045
1046    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1047        if let Some(token) = live_kit
1048            .room_token(
1049                &joined_room.room.live_kit_room,
1050                &session.user_id.to_string(),
1051            )
1052            .trace_err()
1053        {
1054            Some(proto::LiveKitConnectionInfo {
1055                server_url: live_kit.url().into(),
1056                token,
1057                can_publish: true,
1058            })
1059        } else {
1060            None
1061        }
1062    } else {
1063        None
1064    };
1065
1066    response.send(proto::JoinRoomResponse {
1067        room: Some(joined_room.room),
1068        channel_id: None,
1069        live_kit_connection_info,
1070    })?;
1071
1072    update_user_contacts(session.user_id, &session).await?;
1073    Ok(())
1074}
1075
1076/// Rejoin room is used to reconnect to a room after connection errors.
1077async fn rejoin_room(
1078    request: proto::RejoinRoom,
1079    response: Response<proto::RejoinRoom>,
1080    session: Session,
1081) -> Result<()> {
1082    let room;
1083    let channel_id;
1084    let channel_members;
1085    {
1086        let mut rejoined_room = session
1087            .db()
1088            .await
1089            .rejoin_room(request, session.user_id, session.connection_id)
1090            .await?;
1091
1092        response.send(proto::RejoinRoomResponse {
1093            room: Some(rejoined_room.room.clone()),
1094            reshared_projects: rejoined_room
1095                .reshared_projects
1096                .iter()
1097                .map(|project| proto::ResharedProject {
1098                    id: project.id.to_proto(),
1099                    collaborators: project
1100                        .collaborators
1101                        .iter()
1102                        .map(|collaborator| collaborator.to_proto())
1103                        .collect(),
1104                })
1105                .collect(),
1106            rejoined_projects: rejoined_room
1107                .rejoined_projects
1108                .iter()
1109                .map(|rejoined_project| proto::RejoinedProject {
1110                    id: rejoined_project.id.to_proto(),
1111                    worktrees: rejoined_project
1112                        .worktrees
1113                        .iter()
1114                        .map(|worktree| proto::WorktreeMetadata {
1115                            id: worktree.id,
1116                            root_name: worktree.root_name.clone(),
1117                            visible: worktree.visible,
1118                            abs_path: worktree.abs_path.clone(),
1119                        })
1120                        .collect(),
1121                    collaborators: rejoined_project
1122                        .collaborators
1123                        .iter()
1124                        .map(|collaborator| collaborator.to_proto())
1125                        .collect(),
1126                    language_servers: rejoined_project.language_servers.clone(),
1127                })
1128                .collect(),
1129        })?;
1130        room_updated(&rejoined_room.room, &session.peer);
1131
1132        for project in &rejoined_room.reshared_projects {
1133            for collaborator in &project.collaborators {
1134                session
1135                    .peer
1136                    .send(
1137                        collaborator.connection_id,
1138                        proto::UpdateProjectCollaborator {
1139                            project_id: project.id.to_proto(),
1140                            old_peer_id: Some(project.old_connection_id.into()),
1141                            new_peer_id: Some(session.connection_id.into()),
1142                        },
1143                    )
1144                    .trace_err();
1145            }
1146
1147            broadcast(
1148                Some(session.connection_id),
1149                project
1150                    .collaborators
1151                    .iter()
1152                    .map(|collaborator| collaborator.connection_id),
1153                |connection_id| {
1154                    session.peer.forward_send(
1155                        session.connection_id,
1156                        connection_id,
1157                        proto::UpdateProject {
1158                            project_id: project.id.to_proto(),
1159                            worktrees: project.worktrees.clone(),
1160                        },
1161                    )
1162                },
1163            );
1164        }
1165
1166        for project in &rejoined_room.rejoined_projects {
1167            for collaborator in &project.collaborators {
1168                session
1169                    .peer
1170                    .send(
1171                        collaborator.connection_id,
1172                        proto::UpdateProjectCollaborator {
1173                            project_id: project.id.to_proto(),
1174                            old_peer_id: Some(project.old_connection_id.into()),
1175                            new_peer_id: Some(session.connection_id.into()),
1176                        },
1177                    )
1178                    .trace_err();
1179            }
1180        }
1181
1182        for project in &mut rejoined_room.rejoined_projects {
1183            for worktree in mem::take(&mut project.worktrees) {
1184                #[cfg(any(test, feature = "test-support"))]
1185                const MAX_CHUNK_SIZE: usize = 2;
1186                #[cfg(not(any(test, feature = "test-support")))]
1187                const MAX_CHUNK_SIZE: usize = 256;
1188
1189                // Stream this worktree's entries.
1190                let message = proto::UpdateWorktree {
1191                    project_id: project.id.to_proto(),
1192                    worktree_id: worktree.id,
1193                    abs_path: worktree.abs_path.clone(),
1194                    root_name: worktree.root_name,
1195                    updated_entries: worktree.updated_entries,
1196                    removed_entries: worktree.removed_entries,
1197                    scan_id: worktree.scan_id,
1198                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1199                    updated_repositories: worktree.updated_repositories,
1200                    removed_repositories: worktree.removed_repositories,
1201                };
1202                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1203                    session.peer.send(session.connection_id, update.clone())?;
1204                }
1205
1206                // Stream this worktree's diagnostics.
1207                for summary in worktree.diagnostic_summaries {
1208                    session.peer.send(
1209                        session.connection_id,
1210                        proto::UpdateDiagnosticSummary {
1211                            project_id: project.id.to_proto(),
1212                            worktree_id: worktree.id,
1213                            summary: Some(summary),
1214                        },
1215                    )?;
1216                }
1217
1218                for settings_file in worktree.settings_files {
1219                    session.peer.send(
1220                        session.connection_id,
1221                        proto::UpdateWorktreeSettings {
1222                            project_id: project.id.to_proto(),
1223                            worktree_id: worktree.id,
1224                            path: settings_file.path,
1225                            content: Some(settings_file.content),
1226                        },
1227                    )?;
1228                }
1229            }
1230
1231            for language_server in &project.language_servers {
1232                session.peer.send(
1233                    session.connection_id,
1234                    proto::UpdateLanguageServer {
1235                        project_id: project.id.to_proto(),
1236                        language_server_id: language_server.id,
1237                        variant: Some(
1238                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1239                                proto::LspDiskBasedDiagnosticsUpdated {},
1240                            ),
1241                        ),
1242                    },
1243                )?;
1244            }
1245        }
1246
1247        let rejoined_room = rejoined_room.into_inner();
1248
1249        room = rejoined_room.room;
1250        channel_id = rejoined_room.channel_id;
1251        channel_members = rejoined_room.channel_members;
1252    }
1253
1254    if let Some(channel_id) = channel_id {
1255        channel_updated(
1256            channel_id,
1257            &room,
1258            &channel_members,
1259            &session.peer,
1260            &*session.connection_pool().await,
1261        );
1262    }
1263
1264    update_user_contacts(session.user_id, &session).await?;
1265    Ok(())
1266}
1267
1268/// leave room disonnects from the room.
1269async fn leave_room(
1270    _: proto::LeaveRoom,
1271    response: Response<proto::LeaveRoom>,
1272    session: Session,
1273) -> Result<()> {
1274    leave_room_for_session(&session).await?;
1275    response.send(proto::Ack {})?;
1276    Ok(())
1277}
1278
1279/// Update the permissions of someone else in the room.
1280async fn set_room_participant_role(
1281    request: proto::SetRoomParticipantRole,
1282    response: Response<proto::SetRoomParticipantRole>,
1283    session: Session,
1284) -> Result<()> {
1285    let (live_kit_room, can_publish) = {
1286        let room = session
1287            .db()
1288            .await
1289            .set_room_participant_role(
1290                session.user_id,
1291                RoomId::from_proto(request.room_id),
1292                UserId::from_proto(request.user_id),
1293                ChannelRole::from(request.role()),
1294            )
1295            .await?;
1296
1297        let live_kit_room = room.live_kit_room.clone();
1298        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1299        room_updated(&room, &session.peer);
1300        (live_kit_room, can_publish)
1301    };
1302
1303    if let Some(live_kit) = session.live_kit_client.as_ref() {
1304        live_kit
1305            .update_participant(
1306                live_kit_room.clone(),
1307                request.user_id.to_string(),
1308                live_kit_server::proto::ParticipantPermission {
1309                    can_subscribe: true,
1310                    can_publish,
1311                    can_publish_data: can_publish,
1312                    hidden: false,
1313                    recorder: false,
1314                },
1315            )
1316            .await
1317            .trace_err();
1318    }
1319
1320    response.send(proto::Ack {})?;
1321    Ok(())
1322}
1323
1324/// Call someone else into the current room
1325async fn call(
1326    request: proto::Call,
1327    response: Response<proto::Call>,
1328    session: Session,
1329) -> Result<()> {
1330    let room_id = RoomId::from_proto(request.room_id);
1331    let calling_user_id = session.user_id;
1332    let calling_connection_id = session.connection_id;
1333    let called_user_id = UserId::from_proto(request.called_user_id);
1334    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1335    if !session
1336        .db()
1337        .await
1338        .has_contact(calling_user_id, called_user_id)
1339        .await?
1340    {
1341        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1342    }
1343
1344    let incoming_call = {
1345        let (room, incoming_call) = &mut *session
1346            .db()
1347            .await
1348            .call(
1349                room_id,
1350                calling_user_id,
1351                calling_connection_id,
1352                called_user_id,
1353                initial_project_id,
1354            )
1355            .await?;
1356        room_updated(&room, &session.peer);
1357        mem::take(incoming_call)
1358    };
1359    update_user_contacts(called_user_id, &session).await?;
1360
1361    let mut calls = session
1362        .connection_pool()
1363        .await
1364        .user_connection_ids(called_user_id)
1365        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1366        .collect::<FuturesUnordered<_>>();
1367
1368    while let Some(call_response) = calls.next().await {
1369        match call_response.as_ref() {
1370            Ok(_) => {
1371                response.send(proto::Ack {})?;
1372                return Ok(());
1373            }
1374            Err(_) => {
1375                call_response.trace_err();
1376            }
1377        }
1378    }
1379
1380    {
1381        let room = session
1382            .db()
1383            .await
1384            .call_failed(room_id, called_user_id)
1385            .await?;
1386        room_updated(&room, &session.peer);
1387    }
1388    update_user_contacts(called_user_id, &session).await?;
1389
1390    Err(anyhow!("failed to ring user"))?
1391}
1392
1393/// Cancel an outgoing call.
1394async fn cancel_call(
1395    request: proto::CancelCall,
1396    response: Response<proto::CancelCall>,
1397    session: Session,
1398) -> Result<()> {
1399    let called_user_id = UserId::from_proto(request.called_user_id);
1400    let room_id = RoomId::from_proto(request.room_id);
1401    {
1402        let room = session
1403            .db()
1404            .await
1405            .cancel_call(room_id, session.connection_id, called_user_id)
1406            .await?;
1407        room_updated(&room, &session.peer);
1408    }
1409
1410    for connection_id in session
1411        .connection_pool()
1412        .await
1413        .user_connection_ids(called_user_id)
1414    {
1415        session
1416            .peer
1417            .send(
1418                connection_id,
1419                proto::CallCanceled {
1420                    room_id: room_id.to_proto(),
1421                },
1422            )
1423            .trace_err();
1424    }
1425    response.send(proto::Ack {})?;
1426
1427    update_user_contacts(called_user_id, &session).await?;
1428    Ok(())
1429}
1430
1431/// Decline an incoming call.
1432async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1433    let room_id = RoomId::from_proto(message.room_id);
1434    {
1435        let room = session
1436            .db()
1437            .await
1438            .decline_call(Some(room_id), session.user_id)
1439            .await?
1440            .ok_or_else(|| anyhow!("failed to decline call"))?;
1441        room_updated(&room, &session.peer);
1442    }
1443
1444    for connection_id in session
1445        .connection_pool()
1446        .await
1447        .user_connection_ids(session.user_id)
1448    {
1449        session
1450            .peer
1451            .send(
1452                connection_id,
1453                proto::CallCanceled {
1454                    room_id: room_id.to_proto(),
1455                },
1456            )
1457            .trace_err();
1458    }
1459    update_user_contacts(session.user_id, &session).await?;
1460    Ok(())
1461}
1462
1463/// Update other participants in the room with your current location.
1464async fn update_participant_location(
1465    request: proto::UpdateParticipantLocation,
1466    response: Response<proto::UpdateParticipantLocation>,
1467    session: Session,
1468) -> Result<()> {
1469    let room_id = RoomId::from_proto(request.room_id);
1470    let location = request
1471        .location
1472        .ok_or_else(|| anyhow!("invalid location"))?;
1473
1474    let db = session.db().await;
1475    let room = db
1476        .update_room_participant_location(room_id, session.connection_id, location)
1477        .await?;
1478
1479    room_updated(&room, &session.peer);
1480    response.send(proto::Ack {})?;
1481    Ok(())
1482}
1483
1484/// Share a project into the room.
1485async fn share_project(
1486    request: proto::ShareProject,
1487    response: Response<proto::ShareProject>,
1488    session: Session,
1489) -> Result<()> {
1490    let (project_id, room) = &*session
1491        .db()
1492        .await
1493        .share_project(
1494            RoomId::from_proto(request.room_id),
1495            session.connection_id,
1496            &request.worktrees,
1497        )
1498        .await?;
1499    response.send(proto::ShareProjectResponse {
1500        project_id: project_id.to_proto(),
1501    })?;
1502    room_updated(&room, &session.peer);
1503
1504    Ok(())
1505}
1506
1507/// Unshare a project from the room.
1508async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1509    let project_id = ProjectId::from_proto(message.project_id);
1510
1511    let (room, guest_connection_ids) = &*session
1512        .db()
1513        .await
1514        .unshare_project(project_id, session.connection_id)
1515        .await?;
1516
1517    broadcast(
1518        Some(session.connection_id),
1519        guest_connection_ids.iter().copied(),
1520        |conn_id| session.peer.send(conn_id, message.clone()),
1521    );
1522    room_updated(&room, &session.peer);
1523
1524    Ok(())
1525}
1526
1527/// Join someone elses shared project.
1528async fn join_project(
1529    request: proto::JoinProject,
1530    response: Response<proto::JoinProject>,
1531    session: Session,
1532) -> Result<()> {
1533    let project_id = ProjectId::from_proto(request.project_id);
1534    let guest_user_id = session.user_id;
1535
1536    tracing::info!(%project_id, "join project");
1537
1538    let (project, replica_id) = &mut *session
1539        .db()
1540        .await
1541        .join_project(project_id, session.connection_id)
1542        .await?;
1543
1544    let collaborators = project
1545        .collaborators
1546        .iter()
1547        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1548        .map(|collaborator| collaborator.to_proto())
1549        .collect::<Vec<_>>();
1550
1551    let worktrees = project
1552        .worktrees
1553        .iter()
1554        .map(|(id, worktree)| proto::WorktreeMetadata {
1555            id: *id,
1556            root_name: worktree.root_name.clone(),
1557            visible: worktree.visible,
1558            abs_path: worktree.abs_path.clone(),
1559        })
1560        .collect::<Vec<_>>();
1561
1562    for collaborator in &collaborators {
1563        session
1564            .peer
1565            .send(
1566                collaborator.peer_id.unwrap().into(),
1567                proto::AddProjectCollaborator {
1568                    project_id: project_id.to_proto(),
1569                    collaborator: Some(proto::Collaborator {
1570                        peer_id: Some(session.connection_id.into()),
1571                        replica_id: replica_id.0 as u32,
1572                        user_id: guest_user_id.to_proto(),
1573                    }),
1574                },
1575            )
1576            .trace_err();
1577    }
1578
1579    // First, we send the metadata associated with each worktree.
1580    response.send(proto::JoinProjectResponse {
1581        worktrees: worktrees.clone(),
1582        replica_id: replica_id.0 as u32,
1583        collaborators: collaborators.clone(),
1584        language_servers: project.language_servers.clone(),
1585    })?;
1586
1587    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1588        #[cfg(any(test, feature = "test-support"))]
1589        const MAX_CHUNK_SIZE: usize = 2;
1590        #[cfg(not(any(test, feature = "test-support")))]
1591        const MAX_CHUNK_SIZE: usize = 256;
1592
1593        // Stream this worktree's entries.
1594        let message = proto::UpdateWorktree {
1595            project_id: project_id.to_proto(),
1596            worktree_id,
1597            abs_path: worktree.abs_path.clone(),
1598            root_name: worktree.root_name,
1599            updated_entries: worktree.entries,
1600            removed_entries: Default::default(),
1601            scan_id: worktree.scan_id,
1602            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1603            updated_repositories: worktree.repository_entries.into_values().collect(),
1604            removed_repositories: Default::default(),
1605        };
1606        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1607            session.peer.send(session.connection_id, update.clone())?;
1608        }
1609
1610        // Stream this worktree's diagnostics.
1611        for summary in worktree.diagnostic_summaries {
1612            session.peer.send(
1613                session.connection_id,
1614                proto::UpdateDiagnosticSummary {
1615                    project_id: project_id.to_proto(),
1616                    worktree_id: worktree.id,
1617                    summary: Some(summary),
1618                },
1619            )?;
1620        }
1621
1622        for settings_file in worktree.settings_files {
1623            session.peer.send(
1624                session.connection_id,
1625                proto::UpdateWorktreeSettings {
1626                    project_id: project_id.to_proto(),
1627                    worktree_id: worktree.id,
1628                    path: settings_file.path,
1629                    content: Some(settings_file.content),
1630                },
1631            )?;
1632        }
1633    }
1634
1635    for language_server in &project.language_servers {
1636        session.peer.send(
1637            session.connection_id,
1638            proto::UpdateLanguageServer {
1639                project_id: project_id.to_proto(),
1640                language_server_id: language_server.id,
1641                variant: Some(
1642                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1643                        proto::LspDiskBasedDiagnosticsUpdated {},
1644                    ),
1645                ),
1646            },
1647        )?;
1648    }
1649
1650    Ok(())
1651}
1652
1653/// Leave someone elses shared project.
1654async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1655    let sender_id = session.connection_id;
1656    let project_id = ProjectId::from_proto(request.project_id);
1657
1658    let (room, project) = &*session
1659        .db()
1660        .await
1661        .leave_project(project_id, sender_id)
1662        .await?;
1663    tracing::info!(
1664        %project_id,
1665        host_user_id = %project.host_user_id,
1666        host_connection_id = %project.host_connection_id,
1667        "leave project"
1668    );
1669
1670    project_left(&project, &session);
1671    room_updated(&room, &session.peer);
1672
1673    Ok(())
1674}
1675
1676/// Update other participants with changes to the project
1677async fn update_project(
1678    request: proto::UpdateProject,
1679    response: Response<proto::UpdateProject>,
1680    session: Session,
1681) -> Result<()> {
1682    let project_id = ProjectId::from_proto(request.project_id);
1683    let (room, guest_connection_ids) = &*session
1684        .db()
1685        .await
1686        .update_project(project_id, session.connection_id, &request.worktrees)
1687        .await?;
1688    broadcast(
1689        Some(session.connection_id),
1690        guest_connection_ids.iter().copied(),
1691        |connection_id| {
1692            session
1693                .peer
1694                .forward_send(session.connection_id, connection_id, request.clone())
1695        },
1696    );
1697    room_updated(&room, &session.peer);
1698    response.send(proto::Ack {})?;
1699
1700    Ok(())
1701}
1702
1703/// Update other participants with changes to the worktree
1704async fn update_worktree(
1705    request: proto::UpdateWorktree,
1706    response: Response<proto::UpdateWorktree>,
1707    session: Session,
1708) -> Result<()> {
1709    let guest_connection_ids = session
1710        .db()
1711        .await
1712        .update_worktree(&request, session.connection_id)
1713        .await?;
1714
1715    broadcast(
1716        Some(session.connection_id),
1717        guest_connection_ids.iter().copied(),
1718        |connection_id| {
1719            session
1720                .peer
1721                .forward_send(session.connection_id, connection_id, request.clone())
1722        },
1723    );
1724    response.send(proto::Ack {})?;
1725    Ok(())
1726}
1727
1728/// Update other participants with changes to the diagnostics
1729async fn update_diagnostic_summary(
1730    message: proto::UpdateDiagnosticSummary,
1731    session: Session,
1732) -> Result<()> {
1733    let guest_connection_ids = session
1734        .db()
1735        .await
1736        .update_diagnostic_summary(&message, session.connection_id)
1737        .await?;
1738
1739    broadcast(
1740        Some(session.connection_id),
1741        guest_connection_ids.iter().copied(),
1742        |connection_id| {
1743            session
1744                .peer
1745                .forward_send(session.connection_id, connection_id, message.clone())
1746        },
1747    );
1748
1749    Ok(())
1750}
1751
1752/// Update other participants with changes to the worktree settings
1753async fn update_worktree_settings(
1754    message: proto::UpdateWorktreeSettings,
1755    session: Session,
1756) -> Result<()> {
1757    let guest_connection_ids = session
1758        .db()
1759        .await
1760        .update_worktree_settings(&message, session.connection_id)
1761        .await?;
1762
1763    broadcast(
1764        Some(session.connection_id),
1765        guest_connection_ids.iter().copied(),
1766        |connection_id| {
1767            session
1768                .peer
1769                .forward_send(session.connection_id, connection_id, message.clone())
1770        },
1771    );
1772
1773    Ok(())
1774}
1775
1776/// Notify other participants that a  language server has started.
1777async fn start_language_server(
1778    request: proto::StartLanguageServer,
1779    session: Session,
1780) -> Result<()> {
1781    let guest_connection_ids = session
1782        .db()
1783        .await
1784        .start_language_server(&request, session.connection_id)
1785        .await?;
1786
1787    broadcast(
1788        Some(session.connection_id),
1789        guest_connection_ids.iter().copied(),
1790        |connection_id| {
1791            session
1792                .peer
1793                .forward_send(session.connection_id, connection_id, request.clone())
1794        },
1795    );
1796    Ok(())
1797}
1798
1799/// Notify other participants that a language server has changed.
1800async fn update_language_server(
1801    request: proto::UpdateLanguageServer,
1802    session: Session,
1803) -> Result<()> {
1804    let project_id = ProjectId::from_proto(request.project_id);
1805    let project_connection_ids = session
1806        .db()
1807        .await
1808        .project_connection_ids(project_id, session.connection_id)
1809        .await?;
1810    broadcast(
1811        Some(session.connection_id),
1812        project_connection_ids.iter().copied(),
1813        |connection_id| {
1814            session
1815                .peer
1816                .forward_send(session.connection_id, connection_id, request.clone())
1817        },
1818    );
1819    Ok(())
1820}
1821
1822/// forward a project request to the host. These requests should be read only
1823/// as guests are allowed to send them.
1824async fn forward_read_only_project_request<T>(
1825    request: T,
1826    response: Response<T>,
1827    session: Session,
1828) -> Result<()>
1829where
1830    T: EntityMessage + RequestMessage,
1831{
1832    let project_id = ProjectId::from_proto(request.remote_entity_id());
1833    let host_connection_id = session
1834        .db()
1835        .await
1836        .host_for_read_only_project_request(project_id, session.connection_id)
1837        .await?;
1838    let payload = session
1839        .peer
1840        .forward_request(session.connection_id, host_connection_id, request)
1841        .await?;
1842    response.send(payload)?;
1843    Ok(())
1844}
1845
1846/// forward a project request to the host. These requests are disallowed
1847/// for guests.
1848async fn forward_mutating_project_request<T>(
1849    request: T,
1850    response: Response<T>,
1851    session: Session,
1852) -> Result<()>
1853where
1854    T: EntityMessage + RequestMessage,
1855{
1856    let project_id = ProjectId::from_proto(request.remote_entity_id());
1857    let host_connection_id = session
1858        .db()
1859        .await
1860        .host_for_mutating_project_request(project_id, session.connection_id)
1861        .await?;
1862    let payload = session
1863        .peer
1864        .forward_request(session.connection_id, host_connection_id, request)
1865        .await?;
1866    response.send(payload)?;
1867    Ok(())
1868}
1869
1870/// Notify other participants that a new buffer has been created
1871async fn create_buffer_for_peer(
1872    request: proto::CreateBufferForPeer,
1873    session: Session,
1874) -> Result<()> {
1875    session
1876        .db()
1877        .await
1878        .check_user_is_project_host(
1879            ProjectId::from_proto(request.project_id),
1880            session.connection_id,
1881        )
1882        .await?;
1883    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1884    session
1885        .peer
1886        .forward_send(session.connection_id, peer_id.into(), request)?;
1887    Ok(())
1888}
1889
1890/// Notify other participants that a buffer has been updated. This is
1891/// allowed for guests as long as the update is limited to selections.
1892async fn update_buffer(
1893    request: proto::UpdateBuffer,
1894    response: Response<proto::UpdateBuffer>,
1895    session: Session,
1896) -> Result<()> {
1897    let project_id = ProjectId::from_proto(request.project_id);
1898    let mut guest_connection_ids;
1899    let mut host_connection_id = None;
1900
1901    let mut requires_write_permission = false;
1902
1903    for op in request.operations.iter() {
1904        match op.variant {
1905            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1906            Some(_) => requires_write_permission = true,
1907        }
1908    }
1909
1910    {
1911        let collaborators = session
1912            .db()
1913            .await
1914            .project_collaborators_for_buffer_update(
1915                project_id,
1916                session.connection_id,
1917                requires_write_permission,
1918            )
1919            .await?;
1920        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1921        for collaborator in collaborators.iter() {
1922            if collaborator.is_host {
1923                host_connection_id = Some(collaborator.connection_id);
1924            } else {
1925                guest_connection_ids.push(collaborator.connection_id);
1926            }
1927        }
1928    }
1929    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1930
1931    broadcast(
1932        Some(session.connection_id),
1933        guest_connection_ids,
1934        |connection_id| {
1935            session
1936                .peer
1937                .forward_send(session.connection_id, connection_id, request.clone())
1938        },
1939    );
1940    if host_connection_id != session.connection_id {
1941        session
1942            .peer
1943            .forward_request(session.connection_id, host_connection_id, request.clone())
1944            .await?;
1945    }
1946
1947    response.send(proto::Ack {})?;
1948    Ok(())
1949}
1950
1951/// Notify other participants that a project has been updated.
1952async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1953    request: T,
1954    session: Session,
1955) -> Result<()> {
1956    let project_id = ProjectId::from_proto(request.remote_entity_id());
1957    let project_connection_ids = session
1958        .db()
1959        .await
1960        .project_connection_ids(project_id, session.connection_id)
1961        .await?;
1962
1963    broadcast(
1964        Some(session.connection_id),
1965        project_connection_ids.iter().copied(),
1966        |connection_id| {
1967            session
1968                .peer
1969                .forward_send(session.connection_id, connection_id, request.clone())
1970        },
1971    );
1972    Ok(())
1973}
1974
1975/// Start following another user in a call.
1976async fn follow(
1977    request: proto::Follow,
1978    response: Response<proto::Follow>,
1979    session: Session,
1980) -> Result<()> {
1981    let room_id = RoomId::from_proto(request.room_id);
1982    let project_id = request.project_id.map(ProjectId::from_proto);
1983    let leader_id = request
1984        .leader_id
1985        .ok_or_else(|| anyhow!("invalid leader id"))?
1986        .into();
1987    let follower_id = session.connection_id;
1988
1989    session
1990        .db()
1991        .await
1992        .check_room_participants(room_id, leader_id, session.connection_id)
1993        .await?;
1994
1995    let response_payload = session
1996        .peer
1997        .forward_request(session.connection_id, leader_id, request)
1998        .await?;
1999    response.send(response_payload)?;
2000
2001    if let Some(project_id) = project_id {
2002        let room = session
2003            .db()
2004            .await
2005            .follow(room_id, project_id, leader_id, follower_id)
2006            .await?;
2007        room_updated(&room, &session.peer);
2008    }
2009
2010    Ok(())
2011}
2012
2013/// Stop following another user in a call.
2014async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2015    let room_id = RoomId::from_proto(request.room_id);
2016    let project_id = request.project_id.map(ProjectId::from_proto);
2017    let leader_id = request
2018        .leader_id
2019        .ok_or_else(|| anyhow!("invalid leader id"))?
2020        .into();
2021    let follower_id = session.connection_id;
2022
2023    session
2024        .db()
2025        .await
2026        .check_room_participants(room_id, leader_id, session.connection_id)
2027        .await?;
2028
2029    session
2030        .peer
2031        .forward_send(session.connection_id, leader_id, request)?;
2032
2033    if let Some(project_id) = project_id {
2034        let room = session
2035            .db()
2036            .await
2037            .unfollow(room_id, project_id, leader_id, follower_id)
2038            .await?;
2039        room_updated(&room, &session.peer);
2040    }
2041
2042    Ok(())
2043}
2044
2045/// Notify everyone following you of your current location.
2046async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2047    let room_id = RoomId::from_proto(request.room_id);
2048    let database = session.db.lock().await;
2049
2050    let connection_ids = if let Some(project_id) = request.project_id {
2051        let project_id = ProjectId::from_proto(project_id);
2052        database
2053            .project_connection_ids(project_id, session.connection_id)
2054            .await?
2055    } else {
2056        database
2057            .room_connection_ids(room_id, session.connection_id)
2058            .await?
2059    };
2060
2061    // For now, don't send view update messages back to that view's current leader.
2062    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2063        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2064        _ => None,
2065    });
2066
2067    for follower_peer_id in request.follower_ids.iter().copied() {
2068        let follower_connection_id = follower_peer_id.into();
2069        if Some(follower_peer_id) != connection_id_to_omit
2070            && connection_ids.contains(&follower_connection_id)
2071        {
2072            session.peer.forward_send(
2073                session.connection_id,
2074                follower_connection_id,
2075                request.clone(),
2076            )?;
2077        }
2078    }
2079    Ok(())
2080}
2081
2082/// Get public data about users.
2083async fn get_users(
2084    request: proto::GetUsers,
2085    response: Response<proto::GetUsers>,
2086    session: Session,
2087) -> Result<()> {
2088    let user_ids = request
2089        .user_ids
2090        .into_iter()
2091        .map(UserId::from_proto)
2092        .collect();
2093    let users = session
2094        .db()
2095        .await
2096        .get_users_by_ids(user_ids)
2097        .await?
2098        .into_iter()
2099        .map(|user| proto::User {
2100            id: user.id.to_proto(),
2101            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2102            github_login: user.github_login,
2103        })
2104        .collect();
2105    response.send(proto::UsersResponse { users })?;
2106    Ok(())
2107}
2108
2109/// Search for users (to invite) buy Github login
2110async fn fuzzy_search_users(
2111    request: proto::FuzzySearchUsers,
2112    response: Response<proto::FuzzySearchUsers>,
2113    session: Session,
2114) -> Result<()> {
2115    let query = request.query;
2116    let users = match query.len() {
2117        0 => vec![],
2118        1 | 2 => session
2119            .db()
2120            .await
2121            .get_user_by_github_login(&query)
2122            .await?
2123            .into_iter()
2124            .collect(),
2125        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2126    };
2127    let users = users
2128        .into_iter()
2129        .filter(|user| user.id != session.user_id)
2130        .map(|user| proto::User {
2131            id: user.id.to_proto(),
2132            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2133            github_login: user.github_login,
2134        })
2135        .collect();
2136    response.send(proto::UsersResponse { users })?;
2137    Ok(())
2138}
2139
2140/// Send a contact request to another user.
2141async fn request_contact(
2142    request: proto::RequestContact,
2143    response: Response<proto::RequestContact>,
2144    session: Session,
2145) -> Result<()> {
2146    let requester_id = session.user_id;
2147    let responder_id = UserId::from_proto(request.responder_id);
2148    if requester_id == responder_id {
2149        return Err(anyhow!("cannot add yourself as a contact"))?;
2150    }
2151
2152    let notifications = session
2153        .db()
2154        .await
2155        .send_contact_request(requester_id, responder_id)
2156        .await?;
2157
2158    // Update outgoing contact requests of requester
2159    let mut update = proto::UpdateContacts::default();
2160    update.outgoing_requests.push(responder_id.to_proto());
2161    for connection_id in session
2162        .connection_pool()
2163        .await
2164        .user_connection_ids(requester_id)
2165    {
2166        session.peer.send(connection_id, update.clone())?;
2167    }
2168
2169    // Update incoming contact requests of responder
2170    let mut update = proto::UpdateContacts::default();
2171    update
2172        .incoming_requests
2173        .push(proto::IncomingContactRequest {
2174            requester_id: requester_id.to_proto(),
2175        });
2176    let connection_pool = session.connection_pool().await;
2177    for connection_id in connection_pool.user_connection_ids(responder_id) {
2178        session.peer.send(connection_id, update.clone())?;
2179    }
2180
2181    send_notifications(&*connection_pool, &session.peer, notifications);
2182
2183    response.send(proto::Ack {})?;
2184    Ok(())
2185}
2186
2187/// Accept or decline a contact request
2188async fn respond_to_contact_request(
2189    request: proto::RespondToContactRequest,
2190    response: Response<proto::RespondToContactRequest>,
2191    session: Session,
2192) -> Result<()> {
2193    let responder_id = session.user_id;
2194    let requester_id = UserId::from_proto(request.requester_id);
2195    let db = session.db().await;
2196    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2197        db.dismiss_contact_notification(responder_id, requester_id)
2198            .await?;
2199    } else {
2200        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2201
2202        let notifications = db
2203            .respond_to_contact_request(responder_id, requester_id, accept)
2204            .await?;
2205        let requester_busy = db.is_user_busy(requester_id).await?;
2206        let responder_busy = db.is_user_busy(responder_id).await?;
2207
2208        let pool = session.connection_pool().await;
2209        // Update responder with new contact
2210        let mut update = proto::UpdateContacts::default();
2211        if accept {
2212            update
2213                .contacts
2214                .push(contact_for_user(requester_id, requester_busy, &pool));
2215        }
2216        update
2217            .remove_incoming_requests
2218            .push(requester_id.to_proto());
2219        for connection_id in pool.user_connection_ids(responder_id) {
2220            session.peer.send(connection_id, update.clone())?;
2221        }
2222
2223        // Update requester with new contact
2224        let mut update = proto::UpdateContacts::default();
2225        if accept {
2226            update
2227                .contacts
2228                .push(contact_for_user(responder_id, responder_busy, &pool));
2229        }
2230        update
2231            .remove_outgoing_requests
2232            .push(responder_id.to_proto());
2233
2234        for connection_id in pool.user_connection_ids(requester_id) {
2235            session.peer.send(connection_id, update.clone())?;
2236        }
2237
2238        send_notifications(&*pool, &session.peer, notifications);
2239    }
2240
2241    response.send(proto::Ack {})?;
2242    Ok(())
2243}
2244
2245/// Remove a contact.
2246async fn remove_contact(
2247    request: proto::RemoveContact,
2248    response: Response<proto::RemoveContact>,
2249    session: Session,
2250) -> Result<()> {
2251    let requester_id = session.user_id;
2252    let responder_id = UserId::from_proto(request.user_id);
2253    let db = session.db().await;
2254    let (contact_accepted, deleted_notification_id) =
2255        db.remove_contact(requester_id, responder_id).await?;
2256
2257    let pool = session.connection_pool().await;
2258    // Update outgoing contact requests of requester
2259    let mut update = proto::UpdateContacts::default();
2260    if contact_accepted {
2261        update.remove_contacts.push(responder_id.to_proto());
2262    } else {
2263        update
2264            .remove_outgoing_requests
2265            .push(responder_id.to_proto());
2266    }
2267    for connection_id in pool.user_connection_ids(requester_id) {
2268        session.peer.send(connection_id, update.clone())?;
2269    }
2270
2271    // Update incoming contact requests of responder
2272    let mut update = proto::UpdateContacts::default();
2273    if contact_accepted {
2274        update.remove_contacts.push(requester_id.to_proto());
2275    } else {
2276        update
2277            .remove_incoming_requests
2278            .push(requester_id.to_proto());
2279    }
2280    for connection_id in pool.user_connection_ids(responder_id) {
2281        session.peer.send(connection_id, update.clone())?;
2282        if let Some(notification_id) = deleted_notification_id {
2283            session.peer.send(
2284                connection_id,
2285                proto::DeleteNotification {
2286                    notification_id: notification_id.to_proto(),
2287                },
2288            )?;
2289        }
2290    }
2291
2292    response.send(proto::Ack {})?;
2293    Ok(())
2294}
2295
2296/// Create a new channel.
2297async fn create_channel(
2298    request: proto::CreateChannel,
2299    response: Response<proto::CreateChannel>,
2300    session: Session,
2301) -> Result<()> {
2302    let db = session.db().await;
2303
2304    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2305    let CreateChannelResult {
2306        channel,
2307        participants_to_update,
2308    } = db
2309        .create_channel(&request.name, parent_id, session.user_id)
2310        .await?;
2311
2312    response.send(proto::CreateChannelResponse {
2313        channel: Some(channel.to_proto()),
2314        parent_id: request.parent_id,
2315    })?;
2316
2317    let connection_pool = session.connection_pool().await;
2318    for (user_id, channels) in participants_to_update {
2319        let update = build_channels_update(channels, vec![]);
2320        for connection_id in connection_pool.user_connection_ids(user_id) {
2321            if user_id == session.user_id {
2322                continue;
2323            }
2324            session.peer.send(connection_id, update.clone())?;
2325        }
2326    }
2327
2328    Ok(())
2329}
2330
2331/// Delete a channel
2332async fn delete_channel(
2333    request: proto::DeleteChannel,
2334    response: Response<proto::DeleteChannel>,
2335    session: Session,
2336) -> Result<()> {
2337    let db = session.db().await;
2338
2339    let channel_id = request.channel_id;
2340    let (removed_channels, member_ids) = db
2341        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2342        .await?;
2343    response.send(proto::Ack {})?;
2344
2345    // Notify members of removed channels
2346    let mut update = proto::UpdateChannels::default();
2347    update
2348        .delete_channels
2349        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2350
2351    let connection_pool = session.connection_pool().await;
2352    for member_id in member_ids {
2353        for connection_id in connection_pool.user_connection_ids(member_id) {
2354            session.peer.send(connection_id, update.clone())?;
2355        }
2356    }
2357
2358    Ok(())
2359}
2360
2361/// Invite someone to join a channel.
2362async fn invite_channel_member(
2363    request: proto::InviteChannelMember,
2364    response: Response<proto::InviteChannelMember>,
2365    session: Session,
2366) -> Result<()> {
2367    let db = session.db().await;
2368    let channel_id = ChannelId::from_proto(request.channel_id);
2369    let invitee_id = UserId::from_proto(request.user_id);
2370    let InviteMemberResult {
2371        channel,
2372        notifications,
2373    } = db
2374        .invite_channel_member(
2375            channel_id,
2376            invitee_id,
2377            session.user_id,
2378            request.role().into(),
2379        )
2380        .await?;
2381
2382    let update = proto::UpdateChannels {
2383        channel_invitations: vec![channel.to_proto()],
2384        ..Default::default()
2385    };
2386
2387    let connection_pool = session.connection_pool().await;
2388    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2389        session.peer.send(connection_id, update.clone())?;
2390    }
2391
2392    send_notifications(&*connection_pool, &session.peer, notifications);
2393
2394    response.send(proto::Ack {})?;
2395    Ok(())
2396}
2397
2398/// remove someone from a channel
2399async fn remove_channel_member(
2400    request: proto::RemoveChannelMember,
2401    response: Response<proto::RemoveChannelMember>,
2402    session: Session,
2403) -> Result<()> {
2404    let db = session.db().await;
2405    let channel_id = ChannelId::from_proto(request.channel_id);
2406    let member_id = UserId::from_proto(request.user_id);
2407
2408    let RemoveChannelMemberResult {
2409        membership_update,
2410        notification_id,
2411    } = db
2412        .remove_channel_member(channel_id, member_id, session.user_id)
2413        .await?;
2414
2415    let connection_pool = &session.connection_pool().await;
2416    notify_membership_updated(
2417        &connection_pool,
2418        membership_update,
2419        member_id,
2420        &session.peer,
2421    );
2422    for connection_id in connection_pool.user_connection_ids(member_id) {
2423        if let Some(notification_id) = notification_id {
2424            session
2425                .peer
2426                .send(
2427                    connection_id,
2428                    proto::DeleteNotification {
2429                        notification_id: notification_id.to_proto(),
2430                    },
2431                )
2432                .trace_err();
2433        }
2434    }
2435
2436    response.send(proto::Ack {})?;
2437    Ok(())
2438}
2439
2440/// Toggle the channel between public and private
2441async fn set_channel_visibility(
2442    request: proto::SetChannelVisibility,
2443    response: Response<proto::SetChannelVisibility>,
2444    session: Session,
2445) -> Result<()> {
2446    let db = session.db().await;
2447    let channel_id = ChannelId::from_proto(request.channel_id);
2448    let visibility = request.visibility().into();
2449
2450    let SetChannelVisibilityResult {
2451        participants_to_update,
2452        participants_to_remove,
2453        channels_to_remove,
2454    } = db
2455        .set_channel_visibility(channel_id, visibility, session.user_id)
2456        .await?;
2457
2458    let connection_pool = session.connection_pool().await;
2459    for (user_id, channels) in participants_to_update {
2460        let update = build_channels_update(channels, vec![]);
2461        for connection_id in connection_pool.user_connection_ids(user_id) {
2462            session.peer.send(connection_id, update.clone())?;
2463        }
2464    }
2465    for user_id in participants_to_remove {
2466        let update = proto::UpdateChannels {
2467            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2468            ..Default::default()
2469        };
2470        for connection_id in connection_pool.user_connection_ids(user_id) {
2471            session.peer.send(connection_id, update.clone())?;
2472        }
2473    }
2474
2475    response.send(proto::Ack {})?;
2476    Ok(())
2477}
2478
2479/// Alter the role for a user in the channel
2480async fn set_channel_member_role(
2481    request: proto::SetChannelMemberRole,
2482    response: Response<proto::SetChannelMemberRole>,
2483    session: Session,
2484) -> Result<()> {
2485    let db = session.db().await;
2486    let channel_id = ChannelId::from_proto(request.channel_id);
2487    let member_id = UserId::from_proto(request.user_id);
2488    let result = db
2489        .set_channel_member_role(
2490            channel_id,
2491            session.user_id,
2492            member_id,
2493            request.role().into(),
2494        )
2495        .await?;
2496
2497    match result {
2498        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2499            let connection_pool = session.connection_pool().await;
2500            notify_membership_updated(
2501                &connection_pool,
2502                membership_update,
2503                member_id,
2504                &session.peer,
2505            )
2506        }
2507        db::SetMemberRoleResult::InviteUpdated(channel) => {
2508            let update = proto::UpdateChannels {
2509                channel_invitations: vec![channel.to_proto()],
2510                ..Default::default()
2511            };
2512
2513            for connection_id in session
2514                .connection_pool()
2515                .await
2516                .user_connection_ids(member_id)
2517            {
2518                session.peer.send(connection_id, update.clone())?;
2519            }
2520        }
2521    }
2522
2523    response.send(proto::Ack {})?;
2524    Ok(())
2525}
2526
2527/// Change the name of a channel
2528async fn rename_channel(
2529    request: proto::RenameChannel,
2530    response: Response<proto::RenameChannel>,
2531    session: Session,
2532) -> Result<()> {
2533    let db = session.db().await;
2534    let channel_id = ChannelId::from_proto(request.channel_id);
2535    let RenameChannelResult {
2536        channel,
2537        participants_to_update,
2538    } = db
2539        .rename_channel(channel_id, session.user_id, &request.name)
2540        .await?;
2541
2542    response.send(proto::RenameChannelResponse {
2543        channel: Some(channel.to_proto()),
2544    })?;
2545
2546    let connection_pool = session.connection_pool().await;
2547    for (user_id, channel) in participants_to_update {
2548        for connection_id in connection_pool.user_connection_ids(user_id) {
2549            let update = proto::UpdateChannels {
2550                channels: vec![channel.to_proto()],
2551                ..Default::default()
2552            };
2553
2554            session.peer.send(connection_id, update.clone())?;
2555        }
2556    }
2557
2558    Ok(())
2559}
2560
2561/// Move a channel to a new parent.
2562async fn move_channel(
2563    request: proto::MoveChannel,
2564    response: Response<proto::MoveChannel>,
2565    session: Session,
2566) -> Result<()> {
2567    let channel_id = ChannelId::from_proto(request.channel_id);
2568    let to = request.to.map(ChannelId::from_proto);
2569
2570    let result = session
2571        .db()
2572        .await
2573        .move_channel(channel_id, to, session.user_id)
2574        .await?;
2575
2576    notify_channel_moved(result, session).await?;
2577
2578    response.send(Ack {})?;
2579    Ok(())
2580}
2581
2582async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2583    let Some(MoveChannelResult {
2584        participants_to_remove,
2585        participants_to_update,
2586        moved_channels,
2587    }) = result
2588    else {
2589        return Ok(());
2590    };
2591    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2592
2593    let connection_pool = session.connection_pool().await;
2594    for (user_id, channels) in participants_to_update {
2595        let mut update = build_channels_update(channels, vec![]);
2596        update.delete_channels = moved_channels.clone();
2597        for connection_id in connection_pool.user_connection_ids(user_id) {
2598            session.peer.send(connection_id, update.clone())?;
2599        }
2600    }
2601
2602    for user_id in participants_to_remove {
2603        let update = proto::UpdateChannels {
2604            delete_channels: moved_channels.clone(),
2605            ..Default::default()
2606        };
2607        for connection_id in connection_pool.user_connection_ids(user_id) {
2608            session.peer.send(connection_id, update.clone())?;
2609        }
2610    }
2611    Ok(())
2612}
2613
2614/// Get the list of channel members
2615async fn get_channel_members(
2616    request: proto::GetChannelMembers,
2617    response: Response<proto::GetChannelMembers>,
2618    session: Session,
2619) -> Result<()> {
2620    let db = session.db().await;
2621    let channel_id = ChannelId::from_proto(request.channel_id);
2622    let members = db
2623        .get_channel_participant_details(channel_id, session.user_id)
2624        .await?;
2625    response.send(proto::GetChannelMembersResponse { members })?;
2626    Ok(())
2627}
2628
2629/// Accept or decline a channel invitation.
2630async fn respond_to_channel_invite(
2631    request: proto::RespondToChannelInvite,
2632    response: Response<proto::RespondToChannelInvite>,
2633    session: Session,
2634) -> Result<()> {
2635    let db = session.db().await;
2636    let channel_id = ChannelId::from_proto(request.channel_id);
2637    let RespondToChannelInvite {
2638        membership_update,
2639        notifications,
2640    } = db
2641        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2642        .await?;
2643
2644    let connection_pool = session.connection_pool().await;
2645    if let Some(membership_update) = membership_update {
2646        notify_membership_updated(
2647            &connection_pool,
2648            membership_update,
2649            session.user_id,
2650            &session.peer,
2651        );
2652    } else {
2653        let update = proto::UpdateChannels {
2654            remove_channel_invitations: vec![channel_id.to_proto()],
2655            ..Default::default()
2656        };
2657
2658        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2659            session.peer.send(connection_id, update.clone())?;
2660        }
2661    };
2662
2663    send_notifications(&*connection_pool, &session.peer, notifications);
2664
2665    response.send(proto::Ack {})?;
2666
2667    Ok(())
2668}
2669
2670/// Join the channels' room
2671async fn join_channel(
2672    request: proto::JoinChannel,
2673    response: Response<proto::JoinChannel>,
2674    session: Session,
2675) -> Result<()> {
2676    let channel_id = ChannelId::from_proto(request.channel_id);
2677    join_channel_internal(channel_id, Box::new(response), session).await
2678}
2679
2680trait JoinChannelInternalResponse {
2681    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2682}
2683impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2684    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2685        Response::<proto::JoinChannel>::send(self, result)
2686    }
2687}
2688impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2689    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2690        Response::<proto::JoinRoom>::send(self, result)
2691    }
2692}
2693
2694async fn join_channel_internal(
2695    channel_id: ChannelId,
2696    response: Box<impl JoinChannelInternalResponse>,
2697    session: Session,
2698) -> Result<()> {
2699    let joined_room = {
2700        leave_room_for_session(&session).await?;
2701        let db = session.db().await;
2702
2703        let (joined_room, membership_updated, role) = db
2704            .join_channel(
2705                channel_id,
2706                session.user_id,
2707                session.connection_id,
2708                session.zed_environment.as_ref(),
2709            )
2710            .await?;
2711
2712        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2713            let (can_publish, token) = if role == ChannelRole::Guest {
2714                (
2715                    false,
2716                    live_kit
2717                        .guest_token(
2718                            &joined_room.room.live_kit_room,
2719                            &session.user_id.to_string(),
2720                        )
2721                        .trace_err()?,
2722                )
2723            } else {
2724                (
2725                    true,
2726                    live_kit
2727                        .room_token(
2728                            &joined_room.room.live_kit_room,
2729                            &session.user_id.to_string(),
2730                        )
2731                        .trace_err()?,
2732                )
2733            };
2734
2735            Some(LiveKitConnectionInfo {
2736                server_url: live_kit.url().into(),
2737                token,
2738                can_publish,
2739            })
2740        });
2741
2742        response.send(proto::JoinRoomResponse {
2743            room: Some(joined_room.room.clone()),
2744            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2745            live_kit_connection_info,
2746        })?;
2747
2748        let connection_pool = session.connection_pool().await;
2749        if let Some(membership_updated) = membership_updated {
2750            notify_membership_updated(
2751                &connection_pool,
2752                membership_updated,
2753                session.user_id,
2754                &session.peer,
2755            );
2756        }
2757
2758        room_updated(&joined_room.room, &session.peer);
2759
2760        joined_room
2761    };
2762
2763    channel_updated(
2764        channel_id,
2765        &joined_room.room,
2766        &joined_room.channel_members,
2767        &session.peer,
2768        &*session.connection_pool().await,
2769    );
2770
2771    update_user_contacts(session.user_id, &session).await?;
2772    Ok(())
2773}
2774
2775/// Start editing the channel notes
2776async fn join_channel_buffer(
2777    request: proto::JoinChannelBuffer,
2778    response: Response<proto::JoinChannelBuffer>,
2779    session: Session,
2780) -> Result<()> {
2781    let db = session.db().await;
2782    let channel_id = ChannelId::from_proto(request.channel_id);
2783
2784    let open_response = db
2785        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2786        .await?;
2787
2788    let collaborators = open_response.collaborators.clone();
2789    response.send(open_response)?;
2790
2791    let update = UpdateChannelBufferCollaborators {
2792        channel_id: channel_id.to_proto(),
2793        collaborators: collaborators.clone(),
2794    };
2795    channel_buffer_updated(
2796        session.connection_id,
2797        collaborators
2798            .iter()
2799            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2800        &update,
2801        &session.peer,
2802    );
2803
2804    Ok(())
2805}
2806
2807/// Edit the channel notes
2808async fn update_channel_buffer(
2809    request: proto::UpdateChannelBuffer,
2810    session: Session,
2811) -> Result<()> {
2812    let db = session.db().await;
2813    let channel_id = ChannelId::from_proto(request.channel_id);
2814
2815    let (collaborators, non_collaborators, epoch, version) = db
2816        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2817        .await?;
2818
2819    channel_buffer_updated(
2820        session.connection_id,
2821        collaborators,
2822        &proto::UpdateChannelBuffer {
2823            channel_id: channel_id.to_proto(),
2824            operations: request.operations,
2825        },
2826        &session.peer,
2827    );
2828
2829    let pool = &*session.connection_pool().await;
2830
2831    broadcast(
2832        None,
2833        non_collaborators
2834            .iter()
2835            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2836        |peer_id| {
2837            session.peer.send(
2838                peer_id.into(),
2839                proto::UpdateChannels {
2840                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2841                        channel_id: channel_id.to_proto(),
2842                        epoch: epoch as u64,
2843                        version: version.clone(),
2844                    }],
2845                    ..Default::default()
2846                },
2847            )
2848        },
2849    );
2850
2851    Ok(())
2852}
2853
2854/// Rejoin the channel notes after a connection blip
2855async fn rejoin_channel_buffers(
2856    request: proto::RejoinChannelBuffers,
2857    response: Response<proto::RejoinChannelBuffers>,
2858    session: Session,
2859) -> Result<()> {
2860    let db = session.db().await;
2861    let buffers = db
2862        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2863        .await?;
2864
2865    for rejoined_buffer in &buffers {
2866        let collaborators_to_notify = rejoined_buffer
2867            .buffer
2868            .collaborators
2869            .iter()
2870            .filter_map(|c| Some(c.peer_id?.into()));
2871        channel_buffer_updated(
2872            session.connection_id,
2873            collaborators_to_notify,
2874            &proto::UpdateChannelBufferCollaborators {
2875                channel_id: rejoined_buffer.buffer.channel_id,
2876                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2877            },
2878            &session.peer,
2879        );
2880    }
2881
2882    response.send(proto::RejoinChannelBuffersResponse {
2883        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2884    })?;
2885
2886    Ok(())
2887}
2888
2889/// Stop editing the channel notes
2890async fn leave_channel_buffer(
2891    request: proto::LeaveChannelBuffer,
2892    response: Response<proto::LeaveChannelBuffer>,
2893    session: Session,
2894) -> Result<()> {
2895    let db = session.db().await;
2896    let channel_id = ChannelId::from_proto(request.channel_id);
2897
2898    let left_buffer = db
2899        .leave_channel_buffer(channel_id, session.connection_id)
2900        .await?;
2901
2902    response.send(Ack {})?;
2903
2904    channel_buffer_updated(
2905        session.connection_id,
2906        left_buffer.connections,
2907        &proto::UpdateChannelBufferCollaborators {
2908            channel_id: channel_id.to_proto(),
2909            collaborators: left_buffer.collaborators,
2910        },
2911        &session.peer,
2912    );
2913
2914    Ok(())
2915}
2916
2917fn channel_buffer_updated<T: EnvelopedMessage>(
2918    sender_id: ConnectionId,
2919    collaborators: impl IntoIterator<Item = ConnectionId>,
2920    message: &T,
2921    peer: &Peer,
2922) {
2923    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2924        peer.send(peer_id.into(), message.clone())
2925    });
2926}
2927
2928fn send_notifications(
2929    connection_pool: &ConnectionPool,
2930    peer: &Peer,
2931    notifications: db::NotificationBatch,
2932) {
2933    for (user_id, notification) in notifications {
2934        for connection_id in connection_pool.user_connection_ids(user_id) {
2935            if let Err(error) = peer.send(
2936                connection_id,
2937                proto::AddNotification {
2938                    notification: Some(notification.clone()),
2939                },
2940            ) {
2941                tracing::error!(
2942                    "failed to send notification to {:?} {}",
2943                    connection_id,
2944                    error
2945                );
2946            }
2947        }
2948    }
2949}
2950
2951/// Send a message to the channel
2952async fn send_channel_message(
2953    request: proto::SendChannelMessage,
2954    response: Response<proto::SendChannelMessage>,
2955    session: Session,
2956) -> Result<()> {
2957    // Validate the message body.
2958    let body = request.body.trim().to_string();
2959    if body.len() > MAX_MESSAGE_LEN {
2960        return Err(anyhow!("message is too long"))?;
2961    }
2962    if body.is_empty() {
2963        return Err(anyhow!("message can't be blank"))?;
2964    }
2965
2966    // TODO: adjust mentions if body is trimmed
2967
2968    let timestamp = OffsetDateTime::now_utc();
2969    let nonce = request
2970        .nonce
2971        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2972
2973    let channel_id = ChannelId::from_proto(request.channel_id);
2974    let CreatedChannelMessage {
2975        message_id,
2976        participant_connection_ids,
2977        channel_members,
2978        notifications,
2979    } = session
2980        .db()
2981        .await
2982        .create_channel_message(
2983            channel_id,
2984            session.user_id,
2985            &body,
2986            &request.mentions,
2987            timestamp,
2988            nonce.clone().into(),
2989        )
2990        .await?;
2991    let message = proto::ChannelMessage {
2992        sender_id: session.user_id.to_proto(),
2993        id: message_id.to_proto(),
2994        body,
2995        mentions: request.mentions,
2996        timestamp: timestamp.unix_timestamp() as u64,
2997        nonce: Some(nonce),
2998    };
2999    broadcast(
3000        Some(session.connection_id),
3001        participant_connection_ids,
3002        |connection| {
3003            session.peer.send(
3004                connection,
3005                proto::ChannelMessageSent {
3006                    channel_id: channel_id.to_proto(),
3007                    message: Some(message.clone()),
3008                },
3009            )
3010        },
3011    );
3012    response.send(proto::SendChannelMessageResponse {
3013        message: Some(message),
3014    })?;
3015
3016    let pool = &*session.connection_pool().await;
3017    broadcast(
3018        None,
3019        channel_members
3020            .iter()
3021            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3022        |peer_id| {
3023            session.peer.send(
3024                peer_id.into(),
3025                proto::UpdateChannels {
3026                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
3027                        channel_id: channel_id.to_proto(),
3028                        message_id: message_id.to_proto(),
3029                    }],
3030                    ..Default::default()
3031                },
3032            )
3033        },
3034    );
3035    send_notifications(pool, &session.peer, notifications);
3036
3037    Ok(())
3038}
3039
3040/// Delete a channel message
3041async fn remove_channel_message(
3042    request: proto::RemoveChannelMessage,
3043    response: Response<proto::RemoveChannelMessage>,
3044    session: Session,
3045) -> Result<()> {
3046    let channel_id = ChannelId::from_proto(request.channel_id);
3047    let message_id = MessageId::from_proto(request.message_id);
3048    let connection_ids = session
3049        .db()
3050        .await
3051        .remove_channel_message(channel_id, message_id, session.user_id)
3052        .await?;
3053    broadcast(Some(session.connection_id), connection_ids, |connection| {
3054        session.peer.send(connection, request.clone())
3055    });
3056    response.send(proto::Ack {})?;
3057    Ok(())
3058}
3059
3060/// Mark a channel message as read
3061async fn acknowledge_channel_message(
3062    request: proto::AckChannelMessage,
3063    session: Session,
3064) -> Result<()> {
3065    let channel_id = ChannelId::from_proto(request.channel_id);
3066    let message_id = MessageId::from_proto(request.message_id);
3067    let notifications = session
3068        .db()
3069        .await
3070        .observe_channel_message(channel_id, session.user_id, message_id)
3071        .await?;
3072    send_notifications(
3073        &*session.connection_pool().await,
3074        &session.peer,
3075        notifications,
3076    );
3077    Ok(())
3078}
3079
3080/// Mark a buffer version as synced
3081async fn acknowledge_buffer_version(
3082    request: proto::AckBufferOperation,
3083    session: Session,
3084) -> Result<()> {
3085    let buffer_id = BufferId::from_proto(request.buffer_id);
3086    session
3087        .db()
3088        .await
3089        .observe_buffer_version(
3090            buffer_id,
3091            session.user_id,
3092            request.epoch as i32,
3093            &request.version,
3094        )
3095        .await?;
3096    Ok(())
3097}
3098
3099/// Start receiving chat updates for a channel
3100async fn join_channel_chat(
3101    request: proto::JoinChannelChat,
3102    response: Response<proto::JoinChannelChat>,
3103    session: Session,
3104) -> Result<()> {
3105    let channel_id = ChannelId::from_proto(request.channel_id);
3106
3107    let db = session.db().await;
3108    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3109        .await?;
3110    let messages = db
3111        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3112        .await?;
3113    response.send(proto::JoinChannelChatResponse {
3114        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3115        messages,
3116    })?;
3117    Ok(())
3118}
3119
3120/// Stop receiving chat updates for a channel
3121async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3122    let channel_id = ChannelId::from_proto(request.channel_id);
3123    session
3124        .db()
3125        .await
3126        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3127        .await?;
3128    Ok(())
3129}
3130
3131/// Retrieve the chat history for a channel
3132async fn get_channel_messages(
3133    request: proto::GetChannelMessages,
3134    response: Response<proto::GetChannelMessages>,
3135    session: Session,
3136) -> Result<()> {
3137    let channel_id = ChannelId::from_proto(request.channel_id);
3138    let messages = session
3139        .db()
3140        .await
3141        .get_channel_messages(
3142            channel_id,
3143            session.user_id,
3144            MESSAGE_COUNT_PER_PAGE,
3145            Some(MessageId::from_proto(request.before_message_id)),
3146        )
3147        .await?;
3148    response.send(proto::GetChannelMessagesResponse {
3149        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3150        messages,
3151    })?;
3152    Ok(())
3153}
3154
3155/// Retrieve specific chat messages
3156async fn get_channel_messages_by_id(
3157    request: proto::GetChannelMessagesById,
3158    response: Response<proto::GetChannelMessagesById>,
3159    session: Session,
3160) -> Result<()> {
3161    let message_ids = request
3162        .message_ids
3163        .iter()
3164        .map(|id| MessageId::from_proto(*id))
3165        .collect::<Vec<_>>();
3166    let messages = session
3167        .db()
3168        .await
3169        .get_channel_messages_by_id(session.user_id, &message_ids)
3170        .await?;
3171    response.send(proto::GetChannelMessagesResponse {
3172        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3173        messages,
3174    })?;
3175    Ok(())
3176}
3177
3178/// Retrieve the current users notifications
3179async fn get_notifications(
3180    request: proto::GetNotifications,
3181    response: Response<proto::GetNotifications>,
3182    session: Session,
3183) -> Result<()> {
3184    let notifications = session
3185        .db()
3186        .await
3187        .get_notifications(
3188            session.user_id,
3189            NOTIFICATION_COUNT_PER_PAGE,
3190            request
3191                .before_id
3192                .map(|id| db::NotificationId::from_proto(id)),
3193        )
3194        .await?;
3195    response.send(proto::GetNotificationsResponse {
3196        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3197        notifications,
3198    })?;
3199    Ok(())
3200}
3201
3202/// Mark notifications as read
3203async fn mark_notification_as_read(
3204    request: proto::MarkNotificationRead,
3205    response: Response<proto::MarkNotificationRead>,
3206    session: Session,
3207) -> Result<()> {
3208    let database = &session.db().await;
3209    let notifications = database
3210        .mark_notification_as_read_by_id(
3211            session.user_id,
3212            NotificationId::from_proto(request.notification_id),
3213        )
3214        .await?;
3215    send_notifications(
3216        &*session.connection_pool().await,
3217        &session.peer,
3218        notifications,
3219    );
3220    response.send(proto::Ack {})?;
3221    Ok(())
3222}
3223
3224/// Get the current users information
3225async fn get_private_user_info(
3226    _request: proto::GetPrivateUserInfo,
3227    response: Response<proto::GetPrivateUserInfo>,
3228    session: Session,
3229) -> Result<()> {
3230    let db = session.db().await;
3231
3232    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3233    let user = db
3234        .get_user_by_id(session.user_id)
3235        .await?
3236        .ok_or_else(|| anyhow!("user not found"))?;
3237    let flags = db.get_user_flags(session.user_id).await?;
3238
3239    response.send(proto::GetPrivateUserInfoResponse {
3240        metrics_id,
3241        staff: user.admin,
3242        flags,
3243    })?;
3244    Ok(())
3245}
3246
3247fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3248    match message {
3249        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3250        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3251        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3252        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3253        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3254            code: frame.code.into(),
3255            reason: frame.reason,
3256        })),
3257    }
3258}
3259
3260fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3261    match message {
3262        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3263        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3264        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3265        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3266        AxumMessage::Close(frame) => {
3267            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3268                code: frame.code.into(),
3269                reason: frame.reason,
3270            }))
3271        }
3272    }
3273}
3274
3275fn notify_membership_updated(
3276    connection_pool: &ConnectionPool,
3277    result: MembershipUpdated,
3278    user_id: UserId,
3279    peer: &Peer,
3280) {
3281    let mut update = build_channels_update(result.new_channels, vec![]);
3282    update.delete_channels = result
3283        .removed_channels
3284        .into_iter()
3285        .map(|id| id.to_proto())
3286        .collect();
3287    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3288
3289    for connection_id in connection_pool.user_connection_ids(user_id) {
3290        peer.send(connection_id, update.clone()).trace_err();
3291    }
3292}
3293
3294fn build_channels_update(
3295    channels: ChannelsForUser,
3296    channel_invites: Vec<db::Channel>,
3297) -> proto::UpdateChannels {
3298    let mut update = proto::UpdateChannels::default();
3299
3300    for channel in channels.channels {
3301        update.channels.push(channel.to_proto());
3302    }
3303
3304    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3305    update.unseen_channel_messages = channels.channel_messages;
3306
3307    for (channel_id, participants) in channels.channel_participants {
3308        update
3309            .channel_participants
3310            .push(proto::ChannelParticipants {
3311                channel_id: channel_id.to_proto(),
3312                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3313            });
3314    }
3315
3316    for channel in channel_invites {
3317        update.channel_invitations.push(channel.to_proto());
3318    }
3319
3320    update
3321}
3322
3323fn build_initial_contacts_update(
3324    contacts: Vec<db::Contact>,
3325    pool: &ConnectionPool,
3326) -> proto::UpdateContacts {
3327    let mut update = proto::UpdateContacts::default();
3328
3329    for contact in contacts {
3330        match contact {
3331            db::Contact::Accepted { user_id, busy } => {
3332                update.contacts.push(contact_for_user(user_id, busy, &pool));
3333            }
3334            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3335            db::Contact::Incoming { user_id } => {
3336                update
3337                    .incoming_requests
3338                    .push(proto::IncomingContactRequest {
3339                        requester_id: user_id.to_proto(),
3340                    })
3341            }
3342        }
3343    }
3344
3345    update
3346}
3347
3348fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3349    proto::Contact {
3350        user_id: user_id.to_proto(),
3351        online: pool.is_user_online(user_id),
3352        busy,
3353    }
3354}
3355
3356fn room_updated(room: &proto::Room, peer: &Peer) {
3357    broadcast(
3358        None,
3359        room.participants
3360            .iter()
3361            .filter_map(|participant| Some(participant.peer_id?.into())),
3362        |peer_id| {
3363            peer.send(
3364                peer_id.into(),
3365                proto::RoomUpdated {
3366                    room: Some(room.clone()),
3367                },
3368            )
3369        },
3370    );
3371}
3372
3373fn channel_updated(
3374    channel_id: ChannelId,
3375    room: &proto::Room,
3376    channel_members: &[UserId],
3377    peer: &Peer,
3378    pool: &ConnectionPool,
3379) {
3380    let participants = room
3381        .participants
3382        .iter()
3383        .map(|p| p.user_id)
3384        .collect::<Vec<_>>();
3385
3386    broadcast(
3387        None,
3388        channel_members
3389            .iter()
3390            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3391        |peer_id| {
3392            peer.send(
3393                peer_id.into(),
3394                proto::UpdateChannels {
3395                    channel_participants: vec![proto::ChannelParticipants {
3396                        channel_id: channel_id.to_proto(),
3397                        participant_user_ids: participants.clone(),
3398                    }],
3399                    ..Default::default()
3400                },
3401            )
3402        },
3403    );
3404}
3405
3406async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3407    let db = session.db().await;
3408
3409    let contacts = db.get_contacts(user_id).await?;
3410    let busy = db.is_user_busy(user_id).await?;
3411
3412    let pool = session.connection_pool().await;
3413    let updated_contact = contact_for_user(user_id, busy, &pool);
3414    for contact in contacts {
3415        if let db::Contact::Accepted {
3416            user_id: contact_user_id,
3417            ..
3418        } = contact
3419        {
3420            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3421                session
3422                    .peer
3423                    .send(
3424                        contact_conn_id,
3425                        proto::UpdateContacts {
3426                            contacts: vec![updated_contact.clone()],
3427                            remove_contacts: Default::default(),
3428                            incoming_requests: Default::default(),
3429                            remove_incoming_requests: Default::default(),
3430                            outgoing_requests: Default::default(),
3431                            remove_outgoing_requests: Default::default(),
3432                        },
3433                    )
3434                    .trace_err();
3435            }
3436        }
3437    }
3438    Ok(())
3439}
3440
3441async fn leave_room_for_session(session: &Session) -> Result<()> {
3442    let mut contacts_to_update = HashSet::default();
3443
3444    let room_id;
3445    let canceled_calls_to_user_ids;
3446    let live_kit_room;
3447    let delete_live_kit_room;
3448    let room;
3449    let channel_members;
3450    let channel_id;
3451
3452    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3453        contacts_to_update.insert(session.user_id);
3454
3455        for project in left_room.left_projects.values() {
3456            project_left(project, session);
3457        }
3458
3459        room_id = RoomId::from_proto(left_room.room.id);
3460        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3461        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3462        delete_live_kit_room = left_room.deleted;
3463        room = mem::take(&mut left_room.room);
3464        channel_members = mem::take(&mut left_room.channel_members);
3465        channel_id = left_room.channel_id;
3466
3467        room_updated(&room, &session.peer);
3468    } else {
3469        return Ok(());
3470    }
3471
3472    if let Some(channel_id) = channel_id {
3473        channel_updated(
3474            channel_id,
3475            &room,
3476            &channel_members,
3477            &session.peer,
3478            &*session.connection_pool().await,
3479        );
3480    }
3481
3482    {
3483        let pool = session.connection_pool().await;
3484        for canceled_user_id in canceled_calls_to_user_ids {
3485            for connection_id in pool.user_connection_ids(canceled_user_id) {
3486                session
3487                    .peer
3488                    .send(
3489                        connection_id,
3490                        proto::CallCanceled {
3491                            room_id: room_id.to_proto(),
3492                        },
3493                    )
3494                    .trace_err();
3495            }
3496            contacts_to_update.insert(canceled_user_id);
3497        }
3498    }
3499
3500    for contact_user_id in contacts_to_update {
3501        update_user_contacts(contact_user_id, &session).await?;
3502    }
3503
3504    if let Some(live_kit) = session.live_kit_client.as_ref() {
3505        live_kit
3506            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3507            .await
3508            .trace_err();
3509
3510        if delete_live_kit_room {
3511            live_kit.delete_room(live_kit_room).await.trace_err();
3512        }
3513    }
3514
3515    Ok(())
3516}
3517
3518async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3519    let left_channel_buffers = session
3520        .db()
3521        .await
3522        .leave_channel_buffers(session.connection_id)
3523        .await?;
3524
3525    for left_buffer in left_channel_buffers {
3526        channel_buffer_updated(
3527            session.connection_id,
3528            left_buffer.connections,
3529            &proto::UpdateChannelBufferCollaborators {
3530                channel_id: left_buffer.channel_id.to_proto(),
3531                collaborators: left_buffer.collaborators,
3532            },
3533            &session.peer,
3534        );
3535    }
3536
3537    Ok(())
3538}
3539
3540fn project_left(project: &db::LeftProject, session: &Session) {
3541    for connection_id in &project.connection_ids {
3542        if project.host_user_id == session.user_id {
3543            session
3544                .peer
3545                .send(
3546                    *connection_id,
3547                    proto::UnshareProject {
3548                        project_id: project.id.to_proto(),
3549                    },
3550                )
3551                .trace_err();
3552        } else {
3553            session
3554                .peer
3555                .send(
3556                    *connection_id,
3557                    proto::RemoveProjectCollaborator {
3558                        project_id: project.id.to_proto(),
3559                        peer_id: Some(session.connection_id.into()),
3560                    },
3561                )
3562                .trace_err();
3563        }
3564    }
3565}
3566
3567pub trait ResultExt {
3568    type Ok;
3569
3570    fn trace_err(self) -> Option<Self::Ok>;
3571}
3572
3573impl<T, E> ResultExt for Result<T, E>
3574where
3575    E: std::fmt::Debug,
3576{
3577    type Ok = T;
3578
3579    fn trace_err(self) -> Option<T> {
3580        match self {
3581            Ok(value) => Some(value),
3582            Err(error) => {
3583                tracing::error!("{:?}", error);
3584                None
3585            }
3586        }
3587    }
3588}