rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage,
   7        Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
   8        ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
   9        User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    future::Future,
  53    marker::PhantomData,
  54    mem,
  55    net::SocketAddr,
  56    ops::{Deref, DerefMut},
  57    rc::Rc,
  58    sync::{
  59        atomic::{AtomicBool, Ordering::SeqCst},
  60        Arc, OnceLock,
  61    },
  62    time::{Duration, Instant},
  63};
  64use time::OffsetDateTime;
  65use tokio::sync::{watch, Semaphore};
  66use tower::ServiceBuilder;
  67use tracing::{field, info_span, instrument, Instrument};
  68use util::{http::IsahcHttpClient, SemanticVersion};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71
  72// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  73pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  74
  75const MESSAGE_COUNT_PER_PAGE: usize = 100;
  76const MAX_MESSAGE_LEN: usize = 1024;
  77const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  78
  79type MessageHandler =
  80    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  81
  82struct Response<R> {
  83    peer: Arc<Peer>,
  84    receipt: Receipt<R>,
  85    responded: Arc<AtomicBool>,
  86}
  87
  88impl<R: RequestMessage> Response<R> {
  89    fn send(self, payload: R::Response) -> Result<()> {
  90        self.responded.store(true, SeqCst);
  91        self.peer.respond(self.receipt, payload)?;
  92        Ok(())
  93    }
  94}
  95
  96struct StreamingResponse<R: RequestMessage> {
  97    peer: Arc<Peer>,
  98    receipt: Receipt<R>,
  99}
 100
 101impl<R: RequestMessage> StreamingResponse<R> {
 102    fn send(&self, payload: R::Response) -> Result<()> {
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone)]
 109struct Session {
 110    user_id: UserId,
 111    connection_id: ConnectionId,
 112    db: Arc<tokio::sync::Mutex<DbHandle>>,
 113    peer: Arc<Peer>,
 114    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 115    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 116    http_client: IsahcHttpClient,
 117    rate_limiter: Arc<RateLimiter>,
 118    _executor: Executor,
 119}
 120
 121impl Session {
 122    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.db.lock().await;
 126        #[cfg(test)]
 127        tokio::task::yield_now().await;
 128        guard
 129    }
 130
 131    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 132        #[cfg(test)]
 133        tokio::task::yield_now().await;
 134        let guard = self.connection_pool.lock();
 135        ConnectionPoolGuard {
 136            guard,
 137            _not_send: PhantomData,
 138        }
 139    }
 140}
 141
 142impl Debug for Session {
 143    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 144        f.debug_struct("Session")
 145            .field("user_id", &self.user_id)
 146            .field("connection_id", &self.connection_id)
 147            .finish()
 148    }
 149}
 150
 151struct DbHandle(Arc<Database>);
 152
 153impl Deref for DbHandle {
 154    type Target = Database;
 155
 156    fn deref(&self) -> &Self::Target {
 157        self.0.as_ref()
 158    }
 159}
 160
 161pub struct Server {
 162    id: parking_lot::Mutex<ServerId>,
 163    peer: Arc<Peer>,
 164    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 165    app_state: Arc<AppState>,
 166    handlers: HashMap<TypeId, MessageHandler>,
 167    teardown: watch::Sender<bool>,
 168}
 169
 170pub(crate) struct ConnectionPoolGuard<'a> {
 171    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 172    _not_send: PhantomData<Rc<()>>,
 173}
 174
 175#[derive(Serialize)]
 176pub struct ServerSnapshot<'a> {
 177    peer: &'a Peer,
 178    #[serde(serialize_with = "serialize_deref")]
 179    connection_pool: ConnectionPoolGuard<'a>,
 180}
 181
 182pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 183where
 184    S: Serializer,
 185    T: Deref<Target = U>,
 186    U: Serialize,
 187{
 188    Serialize::serialize(value.deref(), serializer)
 189}
 190
 191impl Server {
 192    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 193        let mut server = Self {
 194            id: parking_lot::Mutex::new(id),
 195            peer: Peer::new(id.0 as u32),
 196            app_state: app_state.clone(),
 197            connection_pool: Default::default(),
 198            handlers: Default::default(),
 199            teardown: watch::channel(false).0,
 200        };
 201
 202        server
 203            .add_request_handler(ping)
 204            .add_request_handler(create_room)
 205            .add_request_handler(join_room)
 206            .add_request_handler(rejoin_room)
 207            .add_request_handler(leave_room)
 208            .add_request_handler(set_room_participant_role)
 209            .add_request_handler(call)
 210            .add_request_handler(cancel_call)
 211            .add_message_handler(decline_call)
 212            .add_request_handler(update_participant_location)
 213            .add_request_handler(share_project)
 214            .add_message_handler(unshare_project)
 215            .add_request_handler(join_project)
 216            .add_request_handler(join_hosted_project)
 217            .add_message_handler(leave_project)
 218            .add_request_handler(update_project)
 219            .add_request_handler(update_worktree)
 220            .add_message_handler(start_language_server)
 221            .add_message_handler(update_language_server)
 222            .add_message_handler(update_diagnostic_summary)
 223            .add_message_handler(update_worktree_settings)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 227            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 228            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 229            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 230            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 231            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 232            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 233            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 234            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 235            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 236            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 237            .add_request_handler(
 238                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 239            )
 240            .add_request_handler(
 241                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 242            )
 243            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 244            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 245            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 246            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 247            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 248            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 249            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 251            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 252            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 253            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 254            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 255            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 256            .add_message_handler(create_buffer_for_peer)
 257            .add_request_handler(update_buffer)
 258            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 259            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 260            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 261            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 262            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 263            .add_request_handler(get_users)
 264            .add_request_handler(fuzzy_search_users)
 265            .add_request_handler(request_contact)
 266            .add_request_handler(remove_contact)
 267            .add_request_handler(respond_to_contact_request)
 268            .add_request_handler(create_channel)
 269            .add_request_handler(delete_channel)
 270            .add_request_handler(invite_channel_member)
 271            .add_request_handler(remove_channel_member)
 272            .add_request_handler(set_channel_member_role)
 273            .add_request_handler(set_channel_visibility)
 274            .add_request_handler(rename_channel)
 275            .add_request_handler(join_channel_buffer)
 276            .add_request_handler(leave_channel_buffer)
 277            .add_message_handler(update_channel_buffer)
 278            .add_request_handler(rejoin_channel_buffers)
 279            .add_request_handler(get_channel_members)
 280            .add_request_handler(respond_to_channel_invite)
 281            .add_request_handler(join_channel)
 282            .add_request_handler(join_channel_chat)
 283            .add_message_handler(leave_channel_chat)
 284            .add_request_handler(send_channel_message)
 285            .add_request_handler(remove_channel_message)
 286            .add_request_handler(get_channel_messages)
 287            .add_request_handler(get_channel_messages_by_id)
 288            .add_request_handler(get_notifications)
 289            .add_request_handler(mark_notification_as_read)
 290            .add_request_handler(move_channel)
 291            .add_request_handler(follow)
 292            .add_message_handler(unfollow)
 293            .add_message_handler(update_followers)
 294            .add_request_handler(get_private_user_info)
 295            .add_message_handler(acknowledge_channel_message)
 296            .add_message_handler(acknowledge_buffer_version)
 297            .add_streaming_request_handler({
 298                let app_state = app_state.clone();
 299                move |request, response, session| {
 300                    complete_with_language_model(
 301                        request,
 302                        response,
 303                        session,
 304                        app_state.config.openai_api_key.clone(),
 305                        app_state.config.google_ai_api_key.clone(),
 306                    )
 307                }
 308            })
 309            .add_request_handler({
 310                let app_state = app_state.clone();
 311                move |request, response, session| {
 312                    count_tokens_with_language_model(
 313                        request,
 314                        response,
 315                        session,
 316                        app_state.config.google_ai_api_key.clone(),
 317                    )
 318                }
 319            });
 320
 321        Arc::new(server)
 322    }
 323
 324    pub async fn start(&self) -> Result<()> {
 325        let server_id = *self.id.lock();
 326        let app_state = self.app_state.clone();
 327        let peer = self.peer.clone();
 328        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 329        let pool = self.connection_pool.clone();
 330        let live_kit_client = self.app_state.live_kit_client.clone();
 331
 332        let span = info_span!("start server");
 333        self.app_state.executor.spawn_detached(
 334            async move {
 335                tracing::info!("waiting for cleanup timeout");
 336                timeout.await;
 337                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 338                if let Some((room_ids, channel_ids)) = app_state
 339                    .db
 340                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 341                    .await
 342                    .trace_err()
 343                {
 344                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 345                    tracing::info!(
 346                        stale_channel_buffer_count = channel_ids.len(),
 347                        "retrieved stale channel buffers"
 348                    );
 349
 350                    for channel_id in channel_ids {
 351                        if let Some(refreshed_channel_buffer) = app_state
 352                            .db
 353                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 354                            .await
 355                            .trace_err()
 356                        {
 357                            for connection_id in refreshed_channel_buffer.connection_ids {
 358                                peer.send(
 359                                    connection_id,
 360                                    proto::UpdateChannelBufferCollaborators {
 361                                        channel_id: channel_id.to_proto(),
 362                                        collaborators: refreshed_channel_buffer
 363                                            .collaborators
 364                                            .clone(),
 365                                    },
 366                                )
 367                                .trace_err();
 368                            }
 369                        }
 370                    }
 371
 372                    for room_id in room_ids {
 373                        let mut contacts_to_update = HashSet::default();
 374                        let mut canceled_calls_to_user_ids = Vec::new();
 375                        let mut live_kit_room = String::new();
 376                        let mut delete_live_kit_room = false;
 377
 378                        if let Some(mut refreshed_room) = app_state
 379                            .db
 380                            .clear_stale_room_participants(room_id, server_id)
 381                            .await
 382                            .trace_err()
 383                        {
 384                            tracing::info!(
 385                                room_id = room_id.0,
 386                                new_participant_count = refreshed_room.room.participants.len(),
 387                                "refreshed room"
 388                            );
 389                            room_updated(&refreshed_room.room, &peer);
 390                            if let Some(channel) = refreshed_room.channel.as_ref() {
 391                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 392                            }
 393                            contacts_to_update
 394                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 395                            contacts_to_update
 396                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 397                            canceled_calls_to_user_ids =
 398                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 399                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 400                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 401                        }
 402
 403                        {
 404                            let pool = pool.lock();
 405                            for canceled_user_id in canceled_calls_to_user_ids {
 406                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 407                                    peer.send(
 408                                        connection_id,
 409                                        proto::CallCanceled {
 410                                            room_id: room_id.to_proto(),
 411                                        },
 412                                    )
 413                                    .trace_err();
 414                                }
 415                            }
 416                        }
 417
 418                        for user_id in contacts_to_update {
 419                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 420                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 421                            if let Some((busy, contacts)) = busy.zip(contacts) {
 422                                let pool = pool.lock();
 423                                let updated_contact = contact_for_user(user_id, busy, &pool);
 424                                for contact in contacts {
 425                                    if let db::Contact::Accepted {
 426                                        user_id: contact_user_id,
 427                                        ..
 428                                    } = contact
 429                                    {
 430                                        for contact_conn_id in
 431                                            pool.user_connection_ids(contact_user_id)
 432                                        {
 433                                            peer.send(
 434                                                contact_conn_id,
 435                                                proto::UpdateContacts {
 436                                                    contacts: vec![updated_contact.clone()],
 437                                                    remove_contacts: Default::default(),
 438                                                    incoming_requests: Default::default(),
 439                                                    remove_incoming_requests: Default::default(),
 440                                                    outgoing_requests: Default::default(),
 441                                                    remove_outgoing_requests: Default::default(),
 442                                                },
 443                                            )
 444                                            .trace_err();
 445                                        }
 446                                    }
 447                                }
 448                            }
 449                        }
 450
 451                        if let Some(live_kit) = live_kit_client.as_ref() {
 452                            if delete_live_kit_room {
 453                                live_kit.delete_room(live_kit_room).await.trace_err();
 454                            }
 455                        }
 456                    }
 457                }
 458
 459                app_state
 460                    .db
 461                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 462                    .await
 463                    .trace_err();
 464            }
 465            .instrument(span),
 466        );
 467        Ok(())
 468    }
 469
 470    pub fn teardown(&self) {
 471        self.peer.teardown();
 472        self.connection_pool.lock().reset();
 473        let _ = self.teardown.send(true);
 474    }
 475
 476    #[cfg(test)]
 477    pub fn reset(&self, id: ServerId) {
 478        self.teardown();
 479        *self.id.lock() = id;
 480        self.peer.reset(id.0 as u32);
 481        let _ = self.teardown.send(false);
 482    }
 483
 484    #[cfg(test)]
 485    pub fn id(&self) -> ServerId {
 486        *self.id.lock()
 487    }
 488
 489    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 490    where
 491        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 492        Fut: 'static + Send + Future<Output = Result<()>>,
 493        M: EnvelopedMessage,
 494    {
 495        let prev_handler = self.handlers.insert(
 496            TypeId::of::<M>(),
 497            Box::new(move |envelope, session| {
 498                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 499                let received_at = envelope.received_at;
 500                    tracing::info!(
 501                        "message received"
 502                    );
 503                let start_time = Instant::now();
 504                let future = (handler)(*envelope, session);
 505                async move {
 506                    let result = future.await;
 507                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 508                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 509                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 510                    match result {
 511                        Err(error) => {
 512                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 513                        }
 514                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 515                    }
 516                }
 517                .boxed()
 518            }),
 519        );
 520        if prev_handler.is_some() {
 521            panic!("registered a handler for the same message twice");
 522        }
 523        self
 524    }
 525
 526    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 527    where
 528        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 529        Fut: 'static + Send + Future<Output = Result<()>>,
 530        M: EnvelopedMessage,
 531    {
 532        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 533        self
 534    }
 535
 536    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 537    where
 538        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 539        Fut: Send + Future<Output = Result<()>>,
 540        M: RequestMessage,
 541    {
 542        let handler = Arc::new(handler);
 543        self.add_handler(move |envelope, session| {
 544            let receipt = envelope.receipt();
 545            let handler = handler.clone();
 546            async move {
 547                let peer = session.peer.clone();
 548                let responded = Arc::new(AtomicBool::default());
 549                let response = Response {
 550                    peer: peer.clone(),
 551                    responded: responded.clone(),
 552                    receipt,
 553                };
 554                match (handler)(envelope.payload, response, session).await {
 555                    Ok(()) => {
 556                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 557                            Ok(())
 558                        } else {
 559                            Err(anyhow!("handler did not send a response"))?
 560                        }
 561                    }
 562                    Err(error) => {
 563                        let proto_err = match &error {
 564                            Error::Internal(err) => err.to_proto(),
 565                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 566                        };
 567                        peer.respond_with_error(receipt, proto_err)?;
 568                        Err(error)
 569                    }
 570                }
 571            }
 572        })
 573    }
 574
 575    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 576    where
 577        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 578        Fut: Send + Future<Output = Result<()>>,
 579        M: RequestMessage,
 580    {
 581        let handler = Arc::new(handler);
 582        self.add_handler(move |envelope, session| {
 583            let receipt = envelope.receipt();
 584            let handler = handler.clone();
 585            async move {
 586                let peer = session.peer.clone();
 587                let response = StreamingResponse {
 588                    peer: peer.clone(),
 589                    receipt,
 590                };
 591                match (handler)(envelope.payload, response, session).await {
 592                    Ok(()) => {
 593                        peer.end_stream(receipt)?;
 594                        Ok(())
 595                    }
 596                    Err(error) => {
 597                        let proto_err = match &error {
 598                            Error::Internal(err) => err.to_proto(),
 599                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 600                        };
 601                        peer.respond_with_error(receipt, proto_err)?;
 602                        Err(error)
 603                    }
 604                }
 605            }
 606        })
 607    }
 608
 609    #[allow(clippy::too_many_arguments)]
 610    pub fn handle_connection(
 611        self: &Arc<Self>,
 612        connection: Connection,
 613        address: String,
 614        user: User,
 615        zed_version: ZedVersion,
 616        impersonator: Option<User>,
 617        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 618        executor: Executor,
 619    ) -> impl Future<Output = ()> {
 620        let this = self.clone();
 621        let user_id = user.id;
 622        let login = user.github_login.clone();
 623        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty);
 624        if let Some(impersonator) = impersonator {
 625            span.record("impersonator", &impersonator.github_login);
 626        }
 627        let mut teardown = self.teardown.subscribe();
 628        async move {
 629            if *teardown.borrow() {
 630                tracing::error!("server is tearing down");
 631                return
 632            }
 633            let (connection_id, handle_io, mut incoming_rx) = this
 634                .peer
 635                .add_connection(connection, {
 636                    let executor = executor.clone();
 637                    move |duration| executor.sleep(duration)
 638                });
 639            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 640            tracing::info!("connection opened");
 641
 642            let http_client = match IsahcHttpClient::new() {
 643                Ok(http_client) => http_client,
 644                Err(error) => {
 645                    tracing::error!(?error, "failed to create HTTP client");
 646                    return;
 647                }
 648            };
 649
 650            let session = Session {
 651                user_id,
 652                connection_id,
 653                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 654                peer: this.peer.clone(),
 655                connection_pool: this.connection_pool.clone(),
 656                live_kit_client: this.app_state.live_kit_client.clone(),
 657                http_client,
 658                rate_limiter: this.app_state.rate_limiter.clone(),
 659                _executor: executor.clone(),
 660            };
 661
 662            if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
 663                tracing::error!(?error, "failed to send initial client update");
 664                return;
 665            }
 666
 667            let handle_io = handle_io.fuse();
 668            futures::pin_mut!(handle_io);
 669
 670            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 671            // This prevents deadlocks when e.g., client A performs a request to client B and
 672            // client B performs a request to client A. If both clients stop processing further
 673            // messages until their respective request completes, they won't have a chance to
 674            // respond to the other client's request and cause a deadlock.
 675            //
 676            // This arrangement ensures we will attempt to process earlier messages first, but fall
 677            // back to processing messages arrived later in the spirit of making progress.
 678            let mut foreground_message_handlers = FuturesUnordered::new();
 679            let concurrent_handlers = Arc::new(Semaphore::new(256));
 680            loop {
 681                let next_message = async {
 682                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 683                    let message = incoming_rx.next().await;
 684                    (permit, message)
 685                }.fuse();
 686                futures::pin_mut!(next_message);
 687                futures::select_biased! {
 688                    _ = teardown.changed().fuse() => return,
 689                    result = handle_io => {
 690                        if let Err(error) = result {
 691                            tracing::error!(?error, "error handling I/O");
 692                        }
 693                        break;
 694                    }
 695                    _ = foreground_message_handlers.next() => {}
 696                    next_message = next_message => {
 697                        let (permit, message) = next_message;
 698                        if let Some(message) = message {
 699                            let type_name = message.payload_type_name();
 700                            // note: we copy all the fields from the parent span so we can query them in the logs.
 701                            // (https://github.com/tokio-rs/tracing/issues/2670).
 702                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 703                            let span_enter = span.enter();
 704                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 705                                let is_background = message.is_background();
 706                                let handle_message = (handler)(message, session.clone());
 707                                drop(span_enter);
 708
 709                                let handle_message = async move {
 710                                    handle_message.await;
 711                                    drop(permit);
 712                                }.instrument(span);
 713                                if is_background {
 714                                    executor.spawn_detached(handle_message);
 715                                } else {
 716                                    foreground_message_handlers.push(handle_message);
 717                                }
 718                            } else {
 719                                tracing::error!("no message handler");
 720                            }
 721                        } else {
 722                            tracing::info!("connection closed");
 723                            break;
 724                        }
 725                    }
 726                }
 727            }
 728
 729            drop(foreground_message_handlers);
 730            tracing::info!("signing out");
 731            if let Err(error) = connection_lost(session, teardown, executor).await {
 732                tracing::error!(?error, "error signing out");
 733            }
 734
 735        }.instrument(span)
 736    }
 737
 738    async fn send_initial_client_update(
 739        &self,
 740        connection_id: ConnectionId,
 741        user: User,
 742        zed_version: ZedVersion,
 743        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 744        session: &Session,
 745    ) -> Result<()> {
 746        self.peer.send(
 747            connection_id,
 748            proto::Hello {
 749                peer_id: Some(connection_id.into()),
 750            },
 751        )?;
 752        tracing::info!("sent hello message");
 753
 754        if let Some(send_connection_id) = send_connection_id.take() {
 755            let _ = send_connection_id.send(connection_id);
 756        }
 757
 758        if !user.connected_once {
 759            self.peer.send(connection_id, proto::ShowContacts {})?;
 760            self.app_state
 761                .db
 762                .set_user_connected_once(user.id, true)
 763                .await?;
 764        }
 765
 766        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 767            self.app_state.db.get_contacts(user.id),
 768            self.app_state.db.get_channels_for_user(user.id),
 769            self.app_state.db.get_channel_invites_for_user(user.id),
 770        )
 771        .await?;
 772
 773        {
 774            let mut pool = self.connection_pool.lock();
 775            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 776            for membership in &channels_for_user.channel_memberships {
 777                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 778            }
 779            self.peer.send(
 780                connection_id,
 781                build_initial_contacts_update(contacts, &pool),
 782            )?;
 783            self.peer.send(
 784                connection_id,
 785                build_update_user_channels(&channels_for_user),
 786            )?;
 787            self.peer.send(
 788                connection_id,
 789                build_channels_update(channels_for_user, channel_invites),
 790            )?;
 791        }
 792
 793        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 794            self.peer.send(connection_id, incoming_call)?;
 795        }
 796
 797        update_user_contacts(user.id, &session).await?;
 798        Ok(())
 799    }
 800
 801    pub async fn invite_code_redeemed(
 802        self: &Arc<Self>,
 803        inviter_id: UserId,
 804        invitee_id: UserId,
 805    ) -> Result<()> {
 806        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 807            if let Some(code) = &user.invite_code {
 808                let pool = self.connection_pool.lock();
 809                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 810                for connection_id in pool.user_connection_ids(inviter_id) {
 811                    self.peer.send(
 812                        connection_id,
 813                        proto::UpdateContacts {
 814                            contacts: vec![invitee_contact.clone()],
 815                            ..Default::default()
 816                        },
 817                    )?;
 818                    self.peer.send(
 819                        connection_id,
 820                        proto::UpdateInviteInfo {
 821                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 822                            count: user.invite_count as u32,
 823                        },
 824                    )?;
 825                }
 826            }
 827        }
 828        Ok(())
 829    }
 830
 831    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 832        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 833            if let Some(invite_code) = &user.invite_code {
 834                let pool = self.connection_pool.lock();
 835                for connection_id in pool.user_connection_ids(user_id) {
 836                    self.peer.send(
 837                        connection_id,
 838                        proto::UpdateInviteInfo {
 839                            url: format!(
 840                                "{}{}",
 841                                self.app_state.config.invite_link_prefix, invite_code
 842                            ),
 843                            count: user.invite_count as u32,
 844                        },
 845                    )?;
 846                }
 847            }
 848        }
 849        Ok(())
 850    }
 851
 852    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 853        ServerSnapshot {
 854            connection_pool: ConnectionPoolGuard {
 855                guard: self.connection_pool.lock(),
 856                _not_send: PhantomData,
 857            },
 858            peer: &self.peer,
 859        }
 860    }
 861}
 862
 863impl<'a> Deref for ConnectionPoolGuard<'a> {
 864    type Target = ConnectionPool;
 865
 866    fn deref(&self) -> &Self::Target {
 867        &self.guard
 868    }
 869}
 870
 871impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 872    fn deref_mut(&mut self) -> &mut Self::Target {
 873        &mut self.guard
 874    }
 875}
 876
 877impl<'a> Drop for ConnectionPoolGuard<'a> {
 878    fn drop(&mut self) {
 879        #[cfg(test)]
 880        self.check_invariants();
 881    }
 882}
 883
 884fn broadcast<F>(
 885    sender_id: Option<ConnectionId>,
 886    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 887    mut f: F,
 888) where
 889    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 890{
 891    for receiver_id in receiver_ids {
 892        if Some(receiver_id) != sender_id {
 893            if let Err(error) = f(receiver_id) {
 894                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 895            }
 896        }
 897    }
 898}
 899
 900pub struct ProtocolVersion(u32);
 901
 902impl Header for ProtocolVersion {
 903    fn name() -> &'static HeaderName {
 904        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 905        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 906    }
 907
 908    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 909    where
 910        Self: Sized,
 911        I: Iterator<Item = &'i axum::http::HeaderValue>,
 912    {
 913        let version = values
 914            .next()
 915            .ok_or_else(axum::headers::Error::invalid)?
 916            .to_str()
 917            .map_err(|_| axum::headers::Error::invalid())?
 918            .parse()
 919            .map_err(|_| axum::headers::Error::invalid())?;
 920        Ok(Self(version))
 921    }
 922
 923    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 924        values.extend([self.0.to_string().parse().unwrap()]);
 925    }
 926}
 927
 928pub struct AppVersionHeader(SemanticVersion);
 929impl Header for AppVersionHeader {
 930    fn name() -> &'static HeaderName {
 931        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 932        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 933    }
 934
 935    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 936    where
 937        Self: Sized,
 938        I: Iterator<Item = &'i axum::http::HeaderValue>,
 939    {
 940        let version = values
 941            .next()
 942            .ok_or_else(axum::headers::Error::invalid)?
 943            .to_str()
 944            .map_err(|_| axum::headers::Error::invalid())?
 945            .parse()
 946            .map_err(|_| axum::headers::Error::invalid())?;
 947        Ok(Self(version))
 948    }
 949
 950    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 951        values.extend([self.0.to_string().parse().unwrap()]);
 952    }
 953}
 954
 955pub fn routes(server: Arc<Server>) -> Router<(), Body> {
 956    Router::new()
 957        .route("/rpc", get(handle_websocket_request))
 958        .layer(
 959            ServiceBuilder::new()
 960                .layer(Extension(server.app_state.clone()))
 961                .layer(middleware::from_fn(auth::validate_header)),
 962        )
 963        .route("/metrics", get(handle_metrics))
 964        .layer(Extension(server))
 965}
 966
 967pub async fn handle_websocket_request(
 968    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 969    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 970    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 971    Extension(server): Extension<Arc<Server>>,
 972    Extension(user): Extension<User>,
 973    Extension(impersonator): Extension<Impersonator>,
 974    ws: WebSocketUpgrade,
 975) -> axum::response::Response {
 976    if protocol_version != rpc::PROTOCOL_VERSION {
 977        return (
 978            StatusCode::UPGRADE_REQUIRED,
 979            "client must be upgraded".to_string(),
 980        )
 981            .into_response();
 982    }
 983
 984    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 985        return (
 986            StatusCode::UPGRADE_REQUIRED,
 987            "no version header found".to_string(),
 988        )
 989            .into_response();
 990    };
 991
 992    if !version.can_collaborate() {
 993        return (
 994            StatusCode::UPGRADE_REQUIRED,
 995            "client must be upgraded".to_string(),
 996        )
 997            .into_response();
 998    }
 999
1000    let socket_address = socket_address.to_string();
1001    ws.on_upgrade(move |socket| {
1002        let socket = socket
1003            .map_ok(to_tungstenite_message)
1004            .err_into()
1005            .with(|message| async move { Ok(to_axum_message(message)) });
1006        let connection = Connection::new(Box::pin(socket));
1007        async move {
1008            server
1009                .handle_connection(
1010                    connection,
1011                    socket_address,
1012                    user,
1013                    version,
1014                    impersonator.0,
1015                    None,
1016                    Executor::Production,
1017                )
1018                .await;
1019        }
1020    })
1021}
1022
1023pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1024    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1025    let connections_metric = CONNECTIONS_METRIC
1026        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1027
1028    let connections = server
1029        .connection_pool
1030        .lock()
1031        .connections()
1032        .filter(|connection| !connection.admin)
1033        .count();
1034    connections_metric.set(connections as _);
1035
1036    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1037    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1038        register_int_gauge!(
1039            "shared_projects",
1040            "number of open projects with one or more guests"
1041        )
1042        .unwrap()
1043    });
1044
1045    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1046    shared_projects_metric.set(shared_projects as _);
1047
1048    let encoder = prometheus::TextEncoder::new();
1049    let metric_families = prometheus::gather();
1050    let encoded_metrics = encoder
1051        .encode_to_string(&metric_families)
1052        .map_err(|err| anyhow!("{}", err))?;
1053    Ok(encoded_metrics)
1054}
1055
1056#[instrument(err, skip(executor))]
1057async fn connection_lost(
1058    session: Session,
1059    mut teardown: watch::Receiver<bool>,
1060    executor: Executor,
1061) -> Result<()> {
1062    session.peer.disconnect(session.connection_id);
1063    session
1064        .connection_pool()
1065        .await
1066        .remove_connection(session.connection_id)?;
1067
1068    session
1069        .db()
1070        .await
1071        .connection_lost(session.connection_id)
1072        .await
1073        .trace_err();
1074
1075    futures::select_biased! {
1076        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1077            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
1078            leave_room_for_session(&session).await.trace_err();
1079            leave_channel_buffers_for_session(&session)
1080                .await
1081                .trace_err();
1082
1083            if !session
1084                .connection_pool()
1085                .await
1086                .is_user_online(session.user_id)
1087            {
1088                let db = session.db().await;
1089                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
1090                    room_updated(&room, &session.peer);
1091                }
1092            }
1093
1094            update_user_contacts(session.user_id, &session).await?;
1095        }
1096        _ = teardown.changed().fuse() => {}
1097    }
1098
1099    Ok(())
1100}
1101
1102/// Acknowledges a ping from a client, used to keep the connection alive.
1103async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1104    response.send(proto::Ack {})?;
1105    Ok(())
1106}
1107
1108/// Creates a new room for calling (outside of channels)
1109async fn create_room(
1110    _request: proto::CreateRoom,
1111    response: Response<proto::CreateRoom>,
1112    session: Session,
1113) -> Result<()> {
1114    let live_kit_room = nanoid::nanoid!(30);
1115
1116    let live_kit_connection_info = {
1117        let live_kit_room = live_kit_room.clone();
1118        let live_kit = session.live_kit_client.as_ref();
1119
1120        util::async_maybe!({
1121            let live_kit = live_kit?;
1122
1123            let token = live_kit
1124                .room_token(&live_kit_room, &session.user_id.to_string())
1125                .trace_err()?;
1126
1127            Some(proto::LiveKitConnectionInfo {
1128                server_url: live_kit.url().into(),
1129                token,
1130                can_publish: true,
1131            })
1132        })
1133    }
1134    .await;
1135
1136    let room = session
1137        .db()
1138        .await
1139        .create_room(session.user_id, session.connection_id, &live_kit_room)
1140        .await?;
1141
1142    response.send(proto::CreateRoomResponse {
1143        room: Some(room.clone()),
1144        live_kit_connection_info,
1145    })?;
1146
1147    update_user_contacts(session.user_id, &session).await?;
1148    Ok(())
1149}
1150
1151/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1152async fn join_room(
1153    request: proto::JoinRoom,
1154    response: Response<proto::JoinRoom>,
1155    session: Session,
1156) -> Result<()> {
1157    let room_id = RoomId::from_proto(request.id);
1158
1159    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1160
1161    if let Some(channel_id) = channel_id {
1162        return join_channel_internal(channel_id, Box::new(response), session).await;
1163    }
1164
1165    let joined_room = {
1166        let room = session
1167            .db()
1168            .await
1169            .join_room(room_id, session.user_id, session.connection_id)
1170            .await?;
1171        room_updated(&room.room, &session.peer);
1172        room.into_inner()
1173    };
1174
1175    for connection_id in session
1176        .connection_pool()
1177        .await
1178        .user_connection_ids(session.user_id)
1179    {
1180        session
1181            .peer
1182            .send(
1183                connection_id,
1184                proto::CallCanceled {
1185                    room_id: room_id.to_proto(),
1186                },
1187            )
1188            .trace_err();
1189    }
1190
1191    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1192        if let Some(token) = live_kit
1193            .room_token(
1194                &joined_room.room.live_kit_room,
1195                &session.user_id.to_string(),
1196            )
1197            .trace_err()
1198        {
1199            Some(proto::LiveKitConnectionInfo {
1200                server_url: live_kit.url().into(),
1201                token,
1202                can_publish: true,
1203            })
1204        } else {
1205            None
1206        }
1207    } else {
1208        None
1209    };
1210
1211    response.send(proto::JoinRoomResponse {
1212        room: Some(joined_room.room),
1213        channel_id: None,
1214        live_kit_connection_info,
1215    })?;
1216
1217    update_user_contacts(session.user_id, &session).await?;
1218    Ok(())
1219}
1220
1221/// Rejoin room is used to reconnect to a room after connection errors.
1222async fn rejoin_room(
1223    request: proto::RejoinRoom,
1224    response: Response<proto::RejoinRoom>,
1225    session: Session,
1226) -> Result<()> {
1227    let room;
1228    let channel;
1229    {
1230        let mut rejoined_room = session
1231            .db()
1232            .await
1233            .rejoin_room(request, session.user_id, session.connection_id)
1234            .await?;
1235
1236        response.send(proto::RejoinRoomResponse {
1237            room: Some(rejoined_room.room.clone()),
1238            reshared_projects: rejoined_room
1239                .reshared_projects
1240                .iter()
1241                .map(|project| proto::ResharedProject {
1242                    id: project.id.to_proto(),
1243                    collaborators: project
1244                        .collaborators
1245                        .iter()
1246                        .map(|collaborator| collaborator.to_proto())
1247                        .collect(),
1248                })
1249                .collect(),
1250            rejoined_projects: rejoined_room
1251                .rejoined_projects
1252                .iter()
1253                .map(|rejoined_project| proto::RejoinedProject {
1254                    id: rejoined_project.id.to_proto(),
1255                    worktrees: rejoined_project
1256                        .worktrees
1257                        .iter()
1258                        .map(|worktree| proto::WorktreeMetadata {
1259                            id: worktree.id,
1260                            root_name: worktree.root_name.clone(),
1261                            visible: worktree.visible,
1262                            abs_path: worktree.abs_path.clone(),
1263                        })
1264                        .collect(),
1265                    collaborators: rejoined_project
1266                        .collaborators
1267                        .iter()
1268                        .map(|collaborator| collaborator.to_proto())
1269                        .collect(),
1270                    language_servers: rejoined_project.language_servers.clone(),
1271                })
1272                .collect(),
1273        })?;
1274        room_updated(&rejoined_room.room, &session.peer);
1275
1276        for project in &rejoined_room.reshared_projects {
1277            for collaborator in &project.collaborators {
1278                session
1279                    .peer
1280                    .send(
1281                        collaborator.connection_id,
1282                        proto::UpdateProjectCollaborator {
1283                            project_id: project.id.to_proto(),
1284                            old_peer_id: Some(project.old_connection_id.into()),
1285                            new_peer_id: Some(session.connection_id.into()),
1286                        },
1287                    )
1288                    .trace_err();
1289            }
1290
1291            broadcast(
1292                Some(session.connection_id),
1293                project
1294                    .collaborators
1295                    .iter()
1296                    .map(|collaborator| collaborator.connection_id),
1297                |connection_id| {
1298                    session.peer.forward_send(
1299                        session.connection_id,
1300                        connection_id,
1301                        proto::UpdateProject {
1302                            project_id: project.id.to_proto(),
1303                            worktrees: project.worktrees.clone(),
1304                        },
1305                    )
1306                },
1307            );
1308        }
1309
1310        for project in &rejoined_room.rejoined_projects {
1311            for collaborator in &project.collaborators {
1312                session
1313                    .peer
1314                    .send(
1315                        collaborator.connection_id,
1316                        proto::UpdateProjectCollaborator {
1317                            project_id: project.id.to_proto(),
1318                            old_peer_id: Some(project.old_connection_id.into()),
1319                            new_peer_id: Some(session.connection_id.into()),
1320                        },
1321                    )
1322                    .trace_err();
1323            }
1324        }
1325
1326        for project in &mut rejoined_room.rejoined_projects {
1327            for worktree in mem::take(&mut project.worktrees) {
1328                #[cfg(any(test, feature = "test-support"))]
1329                const MAX_CHUNK_SIZE: usize = 2;
1330                #[cfg(not(any(test, feature = "test-support")))]
1331                const MAX_CHUNK_SIZE: usize = 256;
1332
1333                // Stream this worktree's entries.
1334                let message = proto::UpdateWorktree {
1335                    project_id: project.id.to_proto(),
1336                    worktree_id: worktree.id,
1337                    abs_path: worktree.abs_path.clone(),
1338                    root_name: worktree.root_name,
1339                    updated_entries: worktree.updated_entries,
1340                    removed_entries: worktree.removed_entries,
1341                    scan_id: worktree.scan_id,
1342                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1343                    updated_repositories: worktree.updated_repositories,
1344                    removed_repositories: worktree.removed_repositories,
1345                };
1346                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1347                    session.peer.send(session.connection_id, update.clone())?;
1348                }
1349
1350                // Stream this worktree's diagnostics.
1351                for summary in worktree.diagnostic_summaries {
1352                    session.peer.send(
1353                        session.connection_id,
1354                        proto::UpdateDiagnosticSummary {
1355                            project_id: project.id.to_proto(),
1356                            worktree_id: worktree.id,
1357                            summary: Some(summary),
1358                        },
1359                    )?;
1360                }
1361
1362                for settings_file in worktree.settings_files {
1363                    session.peer.send(
1364                        session.connection_id,
1365                        proto::UpdateWorktreeSettings {
1366                            project_id: project.id.to_proto(),
1367                            worktree_id: worktree.id,
1368                            path: settings_file.path,
1369                            content: Some(settings_file.content),
1370                        },
1371                    )?;
1372                }
1373            }
1374
1375            for language_server in &project.language_servers {
1376                session.peer.send(
1377                    session.connection_id,
1378                    proto::UpdateLanguageServer {
1379                        project_id: project.id.to_proto(),
1380                        language_server_id: language_server.id,
1381                        variant: Some(
1382                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1383                                proto::LspDiskBasedDiagnosticsUpdated {},
1384                            ),
1385                        ),
1386                    },
1387                )?;
1388            }
1389        }
1390
1391        let rejoined_room = rejoined_room.into_inner();
1392
1393        room = rejoined_room.room;
1394        channel = rejoined_room.channel;
1395    }
1396
1397    if let Some(channel) = channel {
1398        channel_updated(
1399            &channel,
1400            &room,
1401            &session.peer,
1402            &*session.connection_pool().await,
1403        );
1404    }
1405
1406    update_user_contacts(session.user_id, &session).await?;
1407    Ok(())
1408}
1409
1410/// leave room disconnects from the room.
1411async fn leave_room(
1412    _: proto::LeaveRoom,
1413    response: Response<proto::LeaveRoom>,
1414    session: Session,
1415) -> Result<()> {
1416    leave_room_for_session(&session).await?;
1417    response.send(proto::Ack {})?;
1418    Ok(())
1419}
1420
1421/// Updates the permissions of someone else in the room.
1422async fn set_room_participant_role(
1423    request: proto::SetRoomParticipantRole,
1424    response: Response<proto::SetRoomParticipantRole>,
1425    session: Session,
1426) -> Result<()> {
1427    let user_id = UserId::from_proto(request.user_id);
1428    let role = ChannelRole::from(request.role());
1429
1430    let (live_kit_room, can_publish) = {
1431        let room = session
1432            .db()
1433            .await
1434            .set_room_participant_role(
1435                session.user_id,
1436                RoomId::from_proto(request.room_id),
1437                user_id,
1438                role,
1439            )
1440            .await?;
1441
1442        let live_kit_room = room.live_kit_room.clone();
1443        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1444        room_updated(&room, &session.peer);
1445        (live_kit_room, can_publish)
1446    };
1447
1448    if let Some(live_kit) = session.live_kit_client.as_ref() {
1449        live_kit
1450            .update_participant(
1451                live_kit_room.clone(),
1452                request.user_id.to_string(),
1453                live_kit_server::proto::ParticipantPermission {
1454                    can_subscribe: true,
1455                    can_publish,
1456                    can_publish_data: can_publish,
1457                    hidden: false,
1458                    recorder: false,
1459                },
1460            )
1461            .await
1462            .trace_err();
1463    }
1464
1465    response.send(proto::Ack {})?;
1466    Ok(())
1467}
1468
1469/// Call someone else into the current room
1470async fn call(
1471    request: proto::Call,
1472    response: Response<proto::Call>,
1473    session: Session,
1474) -> Result<()> {
1475    let room_id = RoomId::from_proto(request.room_id);
1476    let calling_user_id = session.user_id;
1477    let calling_connection_id = session.connection_id;
1478    let called_user_id = UserId::from_proto(request.called_user_id);
1479    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1480    if !session
1481        .db()
1482        .await
1483        .has_contact(calling_user_id, called_user_id)
1484        .await?
1485    {
1486        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1487    }
1488
1489    let incoming_call = {
1490        let (room, incoming_call) = &mut *session
1491            .db()
1492            .await
1493            .call(
1494                room_id,
1495                calling_user_id,
1496                calling_connection_id,
1497                called_user_id,
1498                initial_project_id,
1499            )
1500            .await?;
1501        room_updated(&room, &session.peer);
1502        mem::take(incoming_call)
1503    };
1504    update_user_contacts(called_user_id, &session).await?;
1505
1506    let mut calls = session
1507        .connection_pool()
1508        .await
1509        .user_connection_ids(called_user_id)
1510        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1511        .collect::<FuturesUnordered<_>>();
1512
1513    while let Some(call_response) = calls.next().await {
1514        match call_response.as_ref() {
1515            Ok(_) => {
1516                response.send(proto::Ack {})?;
1517                return Ok(());
1518            }
1519            Err(_) => {
1520                call_response.trace_err();
1521            }
1522        }
1523    }
1524
1525    {
1526        let room = session
1527            .db()
1528            .await
1529            .call_failed(room_id, called_user_id)
1530            .await?;
1531        room_updated(&room, &session.peer);
1532    }
1533    update_user_contacts(called_user_id, &session).await?;
1534
1535    Err(anyhow!("failed to ring user"))?
1536}
1537
1538/// Cancel an outgoing call.
1539async fn cancel_call(
1540    request: proto::CancelCall,
1541    response: Response<proto::CancelCall>,
1542    session: Session,
1543) -> Result<()> {
1544    let called_user_id = UserId::from_proto(request.called_user_id);
1545    let room_id = RoomId::from_proto(request.room_id);
1546    {
1547        let room = session
1548            .db()
1549            .await
1550            .cancel_call(room_id, session.connection_id, called_user_id)
1551            .await?;
1552        room_updated(&room, &session.peer);
1553    }
1554
1555    for connection_id in session
1556        .connection_pool()
1557        .await
1558        .user_connection_ids(called_user_id)
1559    {
1560        session
1561            .peer
1562            .send(
1563                connection_id,
1564                proto::CallCanceled {
1565                    room_id: room_id.to_proto(),
1566                },
1567            )
1568            .trace_err();
1569    }
1570    response.send(proto::Ack {})?;
1571
1572    update_user_contacts(called_user_id, &session).await?;
1573    Ok(())
1574}
1575
1576/// Decline an incoming call.
1577async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1578    let room_id = RoomId::from_proto(message.room_id);
1579    {
1580        let room = session
1581            .db()
1582            .await
1583            .decline_call(Some(room_id), session.user_id)
1584            .await?
1585            .ok_or_else(|| anyhow!("failed to decline call"))?;
1586        room_updated(&room, &session.peer);
1587    }
1588
1589    for connection_id in session
1590        .connection_pool()
1591        .await
1592        .user_connection_ids(session.user_id)
1593    {
1594        session
1595            .peer
1596            .send(
1597                connection_id,
1598                proto::CallCanceled {
1599                    room_id: room_id.to_proto(),
1600                },
1601            )
1602            .trace_err();
1603    }
1604    update_user_contacts(session.user_id, &session).await?;
1605    Ok(())
1606}
1607
1608/// Updates other participants in the room with your current location.
1609async fn update_participant_location(
1610    request: proto::UpdateParticipantLocation,
1611    response: Response<proto::UpdateParticipantLocation>,
1612    session: Session,
1613) -> Result<()> {
1614    let room_id = RoomId::from_proto(request.room_id);
1615    let location = request
1616        .location
1617        .ok_or_else(|| anyhow!("invalid location"))?;
1618
1619    let db = session.db().await;
1620    let room = db
1621        .update_room_participant_location(room_id, session.connection_id, location)
1622        .await?;
1623
1624    room_updated(&room, &session.peer);
1625    response.send(proto::Ack {})?;
1626    Ok(())
1627}
1628
1629/// Share a project into the room.
1630async fn share_project(
1631    request: proto::ShareProject,
1632    response: Response<proto::ShareProject>,
1633    session: Session,
1634) -> Result<()> {
1635    let (project_id, room) = &*session
1636        .db()
1637        .await
1638        .share_project(
1639            RoomId::from_proto(request.room_id),
1640            session.connection_id,
1641            &request.worktrees,
1642        )
1643        .await?;
1644    response.send(proto::ShareProjectResponse {
1645        project_id: project_id.to_proto(),
1646    })?;
1647    room_updated(&room, &session.peer);
1648
1649    Ok(())
1650}
1651
1652/// Unshare a project from the room.
1653async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1654    let project_id = ProjectId::from_proto(message.project_id);
1655
1656    let (room, guest_connection_ids) = &*session
1657        .db()
1658        .await
1659        .unshare_project(project_id, session.connection_id)
1660        .await?;
1661
1662    broadcast(
1663        Some(session.connection_id),
1664        guest_connection_ids.iter().copied(),
1665        |conn_id| session.peer.send(conn_id, message.clone()),
1666    );
1667    room_updated(&room, &session.peer);
1668
1669    Ok(())
1670}
1671
1672/// Join someone elses shared project.
1673async fn join_project(
1674    request: proto::JoinProject,
1675    response: Response<proto::JoinProject>,
1676    session: Session,
1677) -> Result<()> {
1678    let project_id = ProjectId::from_proto(request.project_id);
1679
1680    tracing::info!(%project_id, "join project");
1681
1682    let (project, replica_id) = &mut *session
1683        .db()
1684        .await
1685        .join_project_in_room(project_id, session.connection_id)
1686        .await?;
1687
1688    join_project_internal(response, session, project, replica_id)
1689}
1690
1691trait JoinProjectInternalResponse {
1692    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1693}
1694impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1695    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1696        Response::<proto::JoinProject>::send(self, result)
1697    }
1698}
1699impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1700    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1701        Response::<proto::JoinHostedProject>::send(self, result)
1702    }
1703}
1704
1705fn join_project_internal(
1706    response: impl JoinProjectInternalResponse,
1707    session: Session,
1708    project: &mut Project,
1709    replica_id: &ReplicaId,
1710) -> Result<()> {
1711    let collaborators = project
1712        .collaborators
1713        .iter()
1714        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1715        .map(|collaborator| collaborator.to_proto())
1716        .collect::<Vec<_>>();
1717    let project_id = project.id;
1718    let guest_user_id = session.user_id;
1719
1720    let worktrees = project
1721        .worktrees
1722        .iter()
1723        .map(|(id, worktree)| proto::WorktreeMetadata {
1724            id: *id,
1725            root_name: worktree.root_name.clone(),
1726            visible: worktree.visible,
1727            abs_path: worktree.abs_path.clone(),
1728        })
1729        .collect::<Vec<_>>();
1730
1731    for collaborator in &collaborators {
1732        session
1733            .peer
1734            .send(
1735                collaborator.peer_id.unwrap().into(),
1736                proto::AddProjectCollaborator {
1737                    project_id: project_id.to_proto(),
1738                    collaborator: Some(proto::Collaborator {
1739                        peer_id: Some(session.connection_id.into()),
1740                        replica_id: replica_id.0 as u32,
1741                        user_id: guest_user_id.to_proto(),
1742                    }),
1743                },
1744            )
1745            .trace_err();
1746    }
1747
1748    // First, we send the metadata associated with each worktree.
1749    response.send(proto::JoinProjectResponse {
1750        project_id: project.id.0 as u64,
1751        worktrees: worktrees.clone(),
1752        replica_id: replica_id.0 as u32,
1753        collaborators: collaborators.clone(),
1754        language_servers: project.language_servers.clone(),
1755        role: project.role.into(), // todo
1756    })?;
1757
1758    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1759        #[cfg(any(test, feature = "test-support"))]
1760        const MAX_CHUNK_SIZE: usize = 2;
1761        #[cfg(not(any(test, feature = "test-support")))]
1762        const MAX_CHUNK_SIZE: usize = 256;
1763
1764        // Stream this worktree's entries.
1765        let message = proto::UpdateWorktree {
1766            project_id: project_id.to_proto(),
1767            worktree_id,
1768            abs_path: worktree.abs_path.clone(),
1769            root_name: worktree.root_name,
1770            updated_entries: worktree.entries,
1771            removed_entries: Default::default(),
1772            scan_id: worktree.scan_id,
1773            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1774            updated_repositories: worktree.repository_entries.into_values().collect(),
1775            removed_repositories: Default::default(),
1776        };
1777        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1778            session.peer.send(session.connection_id, update.clone())?;
1779        }
1780
1781        // Stream this worktree's diagnostics.
1782        for summary in worktree.diagnostic_summaries {
1783            session.peer.send(
1784                session.connection_id,
1785                proto::UpdateDiagnosticSummary {
1786                    project_id: project_id.to_proto(),
1787                    worktree_id: worktree.id,
1788                    summary: Some(summary),
1789                },
1790            )?;
1791        }
1792
1793        for settings_file in worktree.settings_files {
1794            session.peer.send(
1795                session.connection_id,
1796                proto::UpdateWorktreeSettings {
1797                    project_id: project_id.to_proto(),
1798                    worktree_id: worktree.id,
1799                    path: settings_file.path,
1800                    content: Some(settings_file.content),
1801                },
1802            )?;
1803        }
1804    }
1805
1806    for language_server in &project.language_servers {
1807        session.peer.send(
1808            session.connection_id,
1809            proto::UpdateLanguageServer {
1810                project_id: project_id.to_proto(),
1811                language_server_id: language_server.id,
1812                variant: Some(
1813                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1814                        proto::LspDiskBasedDiagnosticsUpdated {},
1815                    ),
1816                ),
1817            },
1818        )?;
1819    }
1820
1821    Ok(())
1822}
1823
1824/// Leave someone elses shared project.
1825async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1826    let sender_id = session.connection_id;
1827    let project_id = ProjectId::from_proto(request.project_id);
1828    let db = session.db().await;
1829    if db.is_hosted_project(project_id).await? {
1830        let project = db.leave_hosted_project(project_id, sender_id).await?;
1831        project_left(&project, &session);
1832        return Ok(());
1833    }
1834
1835    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1836    tracing::info!(
1837        %project_id,
1838        host_user_id = ?project.host_user_id,
1839        host_connection_id = ?project.host_connection_id,
1840        "leave project"
1841    );
1842
1843    project_left(&project, &session);
1844    room_updated(&room, &session.peer);
1845
1846    Ok(())
1847}
1848
1849async fn join_hosted_project(
1850    request: proto::JoinHostedProject,
1851    response: Response<proto::JoinHostedProject>,
1852    session: Session,
1853) -> Result<()> {
1854    let (mut project, replica_id) = session
1855        .db()
1856        .await
1857        .join_hosted_project(
1858            ProjectId(request.project_id as i32),
1859            session.user_id,
1860            session.connection_id,
1861        )
1862        .await?;
1863
1864    join_project_internal(response, session, &mut project, &replica_id)
1865}
1866
1867/// Updates other participants with changes to the project
1868async fn update_project(
1869    request: proto::UpdateProject,
1870    response: Response<proto::UpdateProject>,
1871    session: Session,
1872) -> Result<()> {
1873    let project_id = ProjectId::from_proto(request.project_id);
1874    let (room, guest_connection_ids) = &*session
1875        .db()
1876        .await
1877        .update_project(project_id, session.connection_id, &request.worktrees)
1878        .await?;
1879    broadcast(
1880        Some(session.connection_id),
1881        guest_connection_ids.iter().copied(),
1882        |connection_id| {
1883            session
1884                .peer
1885                .forward_send(session.connection_id, connection_id, request.clone())
1886        },
1887    );
1888    room_updated(&room, &session.peer);
1889    response.send(proto::Ack {})?;
1890
1891    Ok(())
1892}
1893
1894/// Updates other participants with changes to the worktree
1895async fn update_worktree(
1896    request: proto::UpdateWorktree,
1897    response: Response<proto::UpdateWorktree>,
1898    session: Session,
1899) -> Result<()> {
1900    let guest_connection_ids = session
1901        .db()
1902        .await
1903        .update_worktree(&request, session.connection_id)
1904        .await?;
1905
1906    broadcast(
1907        Some(session.connection_id),
1908        guest_connection_ids.iter().copied(),
1909        |connection_id| {
1910            session
1911                .peer
1912                .forward_send(session.connection_id, connection_id, request.clone())
1913        },
1914    );
1915    response.send(proto::Ack {})?;
1916    Ok(())
1917}
1918
1919/// Updates other participants with changes to the diagnostics
1920async fn update_diagnostic_summary(
1921    message: proto::UpdateDiagnosticSummary,
1922    session: Session,
1923) -> Result<()> {
1924    let guest_connection_ids = session
1925        .db()
1926        .await
1927        .update_diagnostic_summary(&message, session.connection_id)
1928        .await?;
1929
1930    broadcast(
1931        Some(session.connection_id),
1932        guest_connection_ids.iter().copied(),
1933        |connection_id| {
1934            session
1935                .peer
1936                .forward_send(session.connection_id, connection_id, message.clone())
1937        },
1938    );
1939
1940    Ok(())
1941}
1942
1943/// Updates other participants with changes to the worktree settings
1944async fn update_worktree_settings(
1945    message: proto::UpdateWorktreeSettings,
1946    session: Session,
1947) -> Result<()> {
1948    let guest_connection_ids = session
1949        .db()
1950        .await
1951        .update_worktree_settings(&message, session.connection_id)
1952        .await?;
1953
1954    broadcast(
1955        Some(session.connection_id),
1956        guest_connection_ids.iter().copied(),
1957        |connection_id| {
1958            session
1959                .peer
1960                .forward_send(session.connection_id, connection_id, message.clone())
1961        },
1962    );
1963
1964    Ok(())
1965}
1966
1967/// Notify other participants that a  language server has started.
1968async fn start_language_server(
1969    request: proto::StartLanguageServer,
1970    session: Session,
1971) -> Result<()> {
1972    let guest_connection_ids = session
1973        .db()
1974        .await
1975        .start_language_server(&request, session.connection_id)
1976        .await?;
1977
1978    broadcast(
1979        Some(session.connection_id),
1980        guest_connection_ids.iter().copied(),
1981        |connection_id| {
1982            session
1983                .peer
1984                .forward_send(session.connection_id, connection_id, request.clone())
1985        },
1986    );
1987    Ok(())
1988}
1989
1990/// Notify other participants that a language server has changed.
1991async fn update_language_server(
1992    request: proto::UpdateLanguageServer,
1993    session: Session,
1994) -> Result<()> {
1995    let project_id = ProjectId::from_proto(request.project_id);
1996    let project_connection_ids = session
1997        .db()
1998        .await
1999        .project_connection_ids(project_id, session.connection_id)
2000        .await?;
2001    broadcast(
2002        Some(session.connection_id),
2003        project_connection_ids.iter().copied(),
2004        |connection_id| {
2005            session
2006                .peer
2007                .forward_send(session.connection_id, connection_id, request.clone())
2008        },
2009    );
2010    Ok(())
2011}
2012
2013/// forward a project request to the host. These requests should be read only
2014/// as guests are allowed to send them.
2015async fn forward_read_only_project_request<T>(
2016    request: T,
2017    response: Response<T>,
2018    session: Session,
2019) -> Result<()>
2020where
2021    T: EntityMessage + RequestMessage,
2022{
2023    let project_id = ProjectId::from_proto(request.remote_entity_id());
2024    let host_connection_id = session
2025        .db()
2026        .await
2027        .host_for_read_only_project_request(project_id, session.connection_id)
2028        .await?;
2029    let payload = session
2030        .peer
2031        .forward_request(session.connection_id, host_connection_id, request)
2032        .await?;
2033    response.send(payload)?;
2034    Ok(())
2035}
2036
2037/// forward a project request to the host. These requests are disallowed
2038/// for guests.
2039async fn forward_mutating_project_request<T>(
2040    request: T,
2041    response: Response<T>,
2042    session: Session,
2043) -> Result<()>
2044where
2045    T: EntityMessage + RequestMessage,
2046{
2047    let project_id = ProjectId::from_proto(request.remote_entity_id());
2048    let host_connection_id = session
2049        .db()
2050        .await
2051        .host_for_mutating_project_request(project_id, session.connection_id)
2052        .await?;
2053    let payload = session
2054        .peer
2055        .forward_request(session.connection_id, host_connection_id, request)
2056        .await?;
2057    response.send(payload)?;
2058    Ok(())
2059}
2060
2061/// Notify other participants that a new buffer has been created
2062async fn create_buffer_for_peer(
2063    request: proto::CreateBufferForPeer,
2064    session: Session,
2065) -> Result<()> {
2066    session
2067        .db()
2068        .await
2069        .check_user_is_project_host(
2070            ProjectId::from_proto(request.project_id),
2071            session.connection_id,
2072        )
2073        .await?;
2074    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2075    session
2076        .peer
2077        .forward_send(session.connection_id, peer_id.into(), request)?;
2078    Ok(())
2079}
2080
2081/// Notify other participants that a buffer has been updated. This is
2082/// allowed for guests as long as the update is limited to selections.
2083async fn update_buffer(
2084    request: proto::UpdateBuffer,
2085    response: Response<proto::UpdateBuffer>,
2086    session: Session,
2087) -> Result<()> {
2088    let project_id = ProjectId::from_proto(request.project_id);
2089    let mut guest_connection_ids;
2090    let mut host_connection_id = None;
2091
2092    let mut requires_write_permission = false;
2093
2094    for op in request.operations.iter() {
2095        match op.variant {
2096            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2097            Some(_) => requires_write_permission = true,
2098        }
2099    }
2100
2101    {
2102        let collaborators = session
2103            .db()
2104            .await
2105            .project_collaborators_for_buffer_update(
2106                project_id,
2107                session.connection_id,
2108                requires_write_permission,
2109            )
2110            .await?;
2111        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2112        for collaborator in collaborators.iter() {
2113            if collaborator.is_host {
2114                host_connection_id = Some(collaborator.connection_id);
2115            } else {
2116                guest_connection_ids.push(collaborator.connection_id);
2117            }
2118        }
2119    }
2120    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2121
2122    broadcast(
2123        Some(session.connection_id),
2124        guest_connection_ids,
2125        |connection_id| {
2126            session
2127                .peer
2128                .forward_send(session.connection_id, connection_id, request.clone())
2129        },
2130    );
2131    if host_connection_id != session.connection_id {
2132        session
2133            .peer
2134            .forward_request(session.connection_id, host_connection_id, request.clone())
2135            .await?;
2136    }
2137
2138    response.send(proto::Ack {})?;
2139    Ok(())
2140}
2141
2142/// Notify other participants that a project has been updated.
2143async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2144    request: T,
2145    session: Session,
2146) -> Result<()> {
2147    let project_id = ProjectId::from_proto(request.remote_entity_id());
2148    let project_connection_ids = session
2149        .db()
2150        .await
2151        .project_connection_ids(project_id, session.connection_id)
2152        .await?;
2153
2154    broadcast(
2155        Some(session.connection_id),
2156        project_connection_ids.iter().copied(),
2157        |connection_id| {
2158            session
2159                .peer
2160                .forward_send(session.connection_id, connection_id, request.clone())
2161        },
2162    );
2163    Ok(())
2164}
2165
2166/// Start following another user in a call.
2167async fn follow(
2168    request: proto::Follow,
2169    response: Response<proto::Follow>,
2170    session: Session,
2171) -> Result<()> {
2172    let room_id = RoomId::from_proto(request.room_id);
2173    let project_id = request.project_id.map(ProjectId::from_proto);
2174    let leader_id = request
2175        .leader_id
2176        .ok_or_else(|| anyhow!("invalid leader id"))?
2177        .into();
2178    let follower_id = session.connection_id;
2179
2180    session
2181        .db()
2182        .await
2183        .check_room_participants(room_id, leader_id, session.connection_id)
2184        .await?;
2185
2186    let response_payload = session
2187        .peer
2188        .forward_request(session.connection_id, leader_id, request)
2189        .await?;
2190    response.send(response_payload)?;
2191
2192    if let Some(project_id) = project_id {
2193        let room = session
2194            .db()
2195            .await
2196            .follow(room_id, project_id, leader_id, follower_id)
2197            .await?;
2198        room_updated(&room, &session.peer);
2199    }
2200
2201    Ok(())
2202}
2203
2204/// Stop following another user in a call.
2205async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2206    let room_id = RoomId::from_proto(request.room_id);
2207    let project_id = request.project_id.map(ProjectId::from_proto);
2208    let leader_id = request
2209        .leader_id
2210        .ok_or_else(|| anyhow!("invalid leader id"))?
2211        .into();
2212    let follower_id = session.connection_id;
2213
2214    session
2215        .db()
2216        .await
2217        .check_room_participants(room_id, leader_id, session.connection_id)
2218        .await?;
2219
2220    session
2221        .peer
2222        .forward_send(session.connection_id, leader_id, request)?;
2223
2224    if let Some(project_id) = project_id {
2225        let room = session
2226            .db()
2227            .await
2228            .unfollow(room_id, project_id, leader_id, follower_id)
2229            .await?;
2230        room_updated(&room, &session.peer);
2231    }
2232
2233    Ok(())
2234}
2235
2236/// Notify everyone following you of your current location.
2237async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2238    let room_id = RoomId::from_proto(request.room_id);
2239    let database = session.db.lock().await;
2240
2241    let connection_ids = if let Some(project_id) = request.project_id {
2242        let project_id = ProjectId::from_proto(project_id);
2243        database
2244            .project_connection_ids(project_id, session.connection_id)
2245            .await?
2246    } else {
2247        database
2248            .room_connection_ids(room_id, session.connection_id)
2249            .await?
2250    };
2251
2252    // For now, don't send view update messages back to that view's current leader.
2253    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2254        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2255        _ => None,
2256    });
2257
2258    for connection_id in connection_ids.iter().cloned() {
2259        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2260            session
2261                .peer
2262                .forward_send(session.connection_id, connection_id, request.clone())?;
2263        }
2264    }
2265    Ok(())
2266}
2267
2268/// Get public data about users.
2269async fn get_users(
2270    request: proto::GetUsers,
2271    response: Response<proto::GetUsers>,
2272    session: Session,
2273) -> Result<()> {
2274    let user_ids = request
2275        .user_ids
2276        .into_iter()
2277        .map(UserId::from_proto)
2278        .collect();
2279    let users = session
2280        .db()
2281        .await
2282        .get_users_by_ids(user_ids)
2283        .await?
2284        .into_iter()
2285        .map(|user| proto::User {
2286            id: user.id.to_proto(),
2287            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2288            github_login: user.github_login,
2289        })
2290        .collect();
2291    response.send(proto::UsersResponse { users })?;
2292    Ok(())
2293}
2294
2295/// Search for users (to invite) buy Github login
2296async fn fuzzy_search_users(
2297    request: proto::FuzzySearchUsers,
2298    response: Response<proto::FuzzySearchUsers>,
2299    session: Session,
2300) -> Result<()> {
2301    let query = request.query;
2302    let users = match query.len() {
2303        0 => vec![],
2304        1 | 2 => session
2305            .db()
2306            .await
2307            .get_user_by_github_login(&query)
2308            .await?
2309            .into_iter()
2310            .collect(),
2311        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2312    };
2313    let users = users
2314        .into_iter()
2315        .filter(|user| user.id != session.user_id)
2316        .map(|user| proto::User {
2317            id: user.id.to_proto(),
2318            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2319            github_login: user.github_login,
2320        })
2321        .collect();
2322    response.send(proto::UsersResponse { users })?;
2323    Ok(())
2324}
2325
2326/// Send a contact request to another user.
2327async fn request_contact(
2328    request: proto::RequestContact,
2329    response: Response<proto::RequestContact>,
2330    session: Session,
2331) -> Result<()> {
2332    let requester_id = session.user_id;
2333    let responder_id = UserId::from_proto(request.responder_id);
2334    if requester_id == responder_id {
2335        return Err(anyhow!("cannot add yourself as a contact"))?;
2336    }
2337
2338    let notifications = session
2339        .db()
2340        .await
2341        .send_contact_request(requester_id, responder_id)
2342        .await?;
2343
2344    // Update outgoing contact requests of requester
2345    let mut update = proto::UpdateContacts::default();
2346    update.outgoing_requests.push(responder_id.to_proto());
2347    for connection_id in session
2348        .connection_pool()
2349        .await
2350        .user_connection_ids(requester_id)
2351    {
2352        session.peer.send(connection_id, update.clone())?;
2353    }
2354
2355    // Update incoming contact requests of responder
2356    let mut update = proto::UpdateContacts::default();
2357    update
2358        .incoming_requests
2359        .push(proto::IncomingContactRequest {
2360            requester_id: requester_id.to_proto(),
2361        });
2362    let connection_pool = session.connection_pool().await;
2363    for connection_id in connection_pool.user_connection_ids(responder_id) {
2364        session.peer.send(connection_id, update.clone())?;
2365    }
2366
2367    send_notifications(&connection_pool, &session.peer, notifications);
2368
2369    response.send(proto::Ack {})?;
2370    Ok(())
2371}
2372
2373/// Accept or decline a contact request
2374async fn respond_to_contact_request(
2375    request: proto::RespondToContactRequest,
2376    response: Response<proto::RespondToContactRequest>,
2377    session: Session,
2378) -> Result<()> {
2379    let responder_id = session.user_id;
2380    let requester_id = UserId::from_proto(request.requester_id);
2381    let db = session.db().await;
2382    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2383        db.dismiss_contact_notification(responder_id, requester_id)
2384            .await?;
2385    } else {
2386        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2387
2388        let notifications = db
2389            .respond_to_contact_request(responder_id, requester_id, accept)
2390            .await?;
2391        let requester_busy = db.is_user_busy(requester_id).await?;
2392        let responder_busy = db.is_user_busy(responder_id).await?;
2393
2394        let pool = session.connection_pool().await;
2395        // Update responder with new contact
2396        let mut update = proto::UpdateContacts::default();
2397        if accept {
2398            update
2399                .contacts
2400                .push(contact_for_user(requester_id, requester_busy, &pool));
2401        }
2402        update
2403            .remove_incoming_requests
2404            .push(requester_id.to_proto());
2405        for connection_id in pool.user_connection_ids(responder_id) {
2406            session.peer.send(connection_id, update.clone())?;
2407        }
2408
2409        // Update requester with new contact
2410        let mut update = proto::UpdateContacts::default();
2411        if accept {
2412            update
2413                .contacts
2414                .push(contact_for_user(responder_id, responder_busy, &pool));
2415        }
2416        update
2417            .remove_outgoing_requests
2418            .push(responder_id.to_proto());
2419
2420        for connection_id in pool.user_connection_ids(requester_id) {
2421            session.peer.send(connection_id, update.clone())?;
2422        }
2423
2424        send_notifications(&pool, &session.peer, notifications);
2425    }
2426
2427    response.send(proto::Ack {})?;
2428    Ok(())
2429}
2430
2431/// Remove a contact.
2432async fn remove_contact(
2433    request: proto::RemoveContact,
2434    response: Response<proto::RemoveContact>,
2435    session: Session,
2436) -> Result<()> {
2437    let requester_id = session.user_id;
2438    let responder_id = UserId::from_proto(request.user_id);
2439    let db = session.db().await;
2440    let (contact_accepted, deleted_notification_id) =
2441        db.remove_contact(requester_id, responder_id).await?;
2442
2443    let pool = session.connection_pool().await;
2444    // Update outgoing contact requests of requester
2445    let mut update = proto::UpdateContacts::default();
2446    if contact_accepted {
2447        update.remove_contacts.push(responder_id.to_proto());
2448    } else {
2449        update
2450            .remove_outgoing_requests
2451            .push(responder_id.to_proto());
2452    }
2453    for connection_id in pool.user_connection_ids(requester_id) {
2454        session.peer.send(connection_id, update.clone())?;
2455    }
2456
2457    // Update incoming contact requests of responder
2458    let mut update = proto::UpdateContacts::default();
2459    if contact_accepted {
2460        update.remove_contacts.push(requester_id.to_proto());
2461    } else {
2462        update
2463            .remove_incoming_requests
2464            .push(requester_id.to_proto());
2465    }
2466    for connection_id in pool.user_connection_ids(responder_id) {
2467        session.peer.send(connection_id, update.clone())?;
2468        if let Some(notification_id) = deleted_notification_id {
2469            session.peer.send(
2470                connection_id,
2471                proto::DeleteNotification {
2472                    notification_id: notification_id.to_proto(),
2473                },
2474            )?;
2475        }
2476    }
2477
2478    response.send(proto::Ack {})?;
2479    Ok(())
2480}
2481
2482/// Creates a new channel.
2483async fn create_channel(
2484    request: proto::CreateChannel,
2485    response: Response<proto::CreateChannel>,
2486    session: Session,
2487) -> Result<()> {
2488    let db = session.db().await;
2489
2490    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2491    let (channel, membership) = db
2492        .create_channel(&request.name, parent_id, session.user_id)
2493        .await?;
2494
2495    let root_id = channel.root_id();
2496    let channel = Channel::from_model(channel);
2497
2498    response.send(proto::CreateChannelResponse {
2499        channel: Some(channel.to_proto()),
2500        parent_id: request.parent_id,
2501    })?;
2502
2503    let mut connection_pool = session.connection_pool().await;
2504    if let Some(membership) = membership {
2505        connection_pool.subscribe_to_channel(
2506            membership.user_id,
2507            membership.channel_id,
2508            membership.role,
2509        );
2510        let update = proto::UpdateUserChannels {
2511            channel_memberships: vec![proto::ChannelMembership {
2512                channel_id: membership.channel_id.to_proto(),
2513                role: membership.role.into(),
2514            }],
2515            ..Default::default()
2516        };
2517        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2518            session.peer.send(connection_id, update.clone())?;
2519        }
2520    }
2521
2522    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2523        if !role.can_see_channel(channel.visibility) {
2524            continue;
2525        }
2526
2527        let update = proto::UpdateChannels {
2528            channels: vec![channel.to_proto()],
2529            ..Default::default()
2530        };
2531        session.peer.send(connection_id, update.clone())?;
2532    }
2533
2534    Ok(())
2535}
2536
2537/// Delete a channel
2538async fn delete_channel(
2539    request: proto::DeleteChannel,
2540    response: Response<proto::DeleteChannel>,
2541    session: Session,
2542) -> Result<()> {
2543    let db = session.db().await;
2544
2545    let channel_id = request.channel_id;
2546    let (root_channel, removed_channels) = db
2547        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2548        .await?;
2549    response.send(proto::Ack {})?;
2550
2551    // Notify members of removed channels
2552    let mut update = proto::UpdateChannels::default();
2553    update
2554        .delete_channels
2555        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2556
2557    let connection_pool = session.connection_pool().await;
2558    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2559        session.peer.send(connection_id, update.clone())?;
2560    }
2561
2562    Ok(())
2563}
2564
2565/// Invite someone to join a channel.
2566async fn invite_channel_member(
2567    request: proto::InviteChannelMember,
2568    response: Response<proto::InviteChannelMember>,
2569    session: Session,
2570) -> Result<()> {
2571    let db = session.db().await;
2572    let channel_id = ChannelId::from_proto(request.channel_id);
2573    let invitee_id = UserId::from_proto(request.user_id);
2574    let InviteMemberResult {
2575        channel,
2576        notifications,
2577    } = db
2578        .invite_channel_member(
2579            channel_id,
2580            invitee_id,
2581            session.user_id,
2582            request.role().into(),
2583        )
2584        .await?;
2585
2586    let update = proto::UpdateChannels {
2587        channel_invitations: vec![channel.to_proto()],
2588        ..Default::default()
2589    };
2590
2591    let connection_pool = session.connection_pool().await;
2592    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2593        session.peer.send(connection_id, update.clone())?;
2594    }
2595
2596    send_notifications(&connection_pool, &session.peer, notifications);
2597
2598    response.send(proto::Ack {})?;
2599    Ok(())
2600}
2601
2602/// remove someone from a channel
2603async fn remove_channel_member(
2604    request: proto::RemoveChannelMember,
2605    response: Response<proto::RemoveChannelMember>,
2606    session: Session,
2607) -> Result<()> {
2608    let db = session.db().await;
2609    let channel_id = ChannelId::from_proto(request.channel_id);
2610    let member_id = UserId::from_proto(request.user_id);
2611
2612    let RemoveChannelMemberResult {
2613        membership_update,
2614        notification_id,
2615    } = db
2616        .remove_channel_member(channel_id, member_id, session.user_id)
2617        .await?;
2618
2619    let mut connection_pool = session.connection_pool().await;
2620    notify_membership_updated(
2621        &mut connection_pool,
2622        membership_update,
2623        member_id,
2624        &session.peer,
2625    );
2626    for connection_id in connection_pool.user_connection_ids(member_id) {
2627        if let Some(notification_id) = notification_id {
2628            session
2629                .peer
2630                .send(
2631                    connection_id,
2632                    proto::DeleteNotification {
2633                        notification_id: notification_id.to_proto(),
2634                    },
2635                )
2636                .trace_err();
2637        }
2638    }
2639
2640    response.send(proto::Ack {})?;
2641    Ok(())
2642}
2643
2644/// Toggle the channel between public and private.
2645/// Care is taken to maintain the invariant that public channels only descend from public channels,
2646/// (though members-only channels can appear at any point in the hierarchy).
2647async fn set_channel_visibility(
2648    request: proto::SetChannelVisibility,
2649    response: Response<proto::SetChannelVisibility>,
2650    session: Session,
2651) -> Result<()> {
2652    let db = session.db().await;
2653    let channel_id = ChannelId::from_proto(request.channel_id);
2654    let visibility = request.visibility().into();
2655
2656    let channel_model = db
2657        .set_channel_visibility(channel_id, visibility, session.user_id)
2658        .await?;
2659    let root_id = channel_model.root_id();
2660    let channel = Channel::from_model(channel_model);
2661
2662    let mut connection_pool = session.connection_pool().await;
2663    for (user_id, role) in connection_pool
2664        .channel_user_ids(root_id)
2665        .collect::<Vec<_>>()
2666        .into_iter()
2667    {
2668        let update = if role.can_see_channel(channel.visibility) {
2669            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2670            proto::UpdateChannels {
2671                channels: vec![channel.to_proto()],
2672                ..Default::default()
2673            }
2674        } else {
2675            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2676            proto::UpdateChannels {
2677                delete_channels: vec![channel.id.to_proto()],
2678                ..Default::default()
2679            }
2680        };
2681
2682        for connection_id in connection_pool.user_connection_ids(user_id) {
2683            session.peer.send(connection_id, update.clone())?;
2684        }
2685    }
2686
2687    response.send(proto::Ack {})?;
2688    Ok(())
2689}
2690
2691/// Alter the role for a user in the channel.
2692async fn set_channel_member_role(
2693    request: proto::SetChannelMemberRole,
2694    response: Response<proto::SetChannelMemberRole>,
2695    session: Session,
2696) -> Result<()> {
2697    let db = session.db().await;
2698    let channel_id = ChannelId::from_proto(request.channel_id);
2699    let member_id = UserId::from_proto(request.user_id);
2700    let result = db
2701        .set_channel_member_role(
2702            channel_id,
2703            session.user_id,
2704            member_id,
2705            request.role().into(),
2706        )
2707        .await?;
2708
2709    match result {
2710        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2711            let mut connection_pool = session.connection_pool().await;
2712            notify_membership_updated(
2713                &mut connection_pool,
2714                membership_update,
2715                member_id,
2716                &session.peer,
2717            )
2718        }
2719        db::SetMemberRoleResult::InviteUpdated(channel) => {
2720            let update = proto::UpdateChannels {
2721                channel_invitations: vec![channel.to_proto()],
2722                ..Default::default()
2723            };
2724
2725            for connection_id in session
2726                .connection_pool()
2727                .await
2728                .user_connection_ids(member_id)
2729            {
2730                session.peer.send(connection_id, update.clone())?;
2731            }
2732        }
2733    }
2734
2735    response.send(proto::Ack {})?;
2736    Ok(())
2737}
2738
2739/// Change the name of a channel
2740async fn rename_channel(
2741    request: proto::RenameChannel,
2742    response: Response<proto::RenameChannel>,
2743    session: Session,
2744) -> Result<()> {
2745    let db = session.db().await;
2746    let channel_id = ChannelId::from_proto(request.channel_id);
2747    let channel_model = db
2748        .rename_channel(channel_id, session.user_id, &request.name)
2749        .await?;
2750    let root_id = channel_model.root_id();
2751    let channel = Channel::from_model(channel_model);
2752
2753    response.send(proto::RenameChannelResponse {
2754        channel: Some(channel.to_proto()),
2755    })?;
2756
2757    let connection_pool = session.connection_pool().await;
2758    let update = proto::UpdateChannels {
2759        channels: vec![channel.to_proto()],
2760        ..Default::default()
2761    };
2762    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2763        if role.can_see_channel(channel.visibility) {
2764            session.peer.send(connection_id, update.clone())?;
2765        }
2766    }
2767
2768    Ok(())
2769}
2770
2771/// Move a channel to a new parent.
2772async fn move_channel(
2773    request: proto::MoveChannel,
2774    response: Response<proto::MoveChannel>,
2775    session: Session,
2776) -> Result<()> {
2777    let channel_id = ChannelId::from_proto(request.channel_id);
2778    let to = ChannelId::from_proto(request.to);
2779
2780    let (root_id, channels) = session
2781        .db()
2782        .await
2783        .move_channel(channel_id, to, session.user_id)
2784        .await?;
2785
2786    let connection_pool = session.connection_pool().await;
2787    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2788        let channels = channels
2789            .iter()
2790            .filter_map(|channel| {
2791                if role.can_see_channel(channel.visibility) {
2792                    Some(channel.to_proto())
2793                } else {
2794                    None
2795                }
2796            })
2797            .collect::<Vec<_>>();
2798        if channels.is_empty() {
2799            continue;
2800        }
2801
2802        let update = proto::UpdateChannels {
2803            channels,
2804            ..Default::default()
2805        };
2806
2807        session.peer.send(connection_id, update.clone())?;
2808    }
2809
2810    response.send(Ack {})?;
2811    Ok(())
2812}
2813
2814/// Get the list of channel members
2815async fn get_channel_members(
2816    request: proto::GetChannelMembers,
2817    response: Response<proto::GetChannelMembers>,
2818    session: Session,
2819) -> Result<()> {
2820    let db = session.db().await;
2821    let channel_id = ChannelId::from_proto(request.channel_id);
2822    let members = db
2823        .get_channel_participant_details(channel_id, session.user_id)
2824        .await?;
2825    response.send(proto::GetChannelMembersResponse { members })?;
2826    Ok(())
2827}
2828
2829/// Accept or decline a channel invitation.
2830async fn respond_to_channel_invite(
2831    request: proto::RespondToChannelInvite,
2832    response: Response<proto::RespondToChannelInvite>,
2833    session: Session,
2834) -> Result<()> {
2835    let db = session.db().await;
2836    let channel_id = ChannelId::from_proto(request.channel_id);
2837    let RespondToChannelInvite {
2838        membership_update,
2839        notifications,
2840    } = db
2841        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2842        .await?;
2843
2844    let mut connection_pool = session.connection_pool().await;
2845    if let Some(membership_update) = membership_update {
2846        notify_membership_updated(
2847            &mut connection_pool,
2848            membership_update,
2849            session.user_id,
2850            &session.peer,
2851        );
2852    } else {
2853        let update = proto::UpdateChannels {
2854            remove_channel_invitations: vec![channel_id.to_proto()],
2855            ..Default::default()
2856        };
2857
2858        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2859            session.peer.send(connection_id, update.clone())?;
2860        }
2861    };
2862
2863    send_notifications(&connection_pool, &session.peer, notifications);
2864
2865    response.send(proto::Ack {})?;
2866
2867    Ok(())
2868}
2869
2870/// Join the channels' room
2871async fn join_channel(
2872    request: proto::JoinChannel,
2873    response: Response<proto::JoinChannel>,
2874    session: Session,
2875) -> Result<()> {
2876    let channel_id = ChannelId::from_proto(request.channel_id);
2877    join_channel_internal(channel_id, Box::new(response), session).await
2878}
2879
2880trait JoinChannelInternalResponse {
2881    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2882}
2883impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2884    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2885        Response::<proto::JoinChannel>::send(self, result)
2886    }
2887}
2888impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2889    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2890        Response::<proto::JoinRoom>::send(self, result)
2891    }
2892}
2893
2894async fn join_channel_internal(
2895    channel_id: ChannelId,
2896    response: Box<impl JoinChannelInternalResponse>,
2897    session: Session,
2898) -> Result<()> {
2899    let joined_room = {
2900        leave_room_for_session(&session).await?;
2901        let db = session.db().await;
2902
2903        let (joined_room, membership_updated, role) = db
2904            .join_channel(channel_id, session.user_id, session.connection_id)
2905            .await?;
2906
2907        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2908            let (can_publish, token) = if role == ChannelRole::Guest {
2909                (
2910                    false,
2911                    live_kit
2912                        .guest_token(
2913                            &joined_room.room.live_kit_room,
2914                            &session.user_id.to_string(),
2915                        )
2916                        .trace_err()?,
2917                )
2918            } else {
2919                (
2920                    true,
2921                    live_kit
2922                        .room_token(
2923                            &joined_room.room.live_kit_room,
2924                            &session.user_id.to_string(),
2925                        )
2926                        .trace_err()?,
2927                )
2928            };
2929
2930            Some(LiveKitConnectionInfo {
2931                server_url: live_kit.url().into(),
2932                token,
2933                can_publish,
2934            })
2935        });
2936
2937        response.send(proto::JoinRoomResponse {
2938            room: Some(joined_room.room.clone()),
2939            channel_id: joined_room
2940                .channel
2941                .as_ref()
2942                .map(|channel| channel.id.to_proto()),
2943            live_kit_connection_info,
2944        })?;
2945
2946        let mut connection_pool = session.connection_pool().await;
2947        if let Some(membership_updated) = membership_updated {
2948            notify_membership_updated(
2949                &mut connection_pool,
2950                membership_updated,
2951                session.user_id,
2952                &session.peer,
2953            );
2954        }
2955
2956        room_updated(&joined_room.room, &session.peer);
2957
2958        joined_room
2959    };
2960
2961    channel_updated(
2962        &joined_room
2963            .channel
2964            .ok_or_else(|| anyhow!("channel not returned"))?,
2965        &joined_room.room,
2966        &session.peer,
2967        &*session.connection_pool().await,
2968    );
2969
2970    update_user_contacts(session.user_id, &session).await?;
2971    Ok(())
2972}
2973
2974/// Start editing the channel notes
2975async fn join_channel_buffer(
2976    request: proto::JoinChannelBuffer,
2977    response: Response<proto::JoinChannelBuffer>,
2978    session: Session,
2979) -> Result<()> {
2980    let db = session.db().await;
2981    let channel_id = ChannelId::from_proto(request.channel_id);
2982
2983    let open_response = db
2984        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2985        .await?;
2986
2987    let collaborators = open_response.collaborators.clone();
2988    response.send(open_response)?;
2989
2990    let update = UpdateChannelBufferCollaborators {
2991        channel_id: channel_id.to_proto(),
2992        collaborators: collaborators.clone(),
2993    };
2994    channel_buffer_updated(
2995        session.connection_id,
2996        collaborators
2997            .iter()
2998            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2999        &update,
3000        &session.peer,
3001    );
3002
3003    Ok(())
3004}
3005
3006/// Edit the channel notes
3007async fn update_channel_buffer(
3008    request: proto::UpdateChannelBuffer,
3009    session: Session,
3010) -> Result<()> {
3011    let db = session.db().await;
3012    let channel_id = ChannelId::from_proto(request.channel_id);
3013
3014    let (collaborators, non_collaborators, epoch, version) = db
3015        .update_channel_buffer(channel_id, session.user_id, &request.operations)
3016        .await?;
3017
3018    channel_buffer_updated(
3019        session.connection_id,
3020        collaborators,
3021        &proto::UpdateChannelBuffer {
3022            channel_id: channel_id.to_proto(),
3023            operations: request.operations,
3024        },
3025        &session.peer,
3026    );
3027
3028    let pool = &*session.connection_pool().await;
3029
3030    broadcast(
3031        None,
3032        non_collaborators
3033            .iter()
3034            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3035        |peer_id| {
3036            session.peer.send(
3037                peer_id,
3038                proto::UpdateChannels {
3039                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3040                        channel_id: channel_id.to_proto(),
3041                        epoch: epoch as u64,
3042                        version: version.clone(),
3043                    }],
3044                    ..Default::default()
3045                },
3046            )
3047        },
3048    );
3049
3050    Ok(())
3051}
3052
3053/// Rejoin the channel notes after a connection blip
3054async fn rejoin_channel_buffers(
3055    request: proto::RejoinChannelBuffers,
3056    response: Response<proto::RejoinChannelBuffers>,
3057    session: Session,
3058) -> Result<()> {
3059    let db = session.db().await;
3060    let buffers = db
3061        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
3062        .await?;
3063
3064    for rejoined_buffer in &buffers {
3065        let collaborators_to_notify = rejoined_buffer
3066            .buffer
3067            .collaborators
3068            .iter()
3069            .filter_map(|c| Some(c.peer_id?.into()));
3070        channel_buffer_updated(
3071            session.connection_id,
3072            collaborators_to_notify,
3073            &proto::UpdateChannelBufferCollaborators {
3074                channel_id: rejoined_buffer.buffer.channel_id,
3075                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3076            },
3077            &session.peer,
3078        );
3079    }
3080
3081    response.send(proto::RejoinChannelBuffersResponse {
3082        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3083    })?;
3084
3085    Ok(())
3086}
3087
3088/// Stop editing the channel notes
3089async fn leave_channel_buffer(
3090    request: proto::LeaveChannelBuffer,
3091    response: Response<proto::LeaveChannelBuffer>,
3092    session: Session,
3093) -> Result<()> {
3094    let db = session.db().await;
3095    let channel_id = ChannelId::from_proto(request.channel_id);
3096
3097    let left_buffer = db
3098        .leave_channel_buffer(channel_id, session.connection_id)
3099        .await?;
3100
3101    response.send(Ack {})?;
3102
3103    channel_buffer_updated(
3104        session.connection_id,
3105        left_buffer.connections,
3106        &proto::UpdateChannelBufferCollaborators {
3107            channel_id: channel_id.to_proto(),
3108            collaborators: left_buffer.collaborators,
3109        },
3110        &session.peer,
3111    );
3112
3113    Ok(())
3114}
3115
3116fn channel_buffer_updated<T: EnvelopedMessage>(
3117    sender_id: ConnectionId,
3118    collaborators: impl IntoIterator<Item = ConnectionId>,
3119    message: &T,
3120    peer: &Peer,
3121) {
3122    broadcast(Some(sender_id), collaborators, |peer_id| {
3123        peer.send(peer_id, message.clone())
3124    });
3125}
3126
3127fn send_notifications(
3128    connection_pool: &ConnectionPool,
3129    peer: &Peer,
3130    notifications: db::NotificationBatch,
3131) {
3132    for (user_id, notification) in notifications {
3133        for connection_id in connection_pool.user_connection_ids(user_id) {
3134            if let Err(error) = peer.send(
3135                connection_id,
3136                proto::AddNotification {
3137                    notification: Some(notification.clone()),
3138                },
3139            ) {
3140                tracing::error!(
3141                    "failed to send notification to {:?} {}",
3142                    connection_id,
3143                    error
3144                );
3145            }
3146        }
3147    }
3148}
3149
3150/// Send a message to the channel
3151async fn send_channel_message(
3152    request: proto::SendChannelMessage,
3153    response: Response<proto::SendChannelMessage>,
3154    session: Session,
3155) -> Result<()> {
3156    // Validate the message body.
3157    let body = request.body.trim().to_string();
3158    if body.len() > MAX_MESSAGE_LEN {
3159        return Err(anyhow!("message is too long"))?;
3160    }
3161    if body.is_empty() {
3162        return Err(anyhow!("message can't be blank"))?;
3163    }
3164
3165    // TODO: adjust mentions if body is trimmed
3166
3167    let timestamp = OffsetDateTime::now_utc();
3168    let nonce = request
3169        .nonce
3170        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3171
3172    let channel_id = ChannelId::from_proto(request.channel_id);
3173    let CreatedChannelMessage {
3174        message_id,
3175        participant_connection_ids,
3176        channel_members,
3177        notifications,
3178    } = session
3179        .db()
3180        .await
3181        .create_channel_message(
3182            channel_id,
3183            session.user_id,
3184            &body,
3185            &request.mentions,
3186            timestamp,
3187            nonce.clone().into(),
3188            match request.reply_to_message_id {
3189                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3190                None => None,
3191            },
3192        )
3193        .await?;
3194    let message = proto::ChannelMessage {
3195        sender_id: session.user_id.to_proto(),
3196        id: message_id.to_proto(),
3197        body,
3198        mentions: request.mentions,
3199        timestamp: timestamp.unix_timestamp() as u64,
3200        nonce: Some(nonce),
3201        reply_to_message_id: request.reply_to_message_id,
3202    };
3203    broadcast(
3204        Some(session.connection_id),
3205        participant_connection_ids,
3206        |connection| {
3207            session.peer.send(
3208                connection,
3209                proto::ChannelMessageSent {
3210                    channel_id: channel_id.to_proto(),
3211                    message: Some(message.clone()),
3212                },
3213            )
3214        },
3215    );
3216    response.send(proto::SendChannelMessageResponse {
3217        message: Some(message),
3218    })?;
3219
3220    let pool = &*session.connection_pool().await;
3221    broadcast(
3222        None,
3223        channel_members
3224            .iter()
3225            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3226        |peer_id| {
3227            session.peer.send(
3228                peer_id,
3229                proto::UpdateChannels {
3230                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3231                        channel_id: channel_id.to_proto(),
3232                        message_id: message_id.to_proto(),
3233                    }],
3234                    ..Default::default()
3235                },
3236            )
3237        },
3238    );
3239    send_notifications(pool, &session.peer, notifications);
3240
3241    Ok(())
3242}
3243
3244/// Delete a channel message
3245async fn remove_channel_message(
3246    request: proto::RemoveChannelMessage,
3247    response: Response<proto::RemoveChannelMessage>,
3248    session: Session,
3249) -> Result<()> {
3250    let channel_id = ChannelId::from_proto(request.channel_id);
3251    let message_id = MessageId::from_proto(request.message_id);
3252    let connection_ids = session
3253        .db()
3254        .await
3255        .remove_channel_message(channel_id, message_id, session.user_id)
3256        .await?;
3257    broadcast(Some(session.connection_id), connection_ids, |connection| {
3258        session.peer.send(connection, request.clone())
3259    });
3260    response.send(proto::Ack {})?;
3261    Ok(())
3262}
3263
3264/// Mark a channel message as read
3265async fn acknowledge_channel_message(
3266    request: proto::AckChannelMessage,
3267    session: Session,
3268) -> Result<()> {
3269    let channel_id = ChannelId::from_proto(request.channel_id);
3270    let message_id = MessageId::from_proto(request.message_id);
3271    let notifications = session
3272        .db()
3273        .await
3274        .observe_channel_message(channel_id, session.user_id, message_id)
3275        .await?;
3276    send_notifications(
3277        &*session.connection_pool().await,
3278        &session.peer,
3279        notifications,
3280    );
3281    Ok(())
3282}
3283
3284/// Mark a buffer version as synced
3285async fn acknowledge_buffer_version(
3286    request: proto::AckBufferOperation,
3287    session: Session,
3288) -> Result<()> {
3289    let buffer_id = BufferId::from_proto(request.buffer_id);
3290    session
3291        .db()
3292        .await
3293        .observe_buffer_version(
3294            buffer_id,
3295            session.user_id,
3296            request.epoch as i32,
3297            &request.version,
3298        )
3299        .await?;
3300    Ok(())
3301}
3302
3303struct CompleteWithLanguageModelRateLimit;
3304
3305impl RateLimit for CompleteWithLanguageModelRateLimit {
3306    fn capacity() -> usize {
3307        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3308            .ok()
3309            .and_then(|v| v.parse().ok())
3310            .unwrap_or(120) // Picked arbitrarily
3311    }
3312
3313    fn refill_duration() -> chrono::Duration {
3314        chrono::Duration::hours(1)
3315    }
3316
3317    fn db_name() -> &'static str {
3318        "complete-with-language-model"
3319    }
3320}
3321
3322async fn complete_with_language_model(
3323    request: proto::CompleteWithLanguageModel,
3324    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3325    session: Session,
3326    open_ai_api_key: Option<Arc<str>>,
3327    google_ai_api_key: Option<Arc<str>>,
3328) -> Result<()> {
3329    authorize_access_to_language_models(&session).await?;
3330    session
3331        .rate_limiter
3332        .check::<CompleteWithLanguageModelRateLimit>(session.user_id)
3333        .await?;
3334
3335    if request.model.starts_with("gpt") {
3336        let api_key =
3337            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3338        complete_with_open_ai(request, response, session, api_key).await?;
3339    } else if request.model.starts_with("gemini") {
3340        let api_key = google_ai_api_key
3341            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3342        complete_with_google_ai(request, response, session, api_key).await?;
3343    }
3344
3345    Ok(())
3346}
3347
3348async fn complete_with_open_ai(
3349    request: proto::CompleteWithLanguageModel,
3350    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3351    session: Session,
3352    api_key: Arc<str>,
3353) -> Result<()> {
3354    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3355
3356    let mut completion_stream = open_ai::stream_completion(
3357        &session.http_client,
3358        OPEN_AI_API_URL,
3359        &api_key,
3360        crate::ai::language_model_request_to_open_ai(request)?,
3361    )
3362    .await
3363    .context("open_ai::stream_completion request failed")?;
3364
3365    while let Some(event) = completion_stream.next().await {
3366        let event = event?;
3367        response.send(proto::LanguageModelResponse {
3368            choices: event
3369                .choices
3370                .into_iter()
3371                .map(|choice| proto::LanguageModelChoiceDelta {
3372                    index: choice.index,
3373                    delta: Some(proto::LanguageModelResponseMessage {
3374                        role: choice.delta.role.map(|role| match role {
3375                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3376                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3377                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3378                        } as i32),
3379                        content: choice.delta.content,
3380                    }),
3381                    finish_reason: choice.finish_reason,
3382                })
3383                .collect(),
3384        })?;
3385    }
3386
3387    Ok(())
3388}
3389
3390async fn complete_with_google_ai(
3391    request: proto::CompleteWithLanguageModel,
3392    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3393    session: Session,
3394    api_key: Arc<str>,
3395) -> Result<()> {
3396    let mut stream = google_ai::stream_generate_content(
3397        &session.http_client,
3398        google_ai::API_URL,
3399        api_key.as_ref(),
3400        crate::ai::language_model_request_to_google_ai(request)?,
3401    )
3402    .await
3403    .context("google_ai::stream_generate_content request failed")?;
3404
3405    while let Some(event) = stream.next().await {
3406        let event = event?;
3407        response.send(proto::LanguageModelResponse {
3408            choices: event
3409                .candidates
3410                .unwrap_or_default()
3411                .into_iter()
3412                .map(|candidate| proto::LanguageModelChoiceDelta {
3413                    index: candidate.index as u32,
3414                    delta: Some(proto::LanguageModelResponseMessage {
3415                        role: Some(match candidate.content.role {
3416                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3417                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3418                        } as i32),
3419                        content: Some(
3420                            candidate
3421                                .content
3422                                .parts
3423                                .into_iter()
3424                                .filter_map(|part| match part {
3425                                    google_ai::Part::TextPart(part) => Some(part.text),
3426                                    google_ai::Part::InlineDataPart(_) => None,
3427                                })
3428                                .collect(),
3429                        ),
3430                    }),
3431                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3432                })
3433                .collect(),
3434        })?;
3435    }
3436
3437    Ok(())
3438}
3439
3440struct CountTokensWithLanguageModelRateLimit;
3441
3442impl RateLimit for CountTokensWithLanguageModelRateLimit {
3443    fn capacity() -> usize {
3444        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3445            .ok()
3446            .and_then(|v| v.parse().ok())
3447            .unwrap_or(600) // Picked arbitrarily
3448    }
3449
3450    fn refill_duration() -> chrono::Duration {
3451        chrono::Duration::hours(1)
3452    }
3453
3454    fn db_name() -> &'static str {
3455        "count-tokens-with-language-model"
3456    }
3457}
3458
3459async fn count_tokens_with_language_model(
3460    request: proto::CountTokensWithLanguageModel,
3461    response: Response<proto::CountTokensWithLanguageModel>,
3462    session: Session,
3463    google_ai_api_key: Option<Arc<str>>,
3464) -> Result<()> {
3465    authorize_access_to_language_models(&session).await?;
3466
3467    if !request.model.starts_with("gemini") {
3468        return Err(anyhow!(
3469            "counting tokens for model: {:?} is not supported",
3470            request.model
3471        ))?;
3472    }
3473
3474    session
3475        .rate_limiter
3476        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
3477        .await?;
3478
3479    let api_key = google_ai_api_key
3480        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3481    let tokens_response = google_ai::count_tokens(
3482        &session.http_client,
3483        google_ai::API_URL,
3484        &api_key,
3485        crate::ai::count_tokens_request_to_google_ai(request)?,
3486    )
3487    .await?;
3488    response.send(proto::CountTokensResponse {
3489        token_count: tokens_response.total_tokens as u32,
3490    })?;
3491    Ok(())
3492}
3493
3494async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
3495    let db = session.db().await;
3496    let flags = db.get_user_flags(session.user_id).await?;
3497    if flags.iter().any(|flag| flag == "language-models") {
3498        Ok(())
3499    } else {
3500        Err(anyhow!("permission denied"))?
3501    }
3502}
3503
3504/// Start receiving chat updates for a channel
3505async fn join_channel_chat(
3506    request: proto::JoinChannelChat,
3507    response: Response<proto::JoinChannelChat>,
3508    session: Session,
3509) -> Result<()> {
3510    let channel_id = ChannelId::from_proto(request.channel_id);
3511
3512    let db = session.db().await;
3513    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3514        .await?;
3515    let messages = db
3516        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3517        .await?;
3518    response.send(proto::JoinChannelChatResponse {
3519        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3520        messages,
3521    })?;
3522    Ok(())
3523}
3524
3525/// Stop receiving chat updates for a channel
3526async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3527    let channel_id = ChannelId::from_proto(request.channel_id);
3528    session
3529        .db()
3530        .await
3531        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3532        .await?;
3533    Ok(())
3534}
3535
3536/// Retrieve the chat history for a channel
3537async fn get_channel_messages(
3538    request: proto::GetChannelMessages,
3539    response: Response<proto::GetChannelMessages>,
3540    session: Session,
3541) -> Result<()> {
3542    let channel_id = ChannelId::from_proto(request.channel_id);
3543    let messages = session
3544        .db()
3545        .await
3546        .get_channel_messages(
3547            channel_id,
3548            session.user_id,
3549            MESSAGE_COUNT_PER_PAGE,
3550            Some(MessageId::from_proto(request.before_message_id)),
3551        )
3552        .await?;
3553    response.send(proto::GetChannelMessagesResponse {
3554        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3555        messages,
3556    })?;
3557    Ok(())
3558}
3559
3560/// Retrieve specific chat messages
3561async fn get_channel_messages_by_id(
3562    request: proto::GetChannelMessagesById,
3563    response: Response<proto::GetChannelMessagesById>,
3564    session: Session,
3565) -> Result<()> {
3566    let message_ids = request
3567        .message_ids
3568        .iter()
3569        .map(|id| MessageId::from_proto(*id))
3570        .collect::<Vec<_>>();
3571    let messages = session
3572        .db()
3573        .await
3574        .get_channel_messages_by_id(session.user_id, &message_ids)
3575        .await?;
3576    response.send(proto::GetChannelMessagesResponse {
3577        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3578        messages,
3579    })?;
3580    Ok(())
3581}
3582
3583/// Retrieve the current users notifications
3584async fn get_notifications(
3585    request: proto::GetNotifications,
3586    response: Response<proto::GetNotifications>,
3587    session: Session,
3588) -> Result<()> {
3589    let notifications = session
3590        .db()
3591        .await
3592        .get_notifications(
3593            session.user_id,
3594            NOTIFICATION_COUNT_PER_PAGE,
3595            request
3596                .before_id
3597                .map(|id| db::NotificationId::from_proto(id)),
3598        )
3599        .await?;
3600    response.send(proto::GetNotificationsResponse {
3601        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3602        notifications,
3603    })?;
3604    Ok(())
3605}
3606
3607/// Mark notifications as read
3608async fn mark_notification_as_read(
3609    request: proto::MarkNotificationRead,
3610    response: Response<proto::MarkNotificationRead>,
3611    session: Session,
3612) -> Result<()> {
3613    let database = &session.db().await;
3614    let notifications = database
3615        .mark_notification_as_read_by_id(
3616            session.user_id,
3617            NotificationId::from_proto(request.notification_id),
3618        )
3619        .await?;
3620    send_notifications(
3621        &*session.connection_pool().await,
3622        &session.peer,
3623        notifications,
3624    );
3625    response.send(proto::Ack {})?;
3626    Ok(())
3627}
3628
3629/// Get the current users information
3630async fn get_private_user_info(
3631    _request: proto::GetPrivateUserInfo,
3632    response: Response<proto::GetPrivateUserInfo>,
3633    session: Session,
3634) -> Result<()> {
3635    let db = session.db().await;
3636
3637    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3638    let user = db
3639        .get_user_by_id(session.user_id)
3640        .await?
3641        .ok_or_else(|| anyhow!("user not found"))?;
3642    let flags = db.get_user_flags(session.user_id).await?;
3643
3644    response.send(proto::GetPrivateUserInfoResponse {
3645        metrics_id,
3646        staff: user.admin,
3647        flags,
3648    })?;
3649    Ok(())
3650}
3651
3652fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3653    match message {
3654        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3655        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3656        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3657        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3658        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3659            code: frame.code.into(),
3660            reason: frame.reason,
3661        })),
3662    }
3663}
3664
3665fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3666    match message {
3667        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3668        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3669        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3670        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3671        AxumMessage::Close(frame) => {
3672            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3673                code: frame.code.into(),
3674                reason: frame.reason,
3675            }))
3676        }
3677    }
3678}
3679
3680fn notify_membership_updated(
3681    connection_pool: &mut ConnectionPool,
3682    result: MembershipUpdated,
3683    user_id: UserId,
3684    peer: &Peer,
3685) {
3686    for membership in &result.new_channels.channel_memberships {
3687        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3688    }
3689    for channel_id in &result.removed_channels {
3690        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3691    }
3692
3693    let user_channels_update = proto::UpdateUserChannels {
3694        channel_memberships: result
3695            .new_channels
3696            .channel_memberships
3697            .iter()
3698            .map(|cm| proto::ChannelMembership {
3699                channel_id: cm.channel_id.to_proto(),
3700                role: cm.role.into(),
3701            })
3702            .collect(),
3703        ..Default::default()
3704    };
3705
3706    let mut update = build_channels_update(result.new_channels, vec![]);
3707    update.delete_channels = result
3708        .removed_channels
3709        .into_iter()
3710        .map(|id| id.to_proto())
3711        .collect();
3712    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3713
3714    for connection_id in connection_pool.user_connection_ids(user_id) {
3715        peer.send(connection_id, user_channels_update.clone())
3716            .trace_err();
3717        peer.send(connection_id, update.clone()).trace_err();
3718    }
3719}
3720
3721fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3722    proto::UpdateUserChannels {
3723        channel_memberships: channels
3724            .channel_memberships
3725            .iter()
3726            .map(|m| proto::ChannelMembership {
3727                channel_id: m.channel_id.to_proto(),
3728                role: m.role.into(),
3729            })
3730            .collect(),
3731        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3732        observed_channel_message_id: channels.observed_channel_messages.clone(),
3733    }
3734}
3735
3736fn build_channels_update(
3737    channels: ChannelsForUser,
3738    channel_invites: Vec<db::Channel>,
3739) -> proto::UpdateChannels {
3740    let mut update = proto::UpdateChannels::default();
3741
3742    for channel in channels.channels {
3743        update.channels.push(channel.to_proto());
3744    }
3745
3746    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3747    update.latest_channel_message_ids = channels.latest_channel_messages;
3748
3749    for (channel_id, participants) in channels.channel_participants {
3750        update
3751            .channel_participants
3752            .push(proto::ChannelParticipants {
3753                channel_id: channel_id.to_proto(),
3754                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3755            });
3756    }
3757
3758    for channel in channel_invites {
3759        update.channel_invitations.push(channel.to_proto());
3760    }
3761    for project in channels.hosted_projects {
3762        update.hosted_projects.push(project);
3763    }
3764
3765    update
3766}
3767
3768fn build_initial_contacts_update(
3769    contacts: Vec<db::Contact>,
3770    pool: &ConnectionPool,
3771) -> proto::UpdateContacts {
3772    let mut update = proto::UpdateContacts::default();
3773
3774    for contact in contacts {
3775        match contact {
3776            db::Contact::Accepted { user_id, busy } => {
3777                update.contacts.push(contact_for_user(user_id, busy, &pool));
3778            }
3779            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3780            db::Contact::Incoming { user_id } => {
3781                update
3782                    .incoming_requests
3783                    .push(proto::IncomingContactRequest {
3784                        requester_id: user_id.to_proto(),
3785                    })
3786            }
3787        }
3788    }
3789
3790    update
3791}
3792
3793fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3794    proto::Contact {
3795        user_id: user_id.to_proto(),
3796        online: pool.is_user_online(user_id),
3797        busy,
3798    }
3799}
3800
3801fn room_updated(room: &proto::Room, peer: &Peer) {
3802    broadcast(
3803        None,
3804        room.participants
3805            .iter()
3806            .filter_map(|participant| Some(participant.peer_id?.into())),
3807        |peer_id| {
3808            peer.send(
3809                peer_id,
3810                proto::RoomUpdated {
3811                    room: Some(room.clone()),
3812                },
3813            )
3814        },
3815    );
3816}
3817
3818fn channel_updated(
3819    channel: &db::channel::Model,
3820    room: &proto::Room,
3821    peer: &Peer,
3822    pool: &ConnectionPool,
3823) {
3824    let participants = room
3825        .participants
3826        .iter()
3827        .map(|p| p.user_id)
3828        .collect::<Vec<_>>();
3829
3830    broadcast(
3831        None,
3832        pool.channel_connection_ids(channel.root_id())
3833            .filter_map(|(channel_id, role)| {
3834                role.can_see_channel(channel.visibility).then(|| channel_id)
3835            }),
3836        |peer_id| {
3837            peer.send(
3838                peer_id,
3839                proto::UpdateChannels {
3840                    channel_participants: vec![proto::ChannelParticipants {
3841                        channel_id: channel.id.to_proto(),
3842                        participant_user_ids: participants.clone(),
3843                    }],
3844                    ..Default::default()
3845                },
3846            )
3847        },
3848    );
3849}
3850
3851async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3852    let db = session.db().await;
3853
3854    let contacts = db.get_contacts(user_id).await?;
3855    let busy = db.is_user_busy(user_id).await?;
3856
3857    let pool = session.connection_pool().await;
3858    let updated_contact = contact_for_user(user_id, busy, &pool);
3859    for contact in contacts {
3860        if let db::Contact::Accepted {
3861            user_id: contact_user_id,
3862            ..
3863        } = contact
3864        {
3865            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3866                session
3867                    .peer
3868                    .send(
3869                        contact_conn_id,
3870                        proto::UpdateContacts {
3871                            contacts: vec![updated_contact.clone()],
3872                            remove_contacts: Default::default(),
3873                            incoming_requests: Default::default(),
3874                            remove_incoming_requests: Default::default(),
3875                            outgoing_requests: Default::default(),
3876                            remove_outgoing_requests: Default::default(),
3877                        },
3878                    )
3879                    .trace_err();
3880            }
3881        }
3882    }
3883    Ok(())
3884}
3885
3886async fn leave_room_for_session(session: &Session) -> Result<()> {
3887    let mut contacts_to_update = HashSet::default();
3888
3889    let room_id;
3890    let canceled_calls_to_user_ids;
3891    let live_kit_room;
3892    let delete_live_kit_room;
3893    let room;
3894    let channel;
3895
3896    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3897        contacts_to_update.insert(session.user_id);
3898
3899        for project in left_room.left_projects.values() {
3900            project_left(project, session);
3901        }
3902
3903        room_id = RoomId::from_proto(left_room.room.id);
3904        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3905        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3906        delete_live_kit_room = left_room.deleted;
3907        room = mem::take(&mut left_room.room);
3908        channel = mem::take(&mut left_room.channel);
3909
3910        room_updated(&room, &session.peer);
3911    } else {
3912        return Ok(());
3913    }
3914
3915    if let Some(channel) = channel {
3916        channel_updated(
3917            &channel,
3918            &room,
3919            &session.peer,
3920            &*session.connection_pool().await,
3921        );
3922    }
3923
3924    {
3925        let pool = session.connection_pool().await;
3926        for canceled_user_id in canceled_calls_to_user_ids {
3927            for connection_id in pool.user_connection_ids(canceled_user_id) {
3928                session
3929                    .peer
3930                    .send(
3931                        connection_id,
3932                        proto::CallCanceled {
3933                            room_id: room_id.to_proto(),
3934                        },
3935                    )
3936                    .trace_err();
3937            }
3938            contacts_to_update.insert(canceled_user_id);
3939        }
3940    }
3941
3942    for contact_user_id in contacts_to_update {
3943        update_user_contacts(contact_user_id, &session).await?;
3944    }
3945
3946    if let Some(live_kit) = session.live_kit_client.as_ref() {
3947        live_kit
3948            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3949            .await
3950            .trace_err();
3951
3952        if delete_live_kit_room {
3953            live_kit.delete_room(live_kit_room).await.trace_err();
3954        }
3955    }
3956
3957    Ok(())
3958}
3959
3960async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3961    let left_channel_buffers = session
3962        .db()
3963        .await
3964        .leave_channel_buffers(session.connection_id)
3965        .await?;
3966
3967    for left_buffer in left_channel_buffers {
3968        channel_buffer_updated(
3969            session.connection_id,
3970            left_buffer.connections,
3971            &proto::UpdateChannelBufferCollaborators {
3972                channel_id: left_buffer.channel_id.to_proto(),
3973                collaborators: left_buffer.collaborators,
3974            },
3975            &session.peer,
3976        );
3977    }
3978
3979    Ok(())
3980}
3981
3982fn project_left(project: &db::LeftProject, session: &Session) {
3983    for connection_id in &project.connection_ids {
3984        if project.host_user_id == Some(session.user_id) {
3985            session
3986                .peer
3987                .send(
3988                    *connection_id,
3989                    proto::UnshareProject {
3990                        project_id: project.id.to_proto(),
3991                    },
3992                )
3993                .trace_err();
3994        } else {
3995            session
3996                .peer
3997                .send(
3998                    *connection_id,
3999                    proto::RemoveProjectCollaborator {
4000                        project_id: project.id.to_proto(),
4001                        peer_id: Some(session.connection_id.into()),
4002                    },
4003                )
4004                .trace_err();
4005        }
4006    }
4007}
4008
4009pub trait ResultExt {
4010    type Ok;
4011
4012    fn trace_err(self) -> Option<Self::Ok>;
4013}
4014
4015impl<T, E> ResultExt for Result<T, E>
4016where
4017    E: std::fmt::Debug,
4018{
4019    type Ok = T;
4020
4021    fn trace_err(self) -> Option<T> {
4022        match self {
4023            Ok(value) => Some(value),
4024            Err(error) => {
4025                tracing::error!("{:?}", error);
4026                None
4027            }
4028        }
4029    }
4030}