rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::{ConnectionPool, ZedVersion};
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  42        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc, OnceLock,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{field, info_span, instrument, Instrument};
  66use util::SemanticVersion;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  74
  75type MessageHandler =
  76    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  77
  78struct Response<R> {
  79    peer: Arc<Peer>,
  80    receipt: Receipt<R>,
  81    responded: Arc<AtomicBool>,
  82}
  83
  84impl<R: RequestMessage> Response<R> {
  85    fn send(self, payload: R::Response) -> Result<()> {
  86        self.responded.store(true, SeqCst);
  87        self.peer.respond(self.receipt, payload)?;
  88        Ok(())
  89    }
  90}
  91
  92#[derive(Clone)]
  93struct Session {
  94    user_id: UserId,
  95    connection_id: ConnectionId,
  96    db: Arc<tokio::sync::Mutex<DbHandle>>,
  97    peer: Arc<Peer>,
  98    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  99    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 100    _executor: Executor,
 101}
 102
 103impl Session {
 104    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        let guard = self.db.lock().await;
 108        #[cfg(test)]
 109        tokio::task::yield_now().await;
 110        guard
 111    }
 112
 113    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.connection_pool.lock();
 117        ConnectionPoolGuard {
 118            guard,
 119            _not_send: PhantomData,
 120        }
 121    }
 122}
 123
 124impl fmt::Debug for Session {
 125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 126        f.debug_struct("Session")
 127            .field("user_id", &self.user_id)
 128            .field("connection_id", &self.connection_id)
 129            .finish()
 130    }
 131}
 132
 133struct DbHandle(Arc<Database>);
 134
 135impl Deref for DbHandle {
 136    type Target = Database;
 137
 138    fn deref(&self) -> &Self::Target {
 139        self.0.as_ref()
 140    }
 141}
 142
 143pub struct Server {
 144    id: parking_lot::Mutex<ServerId>,
 145    peer: Arc<Peer>,
 146    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 147    app_state: Arc<AppState>,
 148    executor: Executor,
 149    handlers: HashMap<TypeId, MessageHandler>,
 150    teardown: watch::Sender<bool>,
 151}
 152
 153pub(crate) struct ConnectionPoolGuard<'a> {
 154    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 155    _not_send: PhantomData<Rc<()>>,
 156}
 157
 158#[derive(Serialize)]
 159pub struct ServerSnapshot<'a> {
 160    peer: &'a Peer,
 161    #[serde(serialize_with = "serialize_deref")]
 162    connection_pool: ConnectionPoolGuard<'a>,
 163}
 164
 165pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 166where
 167    S: Serializer,
 168    T: Deref<Target = U>,
 169    U: Serialize,
 170{
 171    Serialize::serialize(value.deref(), serializer)
 172}
 173
 174impl Server {
 175    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 176        let mut server = Self {
 177            id: parking_lot::Mutex::new(id),
 178            peer: Peer::new(id.0 as u32),
 179            app_state,
 180            executor,
 181            connection_pool: Default::default(),
 182            handlers: Default::default(),
 183            teardown: watch::channel(false).0,
 184        };
 185
 186        server
 187            .add_request_handler(ping)
 188            .add_request_handler(create_room)
 189            .add_request_handler(join_room)
 190            .add_request_handler(rejoin_room)
 191            .add_request_handler(leave_room)
 192            .add_request_handler(set_room_participant_role)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 208            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 209            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 210            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 211            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 212            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 213            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 214            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 215            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 216            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 217            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 218            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 219            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 220            .add_request_handler(
 221                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 222            )
 223            .add_request_handler(
 224                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 225            )
 226            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 227            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 228            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 229            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 230            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 231            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 232            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 233            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 234            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 235            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 236            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 237            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 238            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 239            .add_message_handler(create_buffer_for_peer)
 240            .add_request_handler(update_buffer)
 241            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 242            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 243            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 244            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 245            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 246            .add_request_handler(get_users)
 247            .add_request_handler(fuzzy_search_users)
 248            .add_request_handler(request_contact)
 249            .add_request_handler(remove_contact)
 250            .add_request_handler(respond_to_contact_request)
 251            .add_request_handler(create_channel)
 252            .add_request_handler(delete_channel)
 253            .add_request_handler(invite_channel_member)
 254            .add_request_handler(remove_channel_member)
 255            .add_request_handler(set_channel_member_role)
 256            .add_request_handler(set_channel_visibility)
 257            .add_request_handler(rename_channel)
 258            .add_request_handler(join_channel_buffer)
 259            .add_request_handler(leave_channel_buffer)
 260            .add_message_handler(update_channel_buffer)
 261            .add_request_handler(rejoin_channel_buffers)
 262            .add_request_handler(get_channel_members)
 263            .add_request_handler(respond_to_channel_invite)
 264            .add_request_handler(join_channel)
 265            .add_request_handler(join_channel_chat)
 266            .add_message_handler(leave_channel_chat)
 267            .add_request_handler(send_channel_message)
 268            .add_request_handler(remove_channel_message)
 269            .add_request_handler(get_channel_messages)
 270            .add_request_handler(get_channel_messages_by_id)
 271            .add_request_handler(get_notifications)
 272            .add_request_handler(mark_notification_as_read)
 273            .add_request_handler(move_channel)
 274            .add_request_handler(follow)
 275            .add_message_handler(unfollow)
 276            .add_message_handler(update_followers)
 277            .add_request_handler(get_private_user_info)
 278            .add_message_handler(acknowledge_channel_message)
 279            .add_message_handler(acknowledge_buffer_version);
 280
 281        Arc::new(server)
 282    }
 283
 284    pub async fn start(&self) -> Result<()> {
 285        let server_id = *self.id.lock();
 286        let app_state = self.app_state.clone();
 287        let peer = self.peer.clone();
 288        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 289        let pool = self.connection_pool.clone();
 290        let live_kit_client = self.app_state.live_kit_client.clone();
 291
 292        let span = info_span!("start server");
 293        self.executor.spawn_detached(
 294            async move {
 295                tracing::info!("waiting for cleanup timeout");
 296                timeout.await;
 297                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 298                if let Some((room_ids, channel_ids)) = app_state
 299                    .db
 300                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 301                    .await
 302                    .trace_err()
 303                {
 304                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 305                    tracing::info!(
 306                        stale_channel_buffer_count = channel_ids.len(),
 307                        "retrieved stale channel buffers"
 308                    );
 309
 310                    for channel_id in channel_ids {
 311                        if let Some(refreshed_channel_buffer) = app_state
 312                            .db
 313                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 314                            .await
 315                            .trace_err()
 316                        {
 317                            for connection_id in refreshed_channel_buffer.connection_ids {
 318                                peer.send(
 319                                    connection_id,
 320                                    proto::UpdateChannelBufferCollaborators {
 321                                        channel_id: channel_id.to_proto(),
 322                                        collaborators: refreshed_channel_buffer
 323                                            .collaborators
 324                                            .clone(),
 325                                    },
 326                                )
 327                                .trace_err();
 328                            }
 329                        }
 330                    }
 331
 332                    for room_id in room_ids {
 333                        let mut contacts_to_update = HashSet::default();
 334                        let mut canceled_calls_to_user_ids = Vec::new();
 335                        let mut live_kit_room = String::new();
 336                        let mut delete_live_kit_room = false;
 337
 338                        if let Some(mut refreshed_room) = app_state
 339                            .db
 340                            .clear_stale_room_participants(room_id, server_id)
 341                            .await
 342                            .trace_err()
 343                        {
 344                            tracing::info!(
 345                                room_id = room_id.0,
 346                                new_participant_count = refreshed_room.room.participants.len(),
 347                                "refreshed room"
 348                            );
 349                            room_updated(&refreshed_room.room, &peer);
 350                            if let Some(channel_id) = refreshed_room.channel_id {
 351                                channel_updated(
 352                                    channel_id,
 353                                    &refreshed_room.room,
 354                                    &refreshed_room.channel_members,
 355                                    &peer,
 356                                    &*pool.lock(),
 357                                );
 358                            }
 359                            contacts_to_update
 360                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 361                            contacts_to_update
 362                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 363                            canceled_calls_to_user_ids =
 364                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 365                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 366                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 367                        }
 368
 369                        {
 370                            let pool = pool.lock();
 371                            for canceled_user_id in canceled_calls_to_user_ids {
 372                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 373                                    peer.send(
 374                                        connection_id,
 375                                        proto::CallCanceled {
 376                                            room_id: room_id.to_proto(),
 377                                        },
 378                                    )
 379                                    .trace_err();
 380                                }
 381                            }
 382                        }
 383
 384                        for user_id in contacts_to_update {
 385                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 386                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 387                            if let Some((busy, contacts)) = busy.zip(contacts) {
 388                                let pool = pool.lock();
 389                                let updated_contact = contact_for_user(user_id, busy, &pool);
 390                                for contact in contacts {
 391                                    if let db::Contact::Accepted {
 392                                        user_id: contact_user_id,
 393                                        ..
 394                                    } = contact
 395                                    {
 396                                        for contact_conn_id in
 397                                            pool.user_connection_ids(contact_user_id)
 398                                        {
 399                                            peer.send(
 400                                                contact_conn_id,
 401                                                proto::UpdateContacts {
 402                                                    contacts: vec![updated_contact.clone()],
 403                                                    remove_contacts: Default::default(),
 404                                                    incoming_requests: Default::default(),
 405                                                    remove_incoming_requests: Default::default(),
 406                                                    outgoing_requests: Default::default(),
 407                                                    remove_outgoing_requests: Default::default(),
 408                                                },
 409                                            )
 410                                            .trace_err();
 411                                        }
 412                                    }
 413                                }
 414                            }
 415                        }
 416
 417                        if let Some(live_kit) = live_kit_client.as_ref() {
 418                            if delete_live_kit_room {
 419                                live_kit.delete_room(live_kit_room).await.trace_err();
 420                            }
 421                        }
 422                    }
 423                }
 424
 425                app_state
 426                    .db
 427                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 428                    .await
 429                    .trace_err();
 430            }
 431            .instrument(span),
 432        );
 433        Ok(())
 434    }
 435
 436    pub fn teardown(&self) {
 437        self.peer.teardown();
 438        self.connection_pool.lock().reset();
 439        let _ = self.teardown.send(true);
 440    }
 441
 442    #[cfg(test)]
 443    pub fn reset(&self, id: ServerId) {
 444        self.teardown();
 445        *self.id.lock() = id;
 446        self.peer.reset(id.0 as u32);
 447        let _ = self.teardown.send(false);
 448    }
 449
 450    #[cfg(test)]
 451    pub fn id(&self) -> ServerId {
 452        *self.id.lock()
 453    }
 454
 455    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 456    where
 457        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 458        Fut: 'static + Send + Future<Output = Result<()>>,
 459        M: EnvelopedMessage,
 460    {
 461        let prev_handler = self.handlers.insert(
 462            TypeId::of::<M>(),
 463            Box::new(move |envelope, session| {
 464                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 465                let span = info_span!(
 466                    "handle message",
 467                    payload_type = envelope.payload_type_name()
 468                );
 469                span.in_scope(|| {
 470                    tracing::info!(
 471                        payload_type = envelope.payload_type_name(),
 472                        "message received"
 473                    );
 474                });
 475                let start_time = Instant::now();
 476                let future = (handler)(*envelope, session);
 477                async move {
 478                    let result = future.await;
 479                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 480                    match result {
 481                        Err(error) => {
 482                            tracing::error!(%error, ?duration_ms, "error handling message")
 483                        }
 484                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 485                    }
 486                }
 487                .instrument(span)
 488                .boxed()
 489            }),
 490        );
 491        if prev_handler.is_some() {
 492            panic!("registered a handler for the same message twice");
 493        }
 494        self
 495    }
 496
 497    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 498    where
 499        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 500        Fut: 'static + Send + Future<Output = Result<()>>,
 501        M: EnvelopedMessage,
 502    {
 503        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 504        self
 505    }
 506
 507    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 510        Fut: Send + Future<Output = Result<()>>,
 511        M: RequestMessage,
 512    {
 513        let handler = Arc::new(handler);
 514        self.add_handler(move |envelope, session| {
 515            let receipt = envelope.receipt();
 516            let handler = handler.clone();
 517            async move {
 518                let peer = session.peer.clone();
 519                let responded = Arc::new(AtomicBool::default());
 520                let response = Response {
 521                    peer: peer.clone(),
 522                    responded: responded.clone(),
 523                    receipt,
 524                };
 525                match (handler)(envelope.payload, response, session).await {
 526                    Ok(()) => {
 527                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 528                            Ok(())
 529                        } else {
 530                            Err(anyhow!("handler did not send a response"))?
 531                        }
 532                    }
 533                    Err(error) => {
 534                        let proto_err = match &error {
 535                            Error::Internal(err) => err.to_proto(),
 536                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 537                        };
 538                        peer.respond_with_error(receipt, proto_err)?;
 539                        Err(error)
 540                    }
 541                }
 542            }
 543        })
 544    }
 545
 546    pub fn handle_connection(
 547        self: &Arc<Self>,
 548        connection: Connection,
 549        address: String,
 550        user: User,
 551        zed_version: ZedVersion,
 552        impersonator: Option<User>,
 553        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 554        executor: Executor,
 555    ) -> impl Future<Output = Result<()>> {
 556        let this = self.clone();
 557        let user_id = user.id;
 558        let login = user.github_login;
 559        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 560        if let Some(impersonator) = impersonator {
 561            span.record("impersonator", &impersonator.github_login);
 562        }
 563        let mut teardown = self.teardown.subscribe();
 564        async move {
 565            if *teardown.borrow() {
 566                return Err(anyhow!("server is tearing down"))?;
 567            }
 568            let (connection_id, handle_io, mut incoming_rx) = this
 569                .peer
 570                .add_connection(connection, {
 571                    let executor = executor.clone();
 572                    move |duration| executor.sleep(duration)
 573                });
 574
 575            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 576            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 577            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 578
 579            if let Some(send_connection_id) = send_connection_id.take() {
 580                let _ = send_connection_id.send(connection_id);
 581            }
 582
 583            if !user.connected_once {
 584                this.peer.send(connection_id, proto::ShowContacts {})?;
 585                this.app_state.db.set_user_connected_once(user_id, true).await?;
 586            }
 587
 588            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 589                this.app_state.db.get_contacts(user_id),
 590                this.app_state.db.get_channels_for_user(user_id),
 591                this.app_state.db.get_channel_invites_for_user(user_id),
 592            ).await?;
 593
 594            {
 595                let mut pool = this.connection_pool.lock();
 596                pool.add_connection(connection_id, user_id, user.admin, zed_version);
 597                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 598                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 599                this.peer.send(connection_id, build_channels_update(
 600                    channels_for_user,
 601                    channel_invites
 602                ))?;
 603            }
 604
 605            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 606                this.peer.send(connection_id, incoming_call)?;
 607            }
 608
 609            let session = Session {
 610                user_id,
 611                connection_id,
 612                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 613                peer: this.peer.clone(),
 614                connection_pool: this.connection_pool.clone(),
 615                live_kit_client: this.app_state.live_kit_client.clone(),
 616                _executor: executor.clone()
 617            };
 618            update_user_contacts(user_id, &session).await?;
 619
 620            let handle_io = handle_io.fuse();
 621            futures::pin_mut!(handle_io);
 622
 623            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 624            // This prevents deadlocks when e.g., client A performs a request to client B and
 625            // client B performs a request to client A. If both clients stop processing further
 626            // messages until their respective request completes, they won't have a chance to
 627            // respond to the other client's request and cause a deadlock.
 628            //
 629            // This arrangement ensures we will attempt to process earlier messages first, but fall
 630            // back to processing messages arrived later in the spirit of making progress.
 631            let mut foreground_message_handlers = FuturesUnordered::new();
 632            let concurrent_handlers = Arc::new(Semaphore::new(256));
 633            loop {
 634                let next_message = async {
 635                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 636                    let message = incoming_rx.next().await;
 637                    (permit, message)
 638                }.fuse();
 639                futures::pin_mut!(next_message);
 640                futures::select_biased! {
 641                    _ = teardown.changed().fuse() => return Ok(()),
 642                    result = handle_io => {
 643                        if let Err(error) = result {
 644                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 645                        }
 646                        break;
 647                    }
 648                    _ = foreground_message_handlers.next() => {}
 649                    next_message = next_message => {
 650                        let (permit, message) = next_message;
 651                        if let Some(message) = message {
 652                            let type_name = message.payload_type_name();
 653                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 654                            let span_enter = span.enter();
 655                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 656                                let is_background = message.is_background();
 657                                let handle_message = (handler)(message, session.clone());
 658                                drop(span_enter);
 659
 660                                let handle_message = async move {
 661                                    handle_message.await;
 662                                    drop(permit);
 663                                }.instrument(span);
 664                                if is_background {
 665                                    executor.spawn_detached(handle_message);
 666                                } else {
 667                                    foreground_message_handlers.push(handle_message);
 668                                }
 669                            } else {
 670                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 671                            }
 672                        } else {
 673                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 674                            break;
 675                        }
 676                    }
 677                }
 678            }
 679
 680            drop(foreground_message_handlers);
 681            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 682            if let Err(error) = connection_lost(session, teardown, executor).await {
 683                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 684            }
 685
 686            Ok(())
 687        }.instrument(span)
 688    }
 689
 690    pub async fn invite_code_redeemed(
 691        self: &Arc<Self>,
 692        inviter_id: UserId,
 693        invitee_id: UserId,
 694    ) -> Result<()> {
 695        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 696            if let Some(code) = &user.invite_code {
 697                let pool = self.connection_pool.lock();
 698                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 699                for connection_id in pool.user_connection_ids(inviter_id) {
 700                    self.peer.send(
 701                        connection_id,
 702                        proto::UpdateContacts {
 703                            contacts: vec![invitee_contact.clone()],
 704                            ..Default::default()
 705                        },
 706                    )?;
 707                    self.peer.send(
 708                        connection_id,
 709                        proto::UpdateInviteInfo {
 710                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 711                            count: user.invite_count as u32,
 712                        },
 713                    )?;
 714                }
 715            }
 716        }
 717        Ok(())
 718    }
 719
 720    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 721        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 722            if let Some(invite_code) = &user.invite_code {
 723                let pool = self.connection_pool.lock();
 724                for connection_id in pool.user_connection_ids(user_id) {
 725                    self.peer.send(
 726                        connection_id,
 727                        proto::UpdateInviteInfo {
 728                            url: format!(
 729                                "{}{}",
 730                                self.app_state.config.invite_link_prefix, invite_code
 731                            ),
 732                            count: user.invite_count as u32,
 733                        },
 734                    )?;
 735                }
 736            }
 737        }
 738        Ok(())
 739    }
 740
 741    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 742        ServerSnapshot {
 743            connection_pool: ConnectionPoolGuard {
 744                guard: self.connection_pool.lock(),
 745                _not_send: PhantomData,
 746            },
 747            peer: &self.peer,
 748        }
 749    }
 750}
 751
 752impl<'a> Deref for ConnectionPoolGuard<'a> {
 753    type Target = ConnectionPool;
 754
 755    fn deref(&self) -> &Self::Target {
 756        &*self.guard
 757    }
 758}
 759
 760impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 761    fn deref_mut(&mut self) -> &mut Self::Target {
 762        &mut *self.guard
 763    }
 764}
 765
 766impl<'a> Drop for ConnectionPoolGuard<'a> {
 767    fn drop(&mut self) {
 768        #[cfg(test)]
 769        self.check_invariants();
 770    }
 771}
 772
 773fn broadcast<F>(
 774    sender_id: Option<ConnectionId>,
 775    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 776    mut f: F,
 777) where
 778    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 779{
 780    for receiver_id in receiver_ids {
 781        if Some(receiver_id) != sender_id {
 782            if let Err(error) = f(receiver_id) {
 783                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 784            }
 785        }
 786    }
 787}
 788
 789pub struct ProtocolVersion(u32);
 790
 791impl Header for ProtocolVersion {
 792    fn name() -> &'static HeaderName {
 793        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 794        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 795    }
 796
 797    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 798    where
 799        Self: Sized,
 800        I: Iterator<Item = &'i axum::http::HeaderValue>,
 801    {
 802        let version = values
 803            .next()
 804            .ok_or_else(axum::headers::Error::invalid)?
 805            .to_str()
 806            .map_err(|_| axum::headers::Error::invalid())?
 807            .parse()
 808            .map_err(|_| axum::headers::Error::invalid())?;
 809        Ok(Self(version))
 810    }
 811
 812    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 813        values.extend([self.0.to_string().parse().unwrap()]);
 814    }
 815}
 816
 817pub struct AppVersionHeader(SemanticVersion);
 818impl Header for AppVersionHeader {
 819    fn name() -> &'static HeaderName {
 820        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 821        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 822    }
 823
 824    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 825    where
 826        Self: Sized,
 827        I: Iterator<Item = &'i axum::http::HeaderValue>,
 828    {
 829        let version = values
 830            .next()
 831            .ok_or_else(axum::headers::Error::invalid)?
 832            .to_str()
 833            .map_err(|_| axum::headers::Error::invalid())?
 834            .parse()
 835            .map_err(|_| axum::headers::Error::invalid())?;
 836        Ok(Self(version))
 837    }
 838
 839    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 840        values.extend([self.0.to_string().parse().unwrap()]);
 841    }
 842}
 843
 844pub fn routes(server: Arc<Server>) -> Router<Body> {
 845    Router::new()
 846        .route("/rpc", get(handle_websocket_request))
 847        .layer(
 848            ServiceBuilder::new()
 849                .layer(Extension(server.app_state.clone()))
 850                .layer(middleware::from_fn(auth::validate_header)),
 851        )
 852        .route("/metrics", get(handle_metrics))
 853        .layer(Extension(server))
 854}
 855
 856pub async fn handle_websocket_request(
 857    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 858    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 859    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 860    Extension(server): Extension<Arc<Server>>,
 861    Extension(user): Extension<User>,
 862    Extension(impersonator): Extension<Impersonator>,
 863    ws: WebSocketUpgrade,
 864) -> axum::response::Response {
 865    if protocol_version != rpc::PROTOCOL_VERSION {
 866        return (
 867            StatusCode::UPGRADE_REQUIRED,
 868            "client must be upgraded".to_string(),
 869        )
 870            .into_response();
 871    }
 872
 873    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 874        return (
 875            StatusCode::UPGRADE_REQUIRED,
 876            "no version header found".to_string(),
 877        )
 878            .into_response();
 879    };
 880
 881    if !version.is_supported() {
 882        return (
 883            StatusCode::UPGRADE_REQUIRED,
 884            "client must be upgraded".to_string(),
 885        )
 886            .into_response();
 887    }
 888
 889    let socket_address = socket_address.to_string();
 890    ws.on_upgrade(move |socket| {
 891        use util::ResultExt;
 892        let socket = socket
 893            .map_ok(to_tungstenite_message)
 894            .err_into()
 895            .with(|message| async move { Ok(to_axum_message(message)) });
 896        let connection = Connection::new(Box::pin(socket));
 897        async move {
 898            server
 899                .handle_connection(
 900                    connection,
 901                    socket_address,
 902                    user,
 903                    version,
 904                    impersonator.0,
 905                    None,
 906                    Executor::Production,
 907                )
 908                .await
 909                .log_err();
 910        }
 911    })
 912}
 913
 914pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 915    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 916    let connections_metric = CONNECTIONS_METRIC
 917        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
 918
 919    let connections = server
 920        .connection_pool
 921        .lock()
 922        .connections()
 923        .filter(|connection| !connection.admin)
 924        .count();
 925    connections_metric.set(connections as _);
 926
 927    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 928    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
 929        register_int_gauge!(
 930            "shared_projects",
 931            "number of open projects with one or more guests"
 932        )
 933        .unwrap()
 934    });
 935
 936    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 937    shared_projects_metric.set(shared_projects as _);
 938
 939    let encoder = prometheus::TextEncoder::new();
 940    let metric_families = prometheus::gather();
 941    let encoded_metrics = encoder
 942        .encode_to_string(&metric_families)
 943        .map_err(|err| anyhow!("{}", err))?;
 944    Ok(encoded_metrics)
 945}
 946
 947#[instrument(err, skip(executor))]
 948async fn connection_lost(
 949    session: Session,
 950    mut teardown: watch::Receiver<bool>,
 951    executor: Executor,
 952) -> Result<()> {
 953    session.peer.disconnect(session.connection_id);
 954    session
 955        .connection_pool()
 956        .await
 957        .remove_connection(session.connection_id)?;
 958
 959    session
 960        .db()
 961        .await
 962        .connection_lost(session.connection_id)
 963        .await
 964        .trace_err();
 965
 966    futures::select_biased! {
 967        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 968            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 969            leave_room_for_session(&session).await.trace_err();
 970            leave_channel_buffers_for_session(&session)
 971                .await
 972                .trace_err();
 973
 974            if !session
 975                .connection_pool()
 976                .await
 977                .is_user_online(session.user_id)
 978            {
 979                let db = session.db().await;
 980                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 981                    room_updated(&room, &session.peer);
 982                }
 983            }
 984
 985            update_user_contacts(session.user_id, &session).await?;
 986        }
 987        _ = teardown.changed().fuse() => {}
 988    }
 989
 990    Ok(())
 991}
 992
 993/// Acknowledges a ping from a client, used to keep the connection alive.
 994async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 995    response.send(proto::Ack {})?;
 996    Ok(())
 997}
 998
 999/// Creates a new room for calling (outside of channels)
1000async fn create_room(
1001    _request: proto::CreateRoom,
1002    response: Response<proto::CreateRoom>,
1003    session: Session,
1004) -> Result<()> {
1005    let live_kit_room = nanoid::nanoid!(30);
1006
1007    let live_kit_connection_info = {
1008        let live_kit_room = live_kit_room.clone();
1009        let live_kit = session.live_kit_client.as_ref();
1010
1011        util::async_maybe!({
1012            let live_kit = live_kit?;
1013
1014            let token = live_kit
1015                .room_token(&live_kit_room, &session.user_id.to_string())
1016                .trace_err()?;
1017
1018            Some(proto::LiveKitConnectionInfo {
1019                server_url: live_kit.url().into(),
1020                token,
1021                can_publish: true,
1022            })
1023        })
1024    }
1025    .await;
1026
1027    let room = session
1028        .db()
1029        .await
1030        .create_room(session.user_id, session.connection_id, &live_kit_room)
1031        .await?;
1032
1033    response.send(proto::CreateRoomResponse {
1034        room: Some(room.clone()),
1035        live_kit_connection_info,
1036    })?;
1037
1038    update_user_contacts(session.user_id, &session).await?;
1039    Ok(())
1040}
1041
1042/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1043async fn join_room(
1044    request: proto::JoinRoom,
1045    response: Response<proto::JoinRoom>,
1046    session: Session,
1047) -> Result<()> {
1048    let room_id = RoomId::from_proto(request.id);
1049
1050    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1051
1052    if let Some(channel_id) = channel_id {
1053        return join_channel_internal(channel_id, Box::new(response), session).await;
1054    }
1055
1056    let joined_room = {
1057        let room = session
1058            .db()
1059            .await
1060            .join_room(room_id, session.user_id, session.connection_id)
1061            .await?;
1062        room_updated(&room.room, &session.peer);
1063        room.into_inner()
1064    };
1065
1066    for connection_id in session
1067        .connection_pool()
1068        .await
1069        .user_connection_ids(session.user_id)
1070    {
1071        session
1072            .peer
1073            .send(
1074                connection_id,
1075                proto::CallCanceled {
1076                    room_id: room_id.to_proto(),
1077                },
1078            )
1079            .trace_err();
1080    }
1081
1082    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1083        if let Some(token) = live_kit
1084            .room_token(
1085                &joined_room.room.live_kit_room,
1086                &session.user_id.to_string(),
1087            )
1088            .trace_err()
1089        {
1090            Some(proto::LiveKitConnectionInfo {
1091                server_url: live_kit.url().into(),
1092                token,
1093                can_publish: true,
1094            })
1095        } else {
1096            None
1097        }
1098    } else {
1099        None
1100    };
1101
1102    response.send(proto::JoinRoomResponse {
1103        room: Some(joined_room.room),
1104        channel_id: None,
1105        live_kit_connection_info,
1106    })?;
1107
1108    update_user_contacts(session.user_id, &session).await?;
1109    Ok(())
1110}
1111
1112/// Rejoin room is used to reconnect to a room after connection errors.
1113async fn rejoin_room(
1114    request: proto::RejoinRoom,
1115    response: Response<proto::RejoinRoom>,
1116    session: Session,
1117) -> Result<()> {
1118    let room;
1119    let channel_id;
1120    let channel_members;
1121    {
1122        let mut rejoined_room = session
1123            .db()
1124            .await
1125            .rejoin_room(request, session.user_id, session.connection_id)
1126            .await?;
1127
1128        response.send(proto::RejoinRoomResponse {
1129            room: Some(rejoined_room.room.clone()),
1130            reshared_projects: rejoined_room
1131                .reshared_projects
1132                .iter()
1133                .map(|project| proto::ResharedProject {
1134                    id: project.id.to_proto(),
1135                    collaborators: project
1136                        .collaborators
1137                        .iter()
1138                        .map(|collaborator| collaborator.to_proto())
1139                        .collect(),
1140                })
1141                .collect(),
1142            rejoined_projects: rejoined_room
1143                .rejoined_projects
1144                .iter()
1145                .map(|rejoined_project| proto::RejoinedProject {
1146                    id: rejoined_project.id.to_proto(),
1147                    worktrees: rejoined_project
1148                        .worktrees
1149                        .iter()
1150                        .map(|worktree| proto::WorktreeMetadata {
1151                            id: worktree.id,
1152                            root_name: worktree.root_name.clone(),
1153                            visible: worktree.visible,
1154                            abs_path: worktree.abs_path.clone(),
1155                        })
1156                        .collect(),
1157                    collaborators: rejoined_project
1158                        .collaborators
1159                        .iter()
1160                        .map(|collaborator| collaborator.to_proto())
1161                        .collect(),
1162                    language_servers: rejoined_project.language_servers.clone(),
1163                })
1164                .collect(),
1165        })?;
1166        room_updated(&rejoined_room.room, &session.peer);
1167
1168        for project in &rejoined_room.reshared_projects {
1169            for collaborator in &project.collaborators {
1170                session
1171                    .peer
1172                    .send(
1173                        collaborator.connection_id,
1174                        proto::UpdateProjectCollaborator {
1175                            project_id: project.id.to_proto(),
1176                            old_peer_id: Some(project.old_connection_id.into()),
1177                            new_peer_id: Some(session.connection_id.into()),
1178                        },
1179                    )
1180                    .trace_err();
1181            }
1182
1183            broadcast(
1184                Some(session.connection_id),
1185                project
1186                    .collaborators
1187                    .iter()
1188                    .map(|collaborator| collaborator.connection_id),
1189                |connection_id| {
1190                    session.peer.forward_send(
1191                        session.connection_id,
1192                        connection_id,
1193                        proto::UpdateProject {
1194                            project_id: project.id.to_proto(),
1195                            worktrees: project.worktrees.clone(),
1196                        },
1197                    )
1198                },
1199            );
1200        }
1201
1202        for project in &rejoined_room.rejoined_projects {
1203            for collaborator in &project.collaborators {
1204                session
1205                    .peer
1206                    .send(
1207                        collaborator.connection_id,
1208                        proto::UpdateProjectCollaborator {
1209                            project_id: project.id.to_proto(),
1210                            old_peer_id: Some(project.old_connection_id.into()),
1211                            new_peer_id: Some(session.connection_id.into()),
1212                        },
1213                    )
1214                    .trace_err();
1215            }
1216        }
1217
1218        for project in &mut rejoined_room.rejoined_projects {
1219            for worktree in mem::take(&mut project.worktrees) {
1220                #[cfg(any(test, feature = "test-support"))]
1221                const MAX_CHUNK_SIZE: usize = 2;
1222                #[cfg(not(any(test, feature = "test-support")))]
1223                const MAX_CHUNK_SIZE: usize = 256;
1224
1225                // Stream this worktree's entries.
1226                let message = proto::UpdateWorktree {
1227                    project_id: project.id.to_proto(),
1228                    worktree_id: worktree.id,
1229                    abs_path: worktree.abs_path.clone(),
1230                    root_name: worktree.root_name,
1231                    updated_entries: worktree.updated_entries,
1232                    removed_entries: worktree.removed_entries,
1233                    scan_id: worktree.scan_id,
1234                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1235                    updated_repositories: worktree.updated_repositories,
1236                    removed_repositories: worktree.removed_repositories,
1237                };
1238                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1239                    session.peer.send(session.connection_id, update.clone())?;
1240                }
1241
1242                // Stream this worktree's diagnostics.
1243                for summary in worktree.diagnostic_summaries {
1244                    session.peer.send(
1245                        session.connection_id,
1246                        proto::UpdateDiagnosticSummary {
1247                            project_id: project.id.to_proto(),
1248                            worktree_id: worktree.id,
1249                            summary: Some(summary),
1250                        },
1251                    )?;
1252                }
1253
1254                for settings_file in worktree.settings_files {
1255                    session.peer.send(
1256                        session.connection_id,
1257                        proto::UpdateWorktreeSettings {
1258                            project_id: project.id.to_proto(),
1259                            worktree_id: worktree.id,
1260                            path: settings_file.path,
1261                            content: Some(settings_file.content),
1262                        },
1263                    )?;
1264                }
1265            }
1266
1267            for language_server in &project.language_servers {
1268                session.peer.send(
1269                    session.connection_id,
1270                    proto::UpdateLanguageServer {
1271                        project_id: project.id.to_proto(),
1272                        language_server_id: language_server.id,
1273                        variant: Some(
1274                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1275                                proto::LspDiskBasedDiagnosticsUpdated {},
1276                            ),
1277                        ),
1278                    },
1279                )?;
1280            }
1281        }
1282
1283        let rejoined_room = rejoined_room.into_inner();
1284
1285        room = rejoined_room.room;
1286        channel_id = rejoined_room.channel_id;
1287        channel_members = rejoined_room.channel_members;
1288    }
1289
1290    if let Some(channel_id) = channel_id {
1291        channel_updated(
1292            channel_id,
1293            &room,
1294            &channel_members,
1295            &session.peer,
1296            &*session.connection_pool().await,
1297        );
1298    }
1299
1300    update_user_contacts(session.user_id, &session).await?;
1301    Ok(())
1302}
1303
1304/// leave room disconnects from the room.
1305async fn leave_room(
1306    _: proto::LeaveRoom,
1307    response: Response<proto::LeaveRoom>,
1308    session: Session,
1309) -> Result<()> {
1310    leave_room_for_session(&session).await?;
1311    response.send(proto::Ack {})?;
1312    Ok(())
1313}
1314
1315/// Updates the permissions of someone else in the room.
1316async fn set_room_participant_role(
1317    request: proto::SetRoomParticipantRole,
1318    response: Response<proto::SetRoomParticipantRole>,
1319    session: Session,
1320) -> Result<()> {
1321    let user_id = UserId::from_proto(request.user_id);
1322    let role = ChannelRole::from(request.role());
1323
1324    if role == ChannelRole::Talker {
1325        let pool = session.connection_pool().await;
1326
1327        for connection in pool.user_connections(user_id) {
1328            if !connection.zed_version.supports_talker_role() {
1329                Err(anyhow!(
1330                    "This user is on zed {} which does not support unmute",
1331                    connection.zed_version
1332                ))?;
1333            }
1334        }
1335    }
1336
1337    let (live_kit_room, can_publish) = {
1338        let room = session
1339            .db()
1340            .await
1341            .set_room_participant_role(
1342                session.user_id,
1343                RoomId::from_proto(request.room_id),
1344                user_id,
1345                role,
1346            )
1347            .await?;
1348
1349        let live_kit_room = room.live_kit_room.clone();
1350        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1351        room_updated(&room, &session.peer);
1352        (live_kit_room, can_publish)
1353    };
1354
1355    if let Some(live_kit) = session.live_kit_client.as_ref() {
1356        live_kit
1357            .update_participant(
1358                live_kit_room.clone(),
1359                request.user_id.to_string(),
1360                live_kit_server::proto::ParticipantPermission {
1361                    can_subscribe: true,
1362                    can_publish,
1363                    can_publish_data: can_publish,
1364                    hidden: false,
1365                    recorder: false,
1366                },
1367            )
1368            .await
1369            .trace_err();
1370    }
1371
1372    response.send(proto::Ack {})?;
1373    Ok(())
1374}
1375
1376/// Call someone else into the current room
1377async fn call(
1378    request: proto::Call,
1379    response: Response<proto::Call>,
1380    session: Session,
1381) -> Result<()> {
1382    let room_id = RoomId::from_proto(request.room_id);
1383    let calling_user_id = session.user_id;
1384    let calling_connection_id = session.connection_id;
1385    let called_user_id = UserId::from_proto(request.called_user_id);
1386    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1387    if !session
1388        .db()
1389        .await
1390        .has_contact(calling_user_id, called_user_id)
1391        .await?
1392    {
1393        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1394    }
1395
1396    let incoming_call = {
1397        let (room, incoming_call) = &mut *session
1398            .db()
1399            .await
1400            .call(
1401                room_id,
1402                calling_user_id,
1403                calling_connection_id,
1404                called_user_id,
1405                initial_project_id,
1406            )
1407            .await?;
1408        room_updated(&room, &session.peer);
1409        mem::take(incoming_call)
1410    };
1411    update_user_contacts(called_user_id, &session).await?;
1412
1413    let mut calls = session
1414        .connection_pool()
1415        .await
1416        .user_connection_ids(called_user_id)
1417        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1418        .collect::<FuturesUnordered<_>>();
1419
1420    while let Some(call_response) = calls.next().await {
1421        match call_response.as_ref() {
1422            Ok(_) => {
1423                response.send(proto::Ack {})?;
1424                return Ok(());
1425            }
1426            Err(_) => {
1427                call_response.trace_err();
1428            }
1429        }
1430    }
1431
1432    {
1433        let room = session
1434            .db()
1435            .await
1436            .call_failed(room_id, called_user_id)
1437            .await?;
1438        room_updated(&room, &session.peer);
1439    }
1440    update_user_contacts(called_user_id, &session).await?;
1441
1442    Err(anyhow!("failed to ring user"))?
1443}
1444
1445/// Cancel an outgoing call.
1446async fn cancel_call(
1447    request: proto::CancelCall,
1448    response: Response<proto::CancelCall>,
1449    session: Session,
1450) -> Result<()> {
1451    let called_user_id = UserId::from_proto(request.called_user_id);
1452    let room_id = RoomId::from_proto(request.room_id);
1453    {
1454        let room = session
1455            .db()
1456            .await
1457            .cancel_call(room_id, session.connection_id, called_user_id)
1458            .await?;
1459        room_updated(&room, &session.peer);
1460    }
1461
1462    for connection_id in session
1463        .connection_pool()
1464        .await
1465        .user_connection_ids(called_user_id)
1466    {
1467        session
1468            .peer
1469            .send(
1470                connection_id,
1471                proto::CallCanceled {
1472                    room_id: room_id.to_proto(),
1473                },
1474            )
1475            .trace_err();
1476    }
1477    response.send(proto::Ack {})?;
1478
1479    update_user_contacts(called_user_id, &session).await?;
1480    Ok(())
1481}
1482
1483/// Decline an incoming call.
1484async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1485    let room_id = RoomId::from_proto(message.room_id);
1486    {
1487        let room = session
1488            .db()
1489            .await
1490            .decline_call(Some(room_id), session.user_id)
1491            .await?
1492            .ok_or_else(|| anyhow!("failed to decline call"))?;
1493        room_updated(&room, &session.peer);
1494    }
1495
1496    for connection_id in session
1497        .connection_pool()
1498        .await
1499        .user_connection_ids(session.user_id)
1500    {
1501        session
1502            .peer
1503            .send(
1504                connection_id,
1505                proto::CallCanceled {
1506                    room_id: room_id.to_proto(),
1507                },
1508            )
1509            .trace_err();
1510    }
1511    update_user_contacts(session.user_id, &session).await?;
1512    Ok(())
1513}
1514
1515/// Updates other participants in the room with your current location.
1516async fn update_participant_location(
1517    request: proto::UpdateParticipantLocation,
1518    response: Response<proto::UpdateParticipantLocation>,
1519    session: Session,
1520) -> Result<()> {
1521    let room_id = RoomId::from_proto(request.room_id);
1522    let location = request
1523        .location
1524        .ok_or_else(|| anyhow!("invalid location"))?;
1525
1526    let db = session.db().await;
1527    let room = db
1528        .update_room_participant_location(room_id, session.connection_id, location)
1529        .await?;
1530
1531    room_updated(&room, &session.peer);
1532    response.send(proto::Ack {})?;
1533    Ok(())
1534}
1535
1536/// Share a project into the room.
1537async fn share_project(
1538    request: proto::ShareProject,
1539    response: Response<proto::ShareProject>,
1540    session: Session,
1541) -> Result<()> {
1542    let (project_id, room) = &*session
1543        .db()
1544        .await
1545        .share_project(
1546            RoomId::from_proto(request.room_id),
1547            session.connection_id,
1548            &request.worktrees,
1549        )
1550        .await?;
1551    response.send(proto::ShareProjectResponse {
1552        project_id: project_id.to_proto(),
1553    })?;
1554    room_updated(&room, &session.peer);
1555
1556    Ok(())
1557}
1558
1559/// Unshare a project from the room.
1560async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1561    let project_id = ProjectId::from_proto(message.project_id);
1562
1563    let (room, guest_connection_ids) = &*session
1564        .db()
1565        .await
1566        .unshare_project(project_id, session.connection_id)
1567        .await?;
1568
1569    broadcast(
1570        Some(session.connection_id),
1571        guest_connection_ids.iter().copied(),
1572        |conn_id| session.peer.send(conn_id, message.clone()),
1573    );
1574    room_updated(&room, &session.peer);
1575
1576    Ok(())
1577}
1578
1579/// Join someone elses shared project.
1580async fn join_project(
1581    request: proto::JoinProject,
1582    response: Response<proto::JoinProject>,
1583    session: Session,
1584) -> Result<()> {
1585    let project_id = ProjectId::from_proto(request.project_id);
1586    let guest_user_id = session.user_id;
1587
1588    tracing::info!(%project_id, "join project");
1589
1590    let (project, replica_id) = &mut *session
1591        .db()
1592        .await
1593        .join_project(project_id, session.connection_id)
1594        .await?;
1595
1596    let collaborators = project
1597        .collaborators
1598        .iter()
1599        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1600        .map(|collaborator| collaborator.to_proto())
1601        .collect::<Vec<_>>();
1602
1603    let worktrees = project
1604        .worktrees
1605        .iter()
1606        .map(|(id, worktree)| proto::WorktreeMetadata {
1607            id: *id,
1608            root_name: worktree.root_name.clone(),
1609            visible: worktree.visible,
1610            abs_path: worktree.abs_path.clone(),
1611        })
1612        .collect::<Vec<_>>();
1613
1614    for collaborator in &collaborators {
1615        session
1616            .peer
1617            .send(
1618                collaborator.peer_id.unwrap().into(),
1619                proto::AddProjectCollaborator {
1620                    project_id: project_id.to_proto(),
1621                    collaborator: Some(proto::Collaborator {
1622                        peer_id: Some(session.connection_id.into()),
1623                        replica_id: replica_id.0 as u32,
1624                        user_id: guest_user_id.to_proto(),
1625                    }),
1626                },
1627            )
1628            .trace_err();
1629    }
1630
1631    // First, we send the metadata associated with each worktree.
1632    response.send(proto::JoinProjectResponse {
1633        worktrees: worktrees.clone(),
1634        replica_id: replica_id.0 as u32,
1635        collaborators: collaborators.clone(),
1636        language_servers: project.language_servers.clone(),
1637    })?;
1638
1639    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1640        #[cfg(any(test, feature = "test-support"))]
1641        const MAX_CHUNK_SIZE: usize = 2;
1642        #[cfg(not(any(test, feature = "test-support")))]
1643        const MAX_CHUNK_SIZE: usize = 256;
1644
1645        // Stream this worktree's entries.
1646        let message = proto::UpdateWorktree {
1647            project_id: project_id.to_proto(),
1648            worktree_id,
1649            abs_path: worktree.abs_path.clone(),
1650            root_name: worktree.root_name,
1651            updated_entries: worktree.entries,
1652            removed_entries: Default::default(),
1653            scan_id: worktree.scan_id,
1654            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1655            updated_repositories: worktree.repository_entries.into_values().collect(),
1656            removed_repositories: Default::default(),
1657        };
1658        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1659            session.peer.send(session.connection_id, update.clone())?;
1660        }
1661
1662        // Stream this worktree's diagnostics.
1663        for summary in worktree.diagnostic_summaries {
1664            session.peer.send(
1665                session.connection_id,
1666                proto::UpdateDiagnosticSummary {
1667                    project_id: project_id.to_proto(),
1668                    worktree_id: worktree.id,
1669                    summary: Some(summary),
1670                },
1671            )?;
1672        }
1673
1674        for settings_file in worktree.settings_files {
1675            session.peer.send(
1676                session.connection_id,
1677                proto::UpdateWorktreeSettings {
1678                    project_id: project_id.to_proto(),
1679                    worktree_id: worktree.id,
1680                    path: settings_file.path,
1681                    content: Some(settings_file.content),
1682                },
1683            )?;
1684        }
1685    }
1686
1687    for language_server in &project.language_servers {
1688        session.peer.send(
1689            session.connection_id,
1690            proto::UpdateLanguageServer {
1691                project_id: project_id.to_proto(),
1692                language_server_id: language_server.id,
1693                variant: Some(
1694                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1695                        proto::LspDiskBasedDiagnosticsUpdated {},
1696                    ),
1697                ),
1698            },
1699        )?;
1700    }
1701
1702    Ok(())
1703}
1704
1705/// Leave someone elses shared project.
1706async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1707    let sender_id = session.connection_id;
1708    let project_id = ProjectId::from_proto(request.project_id);
1709
1710    let (room, project) = &*session
1711        .db()
1712        .await
1713        .leave_project(project_id, sender_id)
1714        .await?;
1715    tracing::info!(
1716        %project_id,
1717        host_user_id = %project.host_user_id,
1718        host_connection_id = ?project.host_connection_id,
1719        "leave project"
1720    );
1721
1722    project_left(&project, &session);
1723    room_updated(&room, &session.peer);
1724
1725    Ok(())
1726}
1727
1728/// Updates other participants with changes to the project
1729async fn update_project(
1730    request: proto::UpdateProject,
1731    response: Response<proto::UpdateProject>,
1732    session: Session,
1733) -> Result<()> {
1734    let project_id = ProjectId::from_proto(request.project_id);
1735    let (room, guest_connection_ids) = &*session
1736        .db()
1737        .await
1738        .update_project(project_id, session.connection_id, &request.worktrees)
1739        .await?;
1740    broadcast(
1741        Some(session.connection_id),
1742        guest_connection_ids.iter().copied(),
1743        |connection_id| {
1744            session
1745                .peer
1746                .forward_send(session.connection_id, connection_id, request.clone())
1747        },
1748    );
1749    room_updated(&room, &session.peer);
1750    response.send(proto::Ack {})?;
1751
1752    Ok(())
1753}
1754
1755/// Updates other participants with changes to the worktree
1756async fn update_worktree(
1757    request: proto::UpdateWorktree,
1758    response: Response<proto::UpdateWorktree>,
1759    session: Session,
1760) -> Result<()> {
1761    let guest_connection_ids = session
1762        .db()
1763        .await
1764        .update_worktree(&request, session.connection_id)
1765        .await?;
1766
1767    broadcast(
1768        Some(session.connection_id),
1769        guest_connection_ids.iter().copied(),
1770        |connection_id| {
1771            session
1772                .peer
1773                .forward_send(session.connection_id, connection_id, request.clone())
1774        },
1775    );
1776    response.send(proto::Ack {})?;
1777    Ok(())
1778}
1779
1780/// Updates other participants with changes to the diagnostics
1781async fn update_diagnostic_summary(
1782    message: proto::UpdateDiagnosticSummary,
1783    session: Session,
1784) -> Result<()> {
1785    let guest_connection_ids = session
1786        .db()
1787        .await
1788        .update_diagnostic_summary(&message, session.connection_id)
1789        .await?;
1790
1791    broadcast(
1792        Some(session.connection_id),
1793        guest_connection_ids.iter().copied(),
1794        |connection_id| {
1795            session
1796                .peer
1797                .forward_send(session.connection_id, connection_id, message.clone())
1798        },
1799    );
1800
1801    Ok(())
1802}
1803
1804/// Updates other participants with changes to the worktree settings
1805async fn update_worktree_settings(
1806    message: proto::UpdateWorktreeSettings,
1807    session: Session,
1808) -> Result<()> {
1809    let guest_connection_ids = session
1810        .db()
1811        .await
1812        .update_worktree_settings(&message, session.connection_id)
1813        .await?;
1814
1815    broadcast(
1816        Some(session.connection_id),
1817        guest_connection_ids.iter().copied(),
1818        |connection_id| {
1819            session
1820                .peer
1821                .forward_send(session.connection_id, connection_id, message.clone())
1822        },
1823    );
1824
1825    Ok(())
1826}
1827
1828/// Notify other participants that a  language server has started.
1829async fn start_language_server(
1830    request: proto::StartLanguageServer,
1831    session: Session,
1832) -> Result<()> {
1833    let guest_connection_ids = session
1834        .db()
1835        .await
1836        .start_language_server(&request, session.connection_id)
1837        .await?;
1838
1839    broadcast(
1840        Some(session.connection_id),
1841        guest_connection_ids.iter().copied(),
1842        |connection_id| {
1843            session
1844                .peer
1845                .forward_send(session.connection_id, connection_id, request.clone())
1846        },
1847    );
1848    Ok(())
1849}
1850
1851/// Notify other participants that a language server has changed.
1852async fn update_language_server(
1853    request: proto::UpdateLanguageServer,
1854    session: Session,
1855) -> Result<()> {
1856    let project_id = ProjectId::from_proto(request.project_id);
1857    let project_connection_ids = session
1858        .db()
1859        .await
1860        .project_connection_ids(project_id, session.connection_id)
1861        .await?;
1862    broadcast(
1863        Some(session.connection_id),
1864        project_connection_ids.iter().copied(),
1865        |connection_id| {
1866            session
1867                .peer
1868                .forward_send(session.connection_id, connection_id, request.clone())
1869        },
1870    );
1871    Ok(())
1872}
1873
1874/// forward a project request to the host. These requests should be read only
1875/// as guests are allowed to send them.
1876async fn forward_read_only_project_request<T>(
1877    request: T,
1878    response: Response<T>,
1879    session: Session,
1880) -> Result<()>
1881where
1882    T: EntityMessage + RequestMessage,
1883{
1884    let project_id = ProjectId::from_proto(request.remote_entity_id());
1885    let host_connection_id = session
1886        .db()
1887        .await
1888        .host_for_read_only_project_request(project_id, session.connection_id)
1889        .await?;
1890    let payload = session
1891        .peer
1892        .forward_request(session.connection_id, host_connection_id, request)
1893        .await?;
1894    response.send(payload)?;
1895    Ok(())
1896}
1897
1898/// forward a project request to the host. These requests are disallowed
1899/// for guests.
1900async fn forward_mutating_project_request<T>(
1901    request: T,
1902    response: Response<T>,
1903    session: Session,
1904) -> Result<()>
1905where
1906    T: EntityMessage + RequestMessage,
1907{
1908    let project_id = ProjectId::from_proto(request.remote_entity_id());
1909    let host_connection_id = session
1910        .db()
1911        .await
1912        .host_for_mutating_project_request(project_id, session.connection_id)
1913        .await?;
1914    let payload = session
1915        .peer
1916        .forward_request(session.connection_id, host_connection_id, request)
1917        .await?;
1918    response.send(payload)?;
1919    Ok(())
1920}
1921
1922/// Notify other participants that a new buffer has been created
1923async fn create_buffer_for_peer(
1924    request: proto::CreateBufferForPeer,
1925    session: Session,
1926) -> Result<()> {
1927    session
1928        .db()
1929        .await
1930        .check_user_is_project_host(
1931            ProjectId::from_proto(request.project_id),
1932            session.connection_id,
1933        )
1934        .await?;
1935    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1936    session
1937        .peer
1938        .forward_send(session.connection_id, peer_id.into(), request)?;
1939    Ok(())
1940}
1941
1942/// Notify other participants that a buffer has been updated. This is
1943/// allowed for guests as long as the update is limited to selections.
1944async fn update_buffer(
1945    request: proto::UpdateBuffer,
1946    response: Response<proto::UpdateBuffer>,
1947    session: Session,
1948) -> Result<()> {
1949    let project_id = ProjectId::from_proto(request.project_id);
1950    let mut guest_connection_ids;
1951    let mut host_connection_id = None;
1952
1953    let mut requires_write_permission = false;
1954
1955    for op in request.operations.iter() {
1956        match op.variant {
1957            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1958            Some(_) => requires_write_permission = true,
1959        }
1960    }
1961
1962    {
1963        let collaborators = session
1964            .db()
1965            .await
1966            .project_collaborators_for_buffer_update(
1967                project_id,
1968                session.connection_id,
1969                requires_write_permission,
1970            )
1971            .await?;
1972        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1973        for collaborator in collaborators.iter() {
1974            if collaborator.is_host {
1975                host_connection_id = Some(collaborator.connection_id);
1976            } else {
1977                guest_connection_ids.push(collaborator.connection_id);
1978            }
1979        }
1980    }
1981    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1982
1983    broadcast(
1984        Some(session.connection_id),
1985        guest_connection_ids,
1986        |connection_id| {
1987            session
1988                .peer
1989                .forward_send(session.connection_id, connection_id, request.clone())
1990        },
1991    );
1992    if host_connection_id != session.connection_id {
1993        session
1994            .peer
1995            .forward_request(session.connection_id, host_connection_id, request.clone())
1996            .await?;
1997    }
1998
1999    response.send(proto::Ack {})?;
2000    Ok(())
2001}
2002
2003/// Notify other participants that a project has been updated.
2004async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2005    request: T,
2006    session: Session,
2007) -> Result<()> {
2008    let project_id = ProjectId::from_proto(request.remote_entity_id());
2009    let project_connection_ids = session
2010        .db()
2011        .await
2012        .project_connection_ids(project_id, session.connection_id)
2013        .await?;
2014
2015    broadcast(
2016        Some(session.connection_id),
2017        project_connection_ids.iter().copied(),
2018        |connection_id| {
2019            session
2020                .peer
2021                .forward_send(session.connection_id, connection_id, request.clone())
2022        },
2023    );
2024    Ok(())
2025}
2026
2027/// Start following another user in a call.
2028async fn follow(
2029    request: proto::Follow,
2030    response: Response<proto::Follow>,
2031    session: Session,
2032) -> Result<()> {
2033    let room_id = RoomId::from_proto(request.room_id);
2034    let project_id = request.project_id.map(ProjectId::from_proto);
2035    let leader_id = request
2036        .leader_id
2037        .ok_or_else(|| anyhow!("invalid leader id"))?
2038        .into();
2039    let follower_id = session.connection_id;
2040
2041    session
2042        .db()
2043        .await
2044        .check_room_participants(room_id, leader_id, session.connection_id)
2045        .await?;
2046
2047    let response_payload = session
2048        .peer
2049        .forward_request(session.connection_id, leader_id, request)
2050        .await?;
2051    response.send(response_payload)?;
2052
2053    if let Some(project_id) = project_id {
2054        let room = session
2055            .db()
2056            .await
2057            .follow(room_id, project_id, leader_id, follower_id)
2058            .await?;
2059        room_updated(&room, &session.peer);
2060    }
2061
2062    Ok(())
2063}
2064
2065/// Stop following another user in a call.
2066async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2067    let room_id = RoomId::from_proto(request.room_id);
2068    let project_id = request.project_id.map(ProjectId::from_proto);
2069    let leader_id = request
2070        .leader_id
2071        .ok_or_else(|| anyhow!("invalid leader id"))?
2072        .into();
2073    let follower_id = session.connection_id;
2074
2075    session
2076        .db()
2077        .await
2078        .check_room_participants(room_id, leader_id, session.connection_id)
2079        .await?;
2080
2081    session
2082        .peer
2083        .forward_send(session.connection_id, leader_id, request)?;
2084
2085    if let Some(project_id) = project_id {
2086        let room = session
2087            .db()
2088            .await
2089            .unfollow(room_id, project_id, leader_id, follower_id)
2090            .await?;
2091        room_updated(&room, &session.peer);
2092    }
2093
2094    Ok(())
2095}
2096
2097/// Notify everyone following you of your current location.
2098async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2099    let room_id = RoomId::from_proto(request.room_id);
2100    let database = session.db.lock().await;
2101
2102    let connection_ids = if let Some(project_id) = request.project_id {
2103        let project_id = ProjectId::from_proto(project_id);
2104        database
2105            .project_connection_ids(project_id, session.connection_id)
2106            .await?
2107    } else {
2108        database
2109            .room_connection_ids(room_id, session.connection_id)
2110            .await?
2111    };
2112
2113    // For now, don't send view update messages back to that view's current leader.
2114    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2115        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2116        _ => None,
2117    });
2118
2119    for connection_id in connection_ids.iter().cloned() {
2120        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2121            session
2122                .peer
2123                .forward_send(session.connection_id, connection_id, request.clone())?;
2124        }
2125    }
2126    Ok(())
2127}
2128
2129/// Get public data about users.
2130async fn get_users(
2131    request: proto::GetUsers,
2132    response: Response<proto::GetUsers>,
2133    session: Session,
2134) -> Result<()> {
2135    let user_ids = request
2136        .user_ids
2137        .into_iter()
2138        .map(UserId::from_proto)
2139        .collect();
2140    let users = session
2141        .db()
2142        .await
2143        .get_users_by_ids(user_ids)
2144        .await?
2145        .into_iter()
2146        .map(|user| proto::User {
2147            id: user.id.to_proto(),
2148            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2149            github_login: user.github_login,
2150        })
2151        .collect();
2152    response.send(proto::UsersResponse { users })?;
2153    Ok(())
2154}
2155
2156/// Search for users (to invite) buy Github login
2157async fn fuzzy_search_users(
2158    request: proto::FuzzySearchUsers,
2159    response: Response<proto::FuzzySearchUsers>,
2160    session: Session,
2161) -> Result<()> {
2162    let query = request.query;
2163    let users = match query.len() {
2164        0 => vec![],
2165        1 | 2 => session
2166            .db()
2167            .await
2168            .get_user_by_github_login(&query)
2169            .await?
2170            .into_iter()
2171            .collect(),
2172        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2173    };
2174    let users = users
2175        .into_iter()
2176        .filter(|user| user.id != session.user_id)
2177        .map(|user| proto::User {
2178            id: user.id.to_proto(),
2179            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2180            github_login: user.github_login,
2181        })
2182        .collect();
2183    response.send(proto::UsersResponse { users })?;
2184    Ok(())
2185}
2186
2187/// Send a contact request to another user.
2188async fn request_contact(
2189    request: proto::RequestContact,
2190    response: Response<proto::RequestContact>,
2191    session: Session,
2192) -> Result<()> {
2193    let requester_id = session.user_id;
2194    let responder_id = UserId::from_proto(request.responder_id);
2195    if requester_id == responder_id {
2196        return Err(anyhow!("cannot add yourself as a contact"))?;
2197    }
2198
2199    let notifications = session
2200        .db()
2201        .await
2202        .send_contact_request(requester_id, responder_id)
2203        .await?;
2204
2205    // Update outgoing contact requests of requester
2206    let mut update = proto::UpdateContacts::default();
2207    update.outgoing_requests.push(responder_id.to_proto());
2208    for connection_id in session
2209        .connection_pool()
2210        .await
2211        .user_connection_ids(requester_id)
2212    {
2213        session.peer.send(connection_id, update.clone())?;
2214    }
2215
2216    // Update incoming contact requests of responder
2217    let mut update = proto::UpdateContacts::default();
2218    update
2219        .incoming_requests
2220        .push(proto::IncomingContactRequest {
2221            requester_id: requester_id.to_proto(),
2222        });
2223    let connection_pool = session.connection_pool().await;
2224    for connection_id in connection_pool.user_connection_ids(responder_id) {
2225        session.peer.send(connection_id, update.clone())?;
2226    }
2227
2228    send_notifications(&*connection_pool, &session.peer, notifications);
2229
2230    response.send(proto::Ack {})?;
2231    Ok(())
2232}
2233
2234/// Accept or decline a contact request
2235async fn respond_to_contact_request(
2236    request: proto::RespondToContactRequest,
2237    response: Response<proto::RespondToContactRequest>,
2238    session: Session,
2239) -> Result<()> {
2240    let responder_id = session.user_id;
2241    let requester_id = UserId::from_proto(request.requester_id);
2242    let db = session.db().await;
2243    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2244        db.dismiss_contact_notification(responder_id, requester_id)
2245            .await?;
2246    } else {
2247        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2248
2249        let notifications = db
2250            .respond_to_contact_request(responder_id, requester_id, accept)
2251            .await?;
2252        let requester_busy = db.is_user_busy(requester_id).await?;
2253        let responder_busy = db.is_user_busy(responder_id).await?;
2254
2255        let pool = session.connection_pool().await;
2256        // Update responder with new contact
2257        let mut update = proto::UpdateContacts::default();
2258        if accept {
2259            update
2260                .contacts
2261                .push(contact_for_user(requester_id, requester_busy, &pool));
2262        }
2263        update
2264            .remove_incoming_requests
2265            .push(requester_id.to_proto());
2266        for connection_id in pool.user_connection_ids(responder_id) {
2267            session.peer.send(connection_id, update.clone())?;
2268        }
2269
2270        // Update requester with new contact
2271        let mut update = proto::UpdateContacts::default();
2272        if accept {
2273            update
2274                .contacts
2275                .push(contact_for_user(responder_id, responder_busy, &pool));
2276        }
2277        update
2278            .remove_outgoing_requests
2279            .push(responder_id.to_proto());
2280
2281        for connection_id in pool.user_connection_ids(requester_id) {
2282            session.peer.send(connection_id, update.clone())?;
2283        }
2284
2285        send_notifications(&*pool, &session.peer, notifications);
2286    }
2287
2288    response.send(proto::Ack {})?;
2289    Ok(())
2290}
2291
2292/// Remove a contact.
2293async fn remove_contact(
2294    request: proto::RemoveContact,
2295    response: Response<proto::RemoveContact>,
2296    session: Session,
2297) -> Result<()> {
2298    let requester_id = session.user_id;
2299    let responder_id = UserId::from_proto(request.user_id);
2300    let db = session.db().await;
2301    let (contact_accepted, deleted_notification_id) =
2302        db.remove_contact(requester_id, responder_id).await?;
2303
2304    let pool = session.connection_pool().await;
2305    // Update outgoing contact requests of requester
2306    let mut update = proto::UpdateContacts::default();
2307    if contact_accepted {
2308        update.remove_contacts.push(responder_id.to_proto());
2309    } else {
2310        update
2311            .remove_outgoing_requests
2312            .push(responder_id.to_proto());
2313    }
2314    for connection_id in pool.user_connection_ids(requester_id) {
2315        session.peer.send(connection_id, update.clone())?;
2316    }
2317
2318    // Update incoming contact requests of responder
2319    let mut update = proto::UpdateContacts::default();
2320    if contact_accepted {
2321        update.remove_contacts.push(requester_id.to_proto());
2322    } else {
2323        update
2324            .remove_incoming_requests
2325            .push(requester_id.to_proto());
2326    }
2327    for connection_id in pool.user_connection_ids(responder_id) {
2328        session.peer.send(connection_id, update.clone())?;
2329        if let Some(notification_id) = deleted_notification_id {
2330            session.peer.send(
2331                connection_id,
2332                proto::DeleteNotification {
2333                    notification_id: notification_id.to_proto(),
2334                },
2335            )?;
2336        }
2337    }
2338
2339    response.send(proto::Ack {})?;
2340    Ok(())
2341}
2342
2343/// Creates a new channel.
2344async fn create_channel(
2345    request: proto::CreateChannel,
2346    response: Response<proto::CreateChannel>,
2347    session: Session,
2348) -> Result<()> {
2349    let db = session.db().await;
2350
2351    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2352    let (channel, owner, channel_members) = db
2353        .create_channel(&request.name, parent_id, session.user_id)
2354        .await?;
2355
2356    response.send(proto::CreateChannelResponse {
2357        channel: Some(channel.to_proto()),
2358        parent_id: request.parent_id,
2359    })?;
2360
2361    let connection_pool = session.connection_pool().await;
2362    if let Some(owner) = owner {
2363        let update = proto::UpdateUserChannels {
2364            channel_memberships: vec![proto::ChannelMembership {
2365                channel_id: owner.channel_id.to_proto(),
2366                role: owner.role.into(),
2367            }],
2368            ..Default::default()
2369        };
2370        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2371            session.peer.send(connection_id, update.clone())?;
2372        }
2373    }
2374
2375    for channel_member in channel_members {
2376        if !channel_member.role.can_see_channel(channel.visibility) {
2377            continue;
2378        }
2379
2380        let update = proto::UpdateChannels {
2381            channels: vec![channel.to_proto()],
2382            ..Default::default()
2383        };
2384        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2385            session.peer.send(connection_id, update.clone())?;
2386        }
2387    }
2388
2389    Ok(())
2390}
2391
2392/// Delete a channel
2393async fn delete_channel(
2394    request: proto::DeleteChannel,
2395    response: Response<proto::DeleteChannel>,
2396    session: Session,
2397) -> Result<()> {
2398    let db = session.db().await;
2399
2400    let channel_id = request.channel_id;
2401    let (removed_channels, member_ids) = db
2402        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2403        .await?;
2404    response.send(proto::Ack {})?;
2405
2406    // Notify members of removed channels
2407    let mut update = proto::UpdateChannels::default();
2408    update
2409        .delete_channels
2410        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2411
2412    let connection_pool = session.connection_pool().await;
2413    for member_id in member_ids {
2414        for connection_id in connection_pool.user_connection_ids(member_id) {
2415            session.peer.send(connection_id, update.clone())?;
2416        }
2417    }
2418
2419    Ok(())
2420}
2421
2422/// Invite someone to join a channel.
2423async fn invite_channel_member(
2424    request: proto::InviteChannelMember,
2425    response: Response<proto::InviteChannelMember>,
2426    session: Session,
2427) -> Result<()> {
2428    let db = session.db().await;
2429    let channel_id = ChannelId::from_proto(request.channel_id);
2430    let invitee_id = UserId::from_proto(request.user_id);
2431    let InviteMemberResult {
2432        channel,
2433        notifications,
2434    } = db
2435        .invite_channel_member(
2436            channel_id,
2437            invitee_id,
2438            session.user_id,
2439            request.role().into(),
2440        )
2441        .await?;
2442
2443    let update = proto::UpdateChannels {
2444        channel_invitations: vec![channel.to_proto()],
2445        ..Default::default()
2446    };
2447
2448    let connection_pool = session.connection_pool().await;
2449    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2450        session.peer.send(connection_id, update.clone())?;
2451    }
2452
2453    send_notifications(&*connection_pool, &session.peer, notifications);
2454
2455    response.send(proto::Ack {})?;
2456    Ok(())
2457}
2458
2459/// remove someone from a channel
2460async fn remove_channel_member(
2461    request: proto::RemoveChannelMember,
2462    response: Response<proto::RemoveChannelMember>,
2463    session: Session,
2464) -> Result<()> {
2465    let db = session.db().await;
2466    let channel_id = ChannelId::from_proto(request.channel_id);
2467    let member_id = UserId::from_proto(request.user_id);
2468
2469    let RemoveChannelMemberResult {
2470        membership_update,
2471        notification_id,
2472    } = db
2473        .remove_channel_member(channel_id, member_id, session.user_id)
2474        .await?;
2475
2476    let connection_pool = &session.connection_pool().await;
2477    notify_membership_updated(
2478        &connection_pool,
2479        membership_update,
2480        member_id,
2481        &session.peer,
2482    );
2483    for connection_id in connection_pool.user_connection_ids(member_id) {
2484        if let Some(notification_id) = notification_id {
2485            session
2486                .peer
2487                .send(
2488                    connection_id,
2489                    proto::DeleteNotification {
2490                        notification_id: notification_id.to_proto(),
2491                    },
2492                )
2493                .trace_err();
2494        }
2495    }
2496
2497    response.send(proto::Ack {})?;
2498    Ok(())
2499}
2500
2501/// Toggle the channel between public and private.
2502/// Care is taken to maintain the invariant that public channels only descend from public channels,
2503/// (though members-only channels can appear at any point in the hierarchy).
2504async fn set_channel_visibility(
2505    request: proto::SetChannelVisibility,
2506    response: Response<proto::SetChannelVisibility>,
2507    session: Session,
2508) -> Result<()> {
2509    let db = session.db().await;
2510    let channel_id = ChannelId::from_proto(request.channel_id);
2511    let visibility = request.visibility().into();
2512
2513    let (channel, channel_members) = db
2514        .set_channel_visibility(channel_id, visibility, session.user_id)
2515        .await?;
2516
2517    let connection_pool = session.connection_pool().await;
2518    for member in channel_members {
2519        let update = if member.role.can_see_channel(channel.visibility) {
2520            proto::UpdateChannels {
2521                channels: vec![channel.to_proto()],
2522                ..Default::default()
2523            }
2524        } else {
2525            proto::UpdateChannels {
2526                delete_channels: vec![channel.id.to_proto()],
2527                ..Default::default()
2528            }
2529        };
2530
2531        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2532            session.peer.send(connection_id, update.clone())?;
2533        }
2534    }
2535
2536    response.send(proto::Ack {})?;
2537    Ok(())
2538}
2539
2540/// Alter the role for a user in the channel.
2541async fn set_channel_member_role(
2542    request: proto::SetChannelMemberRole,
2543    response: Response<proto::SetChannelMemberRole>,
2544    session: Session,
2545) -> Result<()> {
2546    let db = session.db().await;
2547    let channel_id = ChannelId::from_proto(request.channel_id);
2548    let member_id = UserId::from_proto(request.user_id);
2549    let result = db
2550        .set_channel_member_role(
2551            channel_id,
2552            session.user_id,
2553            member_id,
2554            request.role().into(),
2555        )
2556        .await?;
2557
2558    match result {
2559        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2560            let connection_pool = session.connection_pool().await;
2561            notify_membership_updated(
2562                &connection_pool,
2563                membership_update,
2564                member_id,
2565                &session.peer,
2566            )
2567        }
2568        db::SetMemberRoleResult::InviteUpdated(channel) => {
2569            let update = proto::UpdateChannels {
2570                channel_invitations: vec![channel.to_proto()],
2571                ..Default::default()
2572            };
2573
2574            for connection_id in session
2575                .connection_pool()
2576                .await
2577                .user_connection_ids(member_id)
2578            {
2579                session.peer.send(connection_id, update.clone())?;
2580            }
2581        }
2582    }
2583
2584    response.send(proto::Ack {})?;
2585    Ok(())
2586}
2587
2588/// Change the name of a channel
2589async fn rename_channel(
2590    request: proto::RenameChannel,
2591    response: Response<proto::RenameChannel>,
2592    session: Session,
2593) -> Result<()> {
2594    let db = session.db().await;
2595    let channel_id = ChannelId::from_proto(request.channel_id);
2596    let (channel, channel_members) = db
2597        .rename_channel(channel_id, session.user_id, &request.name)
2598        .await?;
2599
2600    response.send(proto::RenameChannelResponse {
2601        channel: Some(channel.to_proto()),
2602    })?;
2603
2604    let connection_pool = session.connection_pool().await;
2605    for channel_member in channel_members {
2606        if !channel_member.role.can_see_channel(channel.visibility) {
2607            continue;
2608        }
2609        let update = proto::UpdateChannels {
2610            channels: vec![channel.to_proto()],
2611            ..Default::default()
2612        };
2613        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2614            session.peer.send(connection_id, update.clone())?;
2615        }
2616    }
2617
2618    Ok(())
2619}
2620
2621/// Move a channel to a new parent.
2622async fn move_channel(
2623    request: proto::MoveChannel,
2624    response: Response<proto::MoveChannel>,
2625    session: Session,
2626) -> Result<()> {
2627    let channel_id = ChannelId::from_proto(request.channel_id);
2628    let to = ChannelId::from_proto(request.to);
2629
2630    let (channels, channel_members) = session
2631        .db()
2632        .await
2633        .move_channel(channel_id, to, session.user_id)
2634        .await?;
2635
2636    let connection_pool = session.connection_pool().await;
2637    for member in channel_members {
2638        let channels = channels
2639            .iter()
2640            .filter_map(|channel| {
2641                if member.role.can_see_channel(channel.visibility) {
2642                    Some(channel.to_proto())
2643                } else {
2644                    None
2645                }
2646            })
2647            .collect::<Vec<_>>();
2648        if channels.is_empty() {
2649            continue;
2650        }
2651
2652        let update = proto::UpdateChannels {
2653            channels,
2654            ..Default::default()
2655        };
2656
2657        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2658            session.peer.send(connection_id, update.clone())?;
2659        }
2660    }
2661
2662    response.send(Ack {})?;
2663    Ok(())
2664}
2665
2666/// Get the list of channel members
2667async fn get_channel_members(
2668    request: proto::GetChannelMembers,
2669    response: Response<proto::GetChannelMembers>,
2670    session: Session,
2671) -> Result<()> {
2672    let db = session.db().await;
2673    let channel_id = ChannelId::from_proto(request.channel_id);
2674    let members = db
2675        .get_channel_participant_details(channel_id, session.user_id)
2676        .await?;
2677    response.send(proto::GetChannelMembersResponse { members })?;
2678    Ok(())
2679}
2680
2681/// Accept or decline a channel invitation.
2682async fn respond_to_channel_invite(
2683    request: proto::RespondToChannelInvite,
2684    response: Response<proto::RespondToChannelInvite>,
2685    session: Session,
2686) -> Result<()> {
2687    let db = session.db().await;
2688    let channel_id = ChannelId::from_proto(request.channel_id);
2689    let RespondToChannelInvite {
2690        membership_update,
2691        notifications,
2692    } = db
2693        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2694        .await?;
2695
2696    let connection_pool = session.connection_pool().await;
2697    if let Some(membership_update) = membership_update {
2698        notify_membership_updated(
2699            &connection_pool,
2700            membership_update,
2701            session.user_id,
2702            &session.peer,
2703        );
2704    } else {
2705        let update = proto::UpdateChannels {
2706            remove_channel_invitations: vec![channel_id.to_proto()],
2707            ..Default::default()
2708        };
2709
2710        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2711            session.peer.send(connection_id, update.clone())?;
2712        }
2713    };
2714
2715    send_notifications(&*connection_pool, &session.peer, notifications);
2716
2717    response.send(proto::Ack {})?;
2718
2719    Ok(())
2720}
2721
2722/// Join the channels' room
2723async fn join_channel(
2724    request: proto::JoinChannel,
2725    response: Response<proto::JoinChannel>,
2726    session: Session,
2727) -> Result<()> {
2728    let channel_id = ChannelId::from_proto(request.channel_id);
2729    join_channel_internal(channel_id, Box::new(response), session).await
2730}
2731
2732trait JoinChannelInternalResponse {
2733    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2734}
2735impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2736    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2737        Response::<proto::JoinChannel>::send(self, result)
2738    }
2739}
2740impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2741    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2742        Response::<proto::JoinRoom>::send(self, result)
2743    }
2744}
2745
2746async fn join_channel_internal(
2747    channel_id: ChannelId,
2748    response: Box<impl JoinChannelInternalResponse>,
2749    session: Session,
2750) -> Result<()> {
2751    let joined_room = {
2752        leave_room_for_session(&session).await?;
2753        let db = session.db().await;
2754
2755        let (joined_room, membership_updated, role) = db
2756            .join_channel(channel_id, session.user_id, session.connection_id)
2757            .await?;
2758
2759        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2760            let (can_publish, token) = if role == ChannelRole::Guest {
2761                (
2762                    false,
2763                    live_kit
2764                        .guest_token(
2765                            &joined_room.room.live_kit_room,
2766                            &session.user_id.to_string(),
2767                        )
2768                        .trace_err()?,
2769                )
2770            } else {
2771                (
2772                    true,
2773                    live_kit
2774                        .room_token(
2775                            &joined_room.room.live_kit_room,
2776                            &session.user_id.to_string(),
2777                        )
2778                        .trace_err()?,
2779                )
2780            };
2781
2782            Some(LiveKitConnectionInfo {
2783                server_url: live_kit.url().into(),
2784                token,
2785                can_publish,
2786            })
2787        });
2788
2789        response.send(proto::JoinRoomResponse {
2790            room: Some(joined_room.room.clone()),
2791            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2792            live_kit_connection_info,
2793        })?;
2794
2795        let connection_pool = session.connection_pool().await;
2796        if let Some(membership_updated) = membership_updated {
2797            notify_membership_updated(
2798                &connection_pool,
2799                membership_updated,
2800                session.user_id,
2801                &session.peer,
2802            );
2803        }
2804
2805        room_updated(&joined_room.room, &session.peer);
2806
2807        joined_room
2808    };
2809
2810    channel_updated(
2811        channel_id,
2812        &joined_room.room,
2813        &joined_room.channel_members,
2814        &session.peer,
2815        &*session.connection_pool().await,
2816    );
2817
2818    update_user_contacts(session.user_id, &session).await?;
2819    Ok(())
2820}
2821
2822/// Start editing the channel notes
2823async fn join_channel_buffer(
2824    request: proto::JoinChannelBuffer,
2825    response: Response<proto::JoinChannelBuffer>,
2826    session: Session,
2827) -> Result<()> {
2828    let db = session.db().await;
2829    let channel_id = ChannelId::from_proto(request.channel_id);
2830
2831    let open_response = db
2832        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2833        .await?;
2834
2835    let collaborators = open_response.collaborators.clone();
2836    response.send(open_response)?;
2837
2838    let update = UpdateChannelBufferCollaborators {
2839        channel_id: channel_id.to_proto(),
2840        collaborators: collaborators.clone(),
2841    };
2842    channel_buffer_updated(
2843        session.connection_id,
2844        collaborators
2845            .iter()
2846            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2847        &update,
2848        &session.peer,
2849    );
2850
2851    Ok(())
2852}
2853
2854/// Edit the channel notes
2855async fn update_channel_buffer(
2856    request: proto::UpdateChannelBuffer,
2857    session: Session,
2858) -> Result<()> {
2859    let db = session.db().await;
2860    let channel_id = ChannelId::from_proto(request.channel_id);
2861
2862    let (collaborators, non_collaborators, epoch, version) = db
2863        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2864        .await?;
2865
2866    channel_buffer_updated(
2867        session.connection_id,
2868        collaborators,
2869        &proto::UpdateChannelBuffer {
2870            channel_id: channel_id.to_proto(),
2871            operations: request.operations,
2872        },
2873        &session.peer,
2874    );
2875
2876    let pool = &*session.connection_pool().await;
2877
2878    broadcast(
2879        None,
2880        non_collaborators
2881            .iter()
2882            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2883        |peer_id| {
2884            session.peer.send(
2885                peer_id.into(),
2886                proto::UpdateChannels {
2887                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2888                        channel_id: channel_id.to_proto(),
2889                        epoch: epoch as u64,
2890                        version: version.clone(),
2891                    }],
2892                    ..Default::default()
2893                },
2894            )
2895        },
2896    );
2897
2898    Ok(())
2899}
2900
2901/// Rejoin the channel notes after a connection blip
2902async fn rejoin_channel_buffers(
2903    request: proto::RejoinChannelBuffers,
2904    response: Response<proto::RejoinChannelBuffers>,
2905    session: Session,
2906) -> Result<()> {
2907    let db = session.db().await;
2908    let buffers = db
2909        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2910        .await?;
2911
2912    for rejoined_buffer in &buffers {
2913        let collaborators_to_notify = rejoined_buffer
2914            .buffer
2915            .collaborators
2916            .iter()
2917            .filter_map(|c| Some(c.peer_id?.into()));
2918        channel_buffer_updated(
2919            session.connection_id,
2920            collaborators_to_notify,
2921            &proto::UpdateChannelBufferCollaborators {
2922                channel_id: rejoined_buffer.buffer.channel_id,
2923                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2924            },
2925            &session.peer,
2926        );
2927    }
2928
2929    response.send(proto::RejoinChannelBuffersResponse {
2930        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2931    })?;
2932
2933    Ok(())
2934}
2935
2936/// Stop editing the channel notes
2937async fn leave_channel_buffer(
2938    request: proto::LeaveChannelBuffer,
2939    response: Response<proto::LeaveChannelBuffer>,
2940    session: Session,
2941) -> Result<()> {
2942    let db = session.db().await;
2943    let channel_id = ChannelId::from_proto(request.channel_id);
2944
2945    let left_buffer = db
2946        .leave_channel_buffer(channel_id, session.connection_id)
2947        .await?;
2948
2949    response.send(Ack {})?;
2950
2951    channel_buffer_updated(
2952        session.connection_id,
2953        left_buffer.connections,
2954        &proto::UpdateChannelBufferCollaborators {
2955            channel_id: channel_id.to_proto(),
2956            collaborators: left_buffer.collaborators,
2957        },
2958        &session.peer,
2959    );
2960
2961    Ok(())
2962}
2963
2964fn channel_buffer_updated<T: EnvelopedMessage>(
2965    sender_id: ConnectionId,
2966    collaborators: impl IntoIterator<Item = ConnectionId>,
2967    message: &T,
2968    peer: &Peer,
2969) {
2970    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2971        peer.send(peer_id.into(), message.clone())
2972    });
2973}
2974
2975fn send_notifications(
2976    connection_pool: &ConnectionPool,
2977    peer: &Peer,
2978    notifications: db::NotificationBatch,
2979) {
2980    for (user_id, notification) in notifications {
2981        for connection_id in connection_pool.user_connection_ids(user_id) {
2982            if let Err(error) = peer.send(
2983                connection_id,
2984                proto::AddNotification {
2985                    notification: Some(notification.clone()),
2986                },
2987            ) {
2988                tracing::error!(
2989                    "failed to send notification to {:?} {}",
2990                    connection_id,
2991                    error
2992                );
2993            }
2994        }
2995    }
2996}
2997
2998/// Send a message to the channel
2999async fn send_channel_message(
3000    request: proto::SendChannelMessage,
3001    response: Response<proto::SendChannelMessage>,
3002    session: Session,
3003) -> Result<()> {
3004    // Validate the message body.
3005    let body = request.body.trim().to_string();
3006    if body.len() > MAX_MESSAGE_LEN {
3007        return Err(anyhow!("message is too long"))?;
3008    }
3009    if body.is_empty() {
3010        return Err(anyhow!("message can't be blank"))?;
3011    }
3012
3013    // TODO: adjust mentions if body is trimmed
3014
3015    let timestamp = OffsetDateTime::now_utc();
3016    let nonce = request
3017        .nonce
3018        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3019
3020    let channel_id = ChannelId::from_proto(request.channel_id);
3021    let CreatedChannelMessage {
3022        message_id,
3023        participant_connection_ids,
3024        channel_members,
3025        notifications,
3026    } = session
3027        .db()
3028        .await
3029        .create_channel_message(
3030            channel_id,
3031            session.user_id,
3032            &body,
3033            &request.mentions,
3034            timestamp,
3035            nonce.clone().into(),
3036            match request.reply_to_message_id {
3037                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3038                None => None,
3039            },
3040        )
3041        .await?;
3042    let message = proto::ChannelMessage {
3043        sender_id: session.user_id.to_proto(),
3044        id: message_id.to_proto(),
3045        body,
3046        mentions: request.mentions,
3047        timestamp: timestamp.unix_timestamp() as u64,
3048        nonce: Some(nonce),
3049        reply_to_message_id: request.reply_to_message_id,
3050    };
3051    broadcast(
3052        Some(session.connection_id),
3053        participant_connection_ids,
3054        |connection| {
3055            session.peer.send(
3056                connection,
3057                proto::ChannelMessageSent {
3058                    channel_id: channel_id.to_proto(),
3059                    message: Some(message.clone()),
3060                },
3061            )
3062        },
3063    );
3064    response.send(proto::SendChannelMessageResponse {
3065        message: Some(message),
3066    })?;
3067
3068    let pool = &*session.connection_pool().await;
3069    broadcast(
3070        None,
3071        channel_members
3072            .iter()
3073            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3074        |peer_id| {
3075            session.peer.send(
3076                peer_id.into(),
3077                proto::UpdateChannels {
3078                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3079                        channel_id: channel_id.to_proto(),
3080                        message_id: message_id.to_proto(),
3081                    }],
3082                    ..Default::default()
3083                },
3084            )
3085        },
3086    );
3087    send_notifications(pool, &session.peer, notifications);
3088
3089    Ok(())
3090}
3091
3092/// Delete a channel message
3093async fn remove_channel_message(
3094    request: proto::RemoveChannelMessage,
3095    response: Response<proto::RemoveChannelMessage>,
3096    session: Session,
3097) -> Result<()> {
3098    let channel_id = ChannelId::from_proto(request.channel_id);
3099    let message_id = MessageId::from_proto(request.message_id);
3100    let connection_ids = session
3101        .db()
3102        .await
3103        .remove_channel_message(channel_id, message_id, session.user_id)
3104        .await?;
3105    broadcast(Some(session.connection_id), connection_ids, |connection| {
3106        session.peer.send(connection, request.clone())
3107    });
3108    response.send(proto::Ack {})?;
3109    Ok(())
3110}
3111
3112/// Mark a channel message as read
3113async fn acknowledge_channel_message(
3114    request: proto::AckChannelMessage,
3115    session: Session,
3116) -> Result<()> {
3117    let channel_id = ChannelId::from_proto(request.channel_id);
3118    let message_id = MessageId::from_proto(request.message_id);
3119    let notifications = session
3120        .db()
3121        .await
3122        .observe_channel_message(channel_id, session.user_id, message_id)
3123        .await?;
3124    send_notifications(
3125        &*session.connection_pool().await,
3126        &session.peer,
3127        notifications,
3128    );
3129    Ok(())
3130}
3131
3132/// Mark a buffer version as synced
3133async fn acknowledge_buffer_version(
3134    request: proto::AckBufferOperation,
3135    session: Session,
3136) -> Result<()> {
3137    let buffer_id = BufferId::from_proto(request.buffer_id);
3138    session
3139        .db()
3140        .await
3141        .observe_buffer_version(
3142            buffer_id,
3143            session.user_id,
3144            request.epoch as i32,
3145            &request.version,
3146        )
3147        .await?;
3148    Ok(())
3149}
3150
3151/// Start receiving chat updates for a channel
3152async fn join_channel_chat(
3153    request: proto::JoinChannelChat,
3154    response: Response<proto::JoinChannelChat>,
3155    session: Session,
3156) -> Result<()> {
3157    let channel_id = ChannelId::from_proto(request.channel_id);
3158
3159    let db = session.db().await;
3160    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3161        .await?;
3162    let messages = db
3163        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3164        .await?;
3165    response.send(proto::JoinChannelChatResponse {
3166        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3167        messages,
3168    })?;
3169    Ok(())
3170}
3171
3172/// Stop receiving chat updates for a channel
3173async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3174    let channel_id = ChannelId::from_proto(request.channel_id);
3175    session
3176        .db()
3177        .await
3178        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3179        .await?;
3180    Ok(())
3181}
3182
3183/// Retrieve the chat history for a channel
3184async fn get_channel_messages(
3185    request: proto::GetChannelMessages,
3186    response: Response<proto::GetChannelMessages>,
3187    session: Session,
3188) -> Result<()> {
3189    let channel_id = ChannelId::from_proto(request.channel_id);
3190    let messages = session
3191        .db()
3192        .await
3193        .get_channel_messages(
3194            channel_id,
3195            session.user_id,
3196            MESSAGE_COUNT_PER_PAGE,
3197            Some(MessageId::from_proto(request.before_message_id)),
3198        )
3199        .await?;
3200    response.send(proto::GetChannelMessagesResponse {
3201        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3202        messages,
3203    })?;
3204    Ok(())
3205}
3206
3207/// Retrieve specific chat messages
3208async fn get_channel_messages_by_id(
3209    request: proto::GetChannelMessagesById,
3210    response: Response<proto::GetChannelMessagesById>,
3211    session: Session,
3212) -> Result<()> {
3213    let message_ids = request
3214        .message_ids
3215        .iter()
3216        .map(|id| MessageId::from_proto(*id))
3217        .collect::<Vec<_>>();
3218    let messages = session
3219        .db()
3220        .await
3221        .get_channel_messages_by_id(session.user_id, &message_ids)
3222        .await?;
3223    response.send(proto::GetChannelMessagesResponse {
3224        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3225        messages,
3226    })?;
3227    Ok(())
3228}
3229
3230/// Retrieve the current users notifications
3231async fn get_notifications(
3232    request: proto::GetNotifications,
3233    response: Response<proto::GetNotifications>,
3234    session: Session,
3235) -> Result<()> {
3236    let notifications = session
3237        .db()
3238        .await
3239        .get_notifications(
3240            session.user_id,
3241            NOTIFICATION_COUNT_PER_PAGE,
3242            request
3243                .before_id
3244                .map(|id| db::NotificationId::from_proto(id)),
3245        )
3246        .await?;
3247    response.send(proto::GetNotificationsResponse {
3248        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3249        notifications,
3250    })?;
3251    Ok(())
3252}
3253
3254/// Mark notifications as read
3255async fn mark_notification_as_read(
3256    request: proto::MarkNotificationRead,
3257    response: Response<proto::MarkNotificationRead>,
3258    session: Session,
3259) -> Result<()> {
3260    let database = &session.db().await;
3261    let notifications = database
3262        .mark_notification_as_read_by_id(
3263            session.user_id,
3264            NotificationId::from_proto(request.notification_id),
3265        )
3266        .await?;
3267    send_notifications(
3268        &*session.connection_pool().await,
3269        &session.peer,
3270        notifications,
3271    );
3272    response.send(proto::Ack {})?;
3273    Ok(())
3274}
3275
3276/// Get the current users information
3277async fn get_private_user_info(
3278    _request: proto::GetPrivateUserInfo,
3279    response: Response<proto::GetPrivateUserInfo>,
3280    session: Session,
3281) -> Result<()> {
3282    let db = session.db().await;
3283
3284    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3285    let user = db
3286        .get_user_by_id(session.user_id)
3287        .await?
3288        .ok_or_else(|| anyhow!("user not found"))?;
3289    let flags = db.get_user_flags(session.user_id).await?;
3290
3291    response.send(proto::GetPrivateUserInfoResponse {
3292        metrics_id,
3293        staff: user.admin,
3294        flags,
3295    })?;
3296    Ok(())
3297}
3298
3299fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3300    match message {
3301        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3302        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3303        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3304        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3305        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3306            code: frame.code.into(),
3307            reason: frame.reason,
3308        })),
3309    }
3310}
3311
3312fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3313    match message {
3314        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3315        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3316        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3317        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3318        AxumMessage::Close(frame) => {
3319            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3320                code: frame.code.into(),
3321                reason: frame.reason,
3322            }))
3323        }
3324    }
3325}
3326
3327fn notify_membership_updated(
3328    connection_pool: &ConnectionPool,
3329    result: MembershipUpdated,
3330    user_id: UserId,
3331    peer: &Peer,
3332) {
3333    let user_channels_update = proto::UpdateUserChannels {
3334        channel_memberships: result
3335            .new_channels
3336            .channel_memberships
3337            .iter()
3338            .map(|cm| proto::ChannelMembership {
3339                channel_id: cm.channel_id.to_proto(),
3340                role: cm.role.into(),
3341            })
3342            .collect(),
3343        ..Default::default()
3344    };
3345    let mut update = build_channels_update(result.new_channels, vec![]);
3346    update.delete_channels = result
3347        .removed_channels
3348        .into_iter()
3349        .map(|id| id.to_proto())
3350        .collect();
3351    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3352
3353    for connection_id in connection_pool.user_connection_ids(user_id) {
3354        peer.send(connection_id, user_channels_update.clone())
3355            .trace_err();
3356        peer.send(connection_id, update.clone()).trace_err();
3357    }
3358}
3359
3360fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3361    proto::UpdateUserChannels {
3362        channel_memberships: channels
3363            .channel_memberships
3364            .iter()
3365            .map(|m| proto::ChannelMembership {
3366                channel_id: m.channel_id.to_proto(),
3367                role: m.role.into(),
3368            })
3369            .collect(),
3370        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3371        observed_channel_message_id: channels.observed_channel_messages.clone(),
3372        ..Default::default()
3373    }
3374}
3375
3376fn build_channels_update(
3377    channels: ChannelsForUser,
3378    channel_invites: Vec<db::Channel>,
3379) -> proto::UpdateChannels {
3380    let mut update = proto::UpdateChannels::default();
3381
3382    for channel in channels.channels {
3383        update.channels.push(channel.to_proto());
3384    }
3385
3386    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3387    update.latest_channel_message_ids = channels.latest_channel_messages;
3388
3389    for (channel_id, participants) in channels.channel_participants {
3390        update
3391            .channel_participants
3392            .push(proto::ChannelParticipants {
3393                channel_id: channel_id.to_proto(),
3394                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3395            });
3396    }
3397
3398    for channel in channel_invites {
3399        update.channel_invitations.push(channel.to_proto());
3400    }
3401    for project in channels.hosted_projects {
3402        update.hosted_projects.push(project);
3403    }
3404
3405    update
3406}
3407
3408fn build_initial_contacts_update(
3409    contacts: Vec<db::Contact>,
3410    pool: &ConnectionPool,
3411) -> proto::UpdateContacts {
3412    let mut update = proto::UpdateContacts::default();
3413
3414    for contact in contacts {
3415        match contact {
3416            db::Contact::Accepted { user_id, busy } => {
3417                update.contacts.push(contact_for_user(user_id, busy, &pool));
3418            }
3419            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3420            db::Contact::Incoming { user_id } => {
3421                update
3422                    .incoming_requests
3423                    .push(proto::IncomingContactRequest {
3424                        requester_id: user_id.to_proto(),
3425                    })
3426            }
3427        }
3428    }
3429
3430    update
3431}
3432
3433fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3434    proto::Contact {
3435        user_id: user_id.to_proto(),
3436        online: pool.is_user_online(user_id),
3437        busy,
3438    }
3439}
3440
3441fn room_updated(room: &proto::Room, peer: &Peer) {
3442    broadcast(
3443        None,
3444        room.participants
3445            .iter()
3446            .filter_map(|participant| Some(participant.peer_id?.into())),
3447        |peer_id| {
3448            peer.send(
3449                peer_id.into(),
3450                proto::RoomUpdated {
3451                    room: Some(room.clone()),
3452                },
3453            )
3454        },
3455    );
3456}
3457
3458fn channel_updated(
3459    channel_id: ChannelId,
3460    room: &proto::Room,
3461    channel_members: &[UserId],
3462    peer: &Peer,
3463    pool: &ConnectionPool,
3464) {
3465    let participants = room
3466        .participants
3467        .iter()
3468        .map(|p| p.user_id)
3469        .collect::<Vec<_>>();
3470
3471    broadcast(
3472        None,
3473        channel_members
3474            .iter()
3475            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3476        |peer_id| {
3477            peer.send(
3478                peer_id.into(),
3479                proto::UpdateChannels {
3480                    channel_participants: vec![proto::ChannelParticipants {
3481                        channel_id: channel_id.to_proto(),
3482                        participant_user_ids: participants.clone(),
3483                    }],
3484                    ..Default::default()
3485                },
3486            )
3487        },
3488    );
3489}
3490
3491async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3492    let db = session.db().await;
3493
3494    let contacts = db.get_contacts(user_id).await?;
3495    let busy = db.is_user_busy(user_id).await?;
3496
3497    let pool = session.connection_pool().await;
3498    let updated_contact = contact_for_user(user_id, busy, &pool);
3499    for contact in contacts {
3500        if let db::Contact::Accepted {
3501            user_id: contact_user_id,
3502            ..
3503        } = contact
3504        {
3505            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3506                session
3507                    .peer
3508                    .send(
3509                        contact_conn_id,
3510                        proto::UpdateContacts {
3511                            contacts: vec![updated_contact.clone()],
3512                            remove_contacts: Default::default(),
3513                            incoming_requests: Default::default(),
3514                            remove_incoming_requests: Default::default(),
3515                            outgoing_requests: Default::default(),
3516                            remove_outgoing_requests: Default::default(),
3517                        },
3518                    )
3519                    .trace_err();
3520            }
3521        }
3522    }
3523    Ok(())
3524}
3525
3526async fn leave_room_for_session(session: &Session) -> Result<()> {
3527    let mut contacts_to_update = HashSet::default();
3528
3529    let room_id;
3530    let canceled_calls_to_user_ids;
3531    let live_kit_room;
3532    let delete_live_kit_room;
3533    let room;
3534    let channel_members;
3535    let channel_id;
3536
3537    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3538        contacts_to_update.insert(session.user_id);
3539
3540        for project in left_room.left_projects.values() {
3541            project_left(project, session);
3542        }
3543
3544        room_id = RoomId::from_proto(left_room.room.id);
3545        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3546        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3547        delete_live_kit_room = left_room.deleted;
3548        room = mem::take(&mut left_room.room);
3549        channel_members = mem::take(&mut left_room.channel_members);
3550        channel_id = left_room.channel_id;
3551
3552        room_updated(&room, &session.peer);
3553    } else {
3554        return Ok(());
3555    }
3556
3557    if let Some(channel_id) = channel_id {
3558        channel_updated(
3559            channel_id,
3560            &room,
3561            &channel_members,
3562            &session.peer,
3563            &*session.connection_pool().await,
3564        );
3565    }
3566
3567    {
3568        let pool = session.connection_pool().await;
3569        for canceled_user_id in canceled_calls_to_user_ids {
3570            for connection_id in pool.user_connection_ids(canceled_user_id) {
3571                session
3572                    .peer
3573                    .send(
3574                        connection_id,
3575                        proto::CallCanceled {
3576                            room_id: room_id.to_proto(),
3577                        },
3578                    )
3579                    .trace_err();
3580            }
3581            contacts_to_update.insert(canceled_user_id);
3582        }
3583    }
3584
3585    for contact_user_id in contacts_to_update {
3586        update_user_contacts(contact_user_id, &session).await?;
3587    }
3588
3589    if let Some(live_kit) = session.live_kit_client.as_ref() {
3590        live_kit
3591            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3592            .await
3593            .trace_err();
3594
3595        if delete_live_kit_room {
3596            live_kit.delete_room(live_kit_room).await.trace_err();
3597        }
3598    }
3599
3600    Ok(())
3601}
3602
3603async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3604    let left_channel_buffers = session
3605        .db()
3606        .await
3607        .leave_channel_buffers(session.connection_id)
3608        .await?;
3609
3610    for left_buffer in left_channel_buffers {
3611        channel_buffer_updated(
3612            session.connection_id,
3613            left_buffer.connections,
3614            &proto::UpdateChannelBufferCollaborators {
3615                channel_id: left_buffer.channel_id.to_proto(),
3616                collaborators: left_buffer.collaborators,
3617            },
3618            &session.peer,
3619        );
3620    }
3621
3622    Ok(())
3623}
3624
3625fn project_left(project: &db::LeftProject, session: &Session) {
3626    for connection_id in &project.connection_ids {
3627        if project.host_user_id == session.user_id {
3628            session
3629                .peer
3630                .send(
3631                    *connection_id,
3632                    proto::UnshareProject {
3633                        project_id: project.id.to_proto(),
3634                    },
3635                )
3636                .trace_err();
3637        } else {
3638            session
3639                .peer
3640                .send(
3641                    *connection_id,
3642                    proto::RemoveProjectCollaborator {
3643                        project_id: project.id.to_proto(),
3644                        peer_id: Some(session.connection_id.into()),
3645                    },
3646                )
3647                .trace_err();
3648        }
3649    }
3650}
3651
3652pub trait ResultExt {
3653    type Ok;
3654
3655    fn trace_err(self) -> Option<Self::Ok>;
3656}
3657
3658impl<T, E> ResultExt for Result<T, E>
3659where
3660    E: std::fmt::Debug,
3661{
3662    type Ok = T;
3663
3664    fn trace_err(self) -> Option<T> {
3665        match self {
3666            Ok(value) => Some(value),
3667            Err(error) => {
3668                tracing::error!("{:?}", error);
3669                None
3670            }
3671        }
3672    }
3673}