rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::ConnectionPool;
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use lazy_static::lazy_static;
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67use util::SemanticVersion;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  75
  76lazy_static! {
  77    static ref METRIC_CONNECTIONS: IntGauge =
  78        register_int_gauge!("connections", "number of connections").unwrap();
  79    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  80        "shared_projects",
  81        "number of open projects with one or more guests"
  82    )
  83    .unwrap();
  84}
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103#[derive(Clone)]
 104struct Session {
 105    user_id: UserId,
 106    connection_id: ConnectionId,
 107    db: Arc<tokio::sync::Mutex<DbHandle>>,
 108    peer: Arc<Peer>,
 109    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 110    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 111    _executor: Executor,
 112}
 113
 114impl Session {
 115    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        let guard = self.db.lock().await;
 119        #[cfg(test)]
 120        tokio::task::yield_now().await;
 121        guard
 122    }
 123
 124    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 125        #[cfg(test)]
 126        tokio::task::yield_now().await;
 127        let guard = self.connection_pool.lock();
 128        ConnectionPoolGuard {
 129            guard,
 130            _not_send: PhantomData,
 131        }
 132    }
 133}
 134
 135impl fmt::Debug for Session {
 136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 137        f.debug_struct("Session")
 138            .field("user_id", &self.user_id)
 139            .field("connection_id", &self.connection_id)
 140            .finish()
 141    }
 142}
 143
 144struct DbHandle(Arc<Database>);
 145
 146impl Deref for DbHandle {
 147    type Target = Database;
 148
 149    fn deref(&self) -> &Self::Target {
 150        self.0.as_ref()
 151    }
 152}
 153
 154pub struct Server {
 155    id: parking_lot::Mutex<ServerId>,
 156    peer: Arc<Peer>,
 157    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 158    app_state: Arc<AppState>,
 159    executor: Executor,
 160    handlers: HashMap<TypeId, MessageHandler>,
 161    teardown: watch::Sender<()>,
 162}
 163
 164pub(crate) struct ConnectionPoolGuard<'a> {
 165    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 166    _not_send: PhantomData<Rc<()>>,
 167}
 168
 169#[derive(Serialize)]
 170pub struct ServerSnapshot<'a> {
 171    peer: &'a Peer,
 172    #[serde(serialize_with = "serialize_deref")]
 173    connection_pool: ConnectionPoolGuard<'a>,
 174}
 175
 176pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 177where
 178    S: Serializer,
 179    T: Deref<Target = U>,
 180    U: Serialize,
 181{
 182    Serialize::serialize(value.deref(), serializer)
 183}
 184
 185impl Server {
 186    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 187        let mut server = Self {
 188            id: parking_lot::Mutex::new(id),
 189            peer: Peer::new(id.0 as u32),
 190            app_state,
 191            executor,
 192            connection_pool: Default::default(),
 193            handlers: Default::default(),
 194            teardown: watch::channel(()).0,
 195        };
 196
 197        server
 198            .add_request_handler(ping)
 199            .add_request_handler(create_room)
 200            .add_request_handler(join_room)
 201            .add_request_handler(rejoin_room)
 202            .add_request_handler(leave_room)
 203            .add_request_handler(set_room_participant_role)
 204            .add_request_handler(call)
 205            .add_request_handler(cancel_call)
 206            .add_message_handler(decline_call)
 207            .add_request_handler(update_participant_location)
 208            .add_request_handler(share_project)
 209            .add_message_handler(unshare_project)
 210            .add_request_handler(join_project)
 211            .add_message_handler(leave_project)
 212            .add_request_handler(update_project)
 213            .add_request_handler(update_worktree)
 214            .add_message_handler(start_language_server)
 215            .add_message_handler(update_language_server)
 216            .add_message_handler(update_diagnostic_summary)
 217            .add_message_handler(update_worktree_settings)
 218            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 219            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 222            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 225            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 226            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 227            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 228            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 229            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 230            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 231            .add_request_handler(
 232                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 233            )
 234            .add_request_handler(
 235                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 236            )
 237            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 238            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 239            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 240            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 241            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 242            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 243            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 244            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 249            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 250            .add_message_handler(create_buffer_for_peer)
 251            .add_request_handler(update_buffer)
 252            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 253            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 257            .add_request_handler(get_users)
 258            .add_request_handler(fuzzy_search_users)
 259            .add_request_handler(request_contact)
 260            .add_request_handler(remove_contact)
 261            .add_request_handler(respond_to_contact_request)
 262            .add_request_handler(create_channel)
 263            .add_request_handler(delete_channel)
 264            .add_request_handler(invite_channel_member)
 265            .add_request_handler(remove_channel_member)
 266            .add_request_handler(set_channel_member_role)
 267            .add_request_handler(set_channel_visibility)
 268            .add_request_handler(rename_channel)
 269            .add_request_handler(join_channel_buffer)
 270            .add_request_handler(leave_channel_buffer)
 271            .add_message_handler(update_channel_buffer)
 272            .add_request_handler(rejoin_channel_buffers)
 273            .add_request_handler(get_channel_members)
 274            .add_request_handler(respond_to_channel_invite)
 275            .add_request_handler(join_channel)
 276            .add_request_handler(join_channel_chat)
 277            .add_message_handler(leave_channel_chat)
 278            .add_request_handler(send_channel_message)
 279            .add_request_handler(remove_channel_message)
 280            .add_request_handler(get_channel_messages)
 281            .add_request_handler(get_channel_messages_by_id)
 282            .add_request_handler(get_notifications)
 283            .add_request_handler(mark_notification_as_read)
 284            .add_request_handler(move_channel)
 285            .add_request_handler(follow)
 286            .add_message_handler(unfollow)
 287            .add_message_handler(update_followers)
 288            .add_request_handler(get_private_user_info)
 289            .add_message_handler(acknowledge_channel_message)
 290            .add_message_handler(acknowledge_buffer_version);
 291
 292        Arc::new(server)
 293    }
 294
 295    pub async fn start(&self) -> Result<()> {
 296        let server_id = *self.id.lock();
 297        let app_state = self.app_state.clone();
 298        let peer = self.peer.clone();
 299        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 300        let pool = self.connection_pool.clone();
 301        let live_kit_client = self.app_state.live_kit_client.clone();
 302
 303        let span = info_span!("start server");
 304        self.executor.spawn_detached(
 305            async move {
 306                tracing::info!("waiting for cleanup timeout");
 307                timeout.await;
 308                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 309                if let Some((room_ids, channel_ids)) = app_state
 310                    .db
 311                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 312                    .await
 313                    .trace_err()
 314                {
 315                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 316                    tracing::info!(
 317                        stale_channel_buffer_count = channel_ids.len(),
 318                        "retrieved stale channel buffers"
 319                    );
 320
 321                    for channel_id in channel_ids {
 322                        if let Some(refreshed_channel_buffer) = app_state
 323                            .db
 324                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 325                            .await
 326                            .trace_err()
 327                        {
 328                            for connection_id in refreshed_channel_buffer.connection_ids {
 329                                peer.send(
 330                                    connection_id,
 331                                    proto::UpdateChannelBufferCollaborators {
 332                                        channel_id: channel_id.to_proto(),
 333                                        collaborators: refreshed_channel_buffer
 334                                            .collaborators
 335                                            .clone(),
 336                                    },
 337                                )
 338                                .trace_err();
 339                            }
 340                        }
 341                    }
 342
 343                    for room_id in room_ids {
 344                        let mut contacts_to_update = HashSet::default();
 345                        let mut canceled_calls_to_user_ids = Vec::new();
 346                        let mut live_kit_room = String::new();
 347                        let mut delete_live_kit_room = false;
 348
 349                        if let Some(mut refreshed_room) = app_state
 350                            .db
 351                            .clear_stale_room_participants(room_id, server_id)
 352                            .await
 353                            .trace_err()
 354                        {
 355                            tracing::info!(
 356                                room_id = room_id.0,
 357                                new_participant_count = refreshed_room.room.participants.len(),
 358                                "refreshed room"
 359                            );
 360                            room_updated(&refreshed_room.room, &peer);
 361                            if let Some(channel_id) = refreshed_room.channel_id {
 362                                channel_updated(
 363                                    channel_id,
 364                                    &refreshed_room.room,
 365                                    &refreshed_room.channel_members,
 366                                    &peer,
 367                                    &*pool.lock(),
 368                                );
 369                            }
 370                            contacts_to_update
 371                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 372                            contacts_to_update
 373                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 374                            canceled_calls_to_user_ids =
 375                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 376                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 377                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 378                        }
 379
 380                        {
 381                            let pool = pool.lock();
 382                            for canceled_user_id in canceled_calls_to_user_ids {
 383                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 384                                    peer.send(
 385                                        connection_id,
 386                                        proto::CallCanceled {
 387                                            room_id: room_id.to_proto(),
 388                                        },
 389                                    )
 390                                    .trace_err();
 391                                }
 392                            }
 393                        }
 394
 395                        for user_id in contacts_to_update {
 396                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 397                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 398                            if let Some((busy, contacts)) = busy.zip(contacts) {
 399                                let pool = pool.lock();
 400                                let updated_contact = contact_for_user(user_id, busy, &pool);
 401                                for contact in contacts {
 402                                    if let db::Contact::Accepted {
 403                                        user_id: contact_user_id,
 404                                        ..
 405                                    } = contact
 406                                    {
 407                                        for contact_conn_id in
 408                                            pool.user_connection_ids(contact_user_id)
 409                                        {
 410                                            peer.send(
 411                                                contact_conn_id,
 412                                                proto::UpdateContacts {
 413                                                    contacts: vec![updated_contact.clone()],
 414                                                    remove_contacts: Default::default(),
 415                                                    incoming_requests: Default::default(),
 416                                                    remove_incoming_requests: Default::default(),
 417                                                    outgoing_requests: Default::default(),
 418                                                    remove_outgoing_requests: Default::default(),
 419                                                },
 420                                            )
 421                                            .trace_err();
 422                                        }
 423                                    }
 424                                }
 425                            }
 426                        }
 427
 428                        if let Some(live_kit) = live_kit_client.as_ref() {
 429                            if delete_live_kit_room {
 430                                live_kit.delete_room(live_kit_room).await.trace_err();
 431                            }
 432                        }
 433                    }
 434                }
 435
 436                app_state
 437                    .db
 438                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 439                    .await
 440                    .trace_err();
 441            }
 442            .instrument(span),
 443        );
 444        Ok(())
 445    }
 446
 447    pub fn teardown(&self) {
 448        self.peer.teardown();
 449        self.connection_pool.lock().reset();
 450        let _ = self.teardown.send(());
 451    }
 452
 453    #[cfg(test)]
 454    pub fn reset(&self, id: ServerId) {
 455        self.teardown();
 456        *self.id.lock() = id;
 457        self.peer.reset(id.0 as u32);
 458    }
 459
 460    #[cfg(test)]
 461    pub fn id(&self) -> ServerId {
 462        *self.id.lock()
 463    }
 464
 465    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 466    where
 467        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 468        Fut: 'static + Send + Future<Output = Result<()>>,
 469        M: EnvelopedMessage,
 470    {
 471        let prev_handler = self.handlers.insert(
 472            TypeId::of::<M>(),
 473            Box::new(move |envelope, session| {
 474                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 475                let span = info_span!(
 476                    "handle message",
 477                    payload_type = envelope.payload_type_name()
 478                );
 479                span.in_scope(|| {
 480                    tracing::info!(
 481                        payload_type = envelope.payload_type_name(),
 482                        "message received"
 483                    );
 484                });
 485                let start_time = Instant::now();
 486                let future = (handler)(*envelope, session);
 487                async move {
 488                    let result = future.await;
 489                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 490                    match result {
 491                        Err(error) => {
 492                            tracing::error!(%error, ?duration_ms, "error handling message")
 493                        }
 494                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 495                    }
 496                }
 497                .instrument(span)
 498                .boxed()
 499            }),
 500        );
 501        if prev_handler.is_some() {
 502            panic!("registered a handler for the same message twice");
 503        }
 504        self
 505    }
 506
 507    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 510        Fut: 'static + Send + Future<Output = Result<()>>,
 511        M: EnvelopedMessage,
 512    {
 513        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 514        self
 515    }
 516
 517    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 518    where
 519        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 520        Fut: Send + Future<Output = Result<()>>,
 521        M: RequestMessage,
 522    {
 523        let handler = Arc::new(handler);
 524        self.add_handler(move |envelope, session| {
 525            let receipt = envelope.receipt();
 526            let handler = handler.clone();
 527            async move {
 528                let peer = session.peer.clone();
 529                let responded = Arc::new(AtomicBool::default());
 530                let response = Response {
 531                    peer: peer.clone(),
 532                    responded: responded.clone(),
 533                    receipt,
 534                };
 535                match (handler)(envelope.payload, response, session).await {
 536                    Ok(()) => {
 537                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 538                            Ok(())
 539                        } else {
 540                            Err(anyhow!("handler did not send a response"))?
 541                        }
 542                    }
 543                    Err(error) => {
 544                        let proto_err = match &error {
 545                            Error::Internal(err) => err.to_proto(),
 546                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 547                        };
 548                        peer.respond_with_error(receipt, proto_err)?;
 549                        Err(error)
 550                    }
 551                }
 552            }
 553        })
 554    }
 555
 556    pub fn handle_connection(
 557        self: &Arc<Self>,
 558        connection: Connection,
 559        address: String,
 560        user: User,
 561        impersonator: Option<User>,
 562        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 563        executor: Executor,
 564    ) -> impl Future<Output = Result<()>> {
 565        let this = self.clone();
 566        let user_id = user.id;
 567        let login = user.github_login;
 568        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 569        if let Some(impersonator) = impersonator {
 570            span.record("impersonator", &impersonator.github_login);
 571        }
 572        let mut teardown = self.teardown.subscribe();
 573        async move {
 574            let (connection_id, handle_io, mut incoming_rx) = this
 575                .peer
 576                .add_connection(connection, {
 577                    let executor = executor.clone();
 578                    move |duration| executor.sleep(duration)
 579                });
 580
 581            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 582            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 583            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 584
 585            if let Some(send_connection_id) = send_connection_id.take() {
 586                let _ = send_connection_id.send(connection_id);
 587            }
 588
 589            if !user.connected_once {
 590                this.peer.send(connection_id, proto::ShowContacts {})?;
 591                this.app_state.db.set_user_connected_once(user_id, true).await?;
 592            }
 593
 594            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 595                this.app_state.db.get_contacts(user_id),
 596                this.app_state.db.get_channels_for_user(user_id),
 597                this.app_state.db.get_channel_invites_for_user(user_id),
 598            ).await?;
 599
 600            {
 601                let mut pool = this.connection_pool.lock();
 602                pool.add_connection(connection_id, user_id, user.admin);
 603                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 604                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 605                this.peer.send(connection_id, build_channels_update(
 606                    channels_for_user,
 607                    channel_invites
 608                ))?;
 609            }
 610
 611            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 612                this.peer.send(connection_id, incoming_call)?;
 613            }
 614
 615            let session = Session {
 616                user_id,
 617                connection_id,
 618                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 619                peer: this.peer.clone(),
 620                connection_pool: this.connection_pool.clone(),
 621                live_kit_client: this.app_state.live_kit_client.clone(),
 622                _executor: executor.clone()
 623            };
 624            update_user_contacts(user_id, &session).await?;
 625
 626            let handle_io = handle_io.fuse();
 627            futures::pin_mut!(handle_io);
 628
 629            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 630            // This prevents deadlocks when e.g., client A performs a request to client B and
 631            // client B performs a request to client A. If both clients stop processing further
 632            // messages until their respective request completes, they won't have a chance to
 633            // respond to the other client's request and cause a deadlock.
 634            //
 635            // This arrangement ensures we will attempt to process earlier messages first, but fall
 636            // back to processing messages arrived later in the spirit of making progress.
 637            let mut foreground_message_handlers = FuturesUnordered::new();
 638            let concurrent_handlers = Arc::new(Semaphore::new(256));
 639            loop {
 640                let next_message = async {
 641                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 642                    let message = incoming_rx.next().await;
 643                    (permit, message)
 644                }.fuse();
 645                futures::pin_mut!(next_message);
 646                futures::select_biased! {
 647                    _ = teardown.changed().fuse() => return Ok(()),
 648                    result = handle_io => {
 649                        if let Err(error) = result {
 650                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 651                        }
 652                        break;
 653                    }
 654                    _ = foreground_message_handlers.next() => {}
 655                    next_message = next_message => {
 656                        let (permit, message) = next_message;
 657                        if let Some(message) = message {
 658                            let type_name = message.payload_type_name();
 659                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 660                            let span_enter = span.enter();
 661                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 662                                let is_background = message.is_background();
 663                                let handle_message = (handler)(message, session.clone());
 664                                drop(span_enter);
 665
 666                                let handle_message = async move {
 667                                    handle_message.await;
 668                                    drop(permit);
 669                                }.instrument(span);
 670                                if is_background {
 671                                    executor.spawn_detached(handle_message);
 672                                } else {
 673                                    foreground_message_handlers.push(handle_message);
 674                                }
 675                            } else {
 676                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 677                            }
 678                        } else {
 679                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 680                            break;
 681                        }
 682                    }
 683                }
 684            }
 685
 686            drop(foreground_message_handlers);
 687            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 688            if let Err(error) = connection_lost(session, teardown, executor).await {
 689                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 690            }
 691
 692            Ok(())
 693        }.instrument(span)
 694    }
 695
 696    pub async fn invite_code_redeemed(
 697        self: &Arc<Self>,
 698        inviter_id: UserId,
 699        invitee_id: UserId,
 700    ) -> Result<()> {
 701        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 702            if let Some(code) = &user.invite_code {
 703                let pool = self.connection_pool.lock();
 704                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 705                for connection_id in pool.user_connection_ids(inviter_id) {
 706                    self.peer.send(
 707                        connection_id,
 708                        proto::UpdateContacts {
 709                            contacts: vec![invitee_contact.clone()],
 710                            ..Default::default()
 711                        },
 712                    )?;
 713                    self.peer.send(
 714                        connection_id,
 715                        proto::UpdateInviteInfo {
 716                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 717                            count: user.invite_count as u32,
 718                        },
 719                    )?;
 720                }
 721            }
 722        }
 723        Ok(())
 724    }
 725
 726    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 727        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 728            if let Some(invite_code) = &user.invite_code {
 729                let pool = self.connection_pool.lock();
 730                for connection_id in pool.user_connection_ids(user_id) {
 731                    self.peer.send(
 732                        connection_id,
 733                        proto::UpdateInviteInfo {
 734                            url: format!(
 735                                "{}{}",
 736                                self.app_state.config.invite_link_prefix, invite_code
 737                            ),
 738                            count: user.invite_count as u32,
 739                        },
 740                    )?;
 741                }
 742            }
 743        }
 744        Ok(())
 745    }
 746
 747    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 748        ServerSnapshot {
 749            connection_pool: ConnectionPoolGuard {
 750                guard: self.connection_pool.lock(),
 751                _not_send: PhantomData,
 752            },
 753            peer: &self.peer,
 754        }
 755    }
 756}
 757
 758impl<'a> Deref for ConnectionPoolGuard<'a> {
 759    type Target = ConnectionPool;
 760
 761    fn deref(&self) -> &Self::Target {
 762        &*self.guard
 763    }
 764}
 765
 766impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 767    fn deref_mut(&mut self) -> &mut Self::Target {
 768        &mut *self.guard
 769    }
 770}
 771
 772impl<'a> Drop for ConnectionPoolGuard<'a> {
 773    fn drop(&mut self) {
 774        #[cfg(test)]
 775        self.check_invariants();
 776    }
 777}
 778
 779fn broadcast<F>(
 780    sender_id: Option<ConnectionId>,
 781    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 782    mut f: F,
 783) where
 784    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 785{
 786    for receiver_id in receiver_ids {
 787        if Some(receiver_id) != sender_id {
 788            if let Err(error) = f(receiver_id) {
 789                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 790            }
 791        }
 792    }
 793}
 794
 795lazy_static! {
 796    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 797    static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
 798}
 799
 800pub struct ProtocolVersion(u32);
 801
 802impl Header for ProtocolVersion {
 803    fn name() -> &'static HeaderName {
 804        &ZED_PROTOCOL_VERSION
 805    }
 806
 807    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 808    where
 809        Self: Sized,
 810        I: Iterator<Item = &'i axum::http::HeaderValue>,
 811    {
 812        let version = values
 813            .next()
 814            .ok_or_else(axum::headers::Error::invalid)?
 815            .to_str()
 816            .map_err(|_| axum::headers::Error::invalid())?
 817            .parse()
 818            .map_err(|_| axum::headers::Error::invalid())?;
 819        Ok(Self(version))
 820    }
 821
 822    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 823        values.extend([self.0.to_string().parse().unwrap()]);
 824    }
 825}
 826
 827pub struct AppVersionHeader(SemanticVersion);
 828impl Header for AppVersionHeader {
 829    fn name() -> &'static HeaderName {
 830        &ZED_APP_VERSION
 831    }
 832
 833    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 834    where
 835        Self: Sized,
 836        I: Iterator<Item = &'i axum::http::HeaderValue>,
 837    {
 838        let version = values
 839            .next()
 840            .ok_or_else(axum::headers::Error::invalid)?
 841            .to_str()
 842            .map_err(|_| axum::headers::Error::invalid())?
 843            .parse()
 844            .map_err(|_| axum::headers::Error::invalid())?;
 845        Ok(Self(version))
 846    }
 847
 848    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 849        values.extend([self.0.to_string().parse().unwrap()]);
 850    }
 851}
 852
 853pub fn routes(server: Arc<Server>) -> Router<Body> {
 854    Router::new()
 855        .route("/rpc", get(handle_websocket_request))
 856        .layer(
 857            ServiceBuilder::new()
 858                .layer(Extension(server.app_state.clone()))
 859                .layer(middleware::from_fn(auth::validate_header)),
 860        )
 861        .route("/metrics", get(handle_metrics))
 862        .layer(Extension(server))
 863}
 864
 865pub async fn handle_websocket_request(
 866    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 867    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 868    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 869    Extension(server): Extension<Arc<Server>>,
 870    Extension(user): Extension<User>,
 871    Extension(impersonator): Extension<Impersonator>,
 872    ws: WebSocketUpgrade,
 873) -> axum::response::Response {
 874    if protocol_version != rpc::PROTOCOL_VERSION {
 875        return (
 876            StatusCode::UPGRADE_REQUIRED,
 877            "client must be upgraded".to_string(),
 878        )
 879            .into_response();
 880    }
 881
 882    // the first version of zed that sent this header was 0.121.x
 883    if let Some(version) = app_version_header.map(|header| header.0 .0) {
 884        // 0.123.0 was a nightly version with incompatible collab changes
 885        // that were reverted.
 886        if version == "0.123.0".parse().unwrap() {
 887            return (
 888                StatusCode::UPGRADE_REQUIRED,
 889                "client must be upgraded".to_string(),
 890            )
 891                .into_response();
 892        }
 893    }
 894
 895    let socket_address = socket_address.to_string();
 896    ws.on_upgrade(move |socket| {
 897        use util::ResultExt;
 898        let socket = socket
 899            .map_ok(to_tungstenite_message)
 900            .err_into()
 901            .with(|message| async move { Ok(to_axum_message(message)) });
 902        let connection = Connection::new(Box::pin(socket));
 903        async move {
 904            server
 905                .handle_connection(
 906                    connection,
 907                    socket_address,
 908                    user,
 909                    impersonator.0,
 910                    None,
 911                    Executor::Production,
 912                )
 913                .await
 914                .log_err();
 915        }
 916    })
 917}
 918
 919pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 920    let connections = server
 921        .connection_pool
 922        .lock()
 923        .connections()
 924        .filter(|connection| !connection.admin)
 925        .count();
 926
 927    METRIC_CONNECTIONS.set(connections as _);
 928
 929    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 930    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 931
 932    let encoder = prometheus::TextEncoder::new();
 933    let metric_families = prometheus::gather();
 934    let encoded_metrics = encoder
 935        .encode_to_string(&metric_families)
 936        .map_err(|err| anyhow!("{}", err))?;
 937    Ok(encoded_metrics)
 938}
 939
 940#[instrument(err, skip(executor))]
 941async fn connection_lost(
 942    session: Session,
 943    mut teardown: watch::Receiver<()>,
 944    executor: Executor,
 945) -> Result<()> {
 946    session.peer.disconnect(session.connection_id);
 947    session
 948        .connection_pool()
 949        .await
 950        .remove_connection(session.connection_id)?;
 951
 952    session
 953        .db()
 954        .await
 955        .connection_lost(session.connection_id)
 956        .await
 957        .trace_err();
 958
 959    futures::select_biased! {
 960        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 961            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 962            leave_room_for_session(&session).await.trace_err();
 963            leave_channel_buffers_for_session(&session)
 964                .await
 965                .trace_err();
 966
 967            if !session
 968                .connection_pool()
 969                .await
 970                .is_user_online(session.user_id)
 971            {
 972                let db = session.db().await;
 973                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 974                    room_updated(&room, &session.peer);
 975                }
 976            }
 977
 978            update_user_contacts(session.user_id, &session).await?;
 979        }
 980        _ = teardown.changed().fuse() => {}
 981    }
 982
 983    Ok(())
 984}
 985
 986/// Acknowledges a ping from a client, used to keep the connection alive.
 987async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 988    response.send(proto::Ack {})?;
 989    Ok(())
 990}
 991
 992/// Creates a new room for calling (outside of channels)
 993async fn create_room(
 994    _request: proto::CreateRoom,
 995    response: Response<proto::CreateRoom>,
 996    session: Session,
 997) -> Result<()> {
 998    let live_kit_room = nanoid::nanoid!(30);
 999
1000    let live_kit_connection_info = {
1001        let live_kit_room = live_kit_room.clone();
1002        let live_kit = session.live_kit_client.as_ref();
1003
1004        util::async_maybe!({
1005            let live_kit = live_kit?;
1006
1007            let token = live_kit
1008                .room_token(&live_kit_room, &session.user_id.to_string())
1009                .trace_err()?;
1010
1011            Some(proto::LiveKitConnectionInfo {
1012                server_url: live_kit.url().into(),
1013                token,
1014                can_publish: true,
1015            })
1016        })
1017    }
1018    .await;
1019
1020    let room = session
1021        .db()
1022        .await
1023        .create_room(session.user_id, session.connection_id, &live_kit_room)
1024        .await?;
1025
1026    response.send(proto::CreateRoomResponse {
1027        room: Some(room.clone()),
1028        live_kit_connection_info,
1029    })?;
1030
1031    update_user_contacts(session.user_id, &session).await?;
1032    Ok(())
1033}
1034
1035/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1036async fn join_room(
1037    request: proto::JoinRoom,
1038    response: Response<proto::JoinRoom>,
1039    session: Session,
1040) -> Result<()> {
1041    let room_id = RoomId::from_proto(request.id);
1042
1043    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1044
1045    if let Some(channel_id) = channel_id {
1046        return join_channel_internal(channel_id, Box::new(response), session).await;
1047    }
1048
1049    let joined_room = {
1050        let room = session
1051            .db()
1052            .await
1053            .join_room(room_id, session.user_id, session.connection_id)
1054            .await?;
1055        room_updated(&room.room, &session.peer);
1056        room.into_inner()
1057    };
1058
1059    for connection_id in session
1060        .connection_pool()
1061        .await
1062        .user_connection_ids(session.user_id)
1063    {
1064        session
1065            .peer
1066            .send(
1067                connection_id,
1068                proto::CallCanceled {
1069                    room_id: room_id.to_proto(),
1070                },
1071            )
1072            .trace_err();
1073    }
1074
1075    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1076        if let Some(token) = live_kit
1077            .room_token(
1078                &joined_room.room.live_kit_room,
1079                &session.user_id.to_string(),
1080            )
1081            .trace_err()
1082        {
1083            Some(proto::LiveKitConnectionInfo {
1084                server_url: live_kit.url().into(),
1085                token,
1086                can_publish: true,
1087            })
1088        } else {
1089            None
1090        }
1091    } else {
1092        None
1093    };
1094
1095    response.send(proto::JoinRoomResponse {
1096        room: Some(joined_room.room),
1097        channel_id: None,
1098        live_kit_connection_info,
1099    })?;
1100
1101    update_user_contacts(session.user_id, &session).await?;
1102    Ok(())
1103}
1104
1105/// Rejoin room is used to reconnect to a room after connection errors.
1106async fn rejoin_room(
1107    request: proto::RejoinRoom,
1108    response: Response<proto::RejoinRoom>,
1109    session: Session,
1110) -> Result<()> {
1111    let room;
1112    let channel_id;
1113    let channel_members;
1114    {
1115        let mut rejoined_room = session
1116            .db()
1117            .await
1118            .rejoin_room(request, session.user_id, session.connection_id)
1119            .await?;
1120
1121        response.send(proto::RejoinRoomResponse {
1122            room: Some(rejoined_room.room.clone()),
1123            reshared_projects: rejoined_room
1124                .reshared_projects
1125                .iter()
1126                .map(|project| proto::ResharedProject {
1127                    id: project.id.to_proto(),
1128                    collaborators: project
1129                        .collaborators
1130                        .iter()
1131                        .map(|collaborator| collaborator.to_proto())
1132                        .collect(),
1133                })
1134                .collect(),
1135            rejoined_projects: rejoined_room
1136                .rejoined_projects
1137                .iter()
1138                .map(|rejoined_project| proto::RejoinedProject {
1139                    id: rejoined_project.id.to_proto(),
1140                    worktrees: rejoined_project
1141                        .worktrees
1142                        .iter()
1143                        .map(|worktree| proto::WorktreeMetadata {
1144                            id: worktree.id,
1145                            root_name: worktree.root_name.clone(),
1146                            visible: worktree.visible,
1147                            abs_path: worktree.abs_path.clone(),
1148                        })
1149                        .collect(),
1150                    collaborators: rejoined_project
1151                        .collaborators
1152                        .iter()
1153                        .map(|collaborator| collaborator.to_proto())
1154                        .collect(),
1155                    language_servers: rejoined_project.language_servers.clone(),
1156                })
1157                .collect(),
1158        })?;
1159        room_updated(&rejoined_room.room, &session.peer);
1160
1161        for project in &rejoined_room.reshared_projects {
1162            for collaborator in &project.collaborators {
1163                session
1164                    .peer
1165                    .send(
1166                        collaborator.connection_id,
1167                        proto::UpdateProjectCollaborator {
1168                            project_id: project.id.to_proto(),
1169                            old_peer_id: Some(project.old_connection_id.into()),
1170                            new_peer_id: Some(session.connection_id.into()),
1171                        },
1172                    )
1173                    .trace_err();
1174            }
1175
1176            broadcast(
1177                Some(session.connection_id),
1178                project
1179                    .collaborators
1180                    .iter()
1181                    .map(|collaborator| collaborator.connection_id),
1182                |connection_id| {
1183                    session.peer.forward_send(
1184                        session.connection_id,
1185                        connection_id,
1186                        proto::UpdateProject {
1187                            project_id: project.id.to_proto(),
1188                            worktrees: project.worktrees.clone(),
1189                        },
1190                    )
1191                },
1192            );
1193        }
1194
1195        for project in &rejoined_room.rejoined_projects {
1196            for collaborator in &project.collaborators {
1197                session
1198                    .peer
1199                    .send(
1200                        collaborator.connection_id,
1201                        proto::UpdateProjectCollaborator {
1202                            project_id: project.id.to_proto(),
1203                            old_peer_id: Some(project.old_connection_id.into()),
1204                            new_peer_id: Some(session.connection_id.into()),
1205                        },
1206                    )
1207                    .trace_err();
1208            }
1209        }
1210
1211        for project in &mut rejoined_room.rejoined_projects {
1212            for worktree in mem::take(&mut project.worktrees) {
1213                #[cfg(any(test, feature = "test-support"))]
1214                const MAX_CHUNK_SIZE: usize = 2;
1215                #[cfg(not(any(test, feature = "test-support")))]
1216                const MAX_CHUNK_SIZE: usize = 256;
1217
1218                // Stream this worktree's entries.
1219                let message = proto::UpdateWorktree {
1220                    project_id: project.id.to_proto(),
1221                    worktree_id: worktree.id,
1222                    abs_path: worktree.abs_path.clone(),
1223                    root_name: worktree.root_name,
1224                    updated_entries: worktree.updated_entries,
1225                    removed_entries: worktree.removed_entries,
1226                    scan_id: worktree.scan_id,
1227                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1228                    updated_repositories: worktree.updated_repositories,
1229                    removed_repositories: worktree.removed_repositories,
1230                };
1231                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1232                    session.peer.send(session.connection_id, update.clone())?;
1233                }
1234
1235                // Stream this worktree's diagnostics.
1236                for summary in worktree.diagnostic_summaries {
1237                    session.peer.send(
1238                        session.connection_id,
1239                        proto::UpdateDiagnosticSummary {
1240                            project_id: project.id.to_proto(),
1241                            worktree_id: worktree.id,
1242                            summary: Some(summary),
1243                        },
1244                    )?;
1245                }
1246
1247                for settings_file in worktree.settings_files {
1248                    session.peer.send(
1249                        session.connection_id,
1250                        proto::UpdateWorktreeSettings {
1251                            project_id: project.id.to_proto(),
1252                            worktree_id: worktree.id,
1253                            path: settings_file.path,
1254                            content: Some(settings_file.content),
1255                        },
1256                    )?;
1257                }
1258            }
1259
1260            for language_server in &project.language_servers {
1261                session.peer.send(
1262                    session.connection_id,
1263                    proto::UpdateLanguageServer {
1264                        project_id: project.id.to_proto(),
1265                        language_server_id: language_server.id,
1266                        variant: Some(
1267                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1268                                proto::LspDiskBasedDiagnosticsUpdated {},
1269                            ),
1270                        ),
1271                    },
1272                )?;
1273            }
1274        }
1275
1276        let rejoined_room = rejoined_room.into_inner();
1277
1278        room = rejoined_room.room;
1279        channel_id = rejoined_room.channel_id;
1280        channel_members = rejoined_room.channel_members;
1281    }
1282
1283    if let Some(channel_id) = channel_id {
1284        channel_updated(
1285            channel_id,
1286            &room,
1287            &channel_members,
1288            &session.peer,
1289            &*session.connection_pool().await,
1290        );
1291    }
1292
1293    update_user_contacts(session.user_id, &session).await?;
1294    Ok(())
1295}
1296
1297/// leave room disconnects from the room.
1298async fn leave_room(
1299    _: proto::LeaveRoom,
1300    response: Response<proto::LeaveRoom>,
1301    session: Session,
1302) -> Result<()> {
1303    leave_room_for_session(&session).await?;
1304    response.send(proto::Ack {})?;
1305    Ok(())
1306}
1307
1308/// Updates the permissions of someone else in the room.
1309async fn set_room_participant_role(
1310    request: proto::SetRoomParticipantRole,
1311    response: Response<proto::SetRoomParticipantRole>,
1312    session: Session,
1313) -> Result<()> {
1314    let (live_kit_room, can_publish) = {
1315        let room = session
1316            .db()
1317            .await
1318            .set_room_participant_role(
1319                session.user_id,
1320                RoomId::from_proto(request.room_id),
1321                UserId::from_proto(request.user_id),
1322                ChannelRole::from(request.role()),
1323            )
1324            .await?;
1325
1326        let live_kit_room = room.live_kit_room.clone();
1327        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1328        room_updated(&room, &session.peer);
1329        (live_kit_room, can_publish)
1330    };
1331
1332    if let Some(live_kit) = session.live_kit_client.as_ref() {
1333        live_kit
1334            .update_participant(
1335                live_kit_room.clone(),
1336                request.user_id.to_string(),
1337                live_kit_server::proto::ParticipantPermission {
1338                    can_subscribe: true,
1339                    can_publish,
1340                    can_publish_data: can_publish,
1341                    hidden: false,
1342                    recorder: false,
1343                },
1344            )
1345            .await
1346            .trace_err();
1347    }
1348
1349    response.send(proto::Ack {})?;
1350    Ok(())
1351}
1352
1353/// Call someone else into the current room
1354async fn call(
1355    request: proto::Call,
1356    response: Response<proto::Call>,
1357    session: Session,
1358) -> Result<()> {
1359    let room_id = RoomId::from_proto(request.room_id);
1360    let calling_user_id = session.user_id;
1361    let calling_connection_id = session.connection_id;
1362    let called_user_id = UserId::from_proto(request.called_user_id);
1363    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1364    if !session
1365        .db()
1366        .await
1367        .has_contact(calling_user_id, called_user_id)
1368        .await?
1369    {
1370        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1371    }
1372
1373    let incoming_call = {
1374        let (room, incoming_call) = &mut *session
1375            .db()
1376            .await
1377            .call(
1378                room_id,
1379                calling_user_id,
1380                calling_connection_id,
1381                called_user_id,
1382                initial_project_id,
1383            )
1384            .await?;
1385        room_updated(&room, &session.peer);
1386        mem::take(incoming_call)
1387    };
1388    update_user_contacts(called_user_id, &session).await?;
1389
1390    let mut calls = session
1391        .connection_pool()
1392        .await
1393        .user_connection_ids(called_user_id)
1394        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1395        .collect::<FuturesUnordered<_>>();
1396
1397    while let Some(call_response) = calls.next().await {
1398        match call_response.as_ref() {
1399            Ok(_) => {
1400                response.send(proto::Ack {})?;
1401                return Ok(());
1402            }
1403            Err(_) => {
1404                call_response.trace_err();
1405            }
1406        }
1407    }
1408
1409    {
1410        let room = session
1411            .db()
1412            .await
1413            .call_failed(room_id, called_user_id)
1414            .await?;
1415        room_updated(&room, &session.peer);
1416    }
1417    update_user_contacts(called_user_id, &session).await?;
1418
1419    Err(anyhow!("failed to ring user"))?
1420}
1421
1422/// Cancel an outgoing call.
1423async fn cancel_call(
1424    request: proto::CancelCall,
1425    response: Response<proto::CancelCall>,
1426    session: Session,
1427) -> Result<()> {
1428    let called_user_id = UserId::from_proto(request.called_user_id);
1429    let room_id = RoomId::from_proto(request.room_id);
1430    {
1431        let room = session
1432            .db()
1433            .await
1434            .cancel_call(room_id, session.connection_id, called_user_id)
1435            .await?;
1436        room_updated(&room, &session.peer);
1437    }
1438
1439    for connection_id in session
1440        .connection_pool()
1441        .await
1442        .user_connection_ids(called_user_id)
1443    {
1444        session
1445            .peer
1446            .send(
1447                connection_id,
1448                proto::CallCanceled {
1449                    room_id: room_id.to_proto(),
1450                },
1451            )
1452            .trace_err();
1453    }
1454    response.send(proto::Ack {})?;
1455
1456    update_user_contacts(called_user_id, &session).await?;
1457    Ok(())
1458}
1459
1460/// Decline an incoming call.
1461async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1462    let room_id = RoomId::from_proto(message.room_id);
1463    {
1464        let room = session
1465            .db()
1466            .await
1467            .decline_call(Some(room_id), session.user_id)
1468            .await?
1469            .ok_or_else(|| anyhow!("failed to decline call"))?;
1470        room_updated(&room, &session.peer);
1471    }
1472
1473    for connection_id in session
1474        .connection_pool()
1475        .await
1476        .user_connection_ids(session.user_id)
1477    {
1478        session
1479            .peer
1480            .send(
1481                connection_id,
1482                proto::CallCanceled {
1483                    room_id: room_id.to_proto(),
1484                },
1485            )
1486            .trace_err();
1487    }
1488    update_user_contacts(session.user_id, &session).await?;
1489    Ok(())
1490}
1491
1492/// Updates other participants in the room with your current location.
1493async fn update_participant_location(
1494    request: proto::UpdateParticipantLocation,
1495    response: Response<proto::UpdateParticipantLocation>,
1496    session: Session,
1497) -> Result<()> {
1498    let room_id = RoomId::from_proto(request.room_id);
1499    let location = request
1500        .location
1501        .ok_or_else(|| anyhow!("invalid location"))?;
1502
1503    let db = session.db().await;
1504    let room = db
1505        .update_room_participant_location(room_id, session.connection_id, location)
1506        .await?;
1507
1508    room_updated(&room, &session.peer);
1509    response.send(proto::Ack {})?;
1510    Ok(())
1511}
1512
1513/// Share a project into the room.
1514async fn share_project(
1515    request: proto::ShareProject,
1516    response: Response<proto::ShareProject>,
1517    session: Session,
1518) -> Result<()> {
1519    let (project_id, room) = &*session
1520        .db()
1521        .await
1522        .share_project(
1523            RoomId::from_proto(request.room_id),
1524            session.connection_id,
1525            &request.worktrees,
1526        )
1527        .await?;
1528    response.send(proto::ShareProjectResponse {
1529        project_id: project_id.to_proto(),
1530    })?;
1531    room_updated(&room, &session.peer);
1532
1533    Ok(())
1534}
1535
1536/// Unshare a project from the room.
1537async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1538    let project_id = ProjectId::from_proto(message.project_id);
1539
1540    let (room, guest_connection_ids) = &*session
1541        .db()
1542        .await
1543        .unshare_project(project_id, session.connection_id)
1544        .await?;
1545
1546    broadcast(
1547        Some(session.connection_id),
1548        guest_connection_ids.iter().copied(),
1549        |conn_id| session.peer.send(conn_id, message.clone()),
1550    );
1551    room_updated(&room, &session.peer);
1552
1553    Ok(())
1554}
1555
1556/// Join someone elses shared project.
1557async fn join_project(
1558    request: proto::JoinProject,
1559    response: Response<proto::JoinProject>,
1560    session: Session,
1561) -> Result<()> {
1562    let project_id = ProjectId::from_proto(request.project_id);
1563    let guest_user_id = session.user_id;
1564
1565    tracing::info!(%project_id, "join project");
1566
1567    let (project, replica_id) = &mut *session
1568        .db()
1569        .await
1570        .join_project(project_id, session.connection_id)
1571        .await?;
1572
1573    let collaborators = project
1574        .collaborators
1575        .iter()
1576        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1577        .map(|collaborator| collaborator.to_proto())
1578        .collect::<Vec<_>>();
1579
1580    let worktrees = project
1581        .worktrees
1582        .iter()
1583        .map(|(id, worktree)| proto::WorktreeMetadata {
1584            id: *id,
1585            root_name: worktree.root_name.clone(),
1586            visible: worktree.visible,
1587            abs_path: worktree.abs_path.clone(),
1588        })
1589        .collect::<Vec<_>>();
1590
1591    for collaborator in &collaborators {
1592        session
1593            .peer
1594            .send(
1595                collaborator.peer_id.unwrap().into(),
1596                proto::AddProjectCollaborator {
1597                    project_id: project_id.to_proto(),
1598                    collaborator: Some(proto::Collaborator {
1599                        peer_id: Some(session.connection_id.into()),
1600                        replica_id: replica_id.0 as u32,
1601                        user_id: guest_user_id.to_proto(),
1602                    }),
1603                },
1604            )
1605            .trace_err();
1606    }
1607
1608    // First, we send the metadata associated with each worktree.
1609    response.send(proto::JoinProjectResponse {
1610        worktrees: worktrees.clone(),
1611        replica_id: replica_id.0 as u32,
1612        collaborators: collaborators.clone(),
1613        language_servers: project.language_servers.clone(),
1614    })?;
1615
1616    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1617        #[cfg(any(test, feature = "test-support"))]
1618        const MAX_CHUNK_SIZE: usize = 2;
1619        #[cfg(not(any(test, feature = "test-support")))]
1620        const MAX_CHUNK_SIZE: usize = 256;
1621
1622        // Stream this worktree's entries.
1623        let message = proto::UpdateWorktree {
1624            project_id: project_id.to_proto(),
1625            worktree_id,
1626            abs_path: worktree.abs_path.clone(),
1627            root_name: worktree.root_name,
1628            updated_entries: worktree.entries,
1629            removed_entries: Default::default(),
1630            scan_id: worktree.scan_id,
1631            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1632            updated_repositories: worktree.repository_entries.into_values().collect(),
1633            removed_repositories: Default::default(),
1634        };
1635        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1636            session.peer.send(session.connection_id, update.clone())?;
1637        }
1638
1639        // Stream this worktree's diagnostics.
1640        for summary in worktree.diagnostic_summaries {
1641            session.peer.send(
1642                session.connection_id,
1643                proto::UpdateDiagnosticSummary {
1644                    project_id: project_id.to_proto(),
1645                    worktree_id: worktree.id,
1646                    summary: Some(summary),
1647                },
1648            )?;
1649        }
1650
1651        for settings_file in worktree.settings_files {
1652            session.peer.send(
1653                session.connection_id,
1654                proto::UpdateWorktreeSettings {
1655                    project_id: project_id.to_proto(),
1656                    worktree_id: worktree.id,
1657                    path: settings_file.path,
1658                    content: Some(settings_file.content),
1659                },
1660            )?;
1661        }
1662    }
1663
1664    for language_server in &project.language_servers {
1665        session.peer.send(
1666            session.connection_id,
1667            proto::UpdateLanguageServer {
1668                project_id: project_id.to_proto(),
1669                language_server_id: language_server.id,
1670                variant: Some(
1671                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1672                        proto::LspDiskBasedDiagnosticsUpdated {},
1673                    ),
1674                ),
1675            },
1676        )?;
1677    }
1678
1679    Ok(())
1680}
1681
1682/// Leave someone elses shared project.
1683async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1684    let sender_id = session.connection_id;
1685    let project_id = ProjectId::from_proto(request.project_id);
1686
1687    let (room, project) = &*session
1688        .db()
1689        .await
1690        .leave_project(project_id, sender_id)
1691        .await?;
1692    tracing::info!(
1693        %project_id,
1694        host_user_id = %project.host_user_id,
1695        host_connection_id = ?project.host_connection_id,
1696        "leave project"
1697    );
1698
1699    project_left(&project, &session);
1700    room_updated(&room, &session.peer);
1701
1702    Ok(())
1703}
1704
1705/// Updates other participants with changes to the project
1706async fn update_project(
1707    request: proto::UpdateProject,
1708    response: Response<proto::UpdateProject>,
1709    session: Session,
1710) -> Result<()> {
1711    let project_id = ProjectId::from_proto(request.project_id);
1712    let (room, guest_connection_ids) = &*session
1713        .db()
1714        .await
1715        .update_project(project_id, session.connection_id, &request.worktrees)
1716        .await?;
1717    broadcast(
1718        Some(session.connection_id),
1719        guest_connection_ids.iter().copied(),
1720        |connection_id| {
1721            session
1722                .peer
1723                .forward_send(session.connection_id, connection_id, request.clone())
1724        },
1725    );
1726    room_updated(&room, &session.peer);
1727    response.send(proto::Ack {})?;
1728
1729    Ok(())
1730}
1731
1732/// Updates other participants with changes to the worktree
1733async fn update_worktree(
1734    request: proto::UpdateWorktree,
1735    response: Response<proto::UpdateWorktree>,
1736    session: Session,
1737) -> Result<()> {
1738    let guest_connection_ids = session
1739        .db()
1740        .await
1741        .update_worktree(&request, session.connection_id)
1742        .await?;
1743
1744    broadcast(
1745        Some(session.connection_id),
1746        guest_connection_ids.iter().copied(),
1747        |connection_id| {
1748            session
1749                .peer
1750                .forward_send(session.connection_id, connection_id, request.clone())
1751        },
1752    );
1753    response.send(proto::Ack {})?;
1754    Ok(())
1755}
1756
1757/// Updates other participants with changes to the diagnostics
1758async fn update_diagnostic_summary(
1759    message: proto::UpdateDiagnosticSummary,
1760    session: Session,
1761) -> Result<()> {
1762    let guest_connection_ids = session
1763        .db()
1764        .await
1765        .update_diagnostic_summary(&message, session.connection_id)
1766        .await?;
1767
1768    broadcast(
1769        Some(session.connection_id),
1770        guest_connection_ids.iter().copied(),
1771        |connection_id| {
1772            session
1773                .peer
1774                .forward_send(session.connection_id, connection_id, message.clone())
1775        },
1776    );
1777
1778    Ok(())
1779}
1780
1781/// Updates other participants with changes to the worktree settings
1782async fn update_worktree_settings(
1783    message: proto::UpdateWorktreeSettings,
1784    session: Session,
1785) -> Result<()> {
1786    let guest_connection_ids = session
1787        .db()
1788        .await
1789        .update_worktree_settings(&message, session.connection_id)
1790        .await?;
1791
1792    broadcast(
1793        Some(session.connection_id),
1794        guest_connection_ids.iter().copied(),
1795        |connection_id| {
1796            session
1797                .peer
1798                .forward_send(session.connection_id, connection_id, message.clone())
1799        },
1800    );
1801
1802    Ok(())
1803}
1804
1805/// Notify other participants that a  language server has started.
1806async fn start_language_server(
1807    request: proto::StartLanguageServer,
1808    session: Session,
1809) -> Result<()> {
1810    let guest_connection_ids = session
1811        .db()
1812        .await
1813        .start_language_server(&request, session.connection_id)
1814        .await?;
1815
1816    broadcast(
1817        Some(session.connection_id),
1818        guest_connection_ids.iter().copied(),
1819        |connection_id| {
1820            session
1821                .peer
1822                .forward_send(session.connection_id, connection_id, request.clone())
1823        },
1824    );
1825    Ok(())
1826}
1827
1828/// Notify other participants that a language server has changed.
1829async fn update_language_server(
1830    request: proto::UpdateLanguageServer,
1831    session: Session,
1832) -> Result<()> {
1833    let project_id = ProjectId::from_proto(request.project_id);
1834    let project_connection_ids = session
1835        .db()
1836        .await
1837        .project_connection_ids(project_id, session.connection_id)
1838        .await?;
1839    broadcast(
1840        Some(session.connection_id),
1841        project_connection_ids.iter().copied(),
1842        |connection_id| {
1843            session
1844                .peer
1845                .forward_send(session.connection_id, connection_id, request.clone())
1846        },
1847    );
1848    Ok(())
1849}
1850
1851/// forward a project request to the host. These requests should be read only
1852/// as guests are allowed to send them.
1853async fn forward_read_only_project_request<T>(
1854    request: T,
1855    response: Response<T>,
1856    session: Session,
1857) -> Result<()>
1858where
1859    T: EntityMessage + RequestMessage,
1860{
1861    let project_id = ProjectId::from_proto(request.remote_entity_id());
1862    let host_connection_id = session
1863        .db()
1864        .await
1865        .host_for_read_only_project_request(project_id, session.connection_id)
1866        .await?;
1867    let payload = session
1868        .peer
1869        .forward_request(session.connection_id, host_connection_id, request)
1870        .await?;
1871    response.send(payload)?;
1872    Ok(())
1873}
1874
1875/// forward a project request to the host. These requests are disallowed
1876/// for guests.
1877async fn forward_mutating_project_request<T>(
1878    request: T,
1879    response: Response<T>,
1880    session: Session,
1881) -> Result<()>
1882where
1883    T: EntityMessage + RequestMessage,
1884{
1885    let project_id = ProjectId::from_proto(request.remote_entity_id());
1886    let host_connection_id = session
1887        .db()
1888        .await
1889        .host_for_mutating_project_request(project_id, session.connection_id)
1890        .await?;
1891    let payload = session
1892        .peer
1893        .forward_request(session.connection_id, host_connection_id, request)
1894        .await?;
1895    response.send(payload)?;
1896    Ok(())
1897}
1898
1899/// Notify other participants that a new buffer has been created
1900async fn create_buffer_for_peer(
1901    request: proto::CreateBufferForPeer,
1902    session: Session,
1903) -> Result<()> {
1904    session
1905        .db()
1906        .await
1907        .check_user_is_project_host(
1908            ProjectId::from_proto(request.project_id),
1909            session.connection_id,
1910        )
1911        .await?;
1912    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1913    session
1914        .peer
1915        .forward_send(session.connection_id, peer_id.into(), request)?;
1916    Ok(())
1917}
1918
1919/// Notify other participants that a buffer has been updated. This is
1920/// allowed for guests as long as the update is limited to selections.
1921async fn update_buffer(
1922    request: proto::UpdateBuffer,
1923    response: Response<proto::UpdateBuffer>,
1924    session: Session,
1925) -> Result<()> {
1926    let project_id = ProjectId::from_proto(request.project_id);
1927    let mut guest_connection_ids;
1928    let mut host_connection_id = None;
1929
1930    let mut requires_write_permission = false;
1931
1932    for op in request.operations.iter() {
1933        match op.variant {
1934            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1935            Some(_) => requires_write_permission = true,
1936        }
1937    }
1938
1939    {
1940        let collaborators = session
1941            .db()
1942            .await
1943            .project_collaborators_for_buffer_update(
1944                project_id,
1945                session.connection_id,
1946                requires_write_permission,
1947            )
1948            .await?;
1949        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1950        for collaborator in collaborators.iter() {
1951            if collaborator.is_host {
1952                host_connection_id = Some(collaborator.connection_id);
1953            } else {
1954                guest_connection_ids.push(collaborator.connection_id);
1955            }
1956        }
1957    }
1958    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1959
1960    broadcast(
1961        Some(session.connection_id),
1962        guest_connection_ids,
1963        |connection_id| {
1964            session
1965                .peer
1966                .forward_send(session.connection_id, connection_id, request.clone())
1967        },
1968    );
1969    if host_connection_id != session.connection_id {
1970        session
1971            .peer
1972            .forward_request(session.connection_id, host_connection_id, request.clone())
1973            .await?;
1974    }
1975
1976    response.send(proto::Ack {})?;
1977    Ok(())
1978}
1979
1980/// Notify other participants that a project has been updated.
1981async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1982    request: T,
1983    session: Session,
1984) -> Result<()> {
1985    let project_id = ProjectId::from_proto(request.remote_entity_id());
1986    let project_connection_ids = session
1987        .db()
1988        .await
1989        .project_connection_ids(project_id, session.connection_id)
1990        .await?;
1991
1992    broadcast(
1993        Some(session.connection_id),
1994        project_connection_ids.iter().copied(),
1995        |connection_id| {
1996            session
1997                .peer
1998                .forward_send(session.connection_id, connection_id, request.clone())
1999        },
2000    );
2001    Ok(())
2002}
2003
2004/// Start following another user in a call.
2005async fn follow(
2006    request: proto::Follow,
2007    response: Response<proto::Follow>,
2008    session: Session,
2009) -> Result<()> {
2010    let room_id = RoomId::from_proto(request.room_id);
2011    let project_id = request.project_id.map(ProjectId::from_proto);
2012    let leader_id = request
2013        .leader_id
2014        .ok_or_else(|| anyhow!("invalid leader id"))?
2015        .into();
2016    let follower_id = session.connection_id;
2017
2018    session
2019        .db()
2020        .await
2021        .check_room_participants(room_id, leader_id, session.connection_id)
2022        .await?;
2023
2024    let response_payload = session
2025        .peer
2026        .forward_request(session.connection_id, leader_id, request)
2027        .await?;
2028    response.send(response_payload)?;
2029
2030    if let Some(project_id) = project_id {
2031        let room = session
2032            .db()
2033            .await
2034            .follow(room_id, project_id, leader_id, follower_id)
2035            .await?;
2036        room_updated(&room, &session.peer);
2037    }
2038
2039    Ok(())
2040}
2041
2042/// Stop following another user in a call.
2043async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2044    let room_id = RoomId::from_proto(request.room_id);
2045    let project_id = request.project_id.map(ProjectId::from_proto);
2046    let leader_id = request
2047        .leader_id
2048        .ok_or_else(|| anyhow!("invalid leader id"))?
2049        .into();
2050    let follower_id = session.connection_id;
2051
2052    session
2053        .db()
2054        .await
2055        .check_room_participants(room_id, leader_id, session.connection_id)
2056        .await?;
2057
2058    session
2059        .peer
2060        .forward_send(session.connection_id, leader_id, request)?;
2061
2062    if let Some(project_id) = project_id {
2063        let room = session
2064            .db()
2065            .await
2066            .unfollow(room_id, project_id, leader_id, follower_id)
2067            .await?;
2068        room_updated(&room, &session.peer);
2069    }
2070
2071    Ok(())
2072}
2073
2074/// Notify everyone following you of your current location.
2075async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2076    let room_id = RoomId::from_proto(request.room_id);
2077    let database = session.db.lock().await;
2078
2079    let connection_ids = if let Some(project_id) = request.project_id {
2080        let project_id = ProjectId::from_proto(project_id);
2081        database
2082            .project_connection_ids(project_id, session.connection_id)
2083            .await?
2084    } else {
2085        database
2086            .room_connection_ids(room_id, session.connection_id)
2087            .await?
2088    };
2089
2090    // For now, don't send view update messages back to that view's current leader.
2091    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2092        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2093        _ => None,
2094    });
2095
2096    for connection_id in connection_ids.iter().cloned() {
2097        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2098            session
2099                .peer
2100                .forward_send(session.connection_id, connection_id, request.clone())?;
2101        }
2102    }
2103    Ok(())
2104}
2105
2106/// Get public data about users.
2107async fn get_users(
2108    request: proto::GetUsers,
2109    response: Response<proto::GetUsers>,
2110    session: Session,
2111) -> Result<()> {
2112    let user_ids = request
2113        .user_ids
2114        .into_iter()
2115        .map(UserId::from_proto)
2116        .collect();
2117    let users = session
2118        .db()
2119        .await
2120        .get_users_by_ids(user_ids)
2121        .await?
2122        .into_iter()
2123        .map(|user| proto::User {
2124            id: user.id.to_proto(),
2125            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2126            github_login: user.github_login,
2127        })
2128        .collect();
2129    response.send(proto::UsersResponse { users })?;
2130    Ok(())
2131}
2132
2133/// Search for users (to invite) buy Github login
2134async fn fuzzy_search_users(
2135    request: proto::FuzzySearchUsers,
2136    response: Response<proto::FuzzySearchUsers>,
2137    session: Session,
2138) -> Result<()> {
2139    let query = request.query;
2140    let users = match query.len() {
2141        0 => vec![],
2142        1 | 2 => session
2143            .db()
2144            .await
2145            .get_user_by_github_login(&query)
2146            .await?
2147            .into_iter()
2148            .collect(),
2149        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2150    };
2151    let users = users
2152        .into_iter()
2153        .filter(|user| user.id != session.user_id)
2154        .map(|user| proto::User {
2155            id: user.id.to_proto(),
2156            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2157            github_login: user.github_login,
2158        })
2159        .collect();
2160    response.send(proto::UsersResponse { users })?;
2161    Ok(())
2162}
2163
2164/// Send a contact request to another user.
2165async fn request_contact(
2166    request: proto::RequestContact,
2167    response: Response<proto::RequestContact>,
2168    session: Session,
2169) -> Result<()> {
2170    let requester_id = session.user_id;
2171    let responder_id = UserId::from_proto(request.responder_id);
2172    if requester_id == responder_id {
2173        return Err(anyhow!("cannot add yourself as a contact"))?;
2174    }
2175
2176    let notifications = session
2177        .db()
2178        .await
2179        .send_contact_request(requester_id, responder_id)
2180        .await?;
2181
2182    // Update outgoing contact requests of requester
2183    let mut update = proto::UpdateContacts::default();
2184    update.outgoing_requests.push(responder_id.to_proto());
2185    for connection_id in session
2186        .connection_pool()
2187        .await
2188        .user_connection_ids(requester_id)
2189    {
2190        session.peer.send(connection_id, update.clone())?;
2191    }
2192
2193    // Update incoming contact requests of responder
2194    let mut update = proto::UpdateContacts::default();
2195    update
2196        .incoming_requests
2197        .push(proto::IncomingContactRequest {
2198            requester_id: requester_id.to_proto(),
2199        });
2200    let connection_pool = session.connection_pool().await;
2201    for connection_id in connection_pool.user_connection_ids(responder_id) {
2202        session.peer.send(connection_id, update.clone())?;
2203    }
2204
2205    send_notifications(&*connection_pool, &session.peer, notifications);
2206
2207    response.send(proto::Ack {})?;
2208    Ok(())
2209}
2210
2211/// Accept or decline a contact request
2212async fn respond_to_contact_request(
2213    request: proto::RespondToContactRequest,
2214    response: Response<proto::RespondToContactRequest>,
2215    session: Session,
2216) -> Result<()> {
2217    let responder_id = session.user_id;
2218    let requester_id = UserId::from_proto(request.requester_id);
2219    let db = session.db().await;
2220    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2221        db.dismiss_contact_notification(responder_id, requester_id)
2222            .await?;
2223    } else {
2224        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2225
2226        let notifications = db
2227            .respond_to_contact_request(responder_id, requester_id, accept)
2228            .await?;
2229        let requester_busy = db.is_user_busy(requester_id).await?;
2230        let responder_busy = db.is_user_busy(responder_id).await?;
2231
2232        let pool = session.connection_pool().await;
2233        // Update responder with new contact
2234        let mut update = proto::UpdateContacts::default();
2235        if accept {
2236            update
2237                .contacts
2238                .push(contact_for_user(requester_id, requester_busy, &pool));
2239        }
2240        update
2241            .remove_incoming_requests
2242            .push(requester_id.to_proto());
2243        for connection_id in pool.user_connection_ids(responder_id) {
2244            session.peer.send(connection_id, update.clone())?;
2245        }
2246
2247        // Update requester with new contact
2248        let mut update = proto::UpdateContacts::default();
2249        if accept {
2250            update
2251                .contacts
2252                .push(contact_for_user(responder_id, responder_busy, &pool));
2253        }
2254        update
2255            .remove_outgoing_requests
2256            .push(responder_id.to_proto());
2257
2258        for connection_id in pool.user_connection_ids(requester_id) {
2259            session.peer.send(connection_id, update.clone())?;
2260        }
2261
2262        send_notifications(&*pool, &session.peer, notifications);
2263    }
2264
2265    response.send(proto::Ack {})?;
2266    Ok(())
2267}
2268
2269/// Remove a contact.
2270async fn remove_contact(
2271    request: proto::RemoveContact,
2272    response: Response<proto::RemoveContact>,
2273    session: Session,
2274) -> Result<()> {
2275    let requester_id = session.user_id;
2276    let responder_id = UserId::from_proto(request.user_id);
2277    let db = session.db().await;
2278    let (contact_accepted, deleted_notification_id) =
2279        db.remove_contact(requester_id, responder_id).await?;
2280
2281    let pool = session.connection_pool().await;
2282    // Update outgoing contact requests of requester
2283    let mut update = proto::UpdateContacts::default();
2284    if contact_accepted {
2285        update.remove_contacts.push(responder_id.to_proto());
2286    } else {
2287        update
2288            .remove_outgoing_requests
2289            .push(responder_id.to_proto());
2290    }
2291    for connection_id in pool.user_connection_ids(requester_id) {
2292        session.peer.send(connection_id, update.clone())?;
2293    }
2294
2295    // Update incoming contact requests of responder
2296    let mut update = proto::UpdateContacts::default();
2297    if contact_accepted {
2298        update.remove_contacts.push(requester_id.to_proto());
2299    } else {
2300        update
2301            .remove_incoming_requests
2302            .push(requester_id.to_proto());
2303    }
2304    for connection_id in pool.user_connection_ids(responder_id) {
2305        session.peer.send(connection_id, update.clone())?;
2306        if let Some(notification_id) = deleted_notification_id {
2307            session.peer.send(
2308                connection_id,
2309                proto::DeleteNotification {
2310                    notification_id: notification_id.to_proto(),
2311                },
2312            )?;
2313        }
2314    }
2315
2316    response.send(proto::Ack {})?;
2317    Ok(())
2318}
2319
2320/// Creates a new channel.
2321async fn create_channel(
2322    request: proto::CreateChannel,
2323    response: Response<proto::CreateChannel>,
2324    session: Session,
2325) -> Result<()> {
2326    let db = session.db().await;
2327
2328    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2329    let (channel, owner, channel_members) = db
2330        .create_channel(&request.name, parent_id, session.user_id)
2331        .await?;
2332
2333    response.send(proto::CreateChannelResponse {
2334        channel: Some(channel.to_proto()),
2335        parent_id: request.parent_id,
2336    })?;
2337
2338    let connection_pool = session.connection_pool().await;
2339    if let Some(owner) = owner {
2340        let update = proto::UpdateUserChannels {
2341            channel_memberships: vec![proto::ChannelMembership {
2342                channel_id: owner.channel_id.to_proto(),
2343                role: owner.role.into(),
2344            }],
2345            ..Default::default()
2346        };
2347        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2348            session.peer.send(connection_id, update.clone())?;
2349        }
2350    }
2351
2352    for channel_member in channel_members {
2353        if !channel_member.role.can_see_channel(channel.visibility) {
2354            continue;
2355        }
2356
2357        let update = proto::UpdateChannels {
2358            channels: vec![channel.to_proto()],
2359            ..Default::default()
2360        };
2361        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2362            session.peer.send(connection_id, update.clone())?;
2363        }
2364    }
2365
2366    Ok(())
2367}
2368
2369/// Delete a channel
2370async fn delete_channel(
2371    request: proto::DeleteChannel,
2372    response: Response<proto::DeleteChannel>,
2373    session: Session,
2374) -> Result<()> {
2375    let db = session.db().await;
2376
2377    let channel_id = request.channel_id;
2378    let (removed_channels, member_ids) = db
2379        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2380        .await?;
2381    response.send(proto::Ack {})?;
2382
2383    // Notify members of removed channels
2384    let mut update = proto::UpdateChannels::default();
2385    update
2386        .delete_channels
2387        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2388
2389    let connection_pool = session.connection_pool().await;
2390    for member_id in member_ids {
2391        for connection_id in connection_pool.user_connection_ids(member_id) {
2392            session.peer.send(connection_id, update.clone())?;
2393        }
2394    }
2395
2396    Ok(())
2397}
2398
2399/// Invite someone to join a channel.
2400async fn invite_channel_member(
2401    request: proto::InviteChannelMember,
2402    response: Response<proto::InviteChannelMember>,
2403    session: Session,
2404) -> Result<()> {
2405    let db = session.db().await;
2406    let channel_id = ChannelId::from_proto(request.channel_id);
2407    let invitee_id = UserId::from_proto(request.user_id);
2408    let InviteMemberResult {
2409        channel,
2410        notifications,
2411    } = db
2412        .invite_channel_member(
2413            channel_id,
2414            invitee_id,
2415            session.user_id,
2416            request.role().into(),
2417        )
2418        .await?;
2419
2420    let update = proto::UpdateChannels {
2421        channel_invitations: vec![channel.to_proto()],
2422        ..Default::default()
2423    };
2424
2425    let connection_pool = session.connection_pool().await;
2426    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2427        session.peer.send(connection_id, update.clone())?;
2428    }
2429
2430    send_notifications(&*connection_pool, &session.peer, notifications);
2431
2432    response.send(proto::Ack {})?;
2433    Ok(())
2434}
2435
2436/// remove someone from a channel
2437async fn remove_channel_member(
2438    request: proto::RemoveChannelMember,
2439    response: Response<proto::RemoveChannelMember>,
2440    session: Session,
2441) -> Result<()> {
2442    let db = session.db().await;
2443    let channel_id = ChannelId::from_proto(request.channel_id);
2444    let member_id = UserId::from_proto(request.user_id);
2445
2446    let RemoveChannelMemberResult {
2447        membership_update,
2448        notification_id,
2449    } = db
2450        .remove_channel_member(channel_id, member_id, session.user_id)
2451        .await?;
2452
2453    let connection_pool = &session.connection_pool().await;
2454    notify_membership_updated(
2455        &connection_pool,
2456        membership_update,
2457        member_id,
2458        &session.peer,
2459    );
2460    for connection_id in connection_pool.user_connection_ids(member_id) {
2461        if let Some(notification_id) = notification_id {
2462            session
2463                .peer
2464                .send(
2465                    connection_id,
2466                    proto::DeleteNotification {
2467                        notification_id: notification_id.to_proto(),
2468                    },
2469                )
2470                .trace_err();
2471        }
2472    }
2473
2474    response.send(proto::Ack {})?;
2475    Ok(())
2476}
2477
2478/// Toggle the channel between public and private.
2479/// Care is taken to maintain the invariant that public channels only descend from public channels,
2480/// (though members-only channels can appear at any point in the hierarchy).
2481async fn set_channel_visibility(
2482    request: proto::SetChannelVisibility,
2483    response: Response<proto::SetChannelVisibility>,
2484    session: Session,
2485) -> Result<()> {
2486    let db = session.db().await;
2487    let channel_id = ChannelId::from_proto(request.channel_id);
2488    let visibility = request.visibility().into();
2489
2490    let (channel, channel_members) = db
2491        .set_channel_visibility(channel_id, visibility, session.user_id)
2492        .await?;
2493
2494    let connection_pool = session.connection_pool().await;
2495    for member in channel_members {
2496        let update = if member.role.can_see_channel(channel.visibility) {
2497            proto::UpdateChannels {
2498                channels: vec![channel.to_proto()],
2499                ..Default::default()
2500            }
2501        } else {
2502            proto::UpdateChannels {
2503                delete_channels: vec![channel.id.to_proto()],
2504                ..Default::default()
2505            }
2506        };
2507
2508        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2509            session.peer.send(connection_id, update.clone())?;
2510        }
2511    }
2512
2513    response.send(proto::Ack {})?;
2514    Ok(())
2515}
2516
2517/// Alter the role for a user in the channel.
2518async fn set_channel_member_role(
2519    request: proto::SetChannelMemberRole,
2520    response: Response<proto::SetChannelMemberRole>,
2521    session: Session,
2522) -> Result<()> {
2523    let db = session.db().await;
2524    let channel_id = ChannelId::from_proto(request.channel_id);
2525    let member_id = UserId::from_proto(request.user_id);
2526    let result = db
2527        .set_channel_member_role(
2528            channel_id,
2529            session.user_id,
2530            member_id,
2531            request.role().into(),
2532        )
2533        .await?;
2534
2535    match result {
2536        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2537            let connection_pool = session.connection_pool().await;
2538            notify_membership_updated(
2539                &connection_pool,
2540                membership_update,
2541                member_id,
2542                &session.peer,
2543            )
2544        }
2545        db::SetMemberRoleResult::InviteUpdated(channel) => {
2546            let update = proto::UpdateChannels {
2547                channel_invitations: vec![channel.to_proto()],
2548                ..Default::default()
2549            };
2550
2551            for connection_id in session
2552                .connection_pool()
2553                .await
2554                .user_connection_ids(member_id)
2555            {
2556                session.peer.send(connection_id, update.clone())?;
2557            }
2558        }
2559    }
2560
2561    response.send(proto::Ack {})?;
2562    Ok(())
2563}
2564
2565/// Change the name of a channel
2566async fn rename_channel(
2567    request: proto::RenameChannel,
2568    response: Response<proto::RenameChannel>,
2569    session: Session,
2570) -> Result<()> {
2571    let db = session.db().await;
2572    let channel_id = ChannelId::from_proto(request.channel_id);
2573    let (channel, channel_members) = db
2574        .rename_channel(channel_id, session.user_id, &request.name)
2575        .await?;
2576
2577    response.send(proto::RenameChannelResponse {
2578        channel: Some(channel.to_proto()),
2579    })?;
2580
2581    let connection_pool = session.connection_pool().await;
2582    for channel_member in channel_members {
2583        if !channel_member.role.can_see_channel(channel.visibility) {
2584            continue;
2585        }
2586        let update = proto::UpdateChannels {
2587            channels: vec![channel.to_proto()],
2588            ..Default::default()
2589        };
2590        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2591            session.peer.send(connection_id, update.clone())?;
2592        }
2593    }
2594
2595    Ok(())
2596}
2597
2598/// Move a channel to a new parent.
2599async fn move_channel(
2600    request: proto::MoveChannel,
2601    response: Response<proto::MoveChannel>,
2602    session: Session,
2603) -> Result<()> {
2604    let channel_id = ChannelId::from_proto(request.channel_id);
2605    let to = ChannelId::from_proto(request.to);
2606
2607    let (channels, channel_members) = session
2608        .db()
2609        .await
2610        .move_channel(channel_id, to, session.user_id)
2611        .await?;
2612
2613    let connection_pool = session.connection_pool().await;
2614    for member in channel_members {
2615        let channels = channels
2616            .iter()
2617            .filter_map(|channel| {
2618                if member.role.can_see_channel(channel.visibility) {
2619                    Some(channel.to_proto())
2620                } else {
2621                    None
2622                }
2623            })
2624            .collect::<Vec<_>>();
2625        if channels.is_empty() {
2626            continue;
2627        }
2628
2629        let update = proto::UpdateChannels {
2630            channels,
2631            ..Default::default()
2632        };
2633
2634        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2635            session.peer.send(connection_id, update.clone())?;
2636        }
2637    }
2638
2639    response.send(Ack {})?;
2640    Ok(())
2641}
2642
2643/// Get the list of channel members
2644async fn get_channel_members(
2645    request: proto::GetChannelMembers,
2646    response: Response<proto::GetChannelMembers>,
2647    session: Session,
2648) -> Result<()> {
2649    let db = session.db().await;
2650    let channel_id = ChannelId::from_proto(request.channel_id);
2651    let members = db
2652        .get_channel_participant_details(channel_id, session.user_id)
2653        .await?;
2654    response.send(proto::GetChannelMembersResponse { members })?;
2655    Ok(())
2656}
2657
2658/// Accept or decline a channel invitation.
2659async fn respond_to_channel_invite(
2660    request: proto::RespondToChannelInvite,
2661    response: Response<proto::RespondToChannelInvite>,
2662    session: Session,
2663) -> Result<()> {
2664    let db = session.db().await;
2665    let channel_id = ChannelId::from_proto(request.channel_id);
2666    let RespondToChannelInvite {
2667        membership_update,
2668        notifications,
2669    } = db
2670        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2671        .await?;
2672
2673    let connection_pool = session.connection_pool().await;
2674    if let Some(membership_update) = membership_update {
2675        notify_membership_updated(
2676            &connection_pool,
2677            membership_update,
2678            session.user_id,
2679            &session.peer,
2680        );
2681    } else {
2682        let update = proto::UpdateChannels {
2683            remove_channel_invitations: vec![channel_id.to_proto()],
2684            ..Default::default()
2685        };
2686
2687        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2688            session.peer.send(connection_id, update.clone())?;
2689        }
2690    };
2691
2692    send_notifications(&*connection_pool, &session.peer, notifications);
2693
2694    response.send(proto::Ack {})?;
2695
2696    Ok(())
2697}
2698
2699/// Join the channels' room
2700async fn join_channel(
2701    request: proto::JoinChannel,
2702    response: Response<proto::JoinChannel>,
2703    session: Session,
2704) -> Result<()> {
2705    let channel_id = ChannelId::from_proto(request.channel_id);
2706    join_channel_internal(channel_id, Box::new(response), session).await
2707}
2708
2709trait JoinChannelInternalResponse {
2710    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2711}
2712impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2713    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2714        Response::<proto::JoinChannel>::send(self, result)
2715    }
2716}
2717impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2718    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2719        Response::<proto::JoinRoom>::send(self, result)
2720    }
2721}
2722
2723async fn join_channel_internal(
2724    channel_id: ChannelId,
2725    response: Box<impl JoinChannelInternalResponse>,
2726    session: Session,
2727) -> Result<()> {
2728    let joined_room = {
2729        leave_room_for_session(&session).await?;
2730        let db = session.db().await;
2731
2732        let (joined_room, membership_updated, role) = db
2733            .join_channel(channel_id, session.user_id, session.connection_id)
2734            .await?;
2735
2736        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2737            let (can_publish, token) = if role == ChannelRole::Guest {
2738                (
2739                    false,
2740                    live_kit
2741                        .guest_token(
2742                            &joined_room.room.live_kit_room,
2743                            &session.user_id.to_string(),
2744                        )
2745                        .trace_err()?,
2746                )
2747            } else {
2748                (
2749                    true,
2750                    live_kit
2751                        .room_token(
2752                            &joined_room.room.live_kit_room,
2753                            &session.user_id.to_string(),
2754                        )
2755                        .trace_err()?,
2756                )
2757            };
2758
2759            Some(LiveKitConnectionInfo {
2760                server_url: live_kit.url().into(),
2761                token,
2762                can_publish,
2763            })
2764        });
2765
2766        response.send(proto::JoinRoomResponse {
2767            room: Some(joined_room.room.clone()),
2768            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2769            live_kit_connection_info,
2770        })?;
2771
2772        let connection_pool = session.connection_pool().await;
2773        if let Some(membership_updated) = membership_updated {
2774            notify_membership_updated(
2775                &connection_pool,
2776                membership_updated,
2777                session.user_id,
2778                &session.peer,
2779            );
2780        }
2781
2782        room_updated(&joined_room.room, &session.peer);
2783
2784        joined_room
2785    };
2786
2787    channel_updated(
2788        channel_id,
2789        &joined_room.room,
2790        &joined_room.channel_members,
2791        &session.peer,
2792        &*session.connection_pool().await,
2793    );
2794
2795    update_user_contacts(session.user_id, &session).await?;
2796    Ok(())
2797}
2798
2799/// Start editing the channel notes
2800async fn join_channel_buffer(
2801    request: proto::JoinChannelBuffer,
2802    response: Response<proto::JoinChannelBuffer>,
2803    session: Session,
2804) -> Result<()> {
2805    let db = session.db().await;
2806    let channel_id = ChannelId::from_proto(request.channel_id);
2807
2808    let open_response = db
2809        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2810        .await?;
2811
2812    let collaborators = open_response.collaborators.clone();
2813    response.send(open_response)?;
2814
2815    let update = UpdateChannelBufferCollaborators {
2816        channel_id: channel_id.to_proto(),
2817        collaborators: collaborators.clone(),
2818    };
2819    channel_buffer_updated(
2820        session.connection_id,
2821        collaborators
2822            .iter()
2823            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2824        &update,
2825        &session.peer,
2826    );
2827
2828    Ok(())
2829}
2830
2831/// Edit the channel notes
2832async fn update_channel_buffer(
2833    request: proto::UpdateChannelBuffer,
2834    session: Session,
2835) -> Result<()> {
2836    let db = session.db().await;
2837    let channel_id = ChannelId::from_proto(request.channel_id);
2838
2839    let (collaborators, non_collaborators, epoch, version) = db
2840        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2841        .await?;
2842
2843    channel_buffer_updated(
2844        session.connection_id,
2845        collaborators,
2846        &proto::UpdateChannelBuffer {
2847            channel_id: channel_id.to_proto(),
2848            operations: request.operations,
2849        },
2850        &session.peer,
2851    );
2852
2853    let pool = &*session.connection_pool().await;
2854
2855    broadcast(
2856        None,
2857        non_collaborators
2858            .iter()
2859            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2860        |peer_id| {
2861            session.peer.send(
2862                peer_id.into(),
2863                proto::UpdateChannels {
2864                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2865                        channel_id: channel_id.to_proto(),
2866                        epoch: epoch as u64,
2867                        version: version.clone(),
2868                    }],
2869                    ..Default::default()
2870                },
2871            )
2872        },
2873    );
2874
2875    Ok(())
2876}
2877
2878/// Rejoin the channel notes after a connection blip
2879async fn rejoin_channel_buffers(
2880    request: proto::RejoinChannelBuffers,
2881    response: Response<proto::RejoinChannelBuffers>,
2882    session: Session,
2883) -> Result<()> {
2884    let db = session.db().await;
2885    let buffers = db
2886        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2887        .await?;
2888
2889    for rejoined_buffer in &buffers {
2890        let collaborators_to_notify = rejoined_buffer
2891            .buffer
2892            .collaborators
2893            .iter()
2894            .filter_map(|c| Some(c.peer_id?.into()));
2895        channel_buffer_updated(
2896            session.connection_id,
2897            collaborators_to_notify,
2898            &proto::UpdateChannelBufferCollaborators {
2899                channel_id: rejoined_buffer.buffer.channel_id,
2900                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2901            },
2902            &session.peer,
2903        );
2904    }
2905
2906    response.send(proto::RejoinChannelBuffersResponse {
2907        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2908    })?;
2909
2910    Ok(())
2911}
2912
2913/// Stop editing the channel notes
2914async fn leave_channel_buffer(
2915    request: proto::LeaveChannelBuffer,
2916    response: Response<proto::LeaveChannelBuffer>,
2917    session: Session,
2918) -> Result<()> {
2919    let db = session.db().await;
2920    let channel_id = ChannelId::from_proto(request.channel_id);
2921
2922    let left_buffer = db
2923        .leave_channel_buffer(channel_id, session.connection_id)
2924        .await?;
2925
2926    response.send(Ack {})?;
2927
2928    channel_buffer_updated(
2929        session.connection_id,
2930        left_buffer.connections,
2931        &proto::UpdateChannelBufferCollaborators {
2932            channel_id: channel_id.to_proto(),
2933            collaborators: left_buffer.collaborators,
2934        },
2935        &session.peer,
2936    );
2937
2938    Ok(())
2939}
2940
2941fn channel_buffer_updated<T: EnvelopedMessage>(
2942    sender_id: ConnectionId,
2943    collaborators: impl IntoIterator<Item = ConnectionId>,
2944    message: &T,
2945    peer: &Peer,
2946) {
2947    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2948        peer.send(peer_id.into(), message.clone())
2949    });
2950}
2951
2952fn send_notifications(
2953    connection_pool: &ConnectionPool,
2954    peer: &Peer,
2955    notifications: db::NotificationBatch,
2956) {
2957    for (user_id, notification) in notifications {
2958        for connection_id in connection_pool.user_connection_ids(user_id) {
2959            if let Err(error) = peer.send(
2960                connection_id,
2961                proto::AddNotification {
2962                    notification: Some(notification.clone()),
2963                },
2964            ) {
2965                tracing::error!(
2966                    "failed to send notification to {:?} {}",
2967                    connection_id,
2968                    error
2969                );
2970            }
2971        }
2972    }
2973}
2974
2975/// Send a message to the channel
2976async fn send_channel_message(
2977    request: proto::SendChannelMessage,
2978    response: Response<proto::SendChannelMessage>,
2979    session: Session,
2980) -> Result<()> {
2981    // Validate the message body.
2982    let body = request.body.trim().to_string();
2983    if body.len() > MAX_MESSAGE_LEN {
2984        return Err(anyhow!("message is too long"))?;
2985    }
2986    if body.is_empty() {
2987        return Err(anyhow!("message can't be blank"))?;
2988    }
2989
2990    // TODO: adjust mentions if body is trimmed
2991
2992    let timestamp = OffsetDateTime::now_utc();
2993    let nonce = request
2994        .nonce
2995        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2996
2997    let channel_id = ChannelId::from_proto(request.channel_id);
2998    let CreatedChannelMessage {
2999        message_id,
3000        participant_connection_ids,
3001        channel_members,
3002        notifications,
3003    } = session
3004        .db()
3005        .await
3006        .create_channel_message(
3007            channel_id,
3008            session.user_id,
3009            &body,
3010            &request.mentions,
3011            timestamp,
3012            nonce.clone().into(),
3013            match request.reply_to_message_id {
3014                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3015                None => None,
3016            },
3017        )
3018        .await?;
3019    let message = proto::ChannelMessage {
3020        sender_id: session.user_id.to_proto(),
3021        id: message_id.to_proto(),
3022        body,
3023        mentions: request.mentions,
3024        timestamp: timestamp.unix_timestamp() as u64,
3025        nonce: Some(nonce),
3026        reply_to_message_id: request.reply_to_message_id,
3027    };
3028    broadcast(
3029        Some(session.connection_id),
3030        participant_connection_ids,
3031        |connection| {
3032            session.peer.send(
3033                connection,
3034                proto::ChannelMessageSent {
3035                    channel_id: channel_id.to_proto(),
3036                    message: Some(message.clone()),
3037                },
3038            )
3039        },
3040    );
3041    response.send(proto::SendChannelMessageResponse {
3042        message: Some(message),
3043    })?;
3044
3045    let pool = &*session.connection_pool().await;
3046    broadcast(
3047        None,
3048        channel_members
3049            .iter()
3050            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3051        |peer_id| {
3052            session.peer.send(
3053                peer_id.into(),
3054                proto::UpdateChannels {
3055                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3056                        channel_id: channel_id.to_proto(),
3057                        message_id: message_id.to_proto(),
3058                    }],
3059                    ..Default::default()
3060                },
3061            )
3062        },
3063    );
3064    send_notifications(pool, &session.peer, notifications);
3065
3066    Ok(())
3067}
3068
3069/// Delete a channel message
3070async fn remove_channel_message(
3071    request: proto::RemoveChannelMessage,
3072    response: Response<proto::RemoveChannelMessage>,
3073    session: Session,
3074) -> Result<()> {
3075    let channel_id = ChannelId::from_proto(request.channel_id);
3076    let message_id = MessageId::from_proto(request.message_id);
3077    let connection_ids = session
3078        .db()
3079        .await
3080        .remove_channel_message(channel_id, message_id, session.user_id)
3081        .await?;
3082    broadcast(Some(session.connection_id), connection_ids, |connection| {
3083        session.peer.send(connection, request.clone())
3084    });
3085    response.send(proto::Ack {})?;
3086    Ok(())
3087}
3088
3089/// Mark a channel message as read
3090async fn acknowledge_channel_message(
3091    request: proto::AckChannelMessage,
3092    session: Session,
3093) -> Result<()> {
3094    let channel_id = ChannelId::from_proto(request.channel_id);
3095    let message_id = MessageId::from_proto(request.message_id);
3096    let notifications = session
3097        .db()
3098        .await
3099        .observe_channel_message(channel_id, session.user_id, message_id)
3100        .await?;
3101    send_notifications(
3102        &*session.connection_pool().await,
3103        &session.peer,
3104        notifications,
3105    );
3106    Ok(())
3107}
3108
3109/// Mark a buffer version as synced
3110async fn acknowledge_buffer_version(
3111    request: proto::AckBufferOperation,
3112    session: Session,
3113) -> Result<()> {
3114    let buffer_id = BufferId::from_proto(request.buffer_id);
3115    session
3116        .db()
3117        .await
3118        .observe_buffer_version(
3119            buffer_id,
3120            session.user_id,
3121            request.epoch as i32,
3122            &request.version,
3123        )
3124        .await?;
3125    Ok(())
3126}
3127
3128/// Start receiving chat updates for a channel
3129async fn join_channel_chat(
3130    request: proto::JoinChannelChat,
3131    response: Response<proto::JoinChannelChat>,
3132    session: Session,
3133) -> Result<()> {
3134    let channel_id = ChannelId::from_proto(request.channel_id);
3135
3136    let db = session.db().await;
3137    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3138        .await?;
3139    let messages = db
3140        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3141        .await?;
3142    response.send(proto::JoinChannelChatResponse {
3143        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3144        messages,
3145    })?;
3146    Ok(())
3147}
3148
3149/// Stop receiving chat updates for a channel
3150async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3151    let channel_id = ChannelId::from_proto(request.channel_id);
3152    session
3153        .db()
3154        .await
3155        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3156        .await?;
3157    Ok(())
3158}
3159
3160/// Retrieve the chat history for a channel
3161async fn get_channel_messages(
3162    request: proto::GetChannelMessages,
3163    response: Response<proto::GetChannelMessages>,
3164    session: Session,
3165) -> Result<()> {
3166    let channel_id = ChannelId::from_proto(request.channel_id);
3167    let messages = session
3168        .db()
3169        .await
3170        .get_channel_messages(
3171            channel_id,
3172            session.user_id,
3173            MESSAGE_COUNT_PER_PAGE,
3174            Some(MessageId::from_proto(request.before_message_id)),
3175        )
3176        .await?;
3177    response.send(proto::GetChannelMessagesResponse {
3178        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3179        messages,
3180    })?;
3181    Ok(())
3182}
3183
3184/// Retrieve specific chat messages
3185async fn get_channel_messages_by_id(
3186    request: proto::GetChannelMessagesById,
3187    response: Response<proto::GetChannelMessagesById>,
3188    session: Session,
3189) -> Result<()> {
3190    let message_ids = request
3191        .message_ids
3192        .iter()
3193        .map(|id| MessageId::from_proto(*id))
3194        .collect::<Vec<_>>();
3195    let messages = session
3196        .db()
3197        .await
3198        .get_channel_messages_by_id(session.user_id, &message_ids)
3199        .await?;
3200    response.send(proto::GetChannelMessagesResponse {
3201        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3202        messages,
3203    })?;
3204    Ok(())
3205}
3206
3207/// Retrieve the current users notifications
3208async fn get_notifications(
3209    request: proto::GetNotifications,
3210    response: Response<proto::GetNotifications>,
3211    session: Session,
3212) -> Result<()> {
3213    let notifications = session
3214        .db()
3215        .await
3216        .get_notifications(
3217            session.user_id,
3218            NOTIFICATION_COUNT_PER_PAGE,
3219            request
3220                .before_id
3221                .map(|id| db::NotificationId::from_proto(id)),
3222        )
3223        .await?;
3224    response.send(proto::GetNotificationsResponse {
3225        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3226        notifications,
3227    })?;
3228    Ok(())
3229}
3230
3231/// Mark notifications as read
3232async fn mark_notification_as_read(
3233    request: proto::MarkNotificationRead,
3234    response: Response<proto::MarkNotificationRead>,
3235    session: Session,
3236) -> Result<()> {
3237    let database = &session.db().await;
3238    let notifications = database
3239        .mark_notification_as_read_by_id(
3240            session.user_id,
3241            NotificationId::from_proto(request.notification_id),
3242        )
3243        .await?;
3244    send_notifications(
3245        &*session.connection_pool().await,
3246        &session.peer,
3247        notifications,
3248    );
3249    response.send(proto::Ack {})?;
3250    Ok(())
3251}
3252
3253/// Get the current users information
3254async fn get_private_user_info(
3255    _request: proto::GetPrivateUserInfo,
3256    response: Response<proto::GetPrivateUserInfo>,
3257    session: Session,
3258) -> Result<()> {
3259    let db = session.db().await;
3260
3261    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3262    let user = db
3263        .get_user_by_id(session.user_id)
3264        .await?
3265        .ok_or_else(|| anyhow!("user not found"))?;
3266    let flags = db.get_user_flags(session.user_id).await?;
3267
3268    response.send(proto::GetPrivateUserInfoResponse {
3269        metrics_id,
3270        staff: user.admin,
3271        flags,
3272    })?;
3273    Ok(())
3274}
3275
3276fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3277    match message {
3278        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3279        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3280        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3281        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3282        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3283            code: frame.code.into(),
3284            reason: frame.reason,
3285        })),
3286    }
3287}
3288
3289fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3290    match message {
3291        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3292        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3293        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3294        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3295        AxumMessage::Close(frame) => {
3296            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3297                code: frame.code.into(),
3298                reason: frame.reason,
3299            }))
3300        }
3301    }
3302}
3303
3304fn notify_membership_updated(
3305    connection_pool: &ConnectionPool,
3306    result: MembershipUpdated,
3307    user_id: UserId,
3308    peer: &Peer,
3309) {
3310    let user_channels_update = proto::UpdateUserChannels {
3311        channel_memberships: result
3312            .new_channels
3313            .channel_memberships
3314            .iter()
3315            .map(|cm| proto::ChannelMembership {
3316                channel_id: cm.channel_id.to_proto(),
3317                role: cm.role.into(),
3318            })
3319            .collect(),
3320        ..Default::default()
3321    };
3322    let mut update = build_channels_update(result.new_channels, vec![]);
3323    update.delete_channels = result
3324        .removed_channels
3325        .into_iter()
3326        .map(|id| id.to_proto())
3327        .collect();
3328    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3329
3330    for connection_id in connection_pool.user_connection_ids(user_id) {
3331        peer.send(connection_id, user_channels_update.clone())
3332            .trace_err();
3333        peer.send(connection_id, update.clone()).trace_err();
3334    }
3335}
3336
3337fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3338    proto::UpdateUserChannels {
3339        channel_memberships: channels
3340            .channel_memberships
3341            .iter()
3342            .map(|m| proto::ChannelMembership {
3343                channel_id: m.channel_id.to_proto(),
3344                role: m.role.into(),
3345            })
3346            .collect(),
3347        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3348        observed_channel_message_id: channels.observed_channel_messages.clone(),
3349        ..Default::default()
3350    }
3351}
3352
3353fn build_channels_update(
3354    channels: ChannelsForUser,
3355    channel_invites: Vec<db::Channel>,
3356) -> proto::UpdateChannels {
3357    let mut update = proto::UpdateChannels::default();
3358
3359    for channel in channels.channels {
3360        update.channels.push(channel.to_proto());
3361    }
3362
3363    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3364    update.latest_channel_message_ids = channels.latest_channel_messages;
3365
3366    for (channel_id, participants) in channels.channel_participants {
3367        update
3368            .channel_participants
3369            .push(proto::ChannelParticipants {
3370                channel_id: channel_id.to_proto(),
3371                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3372            });
3373    }
3374
3375    for channel in channel_invites {
3376        update.channel_invitations.push(channel.to_proto());
3377    }
3378
3379    update
3380}
3381
3382fn build_initial_contacts_update(
3383    contacts: Vec<db::Contact>,
3384    pool: &ConnectionPool,
3385) -> proto::UpdateContacts {
3386    let mut update = proto::UpdateContacts::default();
3387
3388    for contact in contacts {
3389        match contact {
3390            db::Contact::Accepted { user_id, busy } => {
3391                update.contacts.push(contact_for_user(user_id, busy, &pool));
3392            }
3393            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3394            db::Contact::Incoming { user_id } => {
3395                update
3396                    .incoming_requests
3397                    .push(proto::IncomingContactRequest {
3398                        requester_id: user_id.to_proto(),
3399                    })
3400            }
3401        }
3402    }
3403
3404    update
3405}
3406
3407fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3408    proto::Contact {
3409        user_id: user_id.to_proto(),
3410        online: pool.is_user_online(user_id),
3411        busy,
3412    }
3413}
3414
3415fn room_updated(room: &proto::Room, peer: &Peer) {
3416    broadcast(
3417        None,
3418        room.participants
3419            .iter()
3420            .filter_map(|participant| Some(participant.peer_id?.into())),
3421        |peer_id| {
3422            peer.send(
3423                peer_id.into(),
3424                proto::RoomUpdated {
3425                    room: Some(room.clone()),
3426                },
3427            )
3428        },
3429    );
3430}
3431
3432fn channel_updated(
3433    channel_id: ChannelId,
3434    room: &proto::Room,
3435    channel_members: &[UserId],
3436    peer: &Peer,
3437    pool: &ConnectionPool,
3438) {
3439    let participants = room
3440        .participants
3441        .iter()
3442        .map(|p| p.user_id)
3443        .collect::<Vec<_>>();
3444
3445    broadcast(
3446        None,
3447        channel_members
3448            .iter()
3449            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3450        |peer_id| {
3451            peer.send(
3452                peer_id.into(),
3453                proto::UpdateChannels {
3454                    channel_participants: vec![proto::ChannelParticipants {
3455                        channel_id: channel_id.to_proto(),
3456                        participant_user_ids: participants.clone(),
3457                    }],
3458                    ..Default::default()
3459                },
3460            )
3461        },
3462    );
3463}
3464
3465async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3466    let db = session.db().await;
3467
3468    let contacts = db.get_contacts(user_id).await?;
3469    let busy = db.is_user_busy(user_id).await?;
3470
3471    let pool = session.connection_pool().await;
3472    let updated_contact = contact_for_user(user_id, busy, &pool);
3473    for contact in contacts {
3474        if let db::Contact::Accepted {
3475            user_id: contact_user_id,
3476            ..
3477        } = contact
3478        {
3479            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3480                session
3481                    .peer
3482                    .send(
3483                        contact_conn_id,
3484                        proto::UpdateContacts {
3485                            contacts: vec![updated_contact.clone()],
3486                            remove_contacts: Default::default(),
3487                            incoming_requests: Default::default(),
3488                            remove_incoming_requests: Default::default(),
3489                            outgoing_requests: Default::default(),
3490                            remove_outgoing_requests: Default::default(),
3491                        },
3492                    )
3493                    .trace_err();
3494            }
3495        }
3496    }
3497    Ok(())
3498}
3499
3500async fn leave_room_for_session(session: &Session) -> Result<()> {
3501    let mut contacts_to_update = HashSet::default();
3502
3503    let room_id;
3504    let canceled_calls_to_user_ids;
3505    let live_kit_room;
3506    let delete_live_kit_room;
3507    let room;
3508    let channel_members;
3509    let channel_id;
3510
3511    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3512        contacts_to_update.insert(session.user_id);
3513
3514        for project in left_room.left_projects.values() {
3515            project_left(project, session);
3516        }
3517
3518        room_id = RoomId::from_proto(left_room.room.id);
3519        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3520        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3521        delete_live_kit_room = left_room.deleted;
3522        room = mem::take(&mut left_room.room);
3523        channel_members = mem::take(&mut left_room.channel_members);
3524        channel_id = left_room.channel_id;
3525
3526        room_updated(&room, &session.peer);
3527    } else {
3528        return Ok(());
3529    }
3530
3531    if let Some(channel_id) = channel_id {
3532        channel_updated(
3533            channel_id,
3534            &room,
3535            &channel_members,
3536            &session.peer,
3537            &*session.connection_pool().await,
3538        );
3539    }
3540
3541    {
3542        let pool = session.connection_pool().await;
3543        for canceled_user_id in canceled_calls_to_user_ids {
3544            for connection_id in pool.user_connection_ids(canceled_user_id) {
3545                session
3546                    .peer
3547                    .send(
3548                        connection_id,
3549                        proto::CallCanceled {
3550                            room_id: room_id.to_proto(),
3551                        },
3552                    )
3553                    .trace_err();
3554            }
3555            contacts_to_update.insert(canceled_user_id);
3556        }
3557    }
3558
3559    for contact_user_id in contacts_to_update {
3560        update_user_contacts(contact_user_id, &session).await?;
3561    }
3562
3563    if let Some(live_kit) = session.live_kit_client.as_ref() {
3564        live_kit
3565            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3566            .await
3567            .trace_err();
3568
3569        if delete_live_kit_room {
3570            live_kit.delete_room(live_kit_room).await.trace_err();
3571        }
3572    }
3573
3574    Ok(())
3575}
3576
3577async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3578    let left_channel_buffers = session
3579        .db()
3580        .await
3581        .leave_channel_buffers(session.connection_id)
3582        .await?;
3583
3584    for left_buffer in left_channel_buffers {
3585        channel_buffer_updated(
3586            session.connection_id,
3587            left_buffer.connections,
3588            &proto::UpdateChannelBufferCollaborators {
3589                channel_id: left_buffer.channel_id.to_proto(),
3590                collaborators: left_buffer.collaborators,
3591            },
3592            &session.peer,
3593        );
3594    }
3595
3596    Ok(())
3597}
3598
3599fn project_left(project: &db::LeftProject, session: &Session) {
3600    for connection_id in &project.connection_ids {
3601        if project.host_user_id == session.user_id {
3602            session
3603                .peer
3604                .send(
3605                    *connection_id,
3606                    proto::UnshareProject {
3607                        project_id: project.id.to_proto(),
3608                    },
3609                )
3610                .trace_err();
3611        } else {
3612            session
3613                .peer
3614                .send(
3615                    *connection_id,
3616                    proto::RemoveProjectCollaborator {
3617                        project_id: project.id.to_proto(),
3618                        peer_id: Some(session.connection_id.into()),
3619                    },
3620                )
3621                .trace_err();
3622        }
3623    }
3624}
3625
3626pub trait ResultExt {
3627    type Ok;
3628
3629    fn trace_err(self) -> Option<Self::Ok>;
3630}
3631
3632impl<T, E> ResultExt for Result<T, E>
3633where
3634    E: std::fmt::Debug,
3635{
3636    type Ok = T;
3637
3638    fn trace_err(self) -> Option<T> {
3639        match self {
3640            Ok(value) => Some(value),
3641            Err(error) => {
3642                tracing::error!("{:?}", error);
3643                None
3644            }
3645        }
3646    }
3647}