rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::ConnectionPool;
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use lazy_static::lazy_static;
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67use util::SemanticVersion;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  75
  76lazy_static! {
  77    static ref METRIC_CONNECTIONS: IntGauge =
  78        register_int_gauge!("connections", "number of connections").unwrap();
  79    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  80        "shared_projects",
  81        "number of open projects with one or more guests"
  82    )
  83    .unwrap();
  84}
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103#[derive(Clone)]
 104struct Session {
 105    user_id: UserId,
 106    connection_id: ConnectionId,
 107    db: Arc<tokio::sync::Mutex<DbHandle>>,
 108    peer: Arc<Peer>,
 109    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 110    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 111    _executor: Executor,
 112}
 113
 114impl Session {
 115    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        let guard = self.db.lock().await;
 119        #[cfg(test)]
 120        tokio::task::yield_now().await;
 121        guard
 122    }
 123
 124    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 125        #[cfg(test)]
 126        tokio::task::yield_now().await;
 127        let guard = self.connection_pool.lock();
 128        ConnectionPoolGuard {
 129            guard,
 130            _not_send: PhantomData,
 131        }
 132    }
 133}
 134
 135impl fmt::Debug for Session {
 136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 137        f.debug_struct("Session")
 138            .field("user_id", &self.user_id)
 139            .field("connection_id", &self.connection_id)
 140            .finish()
 141    }
 142}
 143
 144struct DbHandle(Arc<Database>);
 145
 146impl Deref for DbHandle {
 147    type Target = Database;
 148
 149    fn deref(&self) -> &Self::Target {
 150        self.0.as_ref()
 151    }
 152}
 153
 154pub struct Server {
 155    id: parking_lot::Mutex<ServerId>,
 156    peer: Arc<Peer>,
 157    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 158    app_state: Arc<AppState>,
 159    executor: Executor,
 160    handlers: HashMap<TypeId, MessageHandler>,
 161    teardown: watch::Sender<()>,
 162}
 163
 164pub(crate) struct ConnectionPoolGuard<'a> {
 165    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 166    _not_send: PhantomData<Rc<()>>,
 167}
 168
 169#[derive(Serialize)]
 170pub struct ServerSnapshot<'a> {
 171    peer: &'a Peer,
 172    #[serde(serialize_with = "serialize_deref")]
 173    connection_pool: ConnectionPoolGuard<'a>,
 174}
 175
 176pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 177where
 178    S: Serializer,
 179    T: Deref<Target = U>,
 180    U: Serialize,
 181{
 182    Serialize::serialize(value.deref(), serializer)
 183}
 184
 185impl Server {
 186    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 187        let mut server = Self {
 188            id: parking_lot::Mutex::new(id),
 189            peer: Peer::new(id.0 as u32),
 190            app_state,
 191            executor,
 192            connection_pool: Default::default(),
 193            handlers: Default::default(),
 194            teardown: watch::channel(()).0,
 195        };
 196
 197        server
 198            .add_request_handler(ping)
 199            .add_request_handler(create_room)
 200            .add_request_handler(join_room)
 201            .add_request_handler(rejoin_room)
 202            .add_request_handler(leave_room)
 203            .add_request_handler(set_room_participant_role)
 204            .add_request_handler(call)
 205            .add_request_handler(cancel_call)
 206            .add_message_handler(decline_call)
 207            .add_request_handler(update_participant_location)
 208            .add_request_handler(share_project)
 209            .add_message_handler(unshare_project)
 210            .add_request_handler(join_project)
 211            .add_message_handler(leave_project)
 212            .add_request_handler(update_project)
 213            .add_request_handler(update_worktree)
 214            .add_message_handler(start_language_server)
 215            .add_message_handler(update_language_server)
 216            .add_message_handler(update_diagnostic_summary)
 217            .add_message_handler(update_worktree_settings)
 218            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 219            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 222            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 225            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 226            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 227            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 228            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 229            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 230            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 231            .add_request_handler(
 232                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 233            )
 234            .add_request_handler(
 235                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 236            )
 237            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 238            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 239            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 240            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 241            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 242            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 243            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 244            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 249            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 250            .add_message_handler(create_buffer_for_peer)
 251            .add_request_handler(update_buffer)
 252            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 253            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 257            .add_request_handler(get_users)
 258            .add_request_handler(fuzzy_search_users)
 259            .add_request_handler(request_contact)
 260            .add_request_handler(remove_contact)
 261            .add_request_handler(respond_to_contact_request)
 262            .add_request_handler(create_channel)
 263            .add_request_handler(delete_channel)
 264            .add_request_handler(invite_channel_member)
 265            .add_request_handler(remove_channel_member)
 266            .add_request_handler(set_channel_member_role)
 267            .add_request_handler(set_channel_visibility)
 268            .add_request_handler(rename_channel)
 269            .add_request_handler(join_channel_buffer)
 270            .add_request_handler(leave_channel_buffer)
 271            .add_message_handler(update_channel_buffer)
 272            .add_request_handler(rejoin_channel_buffers)
 273            .add_request_handler(get_channel_members)
 274            .add_request_handler(respond_to_channel_invite)
 275            .add_request_handler(join_channel)
 276            .add_request_handler(join_channel_chat)
 277            .add_message_handler(leave_channel_chat)
 278            .add_request_handler(send_channel_message)
 279            .add_request_handler(remove_channel_message)
 280            .add_request_handler(get_channel_messages)
 281            .add_request_handler(get_channel_messages_by_id)
 282            .add_request_handler(get_notifications)
 283            .add_request_handler(mark_notification_as_read)
 284            .add_request_handler(move_channel)
 285            .add_request_handler(follow)
 286            .add_message_handler(unfollow)
 287            .add_message_handler(update_followers)
 288            .add_request_handler(get_private_user_info)
 289            .add_message_handler(acknowledge_channel_message)
 290            .add_message_handler(acknowledge_buffer_version);
 291
 292        Arc::new(server)
 293    }
 294
 295    pub async fn start(&self) -> Result<()> {
 296        let server_id = *self.id.lock();
 297        let app_state = self.app_state.clone();
 298        let peer = self.peer.clone();
 299        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 300        let pool = self.connection_pool.clone();
 301        let live_kit_client = self.app_state.live_kit_client.clone();
 302
 303        let span = info_span!("start server");
 304        self.executor.spawn_detached(
 305            async move {
 306                tracing::info!("waiting for cleanup timeout");
 307                timeout.await;
 308                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 309                if let Some((room_ids, channel_ids)) = app_state
 310                    .db
 311                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 312                    .await
 313                    .trace_err()
 314                {
 315                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 316                    tracing::info!(
 317                        stale_channel_buffer_count = channel_ids.len(),
 318                        "retrieved stale channel buffers"
 319                    );
 320
 321                    for channel_id in channel_ids {
 322                        if let Some(refreshed_channel_buffer) = app_state
 323                            .db
 324                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 325                            .await
 326                            .trace_err()
 327                        {
 328                            for connection_id in refreshed_channel_buffer.connection_ids {
 329                                peer.send(
 330                                    connection_id,
 331                                    proto::UpdateChannelBufferCollaborators {
 332                                        channel_id: channel_id.to_proto(),
 333                                        collaborators: refreshed_channel_buffer
 334                                            .collaborators
 335                                            .clone(),
 336                                    },
 337                                )
 338                                .trace_err();
 339                            }
 340                        }
 341                    }
 342
 343                    for room_id in room_ids {
 344                        let mut contacts_to_update = HashSet::default();
 345                        let mut canceled_calls_to_user_ids = Vec::new();
 346                        let mut live_kit_room = String::new();
 347                        let mut delete_live_kit_room = false;
 348
 349                        if let Some(mut refreshed_room) = app_state
 350                            .db
 351                            .clear_stale_room_participants(room_id, server_id)
 352                            .await
 353                            .trace_err()
 354                        {
 355                            tracing::info!(
 356                                room_id = room_id.0,
 357                                new_participant_count = refreshed_room.room.participants.len(),
 358                                "refreshed room"
 359                            );
 360                            room_updated(&refreshed_room.room, &peer);
 361                            if let Some(channel_id) = refreshed_room.channel_id {
 362                                channel_updated(
 363                                    channel_id,
 364                                    &refreshed_room.room,
 365                                    &refreshed_room.channel_members,
 366                                    &peer,
 367                                    &*pool.lock(),
 368                                );
 369                            }
 370                            contacts_to_update
 371                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 372                            contacts_to_update
 373                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 374                            canceled_calls_to_user_ids =
 375                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 376                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 377                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 378                        }
 379
 380                        {
 381                            let pool = pool.lock();
 382                            for canceled_user_id in canceled_calls_to_user_ids {
 383                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 384                                    peer.send(
 385                                        connection_id,
 386                                        proto::CallCanceled {
 387                                            room_id: room_id.to_proto(),
 388                                        },
 389                                    )
 390                                    .trace_err();
 391                                }
 392                            }
 393                        }
 394
 395                        for user_id in contacts_to_update {
 396                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 397                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 398                            if let Some((busy, contacts)) = busy.zip(contacts) {
 399                                let pool = pool.lock();
 400                                let updated_contact = contact_for_user(user_id, busy, &pool);
 401                                for contact in contacts {
 402                                    if let db::Contact::Accepted {
 403                                        user_id: contact_user_id,
 404                                        ..
 405                                    } = contact
 406                                    {
 407                                        for contact_conn_id in
 408                                            pool.user_connection_ids(contact_user_id)
 409                                        {
 410                                            peer.send(
 411                                                contact_conn_id,
 412                                                proto::UpdateContacts {
 413                                                    contacts: vec![updated_contact.clone()],
 414                                                    remove_contacts: Default::default(),
 415                                                    incoming_requests: Default::default(),
 416                                                    remove_incoming_requests: Default::default(),
 417                                                    outgoing_requests: Default::default(),
 418                                                    remove_outgoing_requests: Default::default(),
 419                                                },
 420                                            )
 421                                            .trace_err();
 422                                        }
 423                                    }
 424                                }
 425                            }
 426                        }
 427
 428                        if let Some(live_kit) = live_kit_client.as_ref() {
 429                            if delete_live_kit_room {
 430                                live_kit.delete_room(live_kit_room).await.trace_err();
 431                            }
 432                        }
 433                    }
 434                }
 435
 436                app_state
 437                    .db
 438                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 439                    .await
 440                    .trace_err();
 441            }
 442            .instrument(span),
 443        );
 444        Ok(())
 445    }
 446
 447    pub fn teardown(&self) {
 448        self.peer.teardown();
 449        self.connection_pool.lock().reset();
 450        let _ = self.teardown.send(());
 451    }
 452
 453    #[cfg(test)]
 454    pub fn reset(&self, id: ServerId) {
 455        self.teardown();
 456        *self.id.lock() = id;
 457        self.peer.reset(id.0 as u32);
 458    }
 459
 460    #[cfg(test)]
 461    pub fn id(&self) -> ServerId {
 462        *self.id.lock()
 463    }
 464
 465    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 466    where
 467        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 468        Fut: 'static + Send + Future<Output = Result<()>>,
 469        M: EnvelopedMessage,
 470    {
 471        let prev_handler = self.handlers.insert(
 472            TypeId::of::<M>(),
 473            Box::new(move |envelope, session| {
 474                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 475                let span = info_span!(
 476                    "handle message",
 477                    payload_type = envelope.payload_type_name()
 478                );
 479                span.in_scope(|| {
 480                    tracing::info!(
 481                        payload_type = envelope.payload_type_name(),
 482                        "message received"
 483                    );
 484                });
 485                let start_time = Instant::now();
 486                let future = (handler)(*envelope, session);
 487                async move {
 488                    let result = future.await;
 489                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 490                    match result {
 491                        Err(error) => {
 492                            tracing::error!(%error, ?duration_ms, "error handling message")
 493                        }
 494                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 495                    }
 496                }
 497                .instrument(span)
 498                .boxed()
 499            }),
 500        );
 501        if prev_handler.is_some() {
 502            panic!("registered a handler for the same message twice");
 503        }
 504        self
 505    }
 506
 507    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 510        Fut: 'static + Send + Future<Output = Result<()>>,
 511        M: EnvelopedMessage,
 512    {
 513        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 514        self
 515    }
 516
 517    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 518    where
 519        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 520        Fut: Send + Future<Output = Result<()>>,
 521        M: RequestMessage,
 522    {
 523        let handler = Arc::new(handler);
 524        self.add_handler(move |envelope, session| {
 525            let receipt = envelope.receipt();
 526            let handler = handler.clone();
 527            async move {
 528                let peer = session.peer.clone();
 529                let responded = Arc::new(AtomicBool::default());
 530                let response = Response {
 531                    peer: peer.clone(),
 532                    responded: responded.clone(),
 533                    receipt,
 534                };
 535                match (handler)(envelope.payload, response, session).await {
 536                    Ok(()) => {
 537                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 538                            Ok(())
 539                        } else {
 540                            Err(anyhow!("handler did not send a response"))?
 541                        }
 542                    }
 543                    Err(error) => {
 544                        let proto_err = match &error {
 545                            Error::Internal(err) => err.to_proto(),
 546                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 547                        };
 548                        peer.respond_with_error(receipt, proto_err)?;
 549                        Err(error)
 550                    }
 551                }
 552            }
 553        })
 554    }
 555
 556    pub fn handle_connection(
 557        self: &Arc<Self>,
 558        connection: Connection,
 559        address: String,
 560        user: User,
 561        impersonator: Option<User>,
 562        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 563        executor: Executor,
 564    ) -> impl Future<Output = Result<()>> {
 565        let this = self.clone();
 566        let user_id = user.id;
 567        let login = user.github_login;
 568        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 569        if let Some(impersonator) = impersonator {
 570            span.record("impersonator", &impersonator.github_login);
 571        }
 572        let mut teardown = self.teardown.subscribe();
 573        async move {
 574            let (connection_id, handle_io, mut incoming_rx) = this
 575                .peer
 576                .add_connection(connection, {
 577                    let executor = executor.clone();
 578                    move |duration| executor.sleep(duration)
 579                });
 580
 581            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 582            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 583            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 584
 585            if let Some(send_connection_id) = send_connection_id.take() {
 586                let _ = send_connection_id.send(connection_id);
 587            }
 588
 589            if !user.connected_once {
 590                this.peer.send(connection_id, proto::ShowContacts {})?;
 591                this.app_state.db.set_user_connected_once(user_id, true).await?;
 592            }
 593
 594            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 595                this.app_state.db.get_contacts(user_id),
 596                this.app_state.db.get_channels_for_user(user_id),
 597                this.app_state.db.get_channel_invites_for_user(user_id),
 598            ).await?;
 599
 600            {
 601                let mut pool = this.connection_pool.lock();
 602                pool.add_connection(connection_id, user_id, user.admin);
 603                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 604                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 605                this.peer.send(connection_id, build_channels_update(
 606                    channels_for_user,
 607                    channel_invites
 608                ))?;
 609            }
 610
 611            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 612                this.peer.send(connection_id, incoming_call)?;
 613            }
 614
 615            let session = Session {
 616                user_id,
 617                connection_id,
 618                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 619                peer: this.peer.clone(),
 620                connection_pool: this.connection_pool.clone(),
 621                live_kit_client: this.app_state.live_kit_client.clone(),
 622                _executor: executor.clone()
 623            };
 624            update_user_contacts(user_id, &session).await?;
 625
 626            let handle_io = handle_io.fuse();
 627            futures::pin_mut!(handle_io);
 628
 629            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 630            // This prevents deadlocks when e.g., client A performs a request to client B and
 631            // client B performs a request to client A. If both clients stop processing further
 632            // messages until their respective request completes, they won't have a chance to
 633            // respond to the other client's request and cause a deadlock.
 634            //
 635            // This arrangement ensures we will attempt to process earlier messages first, but fall
 636            // back to processing messages arrived later in the spirit of making progress.
 637            let mut foreground_message_handlers = FuturesUnordered::new();
 638            let concurrent_handlers = Arc::new(Semaphore::new(256));
 639            loop {
 640                let next_message = async {
 641                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 642                    let message = incoming_rx.next().await;
 643                    (permit, message)
 644                }.fuse();
 645                futures::pin_mut!(next_message);
 646                futures::select_biased! {
 647                    _ = teardown.changed().fuse() => return Ok(()),
 648                    result = handle_io => {
 649                        if let Err(error) = result {
 650                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 651                        }
 652                        break;
 653                    }
 654                    _ = foreground_message_handlers.next() => {}
 655                    next_message = next_message => {
 656                        let (permit, message) = next_message;
 657                        if let Some(message) = message {
 658                            let type_name = message.payload_type_name();
 659                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 660                            let span_enter = span.enter();
 661                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 662                                let is_background = message.is_background();
 663                                let handle_message = (handler)(message, session.clone());
 664                                drop(span_enter);
 665
 666                                let handle_message = async move {
 667                                    handle_message.await;
 668                                    drop(permit);
 669                                }.instrument(span);
 670                                if is_background {
 671                                    executor.spawn_detached(handle_message);
 672                                } else {
 673                                    foreground_message_handlers.push(handle_message);
 674                                }
 675                            } else {
 676                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 677                            }
 678                        } else {
 679                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 680                            break;
 681                        }
 682                    }
 683                }
 684            }
 685
 686            drop(foreground_message_handlers);
 687            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 688            if let Err(error) = connection_lost(session, teardown, executor).await {
 689                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 690            }
 691
 692            Ok(())
 693        }.instrument(span)
 694    }
 695
 696    pub async fn invite_code_redeemed(
 697        self: &Arc<Self>,
 698        inviter_id: UserId,
 699        invitee_id: UserId,
 700    ) -> Result<()> {
 701        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 702            if let Some(code) = &user.invite_code {
 703                let pool = self.connection_pool.lock();
 704                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 705                for connection_id in pool.user_connection_ids(inviter_id) {
 706                    self.peer.send(
 707                        connection_id,
 708                        proto::UpdateContacts {
 709                            contacts: vec![invitee_contact.clone()],
 710                            ..Default::default()
 711                        },
 712                    )?;
 713                    self.peer.send(
 714                        connection_id,
 715                        proto::UpdateInviteInfo {
 716                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 717                            count: user.invite_count as u32,
 718                        },
 719                    )?;
 720                }
 721            }
 722        }
 723        Ok(())
 724    }
 725
 726    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 727        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 728            if let Some(invite_code) = &user.invite_code {
 729                let pool = self.connection_pool.lock();
 730                for connection_id in pool.user_connection_ids(user_id) {
 731                    self.peer.send(
 732                        connection_id,
 733                        proto::UpdateInviteInfo {
 734                            url: format!(
 735                                "{}{}",
 736                                self.app_state.config.invite_link_prefix, invite_code
 737                            ),
 738                            count: user.invite_count as u32,
 739                        },
 740                    )?;
 741                }
 742            }
 743        }
 744        Ok(())
 745    }
 746
 747    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 748        ServerSnapshot {
 749            connection_pool: ConnectionPoolGuard {
 750                guard: self.connection_pool.lock(),
 751                _not_send: PhantomData,
 752            },
 753            peer: &self.peer,
 754        }
 755    }
 756}
 757
 758impl<'a> Deref for ConnectionPoolGuard<'a> {
 759    type Target = ConnectionPool;
 760
 761    fn deref(&self) -> &Self::Target {
 762        &*self.guard
 763    }
 764}
 765
 766impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 767    fn deref_mut(&mut self) -> &mut Self::Target {
 768        &mut *self.guard
 769    }
 770}
 771
 772impl<'a> Drop for ConnectionPoolGuard<'a> {
 773    fn drop(&mut self) {
 774        #[cfg(test)]
 775        self.check_invariants();
 776    }
 777}
 778
 779fn broadcast<F>(
 780    sender_id: Option<ConnectionId>,
 781    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 782    mut f: F,
 783) where
 784    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 785{
 786    for receiver_id in receiver_ids {
 787        if Some(receiver_id) != sender_id {
 788            if let Err(error) = f(receiver_id) {
 789                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 790            }
 791        }
 792    }
 793}
 794
 795lazy_static! {
 796    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 797    static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
 798}
 799
 800pub struct ProtocolVersion(u32);
 801
 802impl Header for ProtocolVersion {
 803    fn name() -> &'static HeaderName {
 804        &ZED_PROTOCOL_VERSION
 805    }
 806
 807    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 808    where
 809        Self: Sized,
 810        I: Iterator<Item = &'i axum::http::HeaderValue>,
 811    {
 812        let version = values
 813            .next()
 814            .ok_or_else(axum::headers::Error::invalid)?
 815            .to_str()
 816            .map_err(|_| axum::headers::Error::invalid())?
 817            .parse()
 818            .map_err(|_| axum::headers::Error::invalid())?;
 819        Ok(Self(version))
 820    }
 821
 822    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 823        values.extend([self.0.to_string().parse().unwrap()]);
 824    }
 825}
 826
 827pub struct AppVersionHeader(SemanticVersion);
 828impl Header for AppVersionHeader {
 829    fn name() -> &'static HeaderName {
 830        &ZED_APP_VERSION
 831    }
 832
 833    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 834    where
 835        Self: Sized,
 836        I: Iterator<Item = &'i axum::http::HeaderValue>,
 837    {
 838        let version = values
 839            .next()
 840            .ok_or_else(axum::headers::Error::invalid)?
 841            .to_str()
 842            .map_err(|_| axum::headers::Error::invalid())?
 843            .parse()
 844            .map_err(|_| axum::headers::Error::invalid())?;
 845        Ok(Self(version))
 846    }
 847
 848    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 849        values.extend([self.0.to_string().parse().unwrap()]);
 850    }
 851}
 852
 853pub fn routes(server: Arc<Server>) -> Router<Body> {
 854    Router::new()
 855        .route("/rpc", get(handle_websocket_request))
 856        .layer(
 857            ServiceBuilder::new()
 858                .layer(Extension(server.app_state.clone()))
 859                .layer(middleware::from_fn(auth::validate_header)),
 860        )
 861        .route("/metrics", get(handle_metrics))
 862        .layer(Extension(server))
 863}
 864
 865pub async fn handle_websocket_request(
 866    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 867    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 868    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 869    Extension(server): Extension<Arc<Server>>,
 870    Extension(user): Extension<User>,
 871    Extension(impersonator): Extension<Impersonator>,
 872    ws: WebSocketUpgrade,
 873) -> axum::response::Response {
 874    if protocol_version != rpc::PROTOCOL_VERSION {
 875        return (
 876            StatusCode::UPGRADE_REQUIRED,
 877            "client must be upgraded".to_string(),
 878        )
 879            .into_response();
 880    }
 881
 882    // the first version of zed that sent this header was 0.121.x
 883    if let Some(version) = app_version_header.map(|header| header.0 .0) {
 884        // 0.123.0 was a nightly version with incompatible collab changes
 885        // that were reverted.
 886        if version == "0.123.0".parse().unwrap() {
 887            return (
 888                StatusCode::UPGRADE_REQUIRED,
 889                "client must be upgraded".to_string(),
 890            )
 891                .into_response();
 892        }
 893    }
 894
 895    let socket_address = socket_address.to_string();
 896    ws.on_upgrade(move |socket| {
 897        use util::ResultExt;
 898        let socket = socket
 899            .map_ok(to_tungstenite_message)
 900            .err_into()
 901            .with(|message| async move { Ok(to_axum_message(message)) });
 902        let connection = Connection::new(Box::pin(socket));
 903        async move {
 904            server
 905                .handle_connection(
 906                    connection,
 907                    socket_address,
 908                    user,
 909                    impersonator.0,
 910                    None,
 911                    Executor::Production,
 912                )
 913                .await
 914                .log_err();
 915        }
 916    })
 917}
 918
 919pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 920    let connections = server
 921        .connection_pool
 922        .lock()
 923        .connections()
 924        .filter(|connection| !connection.admin)
 925        .count();
 926
 927    METRIC_CONNECTIONS.set(connections as _);
 928
 929    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 930    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 931
 932    let encoder = prometheus::TextEncoder::new();
 933    let metric_families = prometheus::gather();
 934    let encoded_metrics = encoder
 935        .encode_to_string(&metric_families)
 936        .map_err(|err| anyhow!("{}", err))?;
 937    Ok(encoded_metrics)
 938}
 939
 940#[instrument(err, skip(executor))]
 941async fn connection_lost(
 942    session: Session,
 943    mut teardown: watch::Receiver<()>,
 944    executor: Executor,
 945) -> Result<()> {
 946    session.peer.disconnect(session.connection_id);
 947    session
 948        .connection_pool()
 949        .await
 950        .remove_connection(session.connection_id)?;
 951
 952    session
 953        .db()
 954        .await
 955        .connection_lost(session.connection_id)
 956        .await
 957        .trace_err();
 958
 959    futures::select_biased! {
 960        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 961            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 962            leave_room_for_session(&session).await.trace_err();
 963            leave_channel_buffers_for_session(&session)
 964                .await
 965                .trace_err();
 966
 967            if !session
 968                .connection_pool()
 969                .await
 970                .is_user_online(session.user_id)
 971            {
 972                let db = session.db().await;
 973                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 974                    room_updated(&room, &session.peer);
 975                }
 976            }
 977
 978            update_user_contacts(session.user_id, &session).await?;
 979        }
 980        _ = teardown.changed().fuse() => {}
 981    }
 982
 983    Ok(())
 984}
 985
 986/// Acknowledges a ping from a client, used to keep the connection alive.
 987async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 988    response.send(proto::Ack {})?;
 989    Ok(())
 990}
 991
 992/// Creates a new room for calling (outside of channels)
 993async fn create_room(
 994    _request: proto::CreateRoom,
 995    response: Response<proto::CreateRoom>,
 996    session: Session,
 997) -> Result<()> {
 998    let live_kit_room = nanoid::nanoid!(30);
 999
1000    let live_kit_connection_info = {
1001        let live_kit_room = live_kit_room.clone();
1002        let live_kit = session.live_kit_client.as_ref();
1003
1004        util::async_maybe!({
1005            let live_kit = live_kit?;
1006
1007            let token = live_kit
1008                .room_token(&live_kit_room, &session.user_id.to_string())
1009                .trace_err()?;
1010
1011            Some(proto::LiveKitConnectionInfo {
1012                server_url: live_kit.url().into(),
1013                token,
1014                can_publish: true,
1015            })
1016        })
1017    }
1018    .await;
1019
1020    let room = session
1021        .db()
1022        .await
1023        .create_room(session.user_id, session.connection_id, &live_kit_room)
1024        .await?;
1025
1026    response.send(proto::CreateRoomResponse {
1027        room: Some(room.clone()),
1028        live_kit_connection_info,
1029    })?;
1030
1031    update_user_contacts(session.user_id, &session).await?;
1032    Ok(())
1033}
1034
1035/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1036async fn join_room(
1037    request: proto::JoinRoom,
1038    response: Response<proto::JoinRoom>,
1039    session: Session,
1040) -> Result<()> {
1041    let room_id = RoomId::from_proto(request.id);
1042
1043    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1044
1045    if let Some(channel_id) = channel_id {
1046        return join_channel_internal(channel_id, Box::new(response), session).await;
1047    }
1048
1049    let joined_room = {
1050        let room = session
1051            .db()
1052            .await
1053            .join_room(room_id, session.user_id, session.connection_id)
1054            .await?;
1055        room_updated(&room.room, &session.peer);
1056        room.into_inner()
1057    };
1058
1059    for connection_id in session
1060        .connection_pool()
1061        .await
1062        .user_connection_ids(session.user_id)
1063    {
1064        session
1065            .peer
1066            .send(
1067                connection_id,
1068                proto::CallCanceled {
1069                    room_id: room_id.to_proto(),
1070                },
1071            )
1072            .trace_err();
1073    }
1074
1075    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1076        if let Some(token) = live_kit
1077            .room_token(
1078                &joined_room.room.live_kit_room,
1079                &session.user_id.to_string(),
1080            )
1081            .trace_err()
1082        {
1083            Some(proto::LiveKitConnectionInfo {
1084                server_url: live_kit.url().into(),
1085                token,
1086                can_publish: true,
1087            })
1088        } else {
1089            None
1090        }
1091    } else {
1092        None
1093    };
1094
1095    response.send(proto::JoinRoomResponse {
1096        room: Some(joined_room.room),
1097        channel_id: None,
1098        live_kit_connection_info,
1099    })?;
1100
1101    update_user_contacts(session.user_id, &session).await?;
1102    Ok(())
1103}
1104
1105/// Rejoin room is used to reconnect to a room after connection errors.
1106async fn rejoin_room(
1107    request: proto::RejoinRoom,
1108    response: Response<proto::RejoinRoom>,
1109    session: Session,
1110) -> Result<()> {
1111    let room;
1112    let channel_id;
1113    let channel_members;
1114    {
1115        let mut rejoined_room = session
1116            .db()
1117            .await
1118            .rejoin_room(request, session.user_id, session.connection_id)
1119            .await?;
1120
1121        response.send(proto::RejoinRoomResponse {
1122            room: Some(rejoined_room.room.clone()),
1123            reshared_projects: rejoined_room
1124                .reshared_projects
1125                .iter()
1126                .map(|project| proto::ResharedProject {
1127                    id: project.id.to_proto(),
1128                    collaborators: project
1129                        .collaborators
1130                        .iter()
1131                        .map(|collaborator| collaborator.to_proto())
1132                        .collect(),
1133                })
1134                .collect(),
1135            rejoined_projects: rejoined_room
1136                .rejoined_projects
1137                .iter()
1138                .map(|rejoined_project| proto::RejoinedProject {
1139                    id: rejoined_project.id.to_proto(),
1140                    worktrees: rejoined_project
1141                        .worktrees
1142                        .iter()
1143                        .map(|worktree| proto::WorktreeMetadata {
1144                            id: worktree.id,
1145                            root_name: worktree.root_name.clone(),
1146                            visible: worktree.visible,
1147                            abs_path: worktree.abs_path.clone(),
1148                        })
1149                        .collect(),
1150                    collaborators: rejoined_project
1151                        .collaborators
1152                        .iter()
1153                        .map(|collaborator| collaborator.to_proto())
1154                        .collect(),
1155                    language_servers: rejoined_project.language_servers.clone(),
1156                })
1157                .collect(),
1158        })?;
1159        room_updated(&rejoined_room.room, &session.peer);
1160
1161        for project in &rejoined_room.reshared_projects {
1162            for collaborator in &project.collaborators {
1163                session
1164                    .peer
1165                    .send(
1166                        collaborator.connection_id,
1167                        proto::UpdateProjectCollaborator {
1168                            project_id: project.id.to_proto(),
1169                            old_peer_id: Some(project.old_connection_id.into()),
1170                            new_peer_id: Some(session.connection_id.into()),
1171                        },
1172                    )
1173                    .trace_err();
1174            }
1175
1176            broadcast(
1177                Some(session.connection_id),
1178                project
1179                    .collaborators
1180                    .iter()
1181                    .map(|collaborator| collaborator.connection_id),
1182                |connection_id| {
1183                    session.peer.forward_send(
1184                        session.connection_id,
1185                        connection_id,
1186                        proto::UpdateProject {
1187                            project_id: project.id.to_proto(),
1188                            worktrees: project.worktrees.clone(),
1189                        },
1190                    )
1191                },
1192            );
1193        }
1194
1195        for project in &rejoined_room.rejoined_projects {
1196            for collaborator in &project.collaborators {
1197                session
1198                    .peer
1199                    .send(
1200                        collaborator.connection_id,
1201                        proto::UpdateProjectCollaborator {
1202                            project_id: project.id.to_proto(),
1203                            old_peer_id: Some(project.old_connection_id.into()),
1204                            new_peer_id: Some(session.connection_id.into()),
1205                        },
1206                    )
1207                    .trace_err();
1208            }
1209        }
1210
1211        for project in &mut rejoined_room.rejoined_projects {
1212            for worktree in mem::take(&mut project.worktrees) {
1213                #[cfg(any(test, feature = "test-support"))]
1214                const MAX_CHUNK_SIZE: usize = 2;
1215                #[cfg(not(any(test, feature = "test-support")))]
1216                const MAX_CHUNK_SIZE: usize = 256;
1217
1218                // Stream this worktree's entries.
1219                let message = proto::UpdateWorktree {
1220                    project_id: project.id.to_proto(),
1221                    worktree_id: worktree.id,
1222                    abs_path: worktree.abs_path.clone(),
1223                    root_name: worktree.root_name,
1224                    updated_entries: worktree.updated_entries,
1225                    removed_entries: worktree.removed_entries,
1226                    scan_id: worktree.scan_id,
1227                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1228                    updated_repositories: worktree.updated_repositories,
1229                    removed_repositories: worktree.removed_repositories,
1230                };
1231                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1232                    session.peer.send(session.connection_id, update.clone())?;
1233                }
1234
1235                // Stream this worktree's diagnostics.
1236                for summary in worktree.diagnostic_summaries {
1237                    session.peer.send(
1238                        session.connection_id,
1239                        proto::UpdateDiagnosticSummary {
1240                            project_id: project.id.to_proto(),
1241                            worktree_id: worktree.id,
1242                            summary: Some(summary),
1243                        },
1244                    )?;
1245                }
1246
1247                for settings_file in worktree.settings_files {
1248                    session.peer.send(
1249                        session.connection_id,
1250                        proto::UpdateWorktreeSettings {
1251                            project_id: project.id.to_proto(),
1252                            worktree_id: worktree.id,
1253                            path: settings_file.path,
1254                            content: Some(settings_file.content),
1255                        },
1256                    )?;
1257                }
1258            }
1259
1260            for language_server in &project.language_servers {
1261                session.peer.send(
1262                    session.connection_id,
1263                    proto::UpdateLanguageServer {
1264                        project_id: project.id.to_proto(),
1265                        language_server_id: language_server.id,
1266                        variant: Some(
1267                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1268                                proto::LspDiskBasedDiagnosticsUpdated {},
1269                            ),
1270                        ),
1271                    },
1272                )?;
1273            }
1274        }
1275
1276        let rejoined_room = rejoined_room.into_inner();
1277
1278        room = rejoined_room.room;
1279        channel_id = rejoined_room.channel_id;
1280        channel_members = rejoined_room.channel_members;
1281    }
1282
1283    if let Some(channel_id) = channel_id {
1284        channel_updated(
1285            channel_id,
1286            &room,
1287            &channel_members,
1288            &session.peer,
1289            &*session.connection_pool().await,
1290        );
1291    }
1292
1293    update_user_contacts(session.user_id, &session).await?;
1294    Ok(())
1295}
1296
1297/// leave room disconnects from the room.
1298async fn leave_room(
1299    _: proto::LeaveRoom,
1300    response: Response<proto::LeaveRoom>,
1301    session: Session,
1302) -> Result<()> {
1303    leave_room_for_session(&session).await?;
1304    response.send(proto::Ack {})?;
1305    Ok(())
1306}
1307
1308/// Updates the permissions of someone else in the room.
1309async fn set_room_participant_role(
1310    request: proto::SetRoomParticipantRole,
1311    response: Response<proto::SetRoomParticipantRole>,
1312    session: Session,
1313) -> Result<()> {
1314    let (live_kit_room, can_publish) = {
1315        let room = session
1316            .db()
1317            .await
1318            .set_room_participant_role(
1319                session.user_id,
1320                RoomId::from_proto(request.room_id),
1321                UserId::from_proto(request.user_id),
1322                ChannelRole::from(request.role()),
1323            )
1324            .await?;
1325
1326        let live_kit_room = room.live_kit_room.clone();
1327        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1328        room_updated(&room, &session.peer);
1329        (live_kit_room, can_publish)
1330    };
1331
1332    if let Some(live_kit) = session.live_kit_client.as_ref() {
1333        live_kit
1334            .update_participant(
1335                live_kit_room.clone(),
1336                request.user_id.to_string(),
1337                live_kit_server::proto::ParticipantPermission {
1338                    can_subscribe: true,
1339                    can_publish,
1340                    can_publish_data: can_publish,
1341                    hidden: false,
1342                    recorder: false,
1343                },
1344            )
1345            .await
1346            .trace_err();
1347    }
1348
1349    response.send(proto::Ack {})?;
1350    Ok(())
1351}
1352
1353/// Call someone else into the current room
1354async fn call(
1355    request: proto::Call,
1356    response: Response<proto::Call>,
1357    session: Session,
1358) -> Result<()> {
1359    let room_id = RoomId::from_proto(request.room_id);
1360    let calling_user_id = session.user_id;
1361    let calling_connection_id = session.connection_id;
1362    let called_user_id = UserId::from_proto(request.called_user_id);
1363    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1364    if !session
1365        .db()
1366        .await
1367        .has_contact(calling_user_id, called_user_id)
1368        .await?
1369    {
1370        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1371    }
1372
1373    let incoming_call = {
1374        let (room, incoming_call) = &mut *session
1375            .db()
1376            .await
1377            .call(
1378                room_id,
1379                calling_user_id,
1380                calling_connection_id,
1381                called_user_id,
1382                initial_project_id,
1383            )
1384            .await?;
1385        room_updated(&room, &session.peer);
1386        mem::take(incoming_call)
1387    };
1388    update_user_contacts(called_user_id, &session).await?;
1389
1390    let mut calls = session
1391        .connection_pool()
1392        .await
1393        .user_connection_ids(called_user_id)
1394        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1395        .collect::<FuturesUnordered<_>>();
1396
1397    while let Some(call_response) = calls.next().await {
1398        match call_response.as_ref() {
1399            Ok(_) => {
1400                response.send(proto::Ack {})?;
1401                return Ok(());
1402            }
1403            Err(_) => {
1404                call_response.trace_err();
1405            }
1406        }
1407    }
1408
1409    {
1410        let room = session
1411            .db()
1412            .await
1413            .call_failed(room_id, called_user_id)
1414            .await?;
1415        room_updated(&room, &session.peer);
1416    }
1417    update_user_contacts(called_user_id, &session).await?;
1418
1419    Err(anyhow!("failed to ring user"))?
1420}
1421
1422/// Cancel an outgoing call.
1423async fn cancel_call(
1424    request: proto::CancelCall,
1425    response: Response<proto::CancelCall>,
1426    session: Session,
1427) -> Result<()> {
1428    let called_user_id = UserId::from_proto(request.called_user_id);
1429    let room_id = RoomId::from_proto(request.room_id);
1430    {
1431        let room = session
1432            .db()
1433            .await
1434            .cancel_call(room_id, session.connection_id, called_user_id)
1435            .await?;
1436        room_updated(&room, &session.peer);
1437    }
1438
1439    for connection_id in session
1440        .connection_pool()
1441        .await
1442        .user_connection_ids(called_user_id)
1443    {
1444        session
1445            .peer
1446            .send(
1447                connection_id,
1448                proto::CallCanceled {
1449                    room_id: room_id.to_proto(),
1450                },
1451            )
1452            .trace_err();
1453    }
1454    response.send(proto::Ack {})?;
1455
1456    update_user_contacts(called_user_id, &session).await?;
1457    Ok(())
1458}
1459
1460/// Decline an incoming call.
1461async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1462    let room_id = RoomId::from_proto(message.room_id);
1463    {
1464        let room = session
1465            .db()
1466            .await
1467            .decline_call(Some(room_id), session.user_id)
1468            .await?
1469            .ok_or_else(|| anyhow!("failed to decline call"))?;
1470        room_updated(&room, &session.peer);
1471    }
1472
1473    for connection_id in session
1474        .connection_pool()
1475        .await
1476        .user_connection_ids(session.user_id)
1477    {
1478        session
1479            .peer
1480            .send(
1481                connection_id,
1482                proto::CallCanceled {
1483                    room_id: room_id.to_proto(),
1484                },
1485            )
1486            .trace_err();
1487    }
1488    update_user_contacts(session.user_id, &session).await?;
1489    Ok(())
1490}
1491
1492/// Updates other participants in the room with your current location.
1493async fn update_participant_location(
1494    request: proto::UpdateParticipantLocation,
1495    response: Response<proto::UpdateParticipantLocation>,
1496    session: Session,
1497) -> Result<()> {
1498    let room_id = RoomId::from_proto(request.room_id);
1499    let location = request
1500        .location
1501        .ok_or_else(|| anyhow!("invalid location"))?;
1502
1503    let db = session.db().await;
1504    let room = db
1505        .update_room_participant_location(room_id, session.connection_id, location)
1506        .await?;
1507
1508    room_updated(&room, &session.peer);
1509    response.send(proto::Ack {})?;
1510    Ok(())
1511}
1512
1513/// Share a project into the room.
1514async fn share_project(
1515    request: proto::ShareProject,
1516    response: Response<proto::ShareProject>,
1517    session: Session,
1518) -> Result<()> {
1519    let (project_id, room) = &*session
1520        .db()
1521        .await
1522        .share_project(
1523            RoomId::from_proto(request.room_id),
1524            session.connection_id,
1525            &request.worktrees,
1526        )
1527        .await?;
1528    response.send(proto::ShareProjectResponse {
1529        project_id: project_id.to_proto(),
1530    })?;
1531    room_updated(&room, &session.peer);
1532
1533    Ok(())
1534}
1535
1536/// Unshare a project from the room.
1537async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1538    let project_id = ProjectId::from_proto(message.project_id);
1539
1540    let (room, guest_connection_ids) = &*session
1541        .db()
1542        .await
1543        .unshare_project(project_id, session.connection_id)
1544        .await?;
1545
1546    broadcast(
1547        Some(session.connection_id),
1548        guest_connection_ids.iter().copied(),
1549        |conn_id| session.peer.send(conn_id, message.clone()),
1550    );
1551    room_updated(&room, &session.peer);
1552
1553    Ok(())
1554}
1555
1556/// Join someone elses shared project.
1557async fn join_project(
1558    request: proto::JoinProject,
1559    response: Response<proto::JoinProject>,
1560    session: Session,
1561) -> Result<()> {
1562    let project_id = ProjectId::from_proto(request.project_id);
1563    let guest_user_id = session.user_id;
1564
1565    tracing::info!(%project_id, "join project");
1566
1567    let (project, replica_id) = &mut *session
1568        .db()
1569        .await
1570        .join_project(project_id, session.connection_id)
1571        .await?;
1572
1573    let collaborators = project
1574        .collaborators
1575        .iter()
1576        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1577        .map(|collaborator| collaborator.to_proto())
1578        .collect::<Vec<_>>();
1579
1580    let worktrees = project
1581        .worktrees
1582        .iter()
1583        .map(|(id, worktree)| proto::WorktreeMetadata {
1584            id: *id,
1585            root_name: worktree.root_name.clone(),
1586            visible: worktree.visible,
1587            abs_path: worktree.abs_path.clone(),
1588        })
1589        .collect::<Vec<_>>();
1590
1591    for collaborator in &collaborators {
1592        session
1593            .peer
1594            .send(
1595                collaborator.peer_id.unwrap().into(),
1596                proto::AddProjectCollaborator {
1597                    project_id: project_id.to_proto(),
1598                    collaborator: Some(proto::Collaborator {
1599                        peer_id: Some(session.connection_id.into()),
1600                        replica_id: replica_id.0 as u32,
1601                        user_id: guest_user_id.to_proto(),
1602                    }),
1603                },
1604            )
1605            .trace_err();
1606    }
1607
1608    // First, we send the metadata associated with each worktree.
1609    response.send(proto::JoinProjectResponse {
1610        worktrees: worktrees.clone(),
1611        replica_id: replica_id.0 as u32,
1612        collaborators: collaborators.clone(),
1613        language_servers: project.language_servers.clone(),
1614    })?;
1615
1616    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1617        #[cfg(any(test, feature = "test-support"))]
1618        const MAX_CHUNK_SIZE: usize = 2;
1619        #[cfg(not(any(test, feature = "test-support")))]
1620        const MAX_CHUNK_SIZE: usize = 256;
1621
1622        // Stream this worktree's entries.
1623        let message = proto::UpdateWorktree {
1624            project_id: project_id.to_proto(),
1625            worktree_id,
1626            abs_path: worktree.abs_path.clone(),
1627            root_name: worktree.root_name,
1628            updated_entries: worktree.entries,
1629            removed_entries: Default::default(),
1630            scan_id: worktree.scan_id,
1631            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1632            updated_repositories: worktree.repository_entries.into_values().collect(),
1633            removed_repositories: Default::default(),
1634        };
1635        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1636            session.peer.send(session.connection_id, update.clone())?;
1637        }
1638
1639        // Stream this worktree's diagnostics.
1640        for summary in worktree.diagnostic_summaries {
1641            session.peer.send(
1642                session.connection_id,
1643                proto::UpdateDiagnosticSummary {
1644                    project_id: project_id.to_proto(),
1645                    worktree_id: worktree.id,
1646                    summary: Some(summary),
1647                },
1648            )?;
1649        }
1650
1651        for settings_file in worktree.settings_files {
1652            session.peer.send(
1653                session.connection_id,
1654                proto::UpdateWorktreeSettings {
1655                    project_id: project_id.to_proto(),
1656                    worktree_id: worktree.id,
1657                    path: settings_file.path,
1658                    content: Some(settings_file.content),
1659                },
1660            )?;
1661        }
1662    }
1663
1664    for language_server in &project.language_servers {
1665        session.peer.send(
1666            session.connection_id,
1667            proto::UpdateLanguageServer {
1668                project_id: project_id.to_proto(),
1669                language_server_id: language_server.id,
1670                variant: Some(
1671                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1672                        proto::LspDiskBasedDiagnosticsUpdated {},
1673                    ),
1674                ),
1675            },
1676        )?;
1677    }
1678
1679    Ok(())
1680}
1681
1682/// Leave someone elses shared project.
1683async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1684    let sender_id = session.connection_id;
1685    let project_id = ProjectId::from_proto(request.project_id);
1686
1687    let (room, project) = &*session
1688        .db()
1689        .await
1690        .leave_project(project_id, sender_id)
1691        .await?;
1692    tracing::info!(
1693        %project_id,
1694        host_user_id = %project.host_user_id,
1695        host_connection_id = ?project.host_connection_id,
1696        "leave project"
1697    );
1698
1699    project_left(&project, &session);
1700    room_updated(&room, &session.peer);
1701
1702    Ok(())
1703}
1704
1705/// Updates other participants with changes to the project
1706async fn update_project(
1707    request: proto::UpdateProject,
1708    response: Response<proto::UpdateProject>,
1709    session: Session,
1710) -> Result<()> {
1711    let project_id = ProjectId::from_proto(request.project_id);
1712    let (room, guest_connection_ids) = &*session
1713        .db()
1714        .await
1715        .update_project(project_id, session.connection_id, &request.worktrees)
1716        .await?;
1717    broadcast(
1718        Some(session.connection_id),
1719        guest_connection_ids.iter().copied(),
1720        |connection_id| {
1721            session
1722                .peer
1723                .forward_send(session.connection_id, connection_id, request.clone())
1724        },
1725    );
1726    room_updated(&room, &session.peer);
1727    response.send(proto::Ack {})?;
1728
1729    Ok(())
1730}
1731
1732/// Updates other participants with changes to the worktree
1733async fn update_worktree(
1734    request: proto::UpdateWorktree,
1735    response: Response<proto::UpdateWorktree>,
1736    session: Session,
1737) -> Result<()> {
1738    let guest_connection_ids = session
1739        .db()
1740        .await
1741        .update_worktree(&request, session.connection_id)
1742        .await?;
1743
1744    broadcast(
1745        Some(session.connection_id),
1746        guest_connection_ids.iter().copied(),
1747        |connection_id| {
1748            session
1749                .peer
1750                .forward_send(session.connection_id, connection_id, request.clone())
1751        },
1752    );
1753    response.send(proto::Ack {})?;
1754    Ok(())
1755}
1756
1757/// Updates other participants with changes to the diagnostics
1758async fn update_diagnostic_summary(
1759    message: proto::UpdateDiagnosticSummary,
1760    session: Session,
1761) -> Result<()> {
1762    let guest_connection_ids = session
1763        .db()
1764        .await
1765        .update_diagnostic_summary(&message, session.connection_id)
1766        .await?;
1767
1768    broadcast(
1769        Some(session.connection_id),
1770        guest_connection_ids.iter().copied(),
1771        |connection_id| {
1772            session
1773                .peer
1774                .forward_send(session.connection_id, connection_id, message.clone())
1775        },
1776    );
1777
1778    Ok(())
1779}
1780
1781/// Updates other participants with changes to the worktree settings
1782async fn update_worktree_settings(
1783    message: proto::UpdateWorktreeSettings,
1784    session: Session,
1785) -> Result<()> {
1786    let guest_connection_ids = session
1787        .db()
1788        .await
1789        .update_worktree_settings(&message, session.connection_id)
1790        .await?;
1791
1792    broadcast(
1793        Some(session.connection_id),
1794        guest_connection_ids.iter().copied(),
1795        |connection_id| {
1796            session
1797                .peer
1798                .forward_send(session.connection_id, connection_id, message.clone())
1799        },
1800    );
1801
1802    Ok(())
1803}
1804
1805/// Notify other participants that a  language server has started.
1806async fn start_language_server(
1807    request: proto::StartLanguageServer,
1808    session: Session,
1809) -> Result<()> {
1810    let guest_connection_ids = session
1811        .db()
1812        .await
1813        .start_language_server(&request, session.connection_id)
1814        .await?;
1815
1816    broadcast(
1817        Some(session.connection_id),
1818        guest_connection_ids.iter().copied(),
1819        |connection_id| {
1820            session
1821                .peer
1822                .forward_send(session.connection_id, connection_id, request.clone())
1823        },
1824    );
1825    Ok(())
1826}
1827
1828/// Notify other participants that a language server has changed.
1829async fn update_language_server(
1830    request: proto::UpdateLanguageServer,
1831    session: Session,
1832) -> Result<()> {
1833    let project_id = ProjectId::from_proto(request.project_id);
1834    let project_connection_ids = session
1835        .db()
1836        .await
1837        .project_connection_ids(project_id, session.connection_id)
1838        .await?;
1839    broadcast(
1840        Some(session.connection_id),
1841        project_connection_ids.iter().copied(),
1842        |connection_id| {
1843            session
1844                .peer
1845                .forward_send(session.connection_id, connection_id, request.clone())
1846        },
1847    );
1848    Ok(())
1849}
1850
1851/// forward a project request to the host. These requests should be read only
1852/// as guests are allowed to send them.
1853async fn forward_read_only_project_request<T>(
1854    request: T,
1855    response: Response<T>,
1856    session: Session,
1857) -> Result<()>
1858where
1859    T: EntityMessage + RequestMessage,
1860{
1861    let project_id = ProjectId::from_proto(request.remote_entity_id());
1862    let host_connection_id = session
1863        .db()
1864        .await
1865        .host_for_read_only_project_request(project_id, session.connection_id)
1866        .await?;
1867    let payload = session
1868        .peer
1869        .forward_request(session.connection_id, host_connection_id, request)
1870        .await?;
1871    response.send(payload)?;
1872    Ok(())
1873}
1874
1875/// forward a project request to the host. These requests are disallowed
1876/// for guests.
1877async fn forward_mutating_project_request<T>(
1878    request: T,
1879    response: Response<T>,
1880    session: Session,
1881) -> Result<()>
1882where
1883    T: EntityMessage + RequestMessage,
1884{
1885    let project_id = ProjectId::from_proto(request.remote_entity_id());
1886    let host_connection_id = session
1887        .db()
1888        .await
1889        .host_for_mutating_project_request(project_id, session.connection_id)
1890        .await?;
1891    let payload = session
1892        .peer
1893        .forward_request(session.connection_id, host_connection_id, request)
1894        .await?;
1895    response.send(payload)?;
1896    Ok(())
1897}
1898
1899/// Notify other participants that a new buffer has been created
1900async fn create_buffer_for_peer(
1901    request: proto::CreateBufferForPeer,
1902    session: Session,
1903) -> Result<()> {
1904    session
1905        .db()
1906        .await
1907        .check_user_is_project_host(
1908            ProjectId::from_proto(request.project_id),
1909            session.connection_id,
1910        )
1911        .await?;
1912    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1913    session
1914        .peer
1915        .forward_send(session.connection_id, peer_id.into(), request)?;
1916    Ok(())
1917}
1918
1919/// Notify other participants that a buffer has been updated. This is
1920/// allowed for guests as long as the update is limited to selections.
1921async fn update_buffer(
1922    request: proto::UpdateBuffer,
1923    response: Response<proto::UpdateBuffer>,
1924    session: Session,
1925) -> Result<()> {
1926    let project_id = ProjectId::from_proto(request.project_id);
1927    let mut guest_connection_ids;
1928    let mut host_connection_id = None;
1929
1930    let mut requires_write_permission = false;
1931
1932    for op in request.operations.iter() {
1933        match op.variant {
1934            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1935            Some(_) => requires_write_permission = true,
1936        }
1937    }
1938
1939    {
1940        let collaborators = session
1941            .db()
1942            .await
1943            .project_collaborators_for_buffer_update(
1944                project_id,
1945                session.connection_id,
1946                requires_write_permission,
1947            )
1948            .await?;
1949        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1950        for collaborator in collaborators.iter() {
1951            if collaborator.is_host {
1952                host_connection_id = Some(collaborator.connection_id);
1953            } else {
1954                guest_connection_ids.push(collaborator.connection_id);
1955            }
1956        }
1957    }
1958    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1959
1960    broadcast(
1961        Some(session.connection_id),
1962        guest_connection_ids,
1963        |connection_id| {
1964            session
1965                .peer
1966                .forward_send(session.connection_id, connection_id, request.clone())
1967        },
1968    );
1969    if host_connection_id != session.connection_id {
1970        session
1971            .peer
1972            .forward_request(session.connection_id, host_connection_id, request.clone())
1973            .await?;
1974    }
1975
1976    response.send(proto::Ack {})?;
1977    Ok(())
1978}
1979
1980/// Notify other participants that a project has been updated.
1981async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1982    request: T,
1983    session: Session,
1984) -> Result<()> {
1985    let project_id = ProjectId::from_proto(request.remote_entity_id());
1986    let project_connection_ids = session
1987        .db()
1988        .await
1989        .project_connection_ids(project_id, session.connection_id)
1990        .await?;
1991
1992    broadcast(
1993        Some(session.connection_id),
1994        project_connection_ids.iter().copied(),
1995        |connection_id| {
1996            session
1997                .peer
1998                .forward_send(session.connection_id, connection_id, request.clone())
1999        },
2000    );
2001    Ok(())
2002}
2003
2004/// Start following another user in a call.
2005async fn follow(
2006    request: proto::Follow,
2007    response: Response<proto::Follow>,
2008    session: Session,
2009) -> Result<()> {
2010    let room_id = RoomId::from_proto(request.room_id);
2011    let project_id = request.project_id.map(ProjectId::from_proto);
2012    let leader_id = request
2013        .leader_id
2014        .ok_or_else(|| anyhow!("invalid leader id"))?
2015        .into();
2016    let follower_id = session.connection_id;
2017
2018    session
2019        .db()
2020        .await
2021        .check_room_participants(room_id, leader_id, session.connection_id)
2022        .await?;
2023
2024    let response_payload = session
2025        .peer
2026        .forward_request(session.connection_id, leader_id, request)
2027        .await?;
2028    response.send(response_payload)?;
2029
2030    if let Some(project_id) = project_id {
2031        let room = session
2032            .db()
2033            .await
2034            .follow(room_id, project_id, leader_id, follower_id)
2035            .await?;
2036        room_updated(&room, &session.peer);
2037    }
2038
2039    Ok(())
2040}
2041
2042/// Stop following another user in a call.
2043async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2044    let room_id = RoomId::from_proto(request.room_id);
2045    let project_id = request.project_id.map(ProjectId::from_proto);
2046    let leader_id = request
2047        .leader_id
2048        .ok_or_else(|| anyhow!("invalid leader id"))?
2049        .into();
2050    let follower_id = session.connection_id;
2051
2052    session
2053        .db()
2054        .await
2055        .check_room_participants(room_id, leader_id, session.connection_id)
2056        .await?;
2057
2058    session
2059        .peer
2060        .forward_send(session.connection_id, leader_id, request)?;
2061
2062    if let Some(project_id) = project_id {
2063        let room = session
2064            .db()
2065            .await
2066            .unfollow(room_id, project_id, leader_id, follower_id)
2067            .await?;
2068        room_updated(&room, &session.peer);
2069    }
2070
2071    Ok(())
2072}
2073
2074/// Notify everyone following you of your current location.
2075async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2076    let room_id = RoomId::from_proto(request.room_id);
2077    let database = session.db.lock().await;
2078
2079    let connection_ids = if let Some(project_id) = request.project_id {
2080        let project_id = ProjectId::from_proto(project_id);
2081        database
2082            .project_connection_ids(project_id, session.connection_id)
2083            .await?
2084    } else {
2085        database
2086            .room_connection_ids(room_id, session.connection_id)
2087            .await?
2088    };
2089
2090    // For now, don't send view update messages back to that view's current leader.
2091    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2092        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2093        _ => None,
2094    });
2095
2096    for follower_peer_id in request.follower_ids.iter().copied() {
2097        let follower_connection_id = follower_peer_id.into();
2098        if Some(follower_peer_id) != connection_id_to_omit
2099            && connection_ids.contains(&follower_connection_id)
2100        {
2101            session.peer.forward_send(
2102                session.connection_id,
2103                follower_connection_id,
2104                request.clone(),
2105            )?;
2106        }
2107    }
2108    Ok(())
2109}
2110
2111/// Get public data about users.
2112async fn get_users(
2113    request: proto::GetUsers,
2114    response: Response<proto::GetUsers>,
2115    session: Session,
2116) -> Result<()> {
2117    let user_ids = request
2118        .user_ids
2119        .into_iter()
2120        .map(UserId::from_proto)
2121        .collect();
2122    let users = session
2123        .db()
2124        .await
2125        .get_users_by_ids(user_ids)
2126        .await?
2127        .into_iter()
2128        .map(|user| proto::User {
2129            id: user.id.to_proto(),
2130            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2131            github_login: user.github_login,
2132        })
2133        .collect();
2134    response.send(proto::UsersResponse { users })?;
2135    Ok(())
2136}
2137
2138/// Search for users (to invite) buy Github login
2139async fn fuzzy_search_users(
2140    request: proto::FuzzySearchUsers,
2141    response: Response<proto::FuzzySearchUsers>,
2142    session: Session,
2143) -> Result<()> {
2144    let query = request.query;
2145    let users = match query.len() {
2146        0 => vec![],
2147        1 | 2 => session
2148            .db()
2149            .await
2150            .get_user_by_github_login(&query)
2151            .await?
2152            .into_iter()
2153            .collect(),
2154        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2155    };
2156    let users = users
2157        .into_iter()
2158        .filter(|user| user.id != session.user_id)
2159        .map(|user| proto::User {
2160            id: user.id.to_proto(),
2161            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2162            github_login: user.github_login,
2163        })
2164        .collect();
2165    response.send(proto::UsersResponse { users })?;
2166    Ok(())
2167}
2168
2169/// Send a contact request to another user.
2170async fn request_contact(
2171    request: proto::RequestContact,
2172    response: Response<proto::RequestContact>,
2173    session: Session,
2174) -> Result<()> {
2175    let requester_id = session.user_id;
2176    let responder_id = UserId::from_proto(request.responder_id);
2177    if requester_id == responder_id {
2178        return Err(anyhow!("cannot add yourself as a contact"))?;
2179    }
2180
2181    let notifications = session
2182        .db()
2183        .await
2184        .send_contact_request(requester_id, responder_id)
2185        .await?;
2186
2187    // Update outgoing contact requests of requester
2188    let mut update = proto::UpdateContacts::default();
2189    update.outgoing_requests.push(responder_id.to_proto());
2190    for connection_id in session
2191        .connection_pool()
2192        .await
2193        .user_connection_ids(requester_id)
2194    {
2195        session.peer.send(connection_id, update.clone())?;
2196    }
2197
2198    // Update incoming contact requests of responder
2199    let mut update = proto::UpdateContacts::default();
2200    update
2201        .incoming_requests
2202        .push(proto::IncomingContactRequest {
2203            requester_id: requester_id.to_proto(),
2204        });
2205    let connection_pool = session.connection_pool().await;
2206    for connection_id in connection_pool.user_connection_ids(responder_id) {
2207        session.peer.send(connection_id, update.clone())?;
2208    }
2209
2210    send_notifications(&*connection_pool, &session.peer, notifications);
2211
2212    response.send(proto::Ack {})?;
2213    Ok(())
2214}
2215
2216/// Accept or decline a contact request
2217async fn respond_to_contact_request(
2218    request: proto::RespondToContactRequest,
2219    response: Response<proto::RespondToContactRequest>,
2220    session: Session,
2221) -> Result<()> {
2222    let responder_id = session.user_id;
2223    let requester_id = UserId::from_proto(request.requester_id);
2224    let db = session.db().await;
2225    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2226        db.dismiss_contact_notification(responder_id, requester_id)
2227            .await?;
2228    } else {
2229        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2230
2231        let notifications = db
2232            .respond_to_contact_request(responder_id, requester_id, accept)
2233            .await?;
2234        let requester_busy = db.is_user_busy(requester_id).await?;
2235        let responder_busy = db.is_user_busy(responder_id).await?;
2236
2237        let pool = session.connection_pool().await;
2238        // Update responder with new contact
2239        let mut update = proto::UpdateContacts::default();
2240        if accept {
2241            update
2242                .contacts
2243                .push(contact_for_user(requester_id, requester_busy, &pool));
2244        }
2245        update
2246            .remove_incoming_requests
2247            .push(requester_id.to_proto());
2248        for connection_id in pool.user_connection_ids(responder_id) {
2249            session.peer.send(connection_id, update.clone())?;
2250        }
2251
2252        // Update requester with new contact
2253        let mut update = proto::UpdateContacts::default();
2254        if accept {
2255            update
2256                .contacts
2257                .push(contact_for_user(responder_id, responder_busy, &pool));
2258        }
2259        update
2260            .remove_outgoing_requests
2261            .push(responder_id.to_proto());
2262
2263        for connection_id in pool.user_connection_ids(requester_id) {
2264            session.peer.send(connection_id, update.clone())?;
2265        }
2266
2267        send_notifications(&*pool, &session.peer, notifications);
2268    }
2269
2270    response.send(proto::Ack {})?;
2271    Ok(())
2272}
2273
2274/// Remove a contact.
2275async fn remove_contact(
2276    request: proto::RemoveContact,
2277    response: Response<proto::RemoveContact>,
2278    session: Session,
2279) -> Result<()> {
2280    let requester_id = session.user_id;
2281    let responder_id = UserId::from_proto(request.user_id);
2282    let db = session.db().await;
2283    let (contact_accepted, deleted_notification_id) =
2284        db.remove_contact(requester_id, responder_id).await?;
2285
2286    let pool = session.connection_pool().await;
2287    // Update outgoing contact requests of requester
2288    let mut update = proto::UpdateContacts::default();
2289    if contact_accepted {
2290        update.remove_contacts.push(responder_id.to_proto());
2291    } else {
2292        update
2293            .remove_outgoing_requests
2294            .push(responder_id.to_proto());
2295    }
2296    for connection_id in pool.user_connection_ids(requester_id) {
2297        session.peer.send(connection_id, update.clone())?;
2298    }
2299
2300    // Update incoming contact requests of responder
2301    let mut update = proto::UpdateContacts::default();
2302    if contact_accepted {
2303        update.remove_contacts.push(requester_id.to_proto());
2304    } else {
2305        update
2306            .remove_incoming_requests
2307            .push(requester_id.to_proto());
2308    }
2309    for connection_id in pool.user_connection_ids(responder_id) {
2310        session.peer.send(connection_id, update.clone())?;
2311        if let Some(notification_id) = deleted_notification_id {
2312            session.peer.send(
2313                connection_id,
2314                proto::DeleteNotification {
2315                    notification_id: notification_id.to_proto(),
2316                },
2317            )?;
2318        }
2319    }
2320
2321    response.send(proto::Ack {})?;
2322    Ok(())
2323}
2324
2325/// Creates a new channel.
2326async fn create_channel(
2327    request: proto::CreateChannel,
2328    response: Response<proto::CreateChannel>,
2329    session: Session,
2330) -> Result<()> {
2331    let db = session.db().await;
2332
2333    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2334    let (channel, owner, channel_members) = db
2335        .create_channel(&request.name, parent_id, session.user_id)
2336        .await?;
2337
2338    response.send(proto::CreateChannelResponse {
2339        channel: Some(channel.to_proto()),
2340        parent_id: request.parent_id,
2341    })?;
2342
2343    let connection_pool = session.connection_pool().await;
2344    if let Some(owner) = owner {
2345        let update = proto::UpdateUserChannels {
2346            channel_memberships: vec![proto::ChannelMembership {
2347                channel_id: owner.channel_id.to_proto(),
2348                role: owner.role.into(),
2349            }],
2350            ..Default::default()
2351        };
2352        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2353            session.peer.send(connection_id, update.clone())?;
2354        }
2355    }
2356
2357    for channel_member in channel_members {
2358        if !channel_member.role.can_see_channel(channel.visibility) {
2359            continue;
2360        }
2361
2362        let update = proto::UpdateChannels {
2363            channels: vec![channel.to_proto()],
2364            ..Default::default()
2365        };
2366        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2367            session.peer.send(connection_id, update.clone())?;
2368        }
2369    }
2370
2371    Ok(())
2372}
2373
2374/// Delete a channel
2375async fn delete_channel(
2376    request: proto::DeleteChannel,
2377    response: Response<proto::DeleteChannel>,
2378    session: Session,
2379) -> Result<()> {
2380    let db = session.db().await;
2381
2382    let channel_id = request.channel_id;
2383    let (removed_channels, member_ids) = db
2384        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2385        .await?;
2386    response.send(proto::Ack {})?;
2387
2388    // Notify members of removed channels
2389    let mut update = proto::UpdateChannels::default();
2390    update
2391        .delete_channels
2392        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2393
2394    let connection_pool = session.connection_pool().await;
2395    for member_id in member_ids {
2396        for connection_id in connection_pool.user_connection_ids(member_id) {
2397            session.peer.send(connection_id, update.clone())?;
2398        }
2399    }
2400
2401    Ok(())
2402}
2403
2404/// Invite someone to join a channel.
2405async fn invite_channel_member(
2406    request: proto::InviteChannelMember,
2407    response: Response<proto::InviteChannelMember>,
2408    session: Session,
2409) -> Result<()> {
2410    let db = session.db().await;
2411    let channel_id = ChannelId::from_proto(request.channel_id);
2412    let invitee_id = UserId::from_proto(request.user_id);
2413    let InviteMemberResult {
2414        channel,
2415        notifications,
2416    } = db
2417        .invite_channel_member(
2418            channel_id,
2419            invitee_id,
2420            session.user_id,
2421            request.role().into(),
2422        )
2423        .await?;
2424
2425    let update = proto::UpdateChannels {
2426        channel_invitations: vec![channel.to_proto()],
2427        ..Default::default()
2428    };
2429
2430    let connection_pool = session.connection_pool().await;
2431    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2432        session.peer.send(connection_id, update.clone())?;
2433    }
2434
2435    send_notifications(&*connection_pool, &session.peer, notifications);
2436
2437    response.send(proto::Ack {})?;
2438    Ok(())
2439}
2440
2441/// remove someone from a channel
2442async fn remove_channel_member(
2443    request: proto::RemoveChannelMember,
2444    response: Response<proto::RemoveChannelMember>,
2445    session: Session,
2446) -> Result<()> {
2447    let db = session.db().await;
2448    let channel_id = ChannelId::from_proto(request.channel_id);
2449    let member_id = UserId::from_proto(request.user_id);
2450
2451    let RemoveChannelMemberResult {
2452        membership_update,
2453        notification_id,
2454    } = db
2455        .remove_channel_member(channel_id, member_id, session.user_id)
2456        .await?;
2457
2458    let connection_pool = &session.connection_pool().await;
2459    notify_membership_updated(
2460        &connection_pool,
2461        membership_update,
2462        member_id,
2463        &session.peer,
2464    );
2465    for connection_id in connection_pool.user_connection_ids(member_id) {
2466        if let Some(notification_id) = notification_id {
2467            session
2468                .peer
2469                .send(
2470                    connection_id,
2471                    proto::DeleteNotification {
2472                        notification_id: notification_id.to_proto(),
2473                    },
2474                )
2475                .trace_err();
2476        }
2477    }
2478
2479    response.send(proto::Ack {})?;
2480    Ok(())
2481}
2482
2483/// Toggle the channel between public and private.
2484/// Care is taken to maintain the invariant that public channels only descend from public channels,
2485/// (though members-only channels can appear at any point in the hierarchy).
2486async fn set_channel_visibility(
2487    request: proto::SetChannelVisibility,
2488    response: Response<proto::SetChannelVisibility>,
2489    session: Session,
2490) -> Result<()> {
2491    let db = session.db().await;
2492    let channel_id = ChannelId::from_proto(request.channel_id);
2493    let visibility = request.visibility().into();
2494
2495    let (channel, channel_members) = db
2496        .set_channel_visibility(channel_id, visibility, session.user_id)
2497        .await?;
2498
2499    let connection_pool = session.connection_pool().await;
2500    for member in channel_members {
2501        let update = if member.role.can_see_channel(channel.visibility) {
2502            proto::UpdateChannels {
2503                channels: vec![channel.to_proto()],
2504                ..Default::default()
2505            }
2506        } else {
2507            proto::UpdateChannels {
2508                delete_channels: vec![channel.id.to_proto()],
2509                ..Default::default()
2510            }
2511        };
2512
2513        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2514            session.peer.send(connection_id, update.clone())?;
2515        }
2516    }
2517
2518    response.send(proto::Ack {})?;
2519    Ok(())
2520}
2521
2522/// Alter the role for a user in the channel.
2523async fn set_channel_member_role(
2524    request: proto::SetChannelMemberRole,
2525    response: Response<proto::SetChannelMemberRole>,
2526    session: Session,
2527) -> Result<()> {
2528    let db = session.db().await;
2529    let channel_id = ChannelId::from_proto(request.channel_id);
2530    let member_id = UserId::from_proto(request.user_id);
2531    let result = db
2532        .set_channel_member_role(
2533            channel_id,
2534            session.user_id,
2535            member_id,
2536            request.role().into(),
2537        )
2538        .await?;
2539
2540    match result {
2541        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2542            let connection_pool = session.connection_pool().await;
2543            notify_membership_updated(
2544                &connection_pool,
2545                membership_update,
2546                member_id,
2547                &session.peer,
2548            )
2549        }
2550        db::SetMemberRoleResult::InviteUpdated(channel) => {
2551            let update = proto::UpdateChannels {
2552                channel_invitations: vec![channel.to_proto()],
2553                ..Default::default()
2554            };
2555
2556            for connection_id in session
2557                .connection_pool()
2558                .await
2559                .user_connection_ids(member_id)
2560            {
2561                session.peer.send(connection_id, update.clone())?;
2562            }
2563        }
2564    }
2565
2566    response.send(proto::Ack {})?;
2567    Ok(())
2568}
2569
2570/// Change the name of a channel
2571async fn rename_channel(
2572    request: proto::RenameChannel,
2573    response: Response<proto::RenameChannel>,
2574    session: Session,
2575) -> Result<()> {
2576    let db = session.db().await;
2577    let channel_id = ChannelId::from_proto(request.channel_id);
2578    let (channel, channel_members) = db
2579        .rename_channel(channel_id, session.user_id, &request.name)
2580        .await?;
2581
2582    response.send(proto::RenameChannelResponse {
2583        channel: Some(channel.to_proto()),
2584    })?;
2585
2586    let connection_pool = session.connection_pool().await;
2587    for channel_member in channel_members {
2588        if !channel_member.role.can_see_channel(channel.visibility) {
2589            continue;
2590        }
2591        let update = proto::UpdateChannels {
2592            channels: vec![channel.to_proto()],
2593            ..Default::default()
2594        };
2595        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2596            session.peer.send(connection_id, update.clone())?;
2597        }
2598    }
2599
2600    Ok(())
2601}
2602
2603/// Move a channel to a new parent.
2604async fn move_channel(
2605    request: proto::MoveChannel,
2606    response: Response<proto::MoveChannel>,
2607    session: Session,
2608) -> Result<()> {
2609    let channel_id = ChannelId::from_proto(request.channel_id);
2610    let to = ChannelId::from_proto(request.to);
2611
2612    let (channels, channel_members) = session
2613        .db()
2614        .await
2615        .move_channel(channel_id, to, session.user_id)
2616        .await?;
2617
2618    let connection_pool = session.connection_pool().await;
2619    for member in channel_members {
2620        let channels = channels
2621            .iter()
2622            .filter_map(|channel| {
2623                if member.role.can_see_channel(channel.visibility) {
2624                    Some(channel.to_proto())
2625                } else {
2626                    None
2627                }
2628            })
2629            .collect::<Vec<_>>();
2630        if channels.is_empty() {
2631            continue;
2632        }
2633
2634        let update = proto::UpdateChannels {
2635            channels,
2636            ..Default::default()
2637        };
2638
2639        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2640            session.peer.send(connection_id, update.clone())?;
2641        }
2642    }
2643
2644    response.send(Ack {})?;
2645    Ok(())
2646}
2647
2648/// Get the list of channel members
2649async fn get_channel_members(
2650    request: proto::GetChannelMembers,
2651    response: Response<proto::GetChannelMembers>,
2652    session: Session,
2653) -> Result<()> {
2654    let db = session.db().await;
2655    let channel_id = ChannelId::from_proto(request.channel_id);
2656    let members = db
2657        .get_channel_participant_details(channel_id, session.user_id)
2658        .await?;
2659    response.send(proto::GetChannelMembersResponse { members })?;
2660    Ok(())
2661}
2662
2663/// Accept or decline a channel invitation.
2664async fn respond_to_channel_invite(
2665    request: proto::RespondToChannelInvite,
2666    response: Response<proto::RespondToChannelInvite>,
2667    session: Session,
2668) -> Result<()> {
2669    let db = session.db().await;
2670    let channel_id = ChannelId::from_proto(request.channel_id);
2671    let RespondToChannelInvite {
2672        membership_update,
2673        notifications,
2674    } = db
2675        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2676        .await?;
2677
2678    let connection_pool = session.connection_pool().await;
2679    if let Some(membership_update) = membership_update {
2680        notify_membership_updated(
2681            &connection_pool,
2682            membership_update,
2683            session.user_id,
2684            &session.peer,
2685        );
2686    } else {
2687        let update = proto::UpdateChannels {
2688            remove_channel_invitations: vec![channel_id.to_proto()],
2689            ..Default::default()
2690        };
2691
2692        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2693            session.peer.send(connection_id, update.clone())?;
2694        }
2695    };
2696
2697    send_notifications(&*connection_pool, &session.peer, notifications);
2698
2699    response.send(proto::Ack {})?;
2700
2701    Ok(())
2702}
2703
2704/// Join the channels' room
2705async fn join_channel(
2706    request: proto::JoinChannel,
2707    response: Response<proto::JoinChannel>,
2708    session: Session,
2709) -> Result<()> {
2710    let channel_id = ChannelId::from_proto(request.channel_id);
2711    join_channel_internal(channel_id, Box::new(response), session).await
2712}
2713
2714trait JoinChannelInternalResponse {
2715    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2716}
2717impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2718    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2719        Response::<proto::JoinChannel>::send(self, result)
2720    }
2721}
2722impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2723    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2724        Response::<proto::JoinRoom>::send(self, result)
2725    }
2726}
2727
2728async fn join_channel_internal(
2729    channel_id: ChannelId,
2730    response: Box<impl JoinChannelInternalResponse>,
2731    session: Session,
2732) -> Result<()> {
2733    let joined_room = {
2734        leave_room_for_session(&session).await?;
2735        let db = session.db().await;
2736
2737        let (joined_room, membership_updated, role) = db
2738            .join_channel(channel_id, session.user_id, session.connection_id)
2739            .await?;
2740
2741        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2742            let (can_publish, token) = if role == ChannelRole::Guest {
2743                (
2744                    false,
2745                    live_kit
2746                        .guest_token(
2747                            &joined_room.room.live_kit_room,
2748                            &session.user_id.to_string(),
2749                        )
2750                        .trace_err()?,
2751                )
2752            } else {
2753                (
2754                    true,
2755                    live_kit
2756                        .room_token(
2757                            &joined_room.room.live_kit_room,
2758                            &session.user_id.to_string(),
2759                        )
2760                        .trace_err()?,
2761                )
2762            };
2763
2764            Some(LiveKitConnectionInfo {
2765                server_url: live_kit.url().into(),
2766                token,
2767                can_publish,
2768            })
2769        });
2770
2771        response.send(proto::JoinRoomResponse {
2772            room: Some(joined_room.room.clone()),
2773            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2774            live_kit_connection_info,
2775        })?;
2776
2777        let connection_pool = session.connection_pool().await;
2778        if let Some(membership_updated) = membership_updated {
2779            notify_membership_updated(
2780                &connection_pool,
2781                membership_updated,
2782                session.user_id,
2783                &session.peer,
2784            );
2785        }
2786
2787        room_updated(&joined_room.room, &session.peer);
2788
2789        joined_room
2790    };
2791
2792    channel_updated(
2793        channel_id,
2794        &joined_room.room,
2795        &joined_room.channel_members,
2796        &session.peer,
2797        &*session.connection_pool().await,
2798    );
2799
2800    update_user_contacts(session.user_id, &session).await?;
2801    Ok(())
2802}
2803
2804/// Start editing the channel notes
2805async fn join_channel_buffer(
2806    request: proto::JoinChannelBuffer,
2807    response: Response<proto::JoinChannelBuffer>,
2808    session: Session,
2809) -> Result<()> {
2810    let db = session.db().await;
2811    let channel_id = ChannelId::from_proto(request.channel_id);
2812
2813    let open_response = db
2814        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2815        .await?;
2816
2817    let collaborators = open_response.collaborators.clone();
2818    response.send(open_response)?;
2819
2820    let update = UpdateChannelBufferCollaborators {
2821        channel_id: channel_id.to_proto(),
2822        collaborators: collaborators.clone(),
2823    };
2824    channel_buffer_updated(
2825        session.connection_id,
2826        collaborators
2827            .iter()
2828            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2829        &update,
2830        &session.peer,
2831    );
2832
2833    Ok(())
2834}
2835
2836/// Edit the channel notes
2837async fn update_channel_buffer(
2838    request: proto::UpdateChannelBuffer,
2839    session: Session,
2840) -> Result<()> {
2841    let db = session.db().await;
2842    let channel_id = ChannelId::from_proto(request.channel_id);
2843
2844    let (collaborators, non_collaborators, epoch, version) = db
2845        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2846        .await?;
2847
2848    channel_buffer_updated(
2849        session.connection_id,
2850        collaborators,
2851        &proto::UpdateChannelBuffer {
2852            channel_id: channel_id.to_proto(),
2853            operations: request.operations,
2854        },
2855        &session.peer,
2856    );
2857
2858    let pool = &*session.connection_pool().await;
2859
2860    broadcast(
2861        None,
2862        non_collaborators
2863            .iter()
2864            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2865        |peer_id| {
2866            session.peer.send(
2867                peer_id.into(),
2868                proto::UpdateChannels {
2869                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2870                        channel_id: channel_id.to_proto(),
2871                        epoch: epoch as u64,
2872                        version: version.clone(),
2873                    }],
2874                    ..Default::default()
2875                },
2876            )
2877        },
2878    );
2879
2880    Ok(())
2881}
2882
2883/// Rejoin the channel notes after a connection blip
2884async fn rejoin_channel_buffers(
2885    request: proto::RejoinChannelBuffers,
2886    response: Response<proto::RejoinChannelBuffers>,
2887    session: Session,
2888) -> Result<()> {
2889    let db = session.db().await;
2890    let buffers = db
2891        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2892        .await?;
2893
2894    for rejoined_buffer in &buffers {
2895        let collaborators_to_notify = rejoined_buffer
2896            .buffer
2897            .collaborators
2898            .iter()
2899            .filter_map(|c| Some(c.peer_id?.into()));
2900        channel_buffer_updated(
2901            session.connection_id,
2902            collaborators_to_notify,
2903            &proto::UpdateChannelBufferCollaborators {
2904                channel_id: rejoined_buffer.buffer.channel_id,
2905                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2906            },
2907            &session.peer,
2908        );
2909    }
2910
2911    response.send(proto::RejoinChannelBuffersResponse {
2912        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2913    })?;
2914
2915    Ok(())
2916}
2917
2918/// Stop editing the channel notes
2919async fn leave_channel_buffer(
2920    request: proto::LeaveChannelBuffer,
2921    response: Response<proto::LeaveChannelBuffer>,
2922    session: Session,
2923) -> Result<()> {
2924    let db = session.db().await;
2925    let channel_id = ChannelId::from_proto(request.channel_id);
2926
2927    let left_buffer = db
2928        .leave_channel_buffer(channel_id, session.connection_id)
2929        .await?;
2930
2931    response.send(Ack {})?;
2932
2933    channel_buffer_updated(
2934        session.connection_id,
2935        left_buffer.connections,
2936        &proto::UpdateChannelBufferCollaborators {
2937            channel_id: channel_id.to_proto(),
2938            collaborators: left_buffer.collaborators,
2939        },
2940        &session.peer,
2941    );
2942
2943    Ok(())
2944}
2945
2946fn channel_buffer_updated<T: EnvelopedMessage>(
2947    sender_id: ConnectionId,
2948    collaborators: impl IntoIterator<Item = ConnectionId>,
2949    message: &T,
2950    peer: &Peer,
2951) {
2952    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2953        peer.send(peer_id.into(), message.clone())
2954    });
2955}
2956
2957fn send_notifications(
2958    connection_pool: &ConnectionPool,
2959    peer: &Peer,
2960    notifications: db::NotificationBatch,
2961) {
2962    for (user_id, notification) in notifications {
2963        for connection_id in connection_pool.user_connection_ids(user_id) {
2964            if let Err(error) = peer.send(
2965                connection_id,
2966                proto::AddNotification {
2967                    notification: Some(notification.clone()),
2968                },
2969            ) {
2970                tracing::error!(
2971                    "failed to send notification to {:?} {}",
2972                    connection_id,
2973                    error
2974                );
2975            }
2976        }
2977    }
2978}
2979
2980/// Send a message to the channel
2981async fn send_channel_message(
2982    request: proto::SendChannelMessage,
2983    response: Response<proto::SendChannelMessage>,
2984    session: Session,
2985) -> Result<()> {
2986    // Validate the message body.
2987    let body = request.body.trim().to_string();
2988    if body.len() > MAX_MESSAGE_LEN {
2989        return Err(anyhow!("message is too long"))?;
2990    }
2991    if body.is_empty() {
2992        return Err(anyhow!("message can't be blank"))?;
2993    }
2994
2995    // TODO: adjust mentions if body is trimmed
2996
2997    let timestamp = OffsetDateTime::now_utc();
2998    let nonce = request
2999        .nonce
3000        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3001
3002    let channel_id = ChannelId::from_proto(request.channel_id);
3003    let CreatedChannelMessage {
3004        message_id,
3005        participant_connection_ids,
3006        channel_members,
3007        notifications,
3008    } = session
3009        .db()
3010        .await
3011        .create_channel_message(
3012            channel_id,
3013            session.user_id,
3014            &body,
3015            &request.mentions,
3016            timestamp,
3017            nonce.clone().into(),
3018            match request.reply_to_message_id {
3019                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3020                None => None,
3021            },
3022        )
3023        .await?;
3024    let message = proto::ChannelMessage {
3025        sender_id: session.user_id.to_proto(),
3026        id: message_id.to_proto(),
3027        body,
3028        mentions: request.mentions,
3029        timestamp: timestamp.unix_timestamp() as u64,
3030        nonce: Some(nonce),
3031        reply_to_message_id: request.reply_to_message_id,
3032    };
3033    broadcast(
3034        Some(session.connection_id),
3035        participant_connection_ids,
3036        |connection| {
3037            session.peer.send(
3038                connection,
3039                proto::ChannelMessageSent {
3040                    channel_id: channel_id.to_proto(),
3041                    message: Some(message.clone()),
3042                },
3043            )
3044        },
3045    );
3046    response.send(proto::SendChannelMessageResponse {
3047        message: Some(message),
3048    })?;
3049
3050    let pool = &*session.connection_pool().await;
3051    broadcast(
3052        None,
3053        channel_members
3054            .iter()
3055            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3056        |peer_id| {
3057            session.peer.send(
3058                peer_id.into(),
3059                proto::UpdateChannels {
3060                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3061                        channel_id: channel_id.to_proto(),
3062                        message_id: message_id.to_proto(),
3063                    }],
3064                    ..Default::default()
3065                },
3066            )
3067        },
3068    );
3069    send_notifications(pool, &session.peer, notifications);
3070
3071    Ok(())
3072}
3073
3074/// Delete a channel message
3075async fn remove_channel_message(
3076    request: proto::RemoveChannelMessage,
3077    response: Response<proto::RemoveChannelMessage>,
3078    session: Session,
3079) -> Result<()> {
3080    let channel_id = ChannelId::from_proto(request.channel_id);
3081    let message_id = MessageId::from_proto(request.message_id);
3082    let connection_ids = session
3083        .db()
3084        .await
3085        .remove_channel_message(channel_id, message_id, session.user_id)
3086        .await?;
3087    broadcast(Some(session.connection_id), connection_ids, |connection| {
3088        session.peer.send(connection, request.clone())
3089    });
3090    response.send(proto::Ack {})?;
3091    Ok(())
3092}
3093
3094/// Mark a channel message as read
3095async fn acknowledge_channel_message(
3096    request: proto::AckChannelMessage,
3097    session: Session,
3098) -> Result<()> {
3099    let channel_id = ChannelId::from_proto(request.channel_id);
3100    let message_id = MessageId::from_proto(request.message_id);
3101    let notifications = session
3102        .db()
3103        .await
3104        .observe_channel_message(channel_id, session.user_id, message_id)
3105        .await?;
3106    send_notifications(
3107        &*session.connection_pool().await,
3108        &session.peer,
3109        notifications,
3110    );
3111    Ok(())
3112}
3113
3114/// Mark a buffer version as synced
3115async fn acknowledge_buffer_version(
3116    request: proto::AckBufferOperation,
3117    session: Session,
3118) -> Result<()> {
3119    let buffer_id = BufferId::from_proto(request.buffer_id);
3120    session
3121        .db()
3122        .await
3123        .observe_buffer_version(
3124            buffer_id,
3125            session.user_id,
3126            request.epoch as i32,
3127            &request.version,
3128        )
3129        .await?;
3130    Ok(())
3131}
3132
3133/// Start receiving chat updates for a channel
3134async fn join_channel_chat(
3135    request: proto::JoinChannelChat,
3136    response: Response<proto::JoinChannelChat>,
3137    session: Session,
3138) -> Result<()> {
3139    let channel_id = ChannelId::from_proto(request.channel_id);
3140
3141    let db = session.db().await;
3142    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3143        .await?;
3144    let messages = db
3145        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3146        .await?;
3147    response.send(proto::JoinChannelChatResponse {
3148        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3149        messages,
3150    })?;
3151    Ok(())
3152}
3153
3154/// Stop receiving chat updates for a channel
3155async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3156    let channel_id = ChannelId::from_proto(request.channel_id);
3157    session
3158        .db()
3159        .await
3160        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3161        .await?;
3162    Ok(())
3163}
3164
3165/// Retrieve the chat history for a channel
3166async fn get_channel_messages(
3167    request: proto::GetChannelMessages,
3168    response: Response<proto::GetChannelMessages>,
3169    session: Session,
3170) -> Result<()> {
3171    let channel_id = ChannelId::from_proto(request.channel_id);
3172    let messages = session
3173        .db()
3174        .await
3175        .get_channel_messages(
3176            channel_id,
3177            session.user_id,
3178            MESSAGE_COUNT_PER_PAGE,
3179            Some(MessageId::from_proto(request.before_message_id)),
3180        )
3181        .await?;
3182    response.send(proto::GetChannelMessagesResponse {
3183        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3184        messages,
3185    })?;
3186    Ok(())
3187}
3188
3189/// Retrieve specific chat messages
3190async fn get_channel_messages_by_id(
3191    request: proto::GetChannelMessagesById,
3192    response: Response<proto::GetChannelMessagesById>,
3193    session: Session,
3194) -> Result<()> {
3195    let message_ids = request
3196        .message_ids
3197        .iter()
3198        .map(|id| MessageId::from_proto(*id))
3199        .collect::<Vec<_>>();
3200    let messages = session
3201        .db()
3202        .await
3203        .get_channel_messages_by_id(session.user_id, &message_ids)
3204        .await?;
3205    response.send(proto::GetChannelMessagesResponse {
3206        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3207        messages,
3208    })?;
3209    Ok(())
3210}
3211
3212/// Retrieve the current users notifications
3213async fn get_notifications(
3214    request: proto::GetNotifications,
3215    response: Response<proto::GetNotifications>,
3216    session: Session,
3217) -> Result<()> {
3218    let notifications = session
3219        .db()
3220        .await
3221        .get_notifications(
3222            session.user_id,
3223            NOTIFICATION_COUNT_PER_PAGE,
3224            request
3225                .before_id
3226                .map(|id| db::NotificationId::from_proto(id)),
3227        )
3228        .await?;
3229    response.send(proto::GetNotificationsResponse {
3230        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3231        notifications,
3232    })?;
3233    Ok(())
3234}
3235
3236/// Mark notifications as read
3237async fn mark_notification_as_read(
3238    request: proto::MarkNotificationRead,
3239    response: Response<proto::MarkNotificationRead>,
3240    session: Session,
3241) -> Result<()> {
3242    let database = &session.db().await;
3243    let notifications = database
3244        .mark_notification_as_read_by_id(
3245            session.user_id,
3246            NotificationId::from_proto(request.notification_id),
3247        )
3248        .await?;
3249    send_notifications(
3250        &*session.connection_pool().await,
3251        &session.peer,
3252        notifications,
3253    );
3254    response.send(proto::Ack {})?;
3255    Ok(())
3256}
3257
3258/// Get the current users information
3259async fn get_private_user_info(
3260    _request: proto::GetPrivateUserInfo,
3261    response: Response<proto::GetPrivateUserInfo>,
3262    session: Session,
3263) -> Result<()> {
3264    let db = session.db().await;
3265
3266    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3267    let user = db
3268        .get_user_by_id(session.user_id)
3269        .await?
3270        .ok_or_else(|| anyhow!("user not found"))?;
3271    let flags = db.get_user_flags(session.user_id).await?;
3272
3273    response.send(proto::GetPrivateUserInfoResponse {
3274        metrics_id,
3275        staff: user.admin,
3276        flags,
3277    })?;
3278    Ok(())
3279}
3280
3281fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3282    match message {
3283        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3284        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3285        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3286        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3287        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3288            code: frame.code.into(),
3289            reason: frame.reason,
3290        })),
3291    }
3292}
3293
3294fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3295    match message {
3296        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3297        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3298        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3299        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3300        AxumMessage::Close(frame) => {
3301            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3302                code: frame.code.into(),
3303                reason: frame.reason,
3304            }))
3305        }
3306    }
3307}
3308
3309fn notify_membership_updated(
3310    connection_pool: &ConnectionPool,
3311    result: MembershipUpdated,
3312    user_id: UserId,
3313    peer: &Peer,
3314) {
3315    let user_channels_update = proto::UpdateUserChannels {
3316        channel_memberships: result
3317            .new_channels
3318            .channel_memberships
3319            .iter()
3320            .map(|cm| proto::ChannelMembership {
3321                channel_id: cm.channel_id.to_proto(),
3322                role: cm.role.into(),
3323            })
3324            .collect(),
3325        ..Default::default()
3326    };
3327    let mut update = build_channels_update(result.new_channels, vec![]);
3328    update.delete_channels = result
3329        .removed_channels
3330        .into_iter()
3331        .map(|id| id.to_proto())
3332        .collect();
3333    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3334
3335    for connection_id in connection_pool.user_connection_ids(user_id) {
3336        peer.send(connection_id, user_channels_update.clone())
3337            .trace_err();
3338        peer.send(connection_id, update.clone()).trace_err();
3339    }
3340}
3341
3342fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3343    proto::UpdateUserChannels {
3344        channel_memberships: channels
3345            .channel_memberships
3346            .iter()
3347            .map(|m| proto::ChannelMembership {
3348                channel_id: m.channel_id.to_proto(),
3349                role: m.role.into(),
3350            })
3351            .collect(),
3352        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3353        observed_channel_message_id: channels.observed_channel_messages.clone(),
3354        ..Default::default()
3355    }
3356}
3357
3358fn build_channels_update(
3359    channels: ChannelsForUser,
3360    channel_invites: Vec<db::Channel>,
3361) -> proto::UpdateChannels {
3362    let mut update = proto::UpdateChannels::default();
3363
3364    for channel in channels.channels {
3365        update.channels.push(channel.to_proto());
3366    }
3367
3368    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3369    update.latest_channel_message_ids = channels.latest_channel_messages;
3370
3371    for (channel_id, participants) in channels.channel_participants {
3372        update
3373            .channel_participants
3374            .push(proto::ChannelParticipants {
3375                channel_id: channel_id.to_proto(),
3376                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3377            });
3378    }
3379
3380    for channel in channel_invites {
3381        update.channel_invitations.push(channel.to_proto());
3382    }
3383
3384    update
3385}
3386
3387fn build_initial_contacts_update(
3388    contacts: Vec<db::Contact>,
3389    pool: &ConnectionPool,
3390) -> proto::UpdateContacts {
3391    let mut update = proto::UpdateContacts::default();
3392
3393    for contact in contacts {
3394        match contact {
3395            db::Contact::Accepted { user_id, busy } => {
3396                update.contacts.push(contact_for_user(user_id, busy, &pool));
3397            }
3398            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3399            db::Contact::Incoming { user_id } => {
3400                update
3401                    .incoming_requests
3402                    .push(proto::IncomingContactRequest {
3403                        requester_id: user_id.to_proto(),
3404                    })
3405            }
3406        }
3407    }
3408
3409    update
3410}
3411
3412fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3413    proto::Contact {
3414        user_id: user_id.to_proto(),
3415        online: pool.is_user_online(user_id),
3416        busy,
3417    }
3418}
3419
3420fn room_updated(room: &proto::Room, peer: &Peer) {
3421    broadcast(
3422        None,
3423        room.participants
3424            .iter()
3425            .filter_map(|participant| Some(participant.peer_id?.into())),
3426        |peer_id| {
3427            peer.send(
3428                peer_id.into(),
3429                proto::RoomUpdated {
3430                    room: Some(room.clone()),
3431                },
3432            )
3433        },
3434    );
3435}
3436
3437fn channel_updated(
3438    channel_id: ChannelId,
3439    room: &proto::Room,
3440    channel_members: &[UserId],
3441    peer: &Peer,
3442    pool: &ConnectionPool,
3443) {
3444    let participants = room
3445        .participants
3446        .iter()
3447        .map(|p| p.user_id)
3448        .collect::<Vec<_>>();
3449
3450    broadcast(
3451        None,
3452        channel_members
3453            .iter()
3454            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3455        |peer_id| {
3456            peer.send(
3457                peer_id.into(),
3458                proto::UpdateChannels {
3459                    channel_participants: vec![proto::ChannelParticipants {
3460                        channel_id: channel_id.to_proto(),
3461                        participant_user_ids: participants.clone(),
3462                    }],
3463                    ..Default::default()
3464                },
3465            )
3466        },
3467    );
3468}
3469
3470async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3471    let db = session.db().await;
3472
3473    let contacts = db.get_contacts(user_id).await?;
3474    let busy = db.is_user_busy(user_id).await?;
3475
3476    let pool = session.connection_pool().await;
3477    let updated_contact = contact_for_user(user_id, busy, &pool);
3478    for contact in contacts {
3479        if let db::Contact::Accepted {
3480            user_id: contact_user_id,
3481            ..
3482        } = contact
3483        {
3484            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3485                session
3486                    .peer
3487                    .send(
3488                        contact_conn_id,
3489                        proto::UpdateContacts {
3490                            contacts: vec![updated_contact.clone()],
3491                            remove_contacts: Default::default(),
3492                            incoming_requests: Default::default(),
3493                            remove_incoming_requests: Default::default(),
3494                            outgoing_requests: Default::default(),
3495                            remove_outgoing_requests: Default::default(),
3496                        },
3497                    )
3498                    .trace_err();
3499            }
3500        }
3501    }
3502    Ok(())
3503}
3504
3505async fn leave_room_for_session(session: &Session) -> Result<()> {
3506    let mut contacts_to_update = HashSet::default();
3507
3508    let room_id;
3509    let canceled_calls_to_user_ids;
3510    let live_kit_room;
3511    let delete_live_kit_room;
3512    let room;
3513    let channel_members;
3514    let channel_id;
3515
3516    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3517        contacts_to_update.insert(session.user_id);
3518
3519        for project in left_room.left_projects.values() {
3520            project_left(project, session);
3521        }
3522
3523        room_id = RoomId::from_proto(left_room.room.id);
3524        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3525        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3526        delete_live_kit_room = left_room.deleted;
3527        room = mem::take(&mut left_room.room);
3528        channel_members = mem::take(&mut left_room.channel_members);
3529        channel_id = left_room.channel_id;
3530
3531        room_updated(&room, &session.peer);
3532    } else {
3533        return Ok(());
3534    }
3535
3536    if let Some(channel_id) = channel_id {
3537        channel_updated(
3538            channel_id,
3539            &room,
3540            &channel_members,
3541            &session.peer,
3542            &*session.connection_pool().await,
3543        );
3544    }
3545
3546    {
3547        let pool = session.connection_pool().await;
3548        for canceled_user_id in canceled_calls_to_user_ids {
3549            for connection_id in pool.user_connection_ids(canceled_user_id) {
3550                session
3551                    .peer
3552                    .send(
3553                        connection_id,
3554                        proto::CallCanceled {
3555                            room_id: room_id.to_proto(),
3556                        },
3557                    )
3558                    .trace_err();
3559            }
3560            contacts_to_update.insert(canceled_user_id);
3561        }
3562    }
3563
3564    for contact_user_id in contacts_to_update {
3565        update_user_contacts(contact_user_id, &session).await?;
3566    }
3567
3568    if let Some(live_kit) = session.live_kit_client.as_ref() {
3569        live_kit
3570            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3571            .await
3572            .trace_err();
3573
3574        if delete_live_kit_room {
3575            live_kit.delete_room(live_kit_room).await.trace_err();
3576        }
3577    }
3578
3579    Ok(())
3580}
3581
3582async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3583    let left_channel_buffers = session
3584        .db()
3585        .await
3586        .leave_channel_buffers(session.connection_id)
3587        .await?;
3588
3589    for left_buffer in left_channel_buffers {
3590        channel_buffer_updated(
3591            session.connection_id,
3592            left_buffer.connections,
3593            &proto::UpdateChannelBufferCollaborators {
3594                channel_id: left_buffer.channel_id.to_proto(),
3595                collaborators: left_buffer.collaborators,
3596            },
3597            &session.peer,
3598        );
3599    }
3600
3601    Ok(())
3602}
3603
3604fn project_left(project: &db::LeftProject, session: &Session) {
3605    for connection_id in &project.connection_ids {
3606        if project.host_user_id == session.user_id {
3607            session
3608                .peer
3609                .send(
3610                    *connection_id,
3611                    proto::UnshareProject {
3612                        project_id: project.id.to_proto(),
3613                    },
3614                )
3615                .trace_err();
3616        } else {
3617            session
3618                .peer
3619                .send(
3620                    *connection_id,
3621                    proto::RemoveProjectCollaborator {
3622                        project_id: project.id.to_proto(),
3623                        peer_id: Some(session.connection_id.into()),
3624                    },
3625                )
3626                .trace_err();
3627        }
3628    }
3629}
3630
3631pub trait ResultExt {
3632    type Ok;
3633
3634    fn trace_err(self) -> Option<Self::Ok>;
3635}
3636
3637impl<T, E> ResultExt for Result<T, E>
3638where
3639    E: std::fmt::Debug,
3640{
3641    type Ok = T;
3642
3643    fn trace_err(self) -> Option<T> {
3644        match self {
3645            Ok(value) => Some(value),
3646            Err(error) => {
3647                tracing::error!("{:?}", error);
3648                None
3649            }
3650        }
3651    }
3652}