rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::{ConnectionPool, ZedVersion};
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  42        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  43    },
  44    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc, OnceLock,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{field, info_span, instrument, Instrument};
  66use util::SemanticVersion;
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  74
  75type MessageHandler =
  76    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  77
  78struct Response<R> {
  79    peer: Arc<Peer>,
  80    receipt: Receipt<R>,
  81    responded: Arc<AtomicBool>,
  82}
  83
  84impl<R: RequestMessage> Response<R> {
  85    fn send(self, payload: R::Response) -> Result<()> {
  86        self.responded.store(true, SeqCst);
  87        self.peer.respond(self.receipt, payload)?;
  88        Ok(())
  89    }
  90}
  91
  92#[derive(Clone)]
  93struct Session {
  94    user_id: UserId,
  95    connection_id: ConnectionId,
  96    db: Arc<tokio::sync::Mutex<DbHandle>>,
  97    peer: Arc<Peer>,
  98    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  99    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 100    _executor: Executor,
 101}
 102
 103impl Session {
 104    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        let guard = self.db.lock().await;
 108        #[cfg(test)]
 109        tokio::task::yield_now().await;
 110        guard
 111    }
 112
 113    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 114        #[cfg(test)]
 115        tokio::task::yield_now().await;
 116        let guard = self.connection_pool.lock();
 117        ConnectionPoolGuard {
 118            guard,
 119            _not_send: PhantomData,
 120        }
 121    }
 122}
 123
 124impl fmt::Debug for Session {
 125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 126        f.debug_struct("Session")
 127            .field("user_id", &self.user_id)
 128            .field("connection_id", &self.connection_id)
 129            .finish()
 130    }
 131}
 132
 133struct DbHandle(Arc<Database>);
 134
 135impl Deref for DbHandle {
 136    type Target = Database;
 137
 138    fn deref(&self) -> &Self::Target {
 139        self.0.as_ref()
 140    }
 141}
 142
 143pub struct Server {
 144    id: parking_lot::Mutex<ServerId>,
 145    peer: Arc<Peer>,
 146    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 147    app_state: Arc<AppState>,
 148    executor: Executor,
 149    handlers: HashMap<TypeId, MessageHandler>,
 150    teardown: watch::Sender<()>,
 151}
 152
 153pub(crate) struct ConnectionPoolGuard<'a> {
 154    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 155    _not_send: PhantomData<Rc<()>>,
 156}
 157
 158#[derive(Serialize)]
 159pub struct ServerSnapshot<'a> {
 160    peer: &'a Peer,
 161    #[serde(serialize_with = "serialize_deref")]
 162    connection_pool: ConnectionPoolGuard<'a>,
 163}
 164
 165pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 166where
 167    S: Serializer,
 168    T: Deref<Target = U>,
 169    U: Serialize,
 170{
 171    Serialize::serialize(value.deref(), serializer)
 172}
 173
 174impl Server {
 175    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 176        let mut server = Self {
 177            id: parking_lot::Mutex::new(id),
 178            peer: Peer::new(id.0 as u32),
 179            app_state,
 180            executor,
 181            connection_pool: Default::default(),
 182            handlers: Default::default(),
 183            teardown: watch::channel(()).0,
 184        };
 185
 186        server
 187            .add_request_handler(ping)
 188            .add_request_handler(create_room)
 189            .add_request_handler(join_room)
 190            .add_request_handler(rejoin_room)
 191            .add_request_handler(leave_room)
 192            .add_request_handler(set_room_participant_role)
 193            .add_request_handler(call)
 194            .add_request_handler(cancel_call)
 195            .add_message_handler(decline_call)
 196            .add_request_handler(update_participant_location)
 197            .add_request_handler(share_project)
 198            .add_message_handler(unshare_project)
 199            .add_request_handler(join_project)
 200            .add_message_handler(leave_project)
 201            .add_request_handler(update_project)
 202            .add_request_handler(update_worktree)
 203            .add_message_handler(start_language_server)
 204            .add_message_handler(update_language_server)
 205            .add_message_handler(update_diagnostic_summary)
 206            .add_message_handler(update_worktree_settings)
 207            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 208            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 209            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 210            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 211            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 212            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 213            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 214            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 215            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 216            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 217            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 218            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 219            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 220            .add_request_handler(
 221                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 222            )
 223            .add_request_handler(
 224                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 225            )
 226            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 227            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 228            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 229            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 230            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 231            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 232            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 233            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 234            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 235            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 236            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 237            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 238            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 239            .add_message_handler(create_buffer_for_peer)
 240            .add_request_handler(update_buffer)
 241            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 242            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 243            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 244            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 245            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 246            .add_request_handler(get_users)
 247            .add_request_handler(fuzzy_search_users)
 248            .add_request_handler(request_contact)
 249            .add_request_handler(remove_contact)
 250            .add_request_handler(respond_to_contact_request)
 251            .add_request_handler(create_channel)
 252            .add_request_handler(delete_channel)
 253            .add_request_handler(invite_channel_member)
 254            .add_request_handler(remove_channel_member)
 255            .add_request_handler(set_channel_member_role)
 256            .add_request_handler(set_channel_visibility)
 257            .add_request_handler(rename_channel)
 258            .add_request_handler(join_channel_buffer)
 259            .add_request_handler(leave_channel_buffer)
 260            .add_message_handler(update_channel_buffer)
 261            .add_request_handler(rejoin_channel_buffers)
 262            .add_request_handler(get_channel_members)
 263            .add_request_handler(respond_to_channel_invite)
 264            .add_request_handler(join_channel)
 265            .add_request_handler(join_channel_chat)
 266            .add_message_handler(leave_channel_chat)
 267            .add_request_handler(send_channel_message)
 268            .add_request_handler(remove_channel_message)
 269            .add_request_handler(get_channel_messages)
 270            .add_request_handler(get_channel_messages_by_id)
 271            .add_request_handler(get_notifications)
 272            .add_request_handler(mark_notification_as_read)
 273            .add_request_handler(move_channel)
 274            .add_request_handler(follow)
 275            .add_message_handler(unfollow)
 276            .add_message_handler(update_followers)
 277            .add_request_handler(get_private_user_info)
 278            .add_message_handler(acknowledge_channel_message)
 279            .add_message_handler(acknowledge_buffer_version);
 280
 281        Arc::new(server)
 282    }
 283
 284    pub async fn start(&self) -> Result<()> {
 285        let server_id = *self.id.lock();
 286        let app_state = self.app_state.clone();
 287        let peer = self.peer.clone();
 288        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 289        let pool = self.connection_pool.clone();
 290        let live_kit_client = self.app_state.live_kit_client.clone();
 291
 292        let span = info_span!("start server");
 293        self.executor.spawn_detached(
 294            async move {
 295                tracing::info!("waiting for cleanup timeout");
 296                timeout.await;
 297                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 298                if let Some((room_ids, channel_ids)) = app_state
 299                    .db
 300                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 301                    .await
 302                    .trace_err()
 303                {
 304                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 305                    tracing::info!(
 306                        stale_channel_buffer_count = channel_ids.len(),
 307                        "retrieved stale channel buffers"
 308                    );
 309
 310                    for channel_id in channel_ids {
 311                        if let Some(refreshed_channel_buffer) = app_state
 312                            .db
 313                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 314                            .await
 315                            .trace_err()
 316                        {
 317                            for connection_id in refreshed_channel_buffer.connection_ids {
 318                                peer.send(
 319                                    connection_id,
 320                                    proto::UpdateChannelBufferCollaborators {
 321                                        channel_id: channel_id.to_proto(),
 322                                        collaborators: refreshed_channel_buffer
 323                                            .collaborators
 324                                            .clone(),
 325                                    },
 326                                )
 327                                .trace_err();
 328                            }
 329                        }
 330                    }
 331
 332                    for room_id in room_ids {
 333                        let mut contacts_to_update = HashSet::default();
 334                        let mut canceled_calls_to_user_ids = Vec::new();
 335                        let mut live_kit_room = String::new();
 336                        let mut delete_live_kit_room = false;
 337
 338                        if let Some(mut refreshed_room) = app_state
 339                            .db
 340                            .clear_stale_room_participants(room_id, server_id)
 341                            .await
 342                            .trace_err()
 343                        {
 344                            tracing::info!(
 345                                room_id = room_id.0,
 346                                new_participant_count = refreshed_room.room.participants.len(),
 347                                "refreshed room"
 348                            );
 349                            room_updated(&refreshed_room.room, &peer);
 350                            if let Some(channel_id) = refreshed_room.channel_id {
 351                                channel_updated(
 352                                    channel_id,
 353                                    &refreshed_room.room,
 354                                    &refreshed_room.channel_members,
 355                                    &peer,
 356                                    &*pool.lock(),
 357                                );
 358                            }
 359                            contacts_to_update
 360                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 361                            contacts_to_update
 362                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 363                            canceled_calls_to_user_ids =
 364                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 365                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 366                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 367                        }
 368
 369                        {
 370                            let pool = pool.lock();
 371                            for canceled_user_id in canceled_calls_to_user_ids {
 372                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 373                                    peer.send(
 374                                        connection_id,
 375                                        proto::CallCanceled {
 376                                            room_id: room_id.to_proto(),
 377                                        },
 378                                    )
 379                                    .trace_err();
 380                                }
 381                            }
 382                        }
 383
 384                        for user_id in contacts_to_update {
 385                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 386                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 387                            if let Some((busy, contacts)) = busy.zip(contacts) {
 388                                let pool = pool.lock();
 389                                let updated_contact = contact_for_user(user_id, busy, &pool);
 390                                for contact in contacts {
 391                                    if let db::Contact::Accepted {
 392                                        user_id: contact_user_id,
 393                                        ..
 394                                    } = contact
 395                                    {
 396                                        for contact_conn_id in
 397                                            pool.user_connection_ids(contact_user_id)
 398                                        {
 399                                            peer.send(
 400                                                contact_conn_id,
 401                                                proto::UpdateContacts {
 402                                                    contacts: vec![updated_contact.clone()],
 403                                                    remove_contacts: Default::default(),
 404                                                    incoming_requests: Default::default(),
 405                                                    remove_incoming_requests: Default::default(),
 406                                                    outgoing_requests: Default::default(),
 407                                                    remove_outgoing_requests: Default::default(),
 408                                                },
 409                                            )
 410                                            .trace_err();
 411                                        }
 412                                    }
 413                                }
 414                            }
 415                        }
 416
 417                        if let Some(live_kit) = live_kit_client.as_ref() {
 418                            if delete_live_kit_room {
 419                                live_kit.delete_room(live_kit_room).await.trace_err();
 420                            }
 421                        }
 422                    }
 423                }
 424
 425                app_state
 426                    .db
 427                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 428                    .await
 429                    .trace_err();
 430            }
 431            .instrument(span),
 432        );
 433        Ok(())
 434    }
 435
 436    pub fn teardown(&self) {
 437        self.peer.teardown();
 438        self.connection_pool.lock().reset();
 439        let _ = self.teardown.send(());
 440    }
 441
 442    #[cfg(test)]
 443    pub fn reset(&self, id: ServerId) {
 444        self.teardown();
 445        *self.id.lock() = id;
 446        self.peer.reset(id.0 as u32);
 447    }
 448
 449    #[cfg(test)]
 450    pub fn id(&self) -> ServerId {
 451        *self.id.lock()
 452    }
 453
 454    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 455    where
 456        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 457        Fut: 'static + Send + Future<Output = Result<()>>,
 458        M: EnvelopedMessage,
 459    {
 460        let prev_handler = self.handlers.insert(
 461            TypeId::of::<M>(),
 462            Box::new(move |envelope, session| {
 463                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 464                let span = info_span!(
 465                    "handle message",
 466                    payload_type = envelope.payload_type_name()
 467                );
 468                span.in_scope(|| {
 469                    tracing::info!(
 470                        payload_type = envelope.payload_type_name(),
 471                        "message received"
 472                    );
 473                });
 474                let start_time = Instant::now();
 475                let future = (handler)(*envelope, session);
 476                async move {
 477                    let result = future.await;
 478                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 479                    match result {
 480                        Err(error) => {
 481                            tracing::error!(%error, ?duration_ms, "error handling message")
 482                        }
 483                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 484                    }
 485                }
 486                .instrument(span)
 487                .boxed()
 488            }),
 489        );
 490        if prev_handler.is_some() {
 491            panic!("registered a handler for the same message twice");
 492        }
 493        self
 494    }
 495
 496    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 497    where
 498        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 499        Fut: 'static + Send + Future<Output = Result<()>>,
 500        M: EnvelopedMessage,
 501    {
 502        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 503        self
 504    }
 505
 506    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 507    where
 508        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 509        Fut: Send + Future<Output = Result<()>>,
 510        M: RequestMessage,
 511    {
 512        let handler = Arc::new(handler);
 513        self.add_handler(move |envelope, session| {
 514            let receipt = envelope.receipt();
 515            let handler = handler.clone();
 516            async move {
 517                let peer = session.peer.clone();
 518                let responded = Arc::new(AtomicBool::default());
 519                let response = Response {
 520                    peer: peer.clone(),
 521                    responded: responded.clone(),
 522                    receipt,
 523                };
 524                match (handler)(envelope.payload, response, session).await {
 525                    Ok(()) => {
 526                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 527                            Ok(())
 528                        } else {
 529                            Err(anyhow!("handler did not send a response"))?
 530                        }
 531                    }
 532                    Err(error) => {
 533                        let proto_err = match &error {
 534                            Error::Internal(err) => err.to_proto(),
 535                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 536                        };
 537                        peer.respond_with_error(receipt, proto_err)?;
 538                        Err(error)
 539                    }
 540                }
 541            }
 542        })
 543    }
 544
 545    pub fn handle_connection(
 546        self: &Arc<Self>,
 547        connection: Connection,
 548        address: String,
 549        user: User,
 550        zed_version: ZedVersion,
 551        impersonator: Option<User>,
 552        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 553        executor: Executor,
 554    ) -> impl Future<Output = Result<()>> {
 555        let this = self.clone();
 556        let user_id = user.id;
 557        let login = user.github_login;
 558        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 559        if let Some(impersonator) = impersonator {
 560            span.record("impersonator", &impersonator.github_login);
 561        }
 562        let mut teardown = self.teardown.subscribe();
 563        async move {
 564            let (connection_id, handle_io, mut incoming_rx) = this
 565                .peer
 566                .add_connection(connection, {
 567                    let executor = executor.clone();
 568                    move |duration| executor.sleep(duration)
 569                });
 570
 571            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 572            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 573            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 574
 575            if let Some(send_connection_id) = send_connection_id.take() {
 576                let _ = send_connection_id.send(connection_id);
 577            }
 578
 579            if !user.connected_once {
 580                this.peer.send(connection_id, proto::ShowContacts {})?;
 581                this.app_state.db.set_user_connected_once(user_id, true).await?;
 582            }
 583
 584            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 585                this.app_state.db.get_contacts(user_id),
 586                this.app_state.db.get_channels_for_user(user_id),
 587                this.app_state.db.get_channel_invites_for_user(user_id),
 588            ).await?;
 589
 590            {
 591                let mut pool = this.connection_pool.lock();
 592                pool.add_connection(connection_id, user_id, user.admin, zed_version);
 593                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 594                this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
 595                this.peer.send(connection_id, build_channels_update(
 596                    channels_for_user,
 597                    channel_invites
 598                ))?;
 599            }
 600
 601            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 602                this.peer.send(connection_id, incoming_call)?;
 603            }
 604
 605            let session = Session {
 606                user_id,
 607                connection_id,
 608                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 609                peer: this.peer.clone(),
 610                connection_pool: this.connection_pool.clone(),
 611                live_kit_client: this.app_state.live_kit_client.clone(),
 612                _executor: executor.clone()
 613            };
 614            update_user_contacts(user_id, &session).await?;
 615
 616            let handle_io = handle_io.fuse();
 617            futures::pin_mut!(handle_io);
 618
 619            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 620            // This prevents deadlocks when e.g., client A performs a request to client B and
 621            // client B performs a request to client A. If both clients stop processing further
 622            // messages until their respective request completes, they won't have a chance to
 623            // respond to the other client's request and cause a deadlock.
 624            //
 625            // This arrangement ensures we will attempt to process earlier messages first, but fall
 626            // back to processing messages arrived later in the spirit of making progress.
 627            let mut foreground_message_handlers = FuturesUnordered::new();
 628            let concurrent_handlers = Arc::new(Semaphore::new(256));
 629            loop {
 630                let next_message = async {
 631                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 632                    let message = incoming_rx.next().await;
 633                    (permit, message)
 634                }.fuse();
 635                futures::pin_mut!(next_message);
 636                futures::select_biased! {
 637                    _ = teardown.changed().fuse() => return Ok(()),
 638                    result = handle_io => {
 639                        if let Err(error) = result {
 640                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 641                        }
 642                        break;
 643                    }
 644                    _ = foreground_message_handlers.next() => {}
 645                    next_message = next_message => {
 646                        let (permit, message) = next_message;
 647                        if let Some(message) = message {
 648                            let type_name = message.payload_type_name();
 649                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 650                            let span_enter = span.enter();
 651                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 652                                let is_background = message.is_background();
 653                                let handle_message = (handler)(message, session.clone());
 654                                drop(span_enter);
 655
 656                                let handle_message = async move {
 657                                    handle_message.await;
 658                                    drop(permit);
 659                                }.instrument(span);
 660                                if is_background {
 661                                    executor.spawn_detached(handle_message);
 662                                } else {
 663                                    foreground_message_handlers.push(handle_message);
 664                                }
 665                            } else {
 666                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 667                            }
 668                        } else {
 669                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 670                            break;
 671                        }
 672                    }
 673                }
 674            }
 675
 676            drop(foreground_message_handlers);
 677            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 678            if let Err(error) = connection_lost(session, teardown, executor).await {
 679                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 680            }
 681
 682            Ok(())
 683        }.instrument(span)
 684    }
 685
 686    pub async fn invite_code_redeemed(
 687        self: &Arc<Self>,
 688        inviter_id: UserId,
 689        invitee_id: UserId,
 690    ) -> Result<()> {
 691        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 692            if let Some(code) = &user.invite_code {
 693                let pool = self.connection_pool.lock();
 694                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 695                for connection_id in pool.user_connection_ids(inviter_id) {
 696                    self.peer.send(
 697                        connection_id,
 698                        proto::UpdateContacts {
 699                            contacts: vec![invitee_contact.clone()],
 700                            ..Default::default()
 701                        },
 702                    )?;
 703                    self.peer.send(
 704                        connection_id,
 705                        proto::UpdateInviteInfo {
 706                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 707                            count: user.invite_count as u32,
 708                        },
 709                    )?;
 710                }
 711            }
 712        }
 713        Ok(())
 714    }
 715
 716    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 717        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 718            if let Some(invite_code) = &user.invite_code {
 719                let pool = self.connection_pool.lock();
 720                for connection_id in pool.user_connection_ids(user_id) {
 721                    self.peer.send(
 722                        connection_id,
 723                        proto::UpdateInviteInfo {
 724                            url: format!(
 725                                "{}{}",
 726                                self.app_state.config.invite_link_prefix, invite_code
 727                            ),
 728                            count: user.invite_count as u32,
 729                        },
 730                    )?;
 731                }
 732            }
 733        }
 734        Ok(())
 735    }
 736
 737    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 738        ServerSnapshot {
 739            connection_pool: ConnectionPoolGuard {
 740                guard: self.connection_pool.lock(),
 741                _not_send: PhantomData,
 742            },
 743            peer: &self.peer,
 744        }
 745    }
 746}
 747
 748impl<'a> Deref for ConnectionPoolGuard<'a> {
 749    type Target = ConnectionPool;
 750
 751    fn deref(&self) -> &Self::Target {
 752        &*self.guard
 753    }
 754}
 755
 756impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 757    fn deref_mut(&mut self) -> &mut Self::Target {
 758        &mut *self.guard
 759    }
 760}
 761
 762impl<'a> Drop for ConnectionPoolGuard<'a> {
 763    fn drop(&mut self) {
 764        #[cfg(test)]
 765        self.check_invariants();
 766    }
 767}
 768
 769fn broadcast<F>(
 770    sender_id: Option<ConnectionId>,
 771    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 772    mut f: F,
 773) where
 774    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 775{
 776    for receiver_id in receiver_ids {
 777        if Some(receiver_id) != sender_id {
 778            if let Err(error) = f(receiver_id) {
 779                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 780            }
 781        }
 782    }
 783}
 784
 785pub struct ProtocolVersion(u32);
 786
 787impl Header for ProtocolVersion {
 788    fn name() -> &'static HeaderName {
 789        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 790        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 791    }
 792
 793    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 794    where
 795        Self: Sized,
 796        I: Iterator<Item = &'i axum::http::HeaderValue>,
 797    {
 798        let version = values
 799            .next()
 800            .ok_or_else(axum::headers::Error::invalid)?
 801            .to_str()
 802            .map_err(|_| axum::headers::Error::invalid())?
 803            .parse()
 804            .map_err(|_| axum::headers::Error::invalid())?;
 805        Ok(Self(version))
 806    }
 807
 808    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 809        values.extend([self.0.to_string().parse().unwrap()]);
 810    }
 811}
 812
 813pub struct AppVersionHeader(SemanticVersion);
 814impl Header for AppVersionHeader {
 815    fn name() -> &'static HeaderName {
 816        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 817        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 818    }
 819
 820    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 821    where
 822        Self: Sized,
 823        I: Iterator<Item = &'i axum::http::HeaderValue>,
 824    {
 825        let version = values
 826            .next()
 827            .ok_or_else(axum::headers::Error::invalid)?
 828            .to_str()
 829            .map_err(|_| axum::headers::Error::invalid())?
 830            .parse()
 831            .map_err(|_| axum::headers::Error::invalid())?;
 832        Ok(Self(version))
 833    }
 834
 835    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 836        values.extend([self.0.to_string().parse().unwrap()]);
 837    }
 838}
 839
 840pub fn routes(server: Arc<Server>) -> Router<Body> {
 841    Router::new()
 842        .route("/rpc", get(handle_websocket_request))
 843        .layer(
 844            ServiceBuilder::new()
 845                .layer(Extension(server.app_state.clone()))
 846                .layer(middleware::from_fn(auth::validate_header)),
 847        )
 848        .route("/metrics", get(handle_metrics))
 849        .layer(Extension(server))
 850}
 851
 852pub async fn handle_websocket_request(
 853    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 854    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 855    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 856    Extension(server): Extension<Arc<Server>>,
 857    Extension(user): Extension<User>,
 858    Extension(impersonator): Extension<Impersonator>,
 859    ws: WebSocketUpgrade,
 860) -> axum::response::Response {
 861    if protocol_version != rpc::PROTOCOL_VERSION {
 862        return (
 863            StatusCode::UPGRADE_REQUIRED,
 864            "client must be upgraded".to_string(),
 865        )
 866            .into_response();
 867    }
 868
 869    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 870        return (
 871            StatusCode::UPGRADE_REQUIRED,
 872            "no version header found".to_string(),
 873        )
 874            .into_response();
 875    };
 876
 877    if !version.is_supported() {
 878        return (
 879            StatusCode::UPGRADE_REQUIRED,
 880            "client must be upgraded".to_string(),
 881        )
 882            .into_response();
 883    }
 884
 885    let socket_address = socket_address.to_string();
 886    ws.on_upgrade(move |socket| {
 887        use util::ResultExt;
 888        let socket = socket
 889            .map_ok(to_tungstenite_message)
 890            .err_into()
 891            .with(|message| async move { Ok(to_axum_message(message)) });
 892        let connection = Connection::new(Box::pin(socket));
 893        async move {
 894            server
 895                .handle_connection(
 896                    connection,
 897                    socket_address,
 898                    user,
 899                    version,
 900                    impersonator.0,
 901                    None,
 902                    Executor::Production,
 903                )
 904                .await
 905                .log_err();
 906        }
 907    })
 908}
 909
 910pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 911    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 912    let connections_metric = CONNECTIONS_METRIC
 913        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
 914
 915    let connections = server
 916        .connection_pool
 917        .lock()
 918        .connections()
 919        .filter(|connection| !connection.admin)
 920        .count();
 921    connections_metric.set(connections as _);
 922
 923    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
 924    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
 925        register_int_gauge!(
 926            "shared_projects",
 927            "number of open projects with one or more guests"
 928        )
 929        .unwrap()
 930    });
 931
 932    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 933    shared_projects_metric.set(shared_projects as _);
 934
 935    let encoder = prometheus::TextEncoder::new();
 936    let metric_families = prometheus::gather();
 937    let encoded_metrics = encoder
 938        .encode_to_string(&metric_families)
 939        .map_err(|err| anyhow!("{}", err))?;
 940    Ok(encoded_metrics)
 941}
 942
 943#[instrument(err, skip(executor))]
 944async fn connection_lost(
 945    session: Session,
 946    mut teardown: watch::Receiver<()>,
 947    executor: Executor,
 948) -> Result<()> {
 949    session.peer.disconnect(session.connection_id);
 950    session
 951        .connection_pool()
 952        .await
 953        .remove_connection(session.connection_id)?;
 954
 955    session
 956        .db()
 957        .await
 958        .connection_lost(session.connection_id)
 959        .await
 960        .trace_err();
 961
 962    futures::select_biased! {
 963        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 964            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 965            leave_room_for_session(&session).await.trace_err();
 966            leave_channel_buffers_for_session(&session)
 967                .await
 968                .trace_err();
 969
 970            if !session
 971                .connection_pool()
 972                .await
 973                .is_user_online(session.user_id)
 974            {
 975                let db = session.db().await;
 976                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 977                    room_updated(&room, &session.peer);
 978                }
 979            }
 980
 981            update_user_contacts(session.user_id, &session).await?;
 982        }
 983        _ = teardown.changed().fuse() => {}
 984    }
 985
 986    Ok(())
 987}
 988
 989/// Acknowledges a ping from a client, used to keep the connection alive.
 990async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 991    response.send(proto::Ack {})?;
 992    Ok(())
 993}
 994
 995/// Creates a new room for calling (outside of channels)
 996async fn create_room(
 997    _request: proto::CreateRoom,
 998    response: Response<proto::CreateRoom>,
 999    session: Session,
1000) -> Result<()> {
1001    let live_kit_room = nanoid::nanoid!(30);
1002
1003    let live_kit_connection_info = {
1004        let live_kit_room = live_kit_room.clone();
1005        let live_kit = session.live_kit_client.as_ref();
1006
1007        util::async_maybe!({
1008            let live_kit = live_kit?;
1009
1010            let token = live_kit
1011                .room_token(&live_kit_room, &session.user_id.to_string())
1012                .trace_err()?;
1013
1014            Some(proto::LiveKitConnectionInfo {
1015                server_url: live_kit.url().into(),
1016                token,
1017                can_publish: true,
1018            })
1019        })
1020    }
1021    .await;
1022
1023    let room = session
1024        .db()
1025        .await
1026        .create_room(session.user_id, session.connection_id, &live_kit_room)
1027        .await?;
1028
1029    response.send(proto::CreateRoomResponse {
1030        room: Some(room.clone()),
1031        live_kit_connection_info,
1032    })?;
1033
1034    update_user_contacts(session.user_id, &session).await?;
1035    Ok(())
1036}
1037
1038/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1039async fn join_room(
1040    request: proto::JoinRoom,
1041    response: Response<proto::JoinRoom>,
1042    session: Session,
1043) -> Result<()> {
1044    let room_id = RoomId::from_proto(request.id);
1045
1046    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1047
1048    if let Some(channel_id) = channel_id {
1049        return join_channel_internal(channel_id, Box::new(response), session).await;
1050    }
1051
1052    let joined_room = {
1053        let room = session
1054            .db()
1055            .await
1056            .join_room(room_id, session.user_id, session.connection_id)
1057            .await?;
1058        room_updated(&room.room, &session.peer);
1059        room.into_inner()
1060    };
1061
1062    for connection_id in session
1063        .connection_pool()
1064        .await
1065        .user_connection_ids(session.user_id)
1066    {
1067        session
1068            .peer
1069            .send(
1070                connection_id,
1071                proto::CallCanceled {
1072                    room_id: room_id.to_proto(),
1073                },
1074            )
1075            .trace_err();
1076    }
1077
1078    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1079        if let Some(token) = live_kit
1080            .room_token(
1081                &joined_room.room.live_kit_room,
1082                &session.user_id.to_string(),
1083            )
1084            .trace_err()
1085        {
1086            Some(proto::LiveKitConnectionInfo {
1087                server_url: live_kit.url().into(),
1088                token,
1089                can_publish: true,
1090            })
1091        } else {
1092            None
1093        }
1094    } else {
1095        None
1096    };
1097
1098    response.send(proto::JoinRoomResponse {
1099        room: Some(joined_room.room),
1100        channel_id: None,
1101        live_kit_connection_info,
1102    })?;
1103
1104    update_user_contacts(session.user_id, &session).await?;
1105    Ok(())
1106}
1107
1108/// Rejoin room is used to reconnect to a room after connection errors.
1109async fn rejoin_room(
1110    request: proto::RejoinRoom,
1111    response: Response<proto::RejoinRoom>,
1112    session: Session,
1113) -> Result<()> {
1114    let room;
1115    let channel_id;
1116    let channel_members;
1117    {
1118        let mut rejoined_room = session
1119            .db()
1120            .await
1121            .rejoin_room(request, session.user_id, session.connection_id)
1122            .await?;
1123
1124        response.send(proto::RejoinRoomResponse {
1125            room: Some(rejoined_room.room.clone()),
1126            reshared_projects: rejoined_room
1127                .reshared_projects
1128                .iter()
1129                .map(|project| proto::ResharedProject {
1130                    id: project.id.to_proto(),
1131                    collaborators: project
1132                        .collaborators
1133                        .iter()
1134                        .map(|collaborator| collaborator.to_proto())
1135                        .collect(),
1136                })
1137                .collect(),
1138            rejoined_projects: rejoined_room
1139                .rejoined_projects
1140                .iter()
1141                .map(|rejoined_project| proto::RejoinedProject {
1142                    id: rejoined_project.id.to_proto(),
1143                    worktrees: rejoined_project
1144                        .worktrees
1145                        .iter()
1146                        .map(|worktree| proto::WorktreeMetadata {
1147                            id: worktree.id,
1148                            root_name: worktree.root_name.clone(),
1149                            visible: worktree.visible,
1150                            abs_path: worktree.abs_path.clone(),
1151                        })
1152                        .collect(),
1153                    collaborators: rejoined_project
1154                        .collaborators
1155                        .iter()
1156                        .map(|collaborator| collaborator.to_proto())
1157                        .collect(),
1158                    language_servers: rejoined_project.language_servers.clone(),
1159                })
1160                .collect(),
1161        })?;
1162        room_updated(&rejoined_room.room, &session.peer);
1163
1164        for project in &rejoined_room.reshared_projects {
1165            for collaborator in &project.collaborators {
1166                session
1167                    .peer
1168                    .send(
1169                        collaborator.connection_id,
1170                        proto::UpdateProjectCollaborator {
1171                            project_id: project.id.to_proto(),
1172                            old_peer_id: Some(project.old_connection_id.into()),
1173                            new_peer_id: Some(session.connection_id.into()),
1174                        },
1175                    )
1176                    .trace_err();
1177            }
1178
1179            broadcast(
1180                Some(session.connection_id),
1181                project
1182                    .collaborators
1183                    .iter()
1184                    .map(|collaborator| collaborator.connection_id),
1185                |connection_id| {
1186                    session.peer.forward_send(
1187                        session.connection_id,
1188                        connection_id,
1189                        proto::UpdateProject {
1190                            project_id: project.id.to_proto(),
1191                            worktrees: project.worktrees.clone(),
1192                        },
1193                    )
1194                },
1195            );
1196        }
1197
1198        for project in &rejoined_room.rejoined_projects {
1199            for collaborator in &project.collaborators {
1200                session
1201                    .peer
1202                    .send(
1203                        collaborator.connection_id,
1204                        proto::UpdateProjectCollaborator {
1205                            project_id: project.id.to_proto(),
1206                            old_peer_id: Some(project.old_connection_id.into()),
1207                            new_peer_id: Some(session.connection_id.into()),
1208                        },
1209                    )
1210                    .trace_err();
1211            }
1212        }
1213
1214        for project in &mut rejoined_room.rejoined_projects {
1215            for worktree in mem::take(&mut project.worktrees) {
1216                #[cfg(any(test, feature = "test-support"))]
1217                const MAX_CHUNK_SIZE: usize = 2;
1218                #[cfg(not(any(test, feature = "test-support")))]
1219                const MAX_CHUNK_SIZE: usize = 256;
1220
1221                // Stream this worktree's entries.
1222                let message = proto::UpdateWorktree {
1223                    project_id: project.id.to_proto(),
1224                    worktree_id: worktree.id,
1225                    abs_path: worktree.abs_path.clone(),
1226                    root_name: worktree.root_name,
1227                    updated_entries: worktree.updated_entries,
1228                    removed_entries: worktree.removed_entries,
1229                    scan_id: worktree.scan_id,
1230                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1231                    updated_repositories: worktree.updated_repositories,
1232                    removed_repositories: worktree.removed_repositories,
1233                };
1234                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1235                    session.peer.send(session.connection_id, update.clone())?;
1236                }
1237
1238                // Stream this worktree's diagnostics.
1239                for summary in worktree.diagnostic_summaries {
1240                    session.peer.send(
1241                        session.connection_id,
1242                        proto::UpdateDiagnosticSummary {
1243                            project_id: project.id.to_proto(),
1244                            worktree_id: worktree.id,
1245                            summary: Some(summary),
1246                        },
1247                    )?;
1248                }
1249
1250                for settings_file in worktree.settings_files {
1251                    session.peer.send(
1252                        session.connection_id,
1253                        proto::UpdateWorktreeSettings {
1254                            project_id: project.id.to_proto(),
1255                            worktree_id: worktree.id,
1256                            path: settings_file.path,
1257                            content: Some(settings_file.content),
1258                        },
1259                    )?;
1260                }
1261            }
1262
1263            for language_server in &project.language_servers {
1264                session.peer.send(
1265                    session.connection_id,
1266                    proto::UpdateLanguageServer {
1267                        project_id: project.id.to_proto(),
1268                        language_server_id: language_server.id,
1269                        variant: Some(
1270                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1271                                proto::LspDiskBasedDiagnosticsUpdated {},
1272                            ),
1273                        ),
1274                    },
1275                )?;
1276            }
1277        }
1278
1279        let rejoined_room = rejoined_room.into_inner();
1280
1281        room = rejoined_room.room;
1282        channel_id = rejoined_room.channel_id;
1283        channel_members = rejoined_room.channel_members;
1284    }
1285
1286    if let Some(channel_id) = channel_id {
1287        channel_updated(
1288            channel_id,
1289            &room,
1290            &channel_members,
1291            &session.peer,
1292            &*session.connection_pool().await,
1293        );
1294    }
1295
1296    update_user_contacts(session.user_id, &session).await?;
1297    Ok(())
1298}
1299
1300/// leave room disconnects from the room.
1301async fn leave_room(
1302    _: proto::LeaveRoom,
1303    response: Response<proto::LeaveRoom>,
1304    session: Session,
1305) -> Result<()> {
1306    leave_room_for_session(&session).await?;
1307    response.send(proto::Ack {})?;
1308    Ok(())
1309}
1310
1311/// Updates the permissions of someone else in the room.
1312async fn set_room_participant_role(
1313    request: proto::SetRoomParticipantRole,
1314    response: Response<proto::SetRoomParticipantRole>,
1315    session: Session,
1316) -> Result<()> {
1317    let user_id = UserId::from_proto(request.user_id);
1318    let role = ChannelRole::from(request.role());
1319
1320    if role == ChannelRole::Talker {
1321        let pool = session.connection_pool().await;
1322
1323        for connection in pool.user_connections(user_id) {
1324            if !connection.zed_version.supports_talker_role() {
1325                Err(anyhow!(
1326                    "This user is on zed {} which does not support unmute",
1327                    connection.zed_version
1328                ))?;
1329            }
1330        }
1331    }
1332
1333    let (live_kit_room, can_publish) = {
1334        let room = session
1335            .db()
1336            .await
1337            .set_room_participant_role(
1338                session.user_id,
1339                RoomId::from_proto(request.room_id),
1340                user_id,
1341                role,
1342            )
1343            .await?;
1344
1345        let live_kit_room = room.live_kit_room.clone();
1346        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1347        room_updated(&room, &session.peer);
1348        (live_kit_room, can_publish)
1349    };
1350
1351    if let Some(live_kit) = session.live_kit_client.as_ref() {
1352        live_kit
1353            .update_participant(
1354                live_kit_room.clone(),
1355                request.user_id.to_string(),
1356                live_kit_server::proto::ParticipantPermission {
1357                    can_subscribe: true,
1358                    can_publish,
1359                    can_publish_data: can_publish,
1360                    hidden: false,
1361                    recorder: false,
1362                },
1363            )
1364            .await
1365            .trace_err();
1366    }
1367
1368    response.send(proto::Ack {})?;
1369    Ok(())
1370}
1371
1372/// Call someone else into the current room
1373async fn call(
1374    request: proto::Call,
1375    response: Response<proto::Call>,
1376    session: Session,
1377) -> Result<()> {
1378    let room_id = RoomId::from_proto(request.room_id);
1379    let calling_user_id = session.user_id;
1380    let calling_connection_id = session.connection_id;
1381    let called_user_id = UserId::from_proto(request.called_user_id);
1382    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1383    if !session
1384        .db()
1385        .await
1386        .has_contact(calling_user_id, called_user_id)
1387        .await?
1388    {
1389        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1390    }
1391
1392    let incoming_call = {
1393        let (room, incoming_call) = &mut *session
1394            .db()
1395            .await
1396            .call(
1397                room_id,
1398                calling_user_id,
1399                calling_connection_id,
1400                called_user_id,
1401                initial_project_id,
1402            )
1403            .await?;
1404        room_updated(&room, &session.peer);
1405        mem::take(incoming_call)
1406    };
1407    update_user_contacts(called_user_id, &session).await?;
1408
1409    let mut calls = session
1410        .connection_pool()
1411        .await
1412        .user_connection_ids(called_user_id)
1413        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1414        .collect::<FuturesUnordered<_>>();
1415
1416    while let Some(call_response) = calls.next().await {
1417        match call_response.as_ref() {
1418            Ok(_) => {
1419                response.send(proto::Ack {})?;
1420                return Ok(());
1421            }
1422            Err(_) => {
1423                call_response.trace_err();
1424            }
1425        }
1426    }
1427
1428    {
1429        let room = session
1430            .db()
1431            .await
1432            .call_failed(room_id, called_user_id)
1433            .await?;
1434        room_updated(&room, &session.peer);
1435    }
1436    update_user_contacts(called_user_id, &session).await?;
1437
1438    Err(anyhow!("failed to ring user"))?
1439}
1440
1441/// Cancel an outgoing call.
1442async fn cancel_call(
1443    request: proto::CancelCall,
1444    response: Response<proto::CancelCall>,
1445    session: Session,
1446) -> Result<()> {
1447    let called_user_id = UserId::from_proto(request.called_user_id);
1448    let room_id = RoomId::from_proto(request.room_id);
1449    {
1450        let room = session
1451            .db()
1452            .await
1453            .cancel_call(room_id, session.connection_id, called_user_id)
1454            .await?;
1455        room_updated(&room, &session.peer);
1456    }
1457
1458    for connection_id in session
1459        .connection_pool()
1460        .await
1461        .user_connection_ids(called_user_id)
1462    {
1463        session
1464            .peer
1465            .send(
1466                connection_id,
1467                proto::CallCanceled {
1468                    room_id: room_id.to_proto(),
1469                },
1470            )
1471            .trace_err();
1472    }
1473    response.send(proto::Ack {})?;
1474
1475    update_user_contacts(called_user_id, &session).await?;
1476    Ok(())
1477}
1478
1479/// Decline an incoming call.
1480async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1481    let room_id = RoomId::from_proto(message.room_id);
1482    {
1483        let room = session
1484            .db()
1485            .await
1486            .decline_call(Some(room_id), session.user_id)
1487            .await?
1488            .ok_or_else(|| anyhow!("failed to decline call"))?;
1489        room_updated(&room, &session.peer);
1490    }
1491
1492    for connection_id in session
1493        .connection_pool()
1494        .await
1495        .user_connection_ids(session.user_id)
1496    {
1497        session
1498            .peer
1499            .send(
1500                connection_id,
1501                proto::CallCanceled {
1502                    room_id: room_id.to_proto(),
1503                },
1504            )
1505            .trace_err();
1506    }
1507    update_user_contacts(session.user_id, &session).await?;
1508    Ok(())
1509}
1510
1511/// Updates other participants in the room with your current location.
1512async fn update_participant_location(
1513    request: proto::UpdateParticipantLocation,
1514    response: Response<proto::UpdateParticipantLocation>,
1515    session: Session,
1516) -> Result<()> {
1517    let room_id = RoomId::from_proto(request.room_id);
1518    let location = request
1519        .location
1520        .ok_or_else(|| anyhow!("invalid location"))?;
1521
1522    let db = session.db().await;
1523    let room = db
1524        .update_room_participant_location(room_id, session.connection_id, location)
1525        .await?;
1526
1527    room_updated(&room, &session.peer);
1528    response.send(proto::Ack {})?;
1529    Ok(())
1530}
1531
1532/// Share a project into the room.
1533async fn share_project(
1534    request: proto::ShareProject,
1535    response: Response<proto::ShareProject>,
1536    session: Session,
1537) -> Result<()> {
1538    let (project_id, room) = &*session
1539        .db()
1540        .await
1541        .share_project(
1542            RoomId::from_proto(request.room_id),
1543            session.connection_id,
1544            &request.worktrees,
1545        )
1546        .await?;
1547    response.send(proto::ShareProjectResponse {
1548        project_id: project_id.to_proto(),
1549    })?;
1550    room_updated(&room, &session.peer);
1551
1552    Ok(())
1553}
1554
1555/// Unshare a project from the room.
1556async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1557    let project_id = ProjectId::from_proto(message.project_id);
1558
1559    let (room, guest_connection_ids) = &*session
1560        .db()
1561        .await
1562        .unshare_project(project_id, session.connection_id)
1563        .await?;
1564
1565    broadcast(
1566        Some(session.connection_id),
1567        guest_connection_ids.iter().copied(),
1568        |conn_id| session.peer.send(conn_id, message.clone()),
1569    );
1570    room_updated(&room, &session.peer);
1571
1572    Ok(())
1573}
1574
1575/// Join someone elses shared project.
1576async fn join_project(
1577    request: proto::JoinProject,
1578    response: Response<proto::JoinProject>,
1579    session: Session,
1580) -> Result<()> {
1581    let project_id = ProjectId::from_proto(request.project_id);
1582    let guest_user_id = session.user_id;
1583
1584    tracing::info!(%project_id, "join project");
1585
1586    let (project, replica_id) = &mut *session
1587        .db()
1588        .await
1589        .join_project(project_id, session.connection_id)
1590        .await?;
1591
1592    let collaborators = project
1593        .collaborators
1594        .iter()
1595        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1596        .map(|collaborator| collaborator.to_proto())
1597        .collect::<Vec<_>>();
1598
1599    let worktrees = project
1600        .worktrees
1601        .iter()
1602        .map(|(id, worktree)| proto::WorktreeMetadata {
1603            id: *id,
1604            root_name: worktree.root_name.clone(),
1605            visible: worktree.visible,
1606            abs_path: worktree.abs_path.clone(),
1607        })
1608        .collect::<Vec<_>>();
1609
1610    for collaborator in &collaborators {
1611        session
1612            .peer
1613            .send(
1614                collaborator.peer_id.unwrap().into(),
1615                proto::AddProjectCollaborator {
1616                    project_id: project_id.to_proto(),
1617                    collaborator: Some(proto::Collaborator {
1618                        peer_id: Some(session.connection_id.into()),
1619                        replica_id: replica_id.0 as u32,
1620                        user_id: guest_user_id.to_proto(),
1621                    }),
1622                },
1623            )
1624            .trace_err();
1625    }
1626
1627    // First, we send the metadata associated with each worktree.
1628    response.send(proto::JoinProjectResponse {
1629        worktrees: worktrees.clone(),
1630        replica_id: replica_id.0 as u32,
1631        collaborators: collaborators.clone(),
1632        language_servers: project.language_servers.clone(),
1633    })?;
1634
1635    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1636        #[cfg(any(test, feature = "test-support"))]
1637        const MAX_CHUNK_SIZE: usize = 2;
1638        #[cfg(not(any(test, feature = "test-support")))]
1639        const MAX_CHUNK_SIZE: usize = 256;
1640
1641        // Stream this worktree's entries.
1642        let message = proto::UpdateWorktree {
1643            project_id: project_id.to_proto(),
1644            worktree_id,
1645            abs_path: worktree.abs_path.clone(),
1646            root_name: worktree.root_name,
1647            updated_entries: worktree.entries,
1648            removed_entries: Default::default(),
1649            scan_id: worktree.scan_id,
1650            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1651            updated_repositories: worktree.repository_entries.into_values().collect(),
1652            removed_repositories: Default::default(),
1653        };
1654        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1655            session.peer.send(session.connection_id, update.clone())?;
1656        }
1657
1658        // Stream this worktree's diagnostics.
1659        for summary in worktree.diagnostic_summaries {
1660            session.peer.send(
1661                session.connection_id,
1662                proto::UpdateDiagnosticSummary {
1663                    project_id: project_id.to_proto(),
1664                    worktree_id: worktree.id,
1665                    summary: Some(summary),
1666                },
1667            )?;
1668        }
1669
1670        for settings_file in worktree.settings_files {
1671            session.peer.send(
1672                session.connection_id,
1673                proto::UpdateWorktreeSettings {
1674                    project_id: project_id.to_proto(),
1675                    worktree_id: worktree.id,
1676                    path: settings_file.path,
1677                    content: Some(settings_file.content),
1678                },
1679            )?;
1680        }
1681    }
1682
1683    for language_server in &project.language_servers {
1684        session.peer.send(
1685            session.connection_id,
1686            proto::UpdateLanguageServer {
1687                project_id: project_id.to_proto(),
1688                language_server_id: language_server.id,
1689                variant: Some(
1690                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1691                        proto::LspDiskBasedDiagnosticsUpdated {},
1692                    ),
1693                ),
1694            },
1695        )?;
1696    }
1697
1698    Ok(())
1699}
1700
1701/// Leave someone elses shared project.
1702async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1703    let sender_id = session.connection_id;
1704    let project_id = ProjectId::from_proto(request.project_id);
1705
1706    let (room, project) = &*session
1707        .db()
1708        .await
1709        .leave_project(project_id, sender_id)
1710        .await?;
1711    tracing::info!(
1712        %project_id,
1713        host_user_id = %project.host_user_id,
1714        host_connection_id = ?project.host_connection_id,
1715        "leave project"
1716    );
1717
1718    project_left(&project, &session);
1719    room_updated(&room, &session.peer);
1720
1721    Ok(())
1722}
1723
1724/// Updates other participants with changes to the project
1725async fn update_project(
1726    request: proto::UpdateProject,
1727    response: Response<proto::UpdateProject>,
1728    session: Session,
1729) -> Result<()> {
1730    let project_id = ProjectId::from_proto(request.project_id);
1731    let (room, guest_connection_ids) = &*session
1732        .db()
1733        .await
1734        .update_project(project_id, session.connection_id, &request.worktrees)
1735        .await?;
1736    broadcast(
1737        Some(session.connection_id),
1738        guest_connection_ids.iter().copied(),
1739        |connection_id| {
1740            session
1741                .peer
1742                .forward_send(session.connection_id, connection_id, request.clone())
1743        },
1744    );
1745    room_updated(&room, &session.peer);
1746    response.send(proto::Ack {})?;
1747
1748    Ok(())
1749}
1750
1751/// Updates other participants with changes to the worktree
1752async fn update_worktree(
1753    request: proto::UpdateWorktree,
1754    response: Response<proto::UpdateWorktree>,
1755    session: Session,
1756) -> Result<()> {
1757    let guest_connection_ids = session
1758        .db()
1759        .await
1760        .update_worktree(&request, session.connection_id)
1761        .await?;
1762
1763    broadcast(
1764        Some(session.connection_id),
1765        guest_connection_ids.iter().copied(),
1766        |connection_id| {
1767            session
1768                .peer
1769                .forward_send(session.connection_id, connection_id, request.clone())
1770        },
1771    );
1772    response.send(proto::Ack {})?;
1773    Ok(())
1774}
1775
1776/// Updates other participants with changes to the diagnostics
1777async fn update_diagnostic_summary(
1778    message: proto::UpdateDiagnosticSummary,
1779    session: Session,
1780) -> Result<()> {
1781    let guest_connection_ids = session
1782        .db()
1783        .await
1784        .update_diagnostic_summary(&message, session.connection_id)
1785        .await?;
1786
1787    broadcast(
1788        Some(session.connection_id),
1789        guest_connection_ids.iter().copied(),
1790        |connection_id| {
1791            session
1792                .peer
1793                .forward_send(session.connection_id, connection_id, message.clone())
1794        },
1795    );
1796
1797    Ok(())
1798}
1799
1800/// Updates other participants with changes to the worktree settings
1801async fn update_worktree_settings(
1802    message: proto::UpdateWorktreeSettings,
1803    session: Session,
1804) -> Result<()> {
1805    let guest_connection_ids = session
1806        .db()
1807        .await
1808        .update_worktree_settings(&message, session.connection_id)
1809        .await?;
1810
1811    broadcast(
1812        Some(session.connection_id),
1813        guest_connection_ids.iter().copied(),
1814        |connection_id| {
1815            session
1816                .peer
1817                .forward_send(session.connection_id, connection_id, message.clone())
1818        },
1819    );
1820
1821    Ok(())
1822}
1823
1824/// Notify other participants that a  language server has started.
1825async fn start_language_server(
1826    request: proto::StartLanguageServer,
1827    session: Session,
1828) -> Result<()> {
1829    let guest_connection_ids = session
1830        .db()
1831        .await
1832        .start_language_server(&request, session.connection_id)
1833        .await?;
1834
1835    broadcast(
1836        Some(session.connection_id),
1837        guest_connection_ids.iter().copied(),
1838        |connection_id| {
1839            session
1840                .peer
1841                .forward_send(session.connection_id, connection_id, request.clone())
1842        },
1843    );
1844    Ok(())
1845}
1846
1847/// Notify other participants that a language server has changed.
1848async fn update_language_server(
1849    request: proto::UpdateLanguageServer,
1850    session: Session,
1851) -> Result<()> {
1852    let project_id = ProjectId::from_proto(request.project_id);
1853    let project_connection_ids = session
1854        .db()
1855        .await
1856        .project_connection_ids(project_id, session.connection_id)
1857        .await?;
1858    broadcast(
1859        Some(session.connection_id),
1860        project_connection_ids.iter().copied(),
1861        |connection_id| {
1862            session
1863                .peer
1864                .forward_send(session.connection_id, connection_id, request.clone())
1865        },
1866    );
1867    Ok(())
1868}
1869
1870/// forward a project request to the host. These requests should be read only
1871/// as guests are allowed to send them.
1872async fn forward_read_only_project_request<T>(
1873    request: T,
1874    response: Response<T>,
1875    session: Session,
1876) -> Result<()>
1877where
1878    T: EntityMessage + RequestMessage,
1879{
1880    let project_id = ProjectId::from_proto(request.remote_entity_id());
1881    let host_connection_id = session
1882        .db()
1883        .await
1884        .host_for_read_only_project_request(project_id, session.connection_id)
1885        .await?;
1886    let payload = session
1887        .peer
1888        .forward_request(session.connection_id, host_connection_id, request)
1889        .await?;
1890    response.send(payload)?;
1891    Ok(())
1892}
1893
1894/// forward a project request to the host. These requests are disallowed
1895/// for guests.
1896async fn forward_mutating_project_request<T>(
1897    request: T,
1898    response: Response<T>,
1899    session: Session,
1900) -> Result<()>
1901where
1902    T: EntityMessage + RequestMessage,
1903{
1904    let project_id = ProjectId::from_proto(request.remote_entity_id());
1905    let host_connection_id = session
1906        .db()
1907        .await
1908        .host_for_mutating_project_request(project_id, session.connection_id)
1909        .await?;
1910    let payload = session
1911        .peer
1912        .forward_request(session.connection_id, host_connection_id, request)
1913        .await?;
1914    response.send(payload)?;
1915    Ok(())
1916}
1917
1918/// Notify other participants that a new buffer has been created
1919async fn create_buffer_for_peer(
1920    request: proto::CreateBufferForPeer,
1921    session: Session,
1922) -> Result<()> {
1923    session
1924        .db()
1925        .await
1926        .check_user_is_project_host(
1927            ProjectId::from_proto(request.project_id),
1928            session.connection_id,
1929        )
1930        .await?;
1931    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1932    session
1933        .peer
1934        .forward_send(session.connection_id, peer_id.into(), request)?;
1935    Ok(())
1936}
1937
1938/// Notify other participants that a buffer has been updated. This is
1939/// allowed for guests as long as the update is limited to selections.
1940async fn update_buffer(
1941    request: proto::UpdateBuffer,
1942    response: Response<proto::UpdateBuffer>,
1943    session: Session,
1944) -> Result<()> {
1945    let project_id = ProjectId::from_proto(request.project_id);
1946    let mut guest_connection_ids;
1947    let mut host_connection_id = None;
1948
1949    let mut requires_write_permission = false;
1950
1951    for op in request.operations.iter() {
1952        match op.variant {
1953            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1954            Some(_) => requires_write_permission = true,
1955        }
1956    }
1957
1958    {
1959        let collaborators = session
1960            .db()
1961            .await
1962            .project_collaborators_for_buffer_update(
1963                project_id,
1964                session.connection_id,
1965                requires_write_permission,
1966            )
1967            .await?;
1968        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1969        for collaborator in collaborators.iter() {
1970            if collaborator.is_host {
1971                host_connection_id = Some(collaborator.connection_id);
1972            } else {
1973                guest_connection_ids.push(collaborator.connection_id);
1974            }
1975        }
1976    }
1977    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1978
1979    broadcast(
1980        Some(session.connection_id),
1981        guest_connection_ids,
1982        |connection_id| {
1983            session
1984                .peer
1985                .forward_send(session.connection_id, connection_id, request.clone())
1986        },
1987    );
1988    if host_connection_id != session.connection_id {
1989        session
1990            .peer
1991            .forward_request(session.connection_id, host_connection_id, request.clone())
1992            .await?;
1993    }
1994
1995    response.send(proto::Ack {})?;
1996    Ok(())
1997}
1998
1999/// Notify other participants that a project has been updated.
2000async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2001    request: T,
2002    session: Session,
2003) -> Result<()> {
2004    let project_id = ProjectId::from_proto(request.remote_entity_id());
2005    let project_connection_ids = session
2006        .db()
2007        .await
2008        .project_connection_ids(project_id, session.connection_id)
2009        .await?;
2010
2011    broadcast(
2012        Some(session.connection_id),
2013        project_connection_ids.iter().copied(),
2014        |connection_id| {
2015            session
2016                .peer
2017                .forward_send(session.connection_id, connection_id, request.clone())
2018        },
2019    );
2020    Ok(())
2021}
2022
2023/// Start following another user in a call.
2024async fn follow(
2025    request: proto::Follow,
2026    response: Response<proto::Follow>,
2027    session: Session,
2028) -> Result<()> {
2029    let room_id = RoomId::from_proto(request.room_id);
2030    let project_id = request.project_id.map(ProjectId::from_proto);
2031    let leader_id = request
2032        .leader_id
2033        .ok_or_else(|| anyhow!("invalid leader id"))?
2034        .into();
2035    let follower_id = session.connection_id;
2036
2037    session
2038        .db()
2039        .await
2040        .check_room_participants(room_id, leader_id, session.connection_id)
2041        .await?;
2042
2043    let response_payload = session
2044        .peer
2045        .forward_request(session.connection_id, leader_id, request)
2046        .await?;
2047    response.send(response_payload)?;
2048
2049    if let Some(project_id) = project_id {
2050        let room = session
2051            .db()
2052            .await
2053            .follow(room_id, project_id, leader_id, follower_id)
2054            .await?;
2055        room_updated(&room, &session.peer);
2056    }
2057
2058    Ok(())
2059}
2060
2061/// Stop following another user in a call.
2062async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2063    let room_id = RoomId::from_proto(request.room_id);
2064    let project_id = request.project_id.map(ProjectId::from_proto);
2065    let leader_id = request
2066        .leader_id
2067        .ok_or_else(|| anyhow!("invalid leader id"))?
2068        .into();
2069    let follower_id = session.connection_id;
2070
2071    session
2072        .db()
2073        .await
2074        .check_room_participants(room_id, leader_id, session.connection_id)
2075        .await?;
2076
2077    session
2078        .peer
2079        .forward_send(session.connection_id, leader_id, request)?;
2080
2081    if let Some(project_id) = project_id {
2082        let room = session
2083            .db()
2084            .await
2085            .unfollow(room_id, project_id, leader_id, follower_id)
2086            .await?;
2087        room_updated(&room, &session.peer);
2088    }
2089
2090    Ok(())
2091}
2092
2093/// Notify everyone following you of your current location.
2094async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2095    let room_id = RoomId::from_proto(request.room_id);
2096    let database = session.db.lock().await;
2097
2098    let connection_ids = if let Some(project_id) = request.project_id {
2099        let project_id = ProjectId::from_proto(project_id);
2100        database
2101            .project_connection_ids(project_id, session.connection_id)
2102            .await?
2103    } else {
2104        database
2105            .room_connection_ids(room_id, session.connection_id)
2106            .await?
2107    };
2108
2109    // For now, don't send view update messages back to that view's current leader.
2110    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2111        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2112        _ => None,
2113    });
2114
2115    for connection_id in connection_ids.iter().cloned() {
2116        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2117            session
2118                .peer
2119                .forward_send(session.connection_id, connection_id, request.clone())?;
2120        }
2121    }
2122    Ok(())
2123}
2124
2125/// Get public data about users.
2126async fn get_users(
2127    request: proto::GetUsers,
2128    response: Response<proto::GetUsers>,
2129    session: Session,
2130) -> Result<()> {
2131    let user_ids = request
2132        .user_ids
2133        .into_iter()
2134        .map(UserId::from_proto)
2135        .collect();
2136    let users = session
2137        .db()
2138        .await
2139        .get_users_by_ids(user_ids)
2140        .await?
2141        .into_iter()
2142        .map(|user| proto::User {
2143            id: user.id.to_proto(),
2144            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2145            github_login: user.github_login,
2146        })
2147        .collect();
2148    response.send(proto::UsersResponse { users })?;
2149    Ok(())
2150}
2151
2152/// Search for users (to invite) buy Github login
2153async fn fuzzy_search_users(
2154    request: proto::FuzzySearchUsers,
2155    response: Response<proto::FuzzySearchUsers>,
2156    session: Session,
2157) -> Result<()> {
2158    let query = request.query;
2159    let users = match query.len() {
2160        0 => vec![],
2161        1 | 2 => session
2162            .db()
2163            .await
2164            .get_user_by_github_login(&query)
2165            .await?
2166            .into_iter()
2167            .collect(),
2168        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2169    };
2170    let users = users
2171        .into_iter()
2172        .filter(|user| user.id != session.user_id)
2173        .map(|user| proto::User {
2174            id: user.id.to_proto(),
2175            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2176            github_login: user.github_login,
2177        })
2178        .collect();
2179    response.send(proto::UsersResponse { users })?;
2180    Ok(())
2181}
2182
2183/// Send a contact request to another user.
2184async fn request_contact(
2185    request: proto::RequestContact,
2186    response: Response<proto::RequestContact>,
2187    session: Session,
2188) -> Result<()> {
2189    let requester_id = session.user_id;
2190    let responder_id = UserId::from_proto(request.responder_id);
2191    if requester_id == responder_id {
2192        return Err(anyhow!("cannot add yourself as a contact"))?;
2193    }
2194
2195    let notifications = session
2196        .db()
2197        .await
2198        .send_contact_request(requester_id, responder_id)
2199        .await?;
2200
2201    // Update outgoing contact requests of requester
2202    let mut update = proto::UpdateContacts::default();
2203    update.outgoing_requests.push(responder_id.to_proto());
2204    for connection_id in session
2205        .connection_pool()
2206        .await
2207        .user_connection_ids(requester_id)
2208    {
2209        session.peer.send(connection_id, update.clone())?;
2210    }
2211
2212    // Update incoming contact requests of responder
2213    let mut update = proto::UpdateContacts::default();
2214    update
2215        .incoming_requests
2216        .push(proto::IncomingContactRequest {
2217            requester_id: requester_id.to_proto(),
2218        });
2219    let connection_pool = session.connection_pool().await;
2220    for connection_id in connection_pool.user_connection_ids(responder_id) {
2221        session.peer.send(connection_id, update.clone())?;
2222    }
2223
2224    send_notifications(&*connection_pool, &session.peer, notifications);
2225
2226    response.send(proto::Ack {})?;
2227    Ok(())
2228}
2229
2230/// Accept or decline a contact request
2231async fn respond_to_contact_request(
2232    request: proto::RespondToContactRequest,
2233    response: Response<proto::RespondToContactRequest>,
2234    session: Session,
2235) -> Result<()> {
2236    let responder_id = session.user_id;
2237    let requester_id = UserId::from_proto(request.requester_id);
2238    let db = session.db().await;
2239    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2240        db.dismiss_contact_notification(responder_id, requester_id)
2241            .await?;
2242    } else {
2243        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2244
2245        let notifications = db
2246            .respond_to_contact_request(responder_id, requester_id, accept)
2247            .await?;
2248        let requester_busy = db.is_user_busy(requester_id).await?;
2249        let responder_busy = db.is_user_busy(responder_id).await?;
2250
2251        let pool = session.connection_pool().await;
2252        // Update responder with new contact
2253        let mut update = proto::UpdateContacts::default();
2254        if accept {
2255            update
2256                .contacts
2257                .push(contact_for_user(requester_id, requester_busy, &pool));
2258        }
2259        update
2260            .remove_incoming_requests
2261            .push(requester_id.to_proto());
2262        for connection_id in pool.user_connection_ids(responder_id) {
2263            session.peer.send(connection_id, update.clone())?;
2264        }
2265
2266        // Update requester with new contact
2267        let mut update = proto::UpdateContacts::default();
2268        if accept {
2269            update
2270                .contacts
2271                .push(contact_for_user(responder_id, responder_busy, &pool));
2272        }
2273        update
2274            .remove_outgoing_requests
2275            .push(responder_id.to_proto());
2276
2277        for connection_id in pool.user_connection_ids(requester_id) {
2278            session.peer.send(connection_id, update.clone())?;
2279        }
2280
2281        send_notifications(&*pool, &session.peer, notifications);
2282    }
2283
2284    response.send(proto::Ack {})?;
2285    Ok(())
2286}
2287
2288/// Remove a contact.
2289async fn remove_contact(
2290    request: proto::RemoveContact,
2291    response: Response<proto::RemoveContact>,
2292    session: Session,
2293) -> Result<()> {
2294    let requester_id = session.user_id;
2295    let responder_id = UserId::from_proto(request.user_id);
2296    let db = session.db().await;
2297    let (contact_accepted, deleted_notification_id) =
2298        db.remove_contact(requester_id, responder_id).await?;
2299
2300    let pool = session.connection_pool().await;
2301    // Update outgoing contact requests of requester
2302    let mut update = proto::UpdateContacts::default();
2303    if contact_accepted {
2304        update.remove_contacts.push(responder_id.to_proto());
2305    } else {
2306        update
2307            .remove_outgoing_requests
2308            .push(responder_id.to_proto());
2309    }
2310    for connection_id in pool.user_connection_ids(requester_id) {
2311        session.peer.send(connection_id, update.clone())?;
2312    }
2313
2314    // Update incoming contact requests of responder
2315    let mut update = proto::UpdateContacts::default();
2316    if contact_accepted {
2317        update.remove_contacts.push(requester_id.to_proto());
2318    } else {
2319        update
2320            .remove_incoming_requests
2321            .push(requester_id.to_proto());
2322    }
2323    for connection_id in pool.user_connection_ids(responder_id) {
2324        session.peer.send(connection_id, update.clone())?;
2325        if let Some(notification_id) = deleted_notification_id {
2326            session.peer.send(
2327                connection_id,
2328                proto::DeleteNotification {
2329                    notification_id: notification_id.to_proto(),
2330                },
2331            )?;
2332        }
2333    }
2334
2335    response.send(proto::Ack {})?;
2336    Ok(())
2337}
2338
2339/// Creates a new channel.
2340async fn create_channel(
2341    request: proto::CreateChannel,
2342    response: Response<proto::CreateChannel>,
2343    session: Session,
2344) -> Result<()> {
2345    let db = session.db().await;
2346
2347    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2348    let (channel, owner, channel_members) = db
2349        .create_channel(&request.name, parent_id, session.user_id)
2350        .await?;
2351
2352    response.send(proto::CreateChannelResponse {
2353        channel: Some(channel.to_proto()),
2354        parent_id: request.parent_id,
2355    })?;
2356
2357    let connection_pool = session.connection_pool().await;
2358    if let Some(owner) = owner {
2359        let update = proto::UpdateUserChannels {
2360            channel_memberships: vec![proto::ChannelMembership {
2361                channel_id: owner.channel_id.to_proto(),
2362                role: owner.role.into(),
2363            }],
2364            ..Default::default()
2365        };
2366        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2367            session.peer.send(connection_id, update.clone())?;
2368        }
2369    }
2370
2371    for channel_member in channel_members {
2372        if !channel_member.role.can_see_channel(channel.visibility) {
2373            continue;
2374        }
2375
2376        let update = proto::UpdateChannels {
2377            channels: vec![channel.to_proto()],
2378            ..Default::default()
2379        };
2380        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2381            session.peer.send(connection_id, update.clone())?;
2382        }
2383    }
2384
2385    Ok(())
2386}
2387
2388/// Delete a channel
2389async fn delete_channel(
2390    request: proto::DeleteChannel,
2391    response: Response<proto::DeleteChannel>,
2392    session: Session,
2393) -> Result<()> {
2394    let db = session.db().await;
2395
2396    let channel_id = request.channel_id;
2397    let (removed_channels, member_ids) = db
2398        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2399        .await?;
2400    response.send(proto::Ack {})?;
2401
2402    // Notify members of removed channels
2403    let mut update = proto::UpdateChannels::default();
2404    update
2405        .delete_channels
2406        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2407
2408    let connection_pool = session.connection_pool().await;
2409    for member_id in member_ids {
2410        for connection_id in connection_pool.user_connection_ids(member_id) {
2411            session.peer.send(connection_id, update.clone())?;
2412        }
2413    }
2414
2415    Ok(())
2416}
2417
2418/// Invite someone to join a channel.
2419async fn invite_channel_member(
2420    request: proto::InviteChannelMember,
2421    response: Response<proto::InviteChannelMember>,
2422    session: Session,
2423) -> Result<()> {
2424    let db = session.db().await;
2425    let channel_id = ChannelId::from_proto(request.channel_id);
2426    let invitee_id = UserId::from_proto(request.user_id);
2427    let InviteMemberResult {
2428        channel,
2429        notifications,
2430    } = db
2431        .invite_channel_member(
2432            channel_id,
2433            invitee_id,
2434            session.user_id,
2435            request.role().into(),
2436        )
2437        .await?;
2438
2439    let update = proto::UpdateChannels {
2440        channel_invitations: vec![channel.to_proto()],
2441        ..Default::default()
2442    };
2443
2444    let connection_pool = session.connection_pool().await;
2445    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2446        session.peer.send(connection_id, update.clone())?;
2447    }
2448
2449    send_notifications(&*connection_pool, &session.peer, notifications);
2450
2451    response.send(proto::Ack {})?;
2452    Ok(())
2453}
2454
2455/// remove someone from a channel
2456async fn remove_channel_member(
2457    request: proto::RemoveChannelMember,
2458    response: Response<proto::RemoveChannelMember>,
2459    session: Session,
2460) -> Result<()> {
2461    let db = session.db().await;
2462    let channel_id = ChannelId::from_proto(request.channel_id);
2463    let member_id = UserId::from_proto(request.user_id);
2464
2465    let RemoveChannelMemberResult {
2466        membership_update,
2467        notification_id,
2468    } = db
2469        .remove_channel_member(channel_id, member_id, session.user_id)
2470        .await?;
2471
2472    let connection_pool = &session.connection_pool().await;
2473    notify_membership_updated(
2474        &connection_pool,
2475        membership_update,
2476        member_id,
2477        &session.peer,
2478    );
2479    for connection_id in connection_pool.user_connection_ids(member_id) {
2480        if let Some(notification_id) = notification_id {
2481            session
2482                .peer
2483                .send(
2484                    connection_id,
2485                    proto::DeleteNotification {
2486                        notification_id: notification_id.to_proto(),
2487                    },
2488                )
2489                .trace_err();
2490        }
2491    }
2492
2493    response.send(proto::Ack {})?;
2494    Ok(())
2495}
2496
2497/// Toggle the channel between public and private.
2498/// Care is taken to maintain the invariant that public channels only descend from public channels,
2499/// (though members-only channels can appear at any point in the hierarchy).
2500async fn set_channel_visibility(
2501    request: proto::SetChannelVisibility,
2502    response: Response<proto::SetChannelVisibility>,
2503    session: Session,
2504) -> Result<()> {
2505    let db = session.db().await;
2506    let channel_id = ChannelId::from_proto(request.channel_id);
2507    let visibility = request.visibility().into();
2508
2509    let (channel, channel_members) = db
2510        .set_channel_visibility(channel_id, visibility, session.user_id)
2511        .await?;
2512
2513    let connection_pool = session.connection_pool().await;
2514    for member in channel_members {
2515        let update = if member.role.can_see_channel(channel.visibility) {
2516            proto::UpdateChannels {
2517                channels: vec![channel.to_proto()],
2518                ..Default::default()
2519            }
2520        } else {
2521            proto::UpdateChannels {
2522                delete_channels: vec![channel.id.to_proto()],
2523                ..Default::default()
2524            }
2525        };
2526
2527        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2528            session.peer.send(connection_id, update.clone())?;
2529        }
2530    }
2531
2532    response.send(proto::Ack {})?;
2533    Ok(())
2534}
2535
2536/// Alter the role for a user in the channel.
2537async fn set_channel_member_role(
2538    request: proto::SetChannelMemberRole,
2539    response: Response<proto::SetChannelMemberRole>,
2540    session: Session,
2541) -> Result<()> {
2542    let db = session.db().await;
2543    let channel_id = ChannelId::from_proto(request.channel_id);
2544    let member_id = UserId::from_proto(request.user_id);
2545    let result = db
2546        .set_channel_member_role(
2547            channel_id,
2548            session.user_id,
2549            member_id,
2550            request.role().into(),
2551        )
2552        .await?;
2553
2554    match result {
2555        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2556            let connection_pool = session.connection_pool().await;
2557            notify_membership_updated(
2558                &connection_pool,
2559                membership_update,
2560                member_id,
2561                &session.peer,
2562            )
2563        }
2564        db::SetMemberRoleResult::InviteUpdated(channel) => {
2565            let update = proto::UpdateChannels {
2566                channel_invitations: vec![channel.to_proto()],
2567                ..Default::default()
2568            };
2569
2570            for connection_id in session
2571                .connection_pool()
2572                .await
2573                .user_connection_ids(member_id)
2574            {
2575                session.peer.send(connection_id, update.clone())?;
2576            }
2577        }
2578    }
2579
2580    response.send(proto::Ack {})?;
2581    Ok(())
2582}
2583
2584/// Change the name of a channel
2585async fn rename_channel(
2586    request: proto::RenameChannel,
2587    response: Response<proto::RenameChannel>,
2588    session: Session,
2589) -> Result<()> {
2590    let db = session.db().await;
2591    let channel_id = ChannelId::from_proto(request.channel_id);
2592    let (channel, channel_members) = db
2593        .rename_channel(channel_id, session.user_id, &request.name)
2594        .await?;
2595
2596    response.send(proto::RenameChannelResponse {
2597        channel: Some(channel.to_proto()),
2598    })?;
2599
2600    let connection_pool = session.connection_pool().await;
2601    for channel_member in channel_members {
2602        if !channel_member.role.can_see_channel(channel.visibility) {
2603            continue;
2604        }
2605        let update = proto::UpdateChannels {
2606            channels: vec![channel.to_proto()],
2607            ..Default::default()
2608        };
2609        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2610            session.peer.send(connection_id, update.clone())?;
2611        }
2612    }
2613
2614    Ok(())
2615}
2616
2617/// Move a channel to a new parent.
2618async fn move_channel(
2619    request: proto::MoveChannel,
2620    response: Response<proto::MoveChannel>,
2621    session: Session,
2622) -> Result<()> {
2623    let channel_id = ChannelId::from_proto(request.channel_id);
2624    let to = ChannelId::from_proto(request.to);
2625
2626    let (channels, channel_members) = session
2627        .db()
2628        .await
2629        .move_channel(channel_id, to, session.user_id)
2630        .await?;
2631
2632    let connection_pool = session.connection_pool().await;
2633    for member in channel_members {
2634        let channels = channels
2635            .iter()
2636            .filter_map(|channel| {
2637                if member.role.can_see_channel(channel.visibility) {
2638                    Some(channel.to_proto())
2639                } else {
2640                    None
2641                }
2642            })
2643            .collect::<Vec<_>>();
2644        if channels.is_empty() {
2645            continue;
2646        }
2647
2648        let update = proto::UpdateChannels {
2649            channels,
2650            ..Default::default()
2651        };
2652
2653        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2654            session.peer.send(connection_id, update.clone())?;
2655        }
2656    }
2657
2658    response.send(Ack {})?;
2659    Ok(())
2660}
2661
2662/// Get the list of channel members
2663async fn get_channel_members(
2664    request: proto::GetChannelMembers,
2665    response: Response<proto::GetChannelMembers>,
2666    session: Session,
2667) -> Result<()> {
2668    let db = session.db().await;
2669    let channel_id = ChannelId::from_proto(request.channel_id);
2670    let members = db
2671        .get_channel_participant_details(channel_id, session.user_id)
2672        .await?;
2673    response.send(proto::GetChannelMembersResponse { members })?;
2674    Ok(())
2675}
2676
2677/// Accept or decline a channel invitation.
2678async fn respond_to_channel_invite(
2679    request: proto::RespondToChannelInvite,
2680    response: Response<proto::RespondToChannelInvite>,
2681    session: Session,
2682) -> Result<()> {
2683    let db = session.db().await;
2684    let channel_id = ChannelId::from_proto(request.channel_id);
2685    let RespondToChannelInvite {
2686        membership_update,
2687        notifications,
2688    } = db
2689        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2690        .await?;
2691
2692    let connection_pool = session.connection_pool().await;
2693    if let Some(membership_update) = membership_update {
2694        notify_membership_updated(
2695            &connection_pool,
2696            membership_update,
2697            session.user_id,
2698            &session.peer,
2699        );
2700    } else {
2701        let update = proto::UpdateChannels {
2702            remove_channel_invitations: vec![channel_id.to_proto()],
2703            ..Default::default()
2704        };
2705
2706        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2707            session.peer.send(connection_id, update.clone())?;
2708        }
2709    };
2710
2711    send_notifications(&*connection_pool, &session.peer, notifications);
2712
2713    response.send(proto::Ack {})?;
2714
2715    Ok(())
2716}
2717
2718/// Join the channels' room
2719async fn join_channel(
2720    request: proto::JoinChannel,
2721    response: Response<proto::JoinChannel>,
2722    session: Session,
2723) -> Result<()> {
2724    let channel_id = ChannelId::from_proto(request.channel_id);
2725    join_channel_internal(channel_id, Box::new(response), session).await
2726}
2727
2728trait JoinChannelInternalResponse {
2729    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2730}
2731impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2732    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2733        Response::<proto::JoinChannel>::send(self, result)
2734    }
2735}
2736impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2737    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2738        Response::<proto::JoinRoom>::send(self, result)
2739    }
2740}
2741
2742async fn join_channel_internal(
2743    channel_id: ChannelId,
2744    response: Box<impl JoinChannelInternalResponse>,
2745    session: Session,
2746) -> Result<()> {
2747    let joined_room = {
2748        leave_room_for_session(&session).await?;
2749        let db = session.db().await;
2750
2751        let (joined_room, membership_updated, role) = db
2752            .join_channel(channel_id, session.user_id, session.connection_id)
2753            .await?;
2754
2755        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2756            let (can_publish, token) = if role == ChannelRole::Guest {
2757                (
2758                    false,
2759                    live_kit
2760                        .guest_token(
2761                            &joined_room.room.live_kit_room,
2762                            &session.user_id.to_string(),
2763                        )
2764                        .trace_err()?,
2765                )
2766            } else {
2767                (
2768                    true,
2769                    live_kit
2770                        .room_token(
2771                            &joined_room.room.live_kit_room,
2772                            &session.user_id.to_string(),
2773                        )
2774                        .trace_err()?,
2775                )
2776            };
2777
2778            Some(LiveKitConnectionInfo {
2779                server_url: live_kit.url().into(),
2780                token,
2781                can_publish,
2782            })
2783        });
2784
2785        response.send(proto::JoinRoomResponse {
2786            room: Some(joined_room.room.clone()),
2787            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2788            live_kit_connection_info,
2789        })?;
2790
2791        let connection_pool = session.connection_pool().await;
2792        if let Some(membership_updated) = membership_updated {
2793            notify_membership_updated(
2794                &connection_pool,
2795                membership_updated,
2796                session.user_id,
2797                &session.peer,
2798            );
2799        }
2800
2801        room_updated(&joined_room.room, &session.peer);
2802
2803        joined_room
2804    };
2805
2806    channel_updated(
2807        channel_id,
2808        &joined_room.room,
2809        &joined_room.channel_members,
2810        &session.peer,
2811        &*session.connection_pool().await,
2812    );
2813
2814    update_user_contacts(session.user_id, &session).await?;
2815    Ok(())
2816}
2817
2818/// Start editing the channel notes
2819async fn join_channel_buffer(
2820    request: proto::JoinChannelBuffer,
2821    response: Response<proto::JoinChannelBuffer>,
2822    session: Session,
2823) -> Result<()> {
2824    let db = session.db().await;
2825    let channel_id = ChannelId::from_proto(request.channel_id);
2826
2827    let open_response = db
2828        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2829        .await?;
2830
2831    let collaborators = open_response.collaborators.clone();
2832    response.send(open_response)?;
2833
2834    let update = UpdateChannelBufferCollaborators {
2835        channel_id: channel_id.to_proto(),
2836        collaborators: collaborators.clone(),
2837    };
2838    channel_buffer_updated(
2839        session.connection_id,
2840        collaborators
2841            .iter()
2842            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2843        &update,
2844        &session.peer,
2845    );
2846
2847    Ok(())
2848}
2849
2850/// Edit the channel notes
2851async fn update_channel_buffer(
2852    request: proto::UpdateChannelBuffer,
2853    session: Session,
2854) -> Result<()> {
2855    let db = session.db().await;
2856    let channel_id = ChannelId::from_proto(request.channel_id);
2857
2858    let (collaborators, non_collaborators, epoch, version) = db
2859        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2860        .await?;
2861
2862    channel_buffer_updated(
2863        session.connection_id,
2864        collaborators,
2865        &proto::UpdateChannelBuffer {
2866            channel_id: channel_id.to_proto(),
2867            operations: request.operations,
2868        },
2869        &session.peer,
2870    );
2871
2872    let pool = &*session.connection_pool().await;
2873
2874    broadcast(
2875        None,
2876        non_collaborators
2877            .iter()
2878            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2879        |peer_id| {
2880            session.peer.send(
2881                peer_id.into(),
2882                proto::UpdateChannels {
2883                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2884                        channel_id: channel_id.to_proto(),
2885                        epoch: epoch as u64,
2886                        version: version.clone(),
2887                    }],
2888                    ..Default::default()
2889                },
2890            )
2891        },
2892    );
2893
2894    Ok(())
2895}
2896
2897/// Rejoin the channel notes after a connection blip
2898async fn rejoin_channel_buffers(
2899    request: proto::RejoinChannelBuffers,
2900    response: Response<proto::RejoinChannelBuffers>,
2901    session: Session,
2902) -> Result<()> {
2903    let db = session.db().await;
2904    let buffers = db
2905        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2906        .await?;
2907
2908    for rejoined_buffer in &buffers {
2909        let collaborators_to_notify = rejoined_buffer
2910            .buffer
2911            .collaborators
2912            .iter()
2913            .filter_map(|c| Some(c.peer_id?.into()));
2914        channel_buffer_updated(
2915            session.connection_id,
2916            collaborators_to_notify,
2917            &proto::UpdateChannelBufferCollaborators {
2918                channel_id: rejoined_buffer.buffer.channel_id,
2919                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2920            },
2921            &session.peer,
2922        );
2923    }
2924
2925    response.send(proto::RejoinChannelBuffersResponse {
2926        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2927    })?;
2928
2929    Ok(())
2930}
2931
2932/// Stop editing the channel notes
2933async fn leave_channel_buffer(
2934    request: proto::LeaveChannelBuffer,
2935    response: Response<proto::LeaveChannelBuffer>,
2936    session: Session,
2937) -> Result<()> {
2938    let db = session.db().await;
2939    let channel_id = ChannelId::from_proto(request.channel_id);
2940
2941    let left_buffer = db
2942        .leave_channel_buffer(channel_id, session.connection_id)
2943        .await?;
2944
2945    response.send(Ack {})?;
2946
2947    channel_buffer_updated(
2948        session.connection_id,
2949        left_buffer.connections,
2950        &proto::UpdateChannelBufferCollaborators {
2951            channel_id: channel_id.to_proto(),
2952            collaborators: left_buffer.collaborators,
2953        },
2954        &session.peer,
2955    );
2956
2957    Ok(())
2958}
2959
2960fn channel_buffer_updated<T: EnvelopedMessage>(
2961    sender_id: ConnectionId,
2962    collaborators: impl IntoIterator<Item = ConnectionId>,
2963    message: &T,
2964    peer: &Peer,
2965) {
2966    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2967        peer.send(peer_id.into(), message.clone())
2968    });
2969}
2970
2971fn send_notifications(
2972    connection_pool: &ConnectionPool,
2973    peer: &Peer,
2974    notifications: db::NotificationBatch,
2975) {
2976    for (user_id, notification) in notifications {
2977        for connection_id in connection_pool.user_connection_ids(user_id) {
2978            if let Err(error) = peer.send(
2979                connection_id,
2980                proto::AddNotification {
2981                    notification: Some(notification.clone()),
2982                },
2983            ) {
2984                tracing::error!(
2985                    "failed to send notification to {:?} {}",
2986                    connection_id,
2987                    error
2988                );
2989            }
2990        }
2991    }
2992}
2993
2994/// Send a message to the channel
2995async fn send_channel_message(
2996    request: proto::SendChannelMessage,
2997    response: Response<proto::SendChannelMessage>,
2998    session: Session,
2999) -> Result<()> {
3000    // Validate the message body.
3001    let body = request.body.trim().to_string();
3002    if body.len() > MAX_MESSAGE_LEN {
3003        return Err(anyhow!("message is too long"))?;
3004    }
3005    if body.is_empty() {
3006        return Err(anyhow!("message can't be blank"))?;
3007    }
3008
3009    // TODO: adjust mentions if body is trimmed
3010
3011    let timestamp = OffsetDateTime::now_utc();
3012    let nonce = request
3013        .nonce
3014        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3015
3016    let channel_id = ChannelId::from_proto(request.channel_id);
3017    let CreatedChannelMessage {
3018        message_id,
3019        participant_connection_ids,
3020        channel_members,
3021        notifications,
3022    } = session
3023        .db()
3024        .await
3025        .create_channel_message(
3026            channel_id,
3027            session.user_id,
3028            &body,
3029            &request.mentions,
3030            timestamp,
3031            nonce.clone().into(),
3032            match request.reply_to_message_id {
3033                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3034                None => None,
3035            },
3036        )
3037        .await?;
3038    let message = proto::ChannelMessage {
3039        sender_id: session.user_id.to_proto(),
3040        id: message_id.to_proto(),
3041        body,
3042        mentions: request.mentions,
3043        timestamp: timestamp.unix_timestamp() as u64,
3044        nonce: Some(nonce),
3045        reply_to_message_id: request.reply_to_message_id,
3046    };
3047    broadcast(
3048        Some(session.connection_id),
3049        participant_connection_ids,
3050        |connection| {
3051            session.peer.send(
3052                connection,
3053                proto::ChannelMessageSent {
3054                    channel_id: channel_id.to_proto(),
3055                    message: Some(message.clone()),
3056                },
3057            )
3058        },
3059    );
3060    response.send(proto::SendChannelMessageResponse {
3061        message: Some(message),
3062    })?;
3063
3064    let pool = &*session.connection_pool().await;
3065    broadcast(
3066        None,
3067        channel_members
3068            .iter()
3069            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3070        |peer_id| {
3071            session.peer.send(
3072                peer_id.into(),
3073                proto::UpdateChannels {
3074                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3075                        channel_id: channel_id.to_proto(),
3076                        message_id: message_id.to_proto(),
3077                    }],
3078                    ..Default::default()
3079                },
3080            )
3081        },
3082    );
3083    send_notifications(pool, &session.peer, notifications);
3084
3085    Ok(())
3086}
3087
3088/// Delete a channel message
3089async fn remove_channel_message(
3090    request: proto::RemoveChannelMessage,
3091    response: Response<proto::RemoveChannelMessage>,
3092    session: Session,
3093) -> Result<()> {
3094    let channel_id = ChannelId::from_proto(request.channel_id);
3095    let message_id = MessageId::from_proto(request.message_id);
3096    let connection_ids = session
3097        .db()
3098        .await
3099        .remove_channel_message(channel_id, message_id, session.user_id)
3100        .await?;
3101    broadcast(Some(session.connection_id), connection_ids, |connection| {
3102        session.peer.send(connection, request.clone())
3103    });
3104    response.send(proto::Ack {})?;
3105    Ok(())
3106}
3107
3108/// Mark a channel message as read
3109async fn acknowledge_channel_message(
3110    request: proto::AckChannelMessage,
3111    session: Session,
3112) -> Result<()> {
3113    let channel_id = ChannelId::from_proto(request.channel_id);
3114    let message_id = MessageId::from_proto(request.message_id);
3115    let notifications = session
3116        .db()
3117        .await
3118        .observe_channel_message(channel_id, session.user_id, message_id)
3119        .await?;
3120    send_notifications(
3121        &*session.connection_pool().await,
3122        &session.peer,
3123        notifications,
3124    );
3125    Ok(())
3126}
3127
3128/// Mark a buffer version as synced
3129async fn acknowledge_buffer_version(
3130    request: proto::AckBufferOperation,
3131    session: Session,
3132) -> Result<()> {
3133    let buffer_id = BufferId::from_proto(request.buffer_id);
3134    session
3135        .db()
3136        .await
3137        .observe_buffer_version(
3138            buffer_id,
3139            session.user_id,
3140            request.epoch as i32,
3141            &request.version,
3142        )
3143        .await?;
3144    Ok(())
3145}
3146
3147/// Start receiving chat updates for a channel
3148async fn join_channel_chat(
3149    request: proto::JoinChannelChat,
3150    response: Response<proto::JoinChannelChat>,
3151    session: Session,
3152) -> Result<()> {
3153    let channel_id = ChannelId::from_proto(request.channel_id);
3154
3155    let db = session.db().await;
3156    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3157        .await?;
3158    let messages = db
3159        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3160        .await?;
3161    response.send(proto::JoinChannelChatResponse {
3162        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3163        messages,
3164    })?;
3165    Ok(())
3166}
3167
3168/// Stop receiving chat updates for a channel
3169async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3170    let channel_id = ChannelId::from_proto(request.channel_id);
3171    session
3172        .db()
3173        .await
3174        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3175        .await?;
3176    Ok(())
3177}
3178
3179/// Retrieve the chat history for a channel
3180async fn get_channel_messages(
3181    request: proto::GetChannelMessages,
3182    response: Response<proto::GetChannelMessages>,
3183    session: Session,
3184) -> Result<()> {
3185    let channel_id = ChannelId::from_proto(request.channel_id);
3186    let messages = session
3187        .db()
3188        .await
3189        .get_channel_messages(
3190            channel_id,
3191            session.user_id,
3192            MESSAGE_COUNT_PER_PAGE,
3193            Some(MessageId::from_proto(request.before_message_id)),
3194        )
3195        .await?;
3196    response.send(proto::GetChannelMessagesResponse {
3197        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3198        messages,
3199    })?;
3200    Ok(())
3201}
3202
3203/// Retrieve specific chat messages
3204async fn get_channel_messages_by_id(
3205    request: proto::GetChannelMessagesById,
3206    response: Response<proto::GetChannelMessagesById>,
3207    session: Session,
3208) -> Result<()> {
3209    let message_ids = request
3210        .message_ids
3211        .iter()
3212        .map(|id| MessageId::from_proto(*id))
3213        .collect::<Vec<_>>();
3214    let messages = session
3215        .db()
3216        .await
3217        .get_channel_messages_by_id(session.user_id, &message_ids)
3218        .await?;
3219    response.send(proto::GetChannelMessagesResponse {
3220        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3221        messages,
3222    })?;
3223    Ok(())
3224}
3225
3226/// Retrieve the current users notifications
3227async fn get_notifications(
3228    request: proto::GetNotifications,
3229    response: Response<proto::GetNotifications>,
3230    session: Session,
3231) -> Result<()> {
3232    let notifications = session
3233        .db()
3234        .await
3235        .get_notifications(
3236            session.user_id,
3237            NOTIFICATION_COUNT_PER_PAGE,
3238            request
3239                .before_id
3240                .map(|id| db::NotificationId::from_proto(id)),
3241        )
3242        .await?;
3243    response.send(proto::GetNotificationsResponse {
3244        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3245        notifications,
3246    })?;
3247    Ok(())
3248}
3249
3250/// Mark notifications as read
3251async fn mark_notification_as_read(
3252    request: proto::MarkNotificationRead,
3253    response: Response<proto::MarkNotificationRead>,
3254    session: Session,
3255) -> Result<()> {
3256    let database = &session.db().await;
3257    let notifications = database
3258        .mark_notification_as_read_by_id(
3259            session.user_id,
3260            NotificationId::from_proto(request.notification_id),
3261        )
3262        .await?;
3263    send_notifications(
3264        &*session.connection_pool().await,
3265        &session.peer,
3266        notifications,
3267    );
3268    response.send(proto::Ack {})?;
3269    Ok(())
3270}
3271
3272/// Get the current users information
3273async fn get_private_user_info(
3274    _request: proto::GetPrivateUserInfo,
3275    response: Response<proto::GetPrivateUserInfo>,
3276    session: Session,
3277) -> Result<()> {
3278    let db = session.db().await;
3279
3280    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3281    let user = db
3282        .get_user_by_id(session.user_id)
3283        .await?
3284        .ok_or_else(|| anyhow!("user not found"))?;
3285    let flags = db.get_user_flags(session.user_id).await?;
3286
3287    response.send(proto::GetPrivateUserInfoResponse {
3288        metrics_id,
3289        staff: user.admin,
3290        flags,
3291    })?;
3292    Ok(())
3293}
3294
3295fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3296    match message {
3297        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3298        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3299        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3300        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3301        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3302            code: frame.code.into(),
3303            reason: frame.reason,
3304        })),
3305    }
3306}
3307
3308fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3309    match message {
3310        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3311        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3312        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3313        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3314        AxumMessage::Close(frame) => {
3315            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3316                code: frame.code.into(),
3317                reason: frame.reason,
3318            }))
3319        }
3320    }
3321}
3322
3323fn notify_membership_updated(
3324    connection_pool: &ConnectionPool,
3325    result: MembershipUpdated,
3326    user_id: UserId,
3327    peer: &Peer,
3328) {
3329    let user_channels_update = proto::UpdateUserChannels {
3330        channel_memberships: result
3331            .new_channels
3332            .channel_memberships
3333            .iter()
3334            .map(|cm| proto::ChannelMembership {
3335                channel_id: cm.channel_id.to_proto(),
3336                role: cm.role.into(),
3337            })
3338            .collect(),
3339        ..Default::default()
3340    };
3341    let mut update = build_channels_update(result.new_channels, vec![]);
3342    update.delete_channels = result
3343        .removed_channels
3344        .into_iter()
3345        .map(|id| id.to_proto())
3346        .collect();
3347    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3348
3349    for connection_id in connection_pool.user_connection_ids(user_id) {
3350        peer.send(connection_id, user_channels_update.clone())
3351            .trace_err();
3352        peer.send(connection_id, update.clone()).trace_err();
3353    }
3354}
3355
3356fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3357    proto::UpdateUserChannels {
3358        channel_memberships: channels
3359            .channel_memberships
3360            .iter()
3361            .map(|m| proto::ChannelMembership {
3362                channel_id: m.channel_id.to_proto(),
3363                role: m.role.into(),
3364            })
3365            .collect(),
3366        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3367        observed_channel_message_id: channels.observed_channel_messages.clone(),
3368        ..Default::default()
3369    }
3370}
3371
3372fn build_channels_update(
3373    channels: ChannelsForUser,
3374    channel_invites: Vec<db::Channel>,
3375) -> proto::UpdateChannels {
3376    let mut update = proto::UpdateChannels::default();
3377
3378    for channel in channels.channels {
3379        update.channels.push(channel.to_proto());
3380    }
3381
3382    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3383    update.latest_channel_message_ids = channels.latest_channel_messages;
3384
3385    for (channel_id, participants) in channels.channel_participants {
3386        update
3387            .channel_participants
3388            .push(proto::ChannelParticipants {
3389                channel_id: channel_id.to_proto(),
3390                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3391            });
3392    }
3393
3394    for channel in channel_invites {
3395        update.channel_invitations.push(channel.to_proto());
3396    }
3397    for project in channels.hosted_projects {
3398        update.hosted_projects.push(project);
3399    }
3400
3401    update
3402}
3403
3404fn build_initial_contacts_update(
3405    contacts: Vec<db::Contact>,
3406    pool: &ConnectionPool,
3407) -> proto::UpdateContacts {
3408    let mut update = proto::UpdateContacts::default();
3409
3410    for contact in contacts {
3411        match contact {
3412            db::Contact::Accepted { user_id, busy } => {
3413                update.contacts.push(contact_for_user(user_id, busy, &pool));
3414            }
3415            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3416            db::Contact::Incoming { user_id } => {
3417                update
3418                    .incoming_requests
3419                    .push(proto::IncomingContactRequest {
3420                        requester_id: user_id.to_proto(),
3421                    })
3422            }
3423        }
3424    }
3425
3426    update
3427}
3428
3429fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3430    proto::Contact {
3431        user_id: user_id.to_proto(),
3432        online: pool.is_user_online(user_id),
3433        busy,
3434    }
3435}
3436
3437fn room_updated(room: &proto::Room, peer: &Peer) {
3438    broadcast(
3439        None,
3440        room.participants
3441            .iter()
3442            .filter_map(|participant| Some(participant.peer_id?.into())),
3443        |peer_id| {
3444            peer.send(
3445                peer_id.into(),
3446                proto::RoomUpdated {
3447                    room: Some(room.clone()),
3448                },
3449            )
3450        },
3451    );
3452}
3453
3454fn channel_updated(
3455    channel_id: ChannelId,
3456    room: &proto::Room,
3457    channel_members: &[UserId],
3458    peer: &Peer,
3459    pool: &ConnectionPool,
3460) {
3461    let participants = room
3462        .participants
3463        .iter()
3464        .map(|p| p.user_id)
3465        .collect::<Vec<_>>();
3466
3467    broadcast(
3468        None,
3469        channel_members
3470            .iter()
3471            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3472        |peer_id| {
3473            peer.send(
3474                peer_id.into(),
3475                proto::UpdateChannels {
3476                    channel_participants: vec![proto::ChannelParticipants {
3477                        channel_id: channel_id.to_proto(),
3478                        participant_user_ids: participants.clone(),
3479                    }],
3480                    ..Default::default()
3481                },
3482            )
3483        },
3484    );
3485}
3486
3487async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3488    let db = session.db().await;
3489
3490    let contacts = db.get_contacts(user_id).await?;
3491    let busy = db.is_user_busy(user_id).await?;
3492
3493    let pool = session.connection_pool().await;
3494    let updated_contact = contact_for_user(user_id, busy, &pool);
3495    for contact in contacts {
3496        if let db::Contact::Accepted {
3497            user_id: contact_user_id,
3498            ..
3499        } = contact
3500        {
3501            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3502                session
3503                    .peer
3504                    .send(
3505                        contact_conn_id,
3506                        proto::UpdateContacts {
3507                            contacts: vec![updated_contact.clone()],
3508                            remove_contacts: Default::default(),
3509                            incoming_requests: Default::default(),
3510                            remove_incoming_requests: Default::default(),
3511                            outgoing_requests: Default::default(),
3512                            remove_outgoing_requests: Default::default(),
3513                        },
3514                    )
3515                    .trace_err();
3516            }
3517        }
3518    }
3519    Ok(())
3520}
3521
3522async fn leave_room_for_session(session: &Session) -> Result<()> {
3523    let mut contacts_to_update = HashSet::default();
3524
3525    let room_id;
3526    let canceled_calls_to_user_ids;
3527    let live_kit_room;
3528    let delete_live_kit_room;
3529    let room;
3530    let channel_members;
3531    let channel_id;
3532
3533    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3534        contacts_to_update.insert(session.user_id);
3535
3536        for project in left_room.left_projects.values() {
3537            project_left(project, session);
3538        }
3539
3540        room_id = RoomId::from_proto(left_room.room.id);
3541        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3542        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3543        delete_live_kit_room = left_room.deleted;
3544        room = mem::take(&mut left_room.room);
3545        channel_members = mem::take(&mut left_room.channel_members);
3546        channel_id = left_room.channel_id;
3547
3548        room_updated(&room, &session.peer);
3549    } else {
3550        return Ok(());
3551    }
3552
3553    if let Some(channel_id) = channel_id {
3554        channel_updated(
3555            channel_id,
3556            &room,
3557            &channel_members,
3558            &session.peer,
3559            &*session.connection_pool().await,
3560        );
3561    }
3562
3563    {
3564        let pool = session.connection_pool().await;
3565        for canceled_user_id in canceled_calls_to_user_ids {
3566            for connection_id in pool.user_connection_ids(canceled_user_id) {
3567                session
3568                    .peer
3569                    .send(
3570                        connection_id,
3571                        proto::CallCanceled {
3572                            room_id: room_id.to_proto(),
3573                        },
3574                    )
3575                    .trace_err();
3576            }
3577            contacts_to_update.insert(canceled_user_id);
3578        }
3579    }
3580
3581    for contact_user_id in contacts_to_update {
3582        update_user_contacts(contact_user_id, &session).await?;
3583    }
3584
3585    if let Some(live_kit) = session.live_kit_client.as_ref() {
3586        live_kit
3587            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3588            .await
3589            .trace_err();
3590
3591        if delete_live_kit_room {
3592            live_kit.delete_room(live_kit_room).await.trace_err();
3593        }
3594    }
3595
3596    Ok(())
3597}
3598
3599async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3600    let left_channel_buffers = session
3601        .db()
3602        .await
3603        .leave_channel_buffers(session.connection_id)
3604        .await?;
3605
3606    for left_buffer in left_channel_buffers {
3607        channel_buffer_updated(
3608            session.connection_id,
3609            left_buffer.connections,
3610            &proto::UpdateChannelBufferCollaborators {
3611                channel_id: left_buffer.channel_id.to_proto(),
3612                collaborators: left_buffer.collaborators,
3613            },
3614            &session.peer,
3615        );
3616    }
3617
3618    Ok(())
3619}
3620
3621fn project_left(project: &db::LeftProject, session: &Session) {
3622    for connection_id in &project.connection_ids {
3623        if project.host_user_id == session.user_id {
3624            session
3625                .peer
3626                .send(
3627                    *connection_id,
3628                    proto::UnshareProject {
3629                        project_id: project.id.to_proto(),
3630                    },
3631                )
3632                .trace_err();
3633        } else {
3634            session
3635                .peer
3636                .send(
3637                    *connection_id,
3638                    proto::RemoveProjectCollaborator {
3639                        project_id: project.id.to_proto(),
3640                        peer_id: Some(session.connection_id.into()),
3641                    },
3642                )
3643                .trace_err();
3644        }
3645    }
3646}
3647
3648pub trait ResultExt {
3649    type Ok;
3650
3651    fn trace_err(self) -> Option<Self::Ok>;
3652}
3653
3654impl<T, E> ResultExt for Result<T, E>
3655where
3656    E: std::fmt::Debug,
3657{
3658    type Ok = T;
3659
3660    fn trace_err(self) -> Option<T> {
3661        match self {
3662            Ok(value) => Some(value),
3663            Err(error) => {
3664                tracing::error!("{:?}", error);
3665                None
3666            }
3667        }
3668    }
3669}