rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage,
   7        Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
   8        ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
   9        UpdatedChannelMessage, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    future::Future,
  53    marker::PhantomData,
  54    mem,
  55    net::SocketAddr,
  56    ops::{Deref, DerefMut},
  57    rc::Rc,
  58    sync::{
  59        atomic::{AtomicBool, Ordering::SeqCst},
  60        Arc, OnceLock,
  61    },
  62    time::{Duration, Instant},
  63};
  64use time::OffsetDateTime;
  65use tokio::sync::{watch, Semaphore};
  66use tower::ServiceBuilder;
  67use tracing::{field, info_span, instrument, Instrument};
  68use util::{http::IsahcHttpClient, SemanticVersion};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71
  72// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  73pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  74
  75const MESSAGE_COUNT_PER_PAGE: usize = 100;
  76const MAX_MESSAGE_LEN: usize = 1024;
  77const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  78
  79type MessageHandler =
  80    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  81
  82struct Response<R> {
  83    peer: Arc<Peer>,
  84    receipt: Receipt<R>,
  85    responded: Arc<AtomicBool>,
  86}
  87
  88impl<R: RequestMessage> Response<R> {
  89    fn send(self, payload: R::Response) -> Result<()> {
  90        self.responded.store(true, SeqCst);
  91        self.peer.respond(self.receipt, payload)?;
  92        Ok(())
  93    }
  94}
  95
  96struct StreamingResponse<R: RequestMessage> {
  97    peer: Arc<Peer>,
  98    receipt: Receipt<R>,
  99}
 100
 101impl<R: RequestMessage> StreamingResponse<R> {
 102    fn send(&self, payload: R::Response) -> Result<()> {
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone)]
 109struct Session {
 110    user_id: UserId,
 111    connection_id: ConnectionId,
 112    db: Arc<tokio::sync::Mutex<DbHandle>>,
 113    peer: Arc<Peer>,
 114    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 115    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 116    http_client: IsahcHttpClient,
 117    rate_limiter: Arc<RateLimiter>,
 118    _executor: Executor,
 119}
 120
 121impl Session {
 122    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 123        #[cfg(test)]
 124        tokio::task::yield_now().await;
 125        let guard = self.db.lock().await;
 126        #[cfg(test)]
 127        tokio::task::yield_now().await;
 128        guard
 129    }
 130
 131    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 132        #[cfg(test)]
 133        tokio::task::yield_now().await;
 134        let guard = self.connection_pool.lock();
 135        ConnectionPoolGuard {
 136            guard,
 137            _not_send: PhantomData,
 138        }
 139    }
 140}
 141
 142impl Debug for Session {
 143    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 144        f.debug_struct("Session")
 145            .field("user_id", &self.user_id)
 146            .field("connection_id", &self.connection_id)
 147            .finish()
 148    }
 149}
 150
 151struct DbHandle(Arc<Database>);
 152
 153impl Deref for DbHandle {
 154    type Target = Database;
 155
 156    fn deref(&self) -> &Self::Target {
 157        self.0.as_ref()
 158    }
 159}
 160
 161pub struct Server {
 162    id: parking_lot::Mutex<ServerId>,
 163    peer: Arc<Peer>,
 164    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 165    app_state: Arc<AppState>,
 166    handlers: HashMap<TypeId, MessageHandler>,
 167    teardown: watch::Sender<bool>,
 168}
 169
 170pub(crate) struct ConnectionPoolGuard<'a> {
 171    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 172    _not_send: PhantomData<Rc<()>>,
 173}
 174
 175#[derive(Serialize)]
 176pub struct ServerSnapshot<'a> {
 177    peer: &'a Peer,
 178    #[serde(serialize_with = "serialize_deref")]
 179    connection_pool: ConnectionPoolGuard<'a>,
 180}
 181
 182pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 183where
 184    S: Serializer,
 185    T: Deref<Target = U>,
 186    U: Serialize,
 187{
 188    Serialize::serialize(value.deref(), serializer)
 189}
 190
 191impl Server {
 192    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 193        let mut server = Self {
 194            id: parking_lot::Mutex::new(id),
 195            peer: Peer::new(id.0 as u32),
 196            app_state: app_state.clone(),
 197            connection_pool: Default::default(),
 198            handlers: Default::default(),
 199            teardown: watch::channel(false).0,
 200        };
 201
 202        server
 203            .add_request_handler(ping)
 204            .add_request_handler(create_room)
 205            .add_request_handler(join_room)
 206            .add_request_handler(rejoin_room)
 207            .add_request_handler(leave_room)
 208            .add_request_handler(set_room_participant_role)
 209            .add_request_handler(call)
 210            .add_request_handler(cancel_call)
 211            .add_message_handler(decline_call)
 212            .add_request_handler(update_participant_location)
 213            .add_request_handler(share_project)
 214            .add_message_handler(unshare_project)
 215            .add_request_handler(join_project)
 216            .add_request_handler(join_hosted_project)
 217            .add_message_handler(leave_project)
 218            .add_request_handler(update_project)
 219            .add_request_handler(update_worktree)
 220            .add_message_handler(start_language_server)
 221            .add_message_handler(update_language_server)
 222            .add_message_handler(update_diagnostic_summary)
 223            .add_message_handler(update_worktree_settings)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 227            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 228            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 229            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 230            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 231            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 232            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 233            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 234            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 235            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 236            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 237            .add_request_handler(
 238                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 239            )
 240            .add_request_handler(
 241                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 242            )
 243            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 244            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 245            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 246            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 247            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 248            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 249            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 251            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 252            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 253            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 254            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 255            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 256            .add_message_handler(create_buffer_for_peer)
 257            .add_request_handler(update_buffer)
 258            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 259            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 260            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 261            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 262            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 263            .add_request_handler(get_users)
 264            .add_request_handler(fuzzy_search_users)
 265            .add_request_handler(request_contact)
 266            .add_request_handler(remove_contact)
 267            .add_request_handler(respond_to_contact_request)
 268            .add_request_handler(create_channel)
 269            .add_request_handler(delete_channel)
 270            .add_request_handler(invite_channel_member)
 271            .add_request_handler(remove_channel_member)
 272            .add_request_handler(set_channel_member_role)
 273            .add_request_handler(set_channel_visibility)
 274            .add_request_handler(rename_channel)
 275            .add_request_handler(join_channel_buffer)
 276            .add_request_handler(leave_channel_buffer)
 277            .add_message_handler(update_channel_buffer)
 278            .add_request_handler(rejoin_channel_buffers)
 279            .add_request_handler(get_channel_members)
 280            .add_request_handler(respond_to_channel_invite)
 281            .add_request_handler(join_channel)
 282            .add_request_handler(join_channel_chat)
 283            .add_message_handler(leave_channel_chat)
 284            .add_request_handler(send_channel_message)
 285            .add_request_handler(remove_channel_message)
 286            .add_request_handler(update_channel_message)
 287            .add_request_handler(get_channel_messages)
 288            .add_request_handler(get_channel_messages_by_id)
 289            .add_request_handler(get_notifications)
 290            .add_request_handler(mark_notification_as_read)
 291            .add_request_handler(move_channel)
 292            .add_request_handler(follow)
 293            .add_message_handler(unfollow)
 294            .add_message_handler(update_followers)
 295            .add_request_handler(get_private_user_info)
 296            .add_message_handler(acknowledge_channel_message)
 297            .add_message_handler(acknowledge_buffer_version)
 298            .add_streaming_request_handler({
 299                let app_state = app_state.clone();
 300                move |request, response, session| {
 301                    complete_with_language_model(
 302                        request,
 303                        response,
 304                        session,
 305                        app_state.config.openai_api_key.clone(),
 306                        app_state.config.google_ai_api_key.clone(),
 307                    )
 308                }
 309            })
 310            .add_request_handler({
 311                let app_state = app_state.clone();
 312                move |request, response, session| {
 313                    count_tokens_with_language_model(
 314                        request,
 315                        response,
 316                        session,
 317                        app_state.config.google_ai_api_key.clone(),
 318                    )
 319                }
 320            });
 321
 322        Arc::new(server)
 323    }
 324
 325    pub async fn start(&self) -> Result<()> {
 326        let server_id = *self.id.lock();
 327        let app_state = self.app_state.clone();
 328        let peer = self.peer.clone();
 329        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 330        let pool = self.connection_pool.clone();
 331        let live_kit_client = self.app_state.live_kit_client.clone();
 332
 333        let span = info_span!("start server");
 334        self.app_state.executor.spawn_detached(
 335            async move {
 336                tracing::info!("waiting for cleanup timeout");
 337                timeout.await;
 338                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 339                if let Some((room_ids, channel_ids)) = app_state
 340                    .db
 341                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 342                    .await
 343                    .trace_err()
 344                {
 345                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 346                    tracing::info!(
 347                        stale_channel_buffer_count = channel_ids.len(),
 348                        "retrieved stale channel buffers"
 349                    );
 350
 351                    for channel_id in channel_ids {
 352                        if let Some(refreshed_channel_buffer) = app_state
 353                            .db
 354                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 355                            .await
 356                            .trace_err()
 357                        {
 358                            for connection_id in refreshed_channel_buffer.connection_ids {
 359                                peer.send(
 360                                    connection_id,
 361                                    proto::UpdateChannelBufferCollaborators {
 362                                        channel_id: channel_id.to_proto(),
 363                                        collaborators: refreshed_channel_buffer
 364                                            .collaborators
 365                                            .clone(),
 366                                    },
 367                                )
 368                                .trace_err();
 369                            }
 370                        }
 371                    }
 372
 373                    for room_id in room_ids {
 374                        let mut contacts_to_update = HashSet::default();
 375                        let mut canceled_calls_to_user_ids = Vec::new();
 376                        let mut live_kit_room = String::new();
 377                        let mut delete_live_kit_room = false;
 378
 379                        if let Some(mut refreshed_room) = app_state
 380                            .db
 381                            .clear_stale_room_participants(room_id, server_id)
 382                            .await
 383                            .trace_err()
 384                        {
 385                            tracing::info!(
 386                                room_id = room_id.0,
 387                                new_participant_count = refreshed_room.room.participants.len(),
 388                                "refreshed room"
 389                            );
 390                            room_updated(&refreshed_room.room, &peer);
 391                            if let Some(channel) = refreshed_room.channel.as_ref() {
 392                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 393                            }
 394                            contacts_to_update
 395                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 396                            contacts_to_update
 397                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 398                            canceled_calls_to_user_ids =
 399                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 400                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 401                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 402                        }
 403
 404                        {
 405                            let pool = pool.lock();
 406                            for canceled_user_id in canceled_calls_to_user_ids {
 407                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 408                                    peer.send(
 409                                        connection_id,
 410                                        proto::CallCanceled {
 411                                            room_id: room_id.to_proto(),
 412                                        },
 413                                    )
 414                                    .trace_err();
 415                                }
 416                            }
 417                        }
 418
 419                        for user_id in contacts_to_update {
 420                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 421                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 422                            if let Some((busy, contacts)) = busy.zip(contacts) {
 423                                let pool = pool.lock();
 424                                let updated_contact = contact_for_user(user_id, busy, &pool);
 425                                for contact in contacts {
 426                                    if let db::Contact::Accepted {
 427                                        user_id: contact_user_id,
 428                                        ..
 429                                    } = contact
 430                                    {
 431                                        for contact_conn_id in
 432                                            pool.user_connection_ids(contact_user_id)
 433                                        {
 434                                            peer.send(
 435                                                contact_conn_id,
 436                                                proto::UpdateContacts {
 437                                                    contacts: vec![updated_contact.clone()],
 438                                                    remove_contacts: Default::default(),
 439                                                    incoming_requests: Default::default(),
 440                                                    remove_incoming_requests: Default::default(),
 441                                                    outgoing_requests: Default::default(),
 442                                                    remove_outgoing_requests: Default::default(),
 443                                                },
 444                                            )
 445                                            .trace_err();
 446                                        }
 447                                    }
 448                                }
 449                            }
 450                        }
 451
 452                        if let Some(live_kit) = live_kit_client.as_ref() {
 453                            if delete_live_kit_room {
 454                                live_kit.delete_room(live_kit_room).await.trace_err();
 455                            }
 456                        }
 457                    }
 458                }
 459
 460                app_state
 461                    .db
 462                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 463                    .await
 464                    .trace_err();
 465            }
 466            .instrument(span),
 467        );
 468        Ok(())
 469    }
 470
 471    pub fn teardown(&self) {
 472        self.peer.teardown();
 473        self.connection_pool.lock().reset();
 474        let _ = self.teardown.send(true);
 475    }
 476
 477    #[cfg(test)]
 478    pub fn reset(&self, id: ServerId) {
 479        self.teardown();
 480        *self.id.lock() = id;
 481        self.peer.reset(id.0 as u32);
 482        let _ = self.teardown.send(false);
 483    }
 484
 485    #[cfg(test)]
 486    pub fn id(&self) -> ServerId {
 487        *self.id.lock()
 488    }
 489
 490    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 491    where
 492        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 493        Fut: 'static + Send + Future<Output = Result<()>>,
 494        M: EnvelopedMessage,
 495    {
 496        let prev_handler = self.handlers.insert(
 497            TypeId::of::<M>(),
 498            Box::new(move |envelope, session| {
 499                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 500                let received_at = envelope.received_at;
 501                    tracing::info!(
 502                        "message received"
 503                    );
 504                let start_time = Instant::now();
 505                let future = (handler)(*envelope, session);
 506                async move {
 507                    let result = future.await;
 508                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 509                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 510                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 511                    match result {
 512                        Err(error) => {
 513                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 514                        }
 515                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 516                    }
 517                }
 518                .boxed()
 519            }),
 520        );
 521        if prev_handler.is_some() {
 522            panic!("registered a handler for the same message twice");
 523        }
 524        self
 525    }
 526
 527    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 528    where
 529        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 530        Fut: 'static + Send + Future<Output = Result<()>>,
 531        M: EnvelopedMessage,
 532    {
 533        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 534        self
 535    }
 536
 537    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 538    where
 539        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 540        Fut: Send + Future<Output = Result<()>>,
 541        M: RequestMessage,
 542    {
 543        let handler = Arc::new(handler);
 544        self.add_handler(move |envelope, session| {
 545            let receipt = envelope.receipt();
 546            let handler = handler.clone();
 547            async move {
 548                let peer = session.peer.clone();
 549                let responded = Arc::new(AtomicBool::default());
 550                let response = Response {
 551                    peer: peer.clone(),
 552                    responded: responded.clone(),
 553                    receipt,
 554                };
 555                match (handler)(envelope.payload, response, session).await {
 556                    Ok(()) => {
 557                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 558                            Ok(())
 559                        } else {
 560                            Err(anyhow!("handler did not send a response"))?
 561                        }
 562                    }
 563                    Err(error) => {
 564                        let proto_err = match &error {
 565                            Error::Internal(err) => err.to_proto(),
 566                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 567                        };
 568                        peer.respond_with_error(receipt, proto_err)?;
 569                        Err(error)
 570                    }
 571                }
 572            }
 573        })
 574    }
 575
 576    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 577    where
 578        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 579        Fut: Send + Future<Output = Result<()>>,
 580        M: RequestMessage,
 581    {
 582        let handler = Arc::new(handler);
 583        self.add_handler(move |envelope, session| {
 584            let receipt = envelope.receipt();
 585            let handler = handler.clone();
 586            async move {
 587                let peer = session.peer.clone();
 588                let response = StreamingResponse {
 589                    peer: peer.clone(),
 590                    receipt,
 591                };
 592                match (handler)(envelope.payload, response, session).await {
 593                    Ok(()) => {
 594                        peer.end_stream(receipt)?;
 595                        Ok(())
 596                    }
 597                    Err(error) => {
 598                        let proto_err = match &error {
 599                            Error::Internal(err) => err.to_proto(),
 600                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 601                        };
 602                        peer.respond_with_error(receipt, proto_err)?;
 603                        Err(error)
 604                    }
 605                }
 606            }
 607        })
 608    }
 609
 610    #[allow(clippy::too_many_arguments)]
 611    pub fn handle_connection(
 612        self: &Arc<Self>,
 613        connection: Connection,
 614        address: String,
 615        user: User,
 616        zed_version: ZedVersion,
 617        impersonator: Option<User>,
 618        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 619        executor: Executor,
 620    ) -> impl Future<Output = ()> {
 621        let this = self.clone();
 622        let user_id = user.id;
 623        let login = user.github_login.clone();
 624        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty);
 625        if let Some(impersonator) = impersonator {
 626            span.record("impersonator", &impersonator.github_login);
 627        }
 628        let mut teardown = self.teardown.subscribe();
 629        async move {
 630            if *teardown.borrow() {
 631                tracing::error!("server is tearing down");
 632                return
 633            }
 634            let (connection_id, handle_io, mut incoming_rx) = this
 635                .peer
 636                .add_connection(connection, {
 637                    let executor = executor.clone();
 638                    move |duration| executor.sleep(duration)
 639                });
 640            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 641            tracing::info!("connection opened");
 642
 643            let http_client = match IsahcHttpClient::new() {
 644                Ok(http_client) => http_client,
 645                Err(error) => {
 646                    tracing::error!(?error, "failed to create HTTP client");
 647                    return;
 648                }
 649            };
 650
 651            let session = Session {
 652                user_id,
 653                connection_id,
 654                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 655                peer: this.peer.clone(),
 656                connection_pool: this.connection_pool.clone(),
 657                live_kit_client: this.app_state.live_kit_client.clone(),
 658                http_client,
 659                rate_limiter: this.app_state.rate_limiter.clone(),
 660                _executor: executor.clone(),
 661            };
 662
 663            if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
 664                tracing::error!(?error, "failed to send initial client update");
 665                return;
 666            }
 667
 668            let handle_io = handle_io.fuse();
 669            futures::pin_mut!(handle_io);
 670
 671            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 672            // This prevents deadlocks when e.g., client A performs a request to client B and
 673            // client B performs a request to client A. If both clients stop processing further
 674            // messages until their respective request completes, they won't have a chance to
 675            // respond to the other client's request and cause a deadlock.
 676            //
 677            // This arrangement ensures we will attempt to process earlier messages first, but fall
 678            // back to processing messages arrived later in the spirit of making progress.
 679            let mut foreground_message_handlers = FuturesUnordered::new();
 680            let concurrent_handlers = Arc::new(Semaphore::new(256));
 681            loop {
 682                let next_message = async {
 683                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 684                    let message = incoming_rx.next().await;
 685                    (permit, message)
 686                }.fuse();
 687                futures::pin_mut!(next_message);
 688                futures::select_biased! {
 689                    _ = teardown.changed().fuse() => return,
 690                    result = handle_io => {
 691                        if let Err(error) = result {
 692                            tracing::error!(?error, "error handling I/O");
 693                        }
 694                        break;
 695                    }
 696                    _ = foreground_message_handlers.next() => {}
 697                    next_message = next_message => {
 698                        let (permit, message) = next_message;
 699                        if let Some(message) = message {
 700                            let type_name = message.payload_type_name();
 701                            // note: we copy all the fields from the parent span so we can query them in the logs.
 702                            // (https://github.com/tokio-rs/tracing/issues/2670).
 703                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 704                            let span_enter = span.enter();
 705                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 706                                let is_background = message.is_background();
 707                                let handle_message = (handler)(message, session.clone());
 708                                drop(span_enter);
 709
 710                                let handle_message = async move {
 711                                    handle_message.await;
 712                                    drop(permit);
 713                                }.instrument(span);
 714                                if is_background {
 715                                    executor.spawn_detached(handle_message);
 716                                } else {
 717                                    foreground_message_handlers.push(handle_message);
 718                                }
 719                            } else {
 720                                tracing::error!("no message handler");
 721                            }
 722                        } else {
 723                            tracing::info!("connection closed");
 724                            break;
 725                        }
 726                    }
 727                }
 728            }
 729
 730            drop(foreground_message_handlers);
 731            tracing::info!("signing out");
 732            if let Err(error) = connection_lost(session, teardown, executor).await {
 733                tracing::error!(?error, "error signing out");
 734            }
 735
 736        }.instrument(span)
 737    }
 738
 739    async fn send_initial_client_update(
 740        &self,
 741        connection_id: ConnectionId,
 742        user: User,
 743        zed_version: ZedVersion,
 744        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 745        session: &Session,
 746    ) -> Result<()> {
 747        self.peer.send(
 748            connection_id,
 749            proto::Hello {
 750                peer_id: Some(connection_id.into()),
 751            },
 752        )?;
 753        tracing::info!("sent hello message");
 754
 755        if let Some(send_connection_id) = send_connection_id.take() {
 756            let _ = send_connection_id.send(connection_id);
 757        }
 758
 759        if !user.connected_once {
 760            self.peer.send(connection_id, proto::ShowContacts {})?;
 761            self.app_state
 762                .db
 763                .set_user_connected_once(user.id, true)
 764                .await?;
 765        }
 766
 767        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 768            self.app_state.db.get_contacts(user.id),
 769            self.app_state.db.get_channels_for_user(user.id),
 770            self.app_state.db.get_channel_invites_for_user(user.id),
 771        )
 772        .await?;
 773
 774        {
 775            let mut pool = self.connection_pool.lock();
 776            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 777            for membership in &channels_for_user.channel_memberships {
 778                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 779            }
 780            self.peer.send(
 781                connection_id,
 782                build_initial_contacts_update(contacts, &pool),
 783            )?;
 784            self.peer.send(
 785                connection_id,
 786                build_update_user_channels(&channels_for_user),
 787            )?;
 788            self.peer.send(
 789                connection_id,
 790                build_channels_update(channels_for_user, channel_invites),
 791            )?;
 792        }
 793
 794        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 795            self.peer.send(connection_id, incoming_call)?;
 796        }
 797
 798        update_user_contacts(user.id, &session).await?;
 799        Ok(())
 800    }
 801
 802    pub async fn invite_code_redeemed(
 803        self: &Arc<Self>,
 804        inviter_id: UserId,
 805        invitee_id: UserId,
 806    ) -> Result<()> {
 807        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 808            if let Some(code) = &user.invite_code {
 809                let pool = self.connection_pool.lock();
 810                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 811                for connection_id in pool.user_connection_ids(inviter_id) {
 812                    self.peer.send(
 813                        connection_id,
 814                        proto::UpdateContacts {
 815                            contacts: vec![invitee_contact.clone()],
 816                            ..Default::default()
 817                        },
 818                    )?;
 819                    self.peer.send(
 820                        connection_id,
 821                        proto::UpdateInviteInfo {
 822                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 823                            count: user.invite_count as u32,
 824                        },
 825                    )?;
 826                }
 827            }
 828        }
 829        Ok(())
 830    }
 831
 832    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 833        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 834            if let Some(invite_code) = &user.invite_code {
 835                let pool = self.connection_pool.lock();
 836                for connection_id in pool.user_connection_ids(user_id) {
 837                    self.peer.send(
 838                        connection_id,
 839                        proto::UpdateInviteInfo {
 840                            url: format!(
 841                                "{}{}",
 842                                self.app_state.config.invite_link_prefix, invite_code
 843                            ),
 844                            count: user.invite_count as u32,
 845                        },
 846                    )?;
 847                }
 848            }
 849        }
 850        Ok(())
 851    }
 852
 853    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 854        ServerSnapshot {
 855            connection_pool: ConnectionPoolGuard {
 856                guard: self.connection_pool.lock(),
 857                _not_send: PhantomData,
 858            },
 859            peer: &self.peer,
 860        }
 861    }
 862}
 863
 864impl<'a> Deref for ConnectionPoolGuard<'a> {
 865    type Target = ConnectionPool;
 866
 867    fn deref(&self) -> &Self::Target {
 868        &self.guard
 869    }
 870}
 871
 872impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 873    fn deref_mut(&mut self) -> &mut Self::Target {
 874        &mut self.guard
 875    }
 876}
 877
 878impl<'a> Drop for ConnectionPoolGuard<'a> {
 879    fn drop(&mut self) {
 880        #[cfg(test)]
 881        self.check_invariants();
 882    }
 883}
 884
 885fn broadcast<F>(
 886    sender_id: Option<ConnectionId>,
 887    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 888    mut f: F,
 889) where
 890    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 891{
 892    for receiver_id in receiver_ids {
 893        if Some(receiver_id) != sender_id {
 894            if let Err(error) = f(receiver_id) {
 895                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 896            }
 897        }
 898    }
 899}
 900
 901pub struct ProtocolVersion(u32);
 902
 903impl Header for ProtocolVersion {
 904    fn name() -> &'static HeaderName {
 905        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 906        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 907    }
 908
 909    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 910    where
 911        Self: Sized,
 912        I: Iterator<Item = &'i axum::http::HeaderValue>,
 913    {
 914        let version = values
 915            .next()
 916            .ok_or_else(axum::headers::Error::invalid)?
 917            .to_str()
 918            .map_err(|_| axum::headers::Error::invalid())?
 919            .parse()
 920            .map_err(|_| axum::headers::Error::invalid())?;
 921        Ok(Self(version))
 922    }
 923
 924    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 925        values.extend([self.0.to_string().parse().unwrap()]);
 926    }
 927}
 928
 929pub struct AppVersionHeader(SemanticVersion);
 930impl Header for AppVersionHeader {
 931    fn name() -> &'static HeaderName {
 932        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
 933        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
 934    }
 935
 936    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 937    where
 938        Self: Sized,
 939        I: Iterator<Item = &'i axum::http::HeaderValue>,
 940    {
 941        let version = values
 942            .next()
 943            .ok_or_else(axum::headers::Error::invalid)?
 944            .to_str()
 945            .map_err(|_| axum::headers::Error::invalid())?
 946            .parse()
 947            .map_err(|_| axum::headers::Error::invalid())?;
 948        Ok(Self(version))
 949    }
 950
 951    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 952        values.extend([self.0.to_string().parse().unwrap()]);
 953    }
 954}
 955
 956pub fn routes(server: Arc<Server>) -> Router<(), Body> {
 957    Router::new()
 958        .route("/rpc", get(handle_websocket_request))
 959        .layer(
 960            ServiceBuilder::new()
 961                .layer(Extension(server.app_state.clone()))
 962                .layer(middleware::from_fn(auth::validate_header)),
 963        )
 964        .route("/metrics", get(handle_metrics))
 965        .layer(Extension(server))
 966}
 967
 968pub async fn handle_websocket_request(
 969    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 970    app_version_header: Option<TypedHeader<AppVersionHeader>>,
 971    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 972    Extension(server): Extension<Arc<Server>>,
 973    Extension(user): Extension<User>,
 974    Extension(impersonator): Extension<Impersonator>,
 975    ws: WebSocketUpgrade,
 976) -> axum::response::Response {
 977    if protocol_version != rpc::PROTOCOL_VERSION {
 978        return (
 979            StatusCode::UPGRADE_REQUIRED,
 980            "client must be upgraded".to_string(),
 981        )
 982            .into_response();
 983    }
 984
 985    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
 986        return (
 987            StatusCode::UPGRADE_REQUIRED,
 988            "no version header found".to_string(),
 989        )
 990            .into_response();
 991    };
 992
 993    if !version.can_collaborate() {
 994        return (
 995            StatusCode::UPGRADE_REQUIRED,
 996            "client must be upgraded".to_string(),
 997        )
 998            .into_response();
 999    }
1000
1001    let socket_address = socket_address.to_string();
1002    ws.on_upgrade(move |socket| {
1003        let socket = socket
1004            .map_ok(to_tungstenite_message)
1005            .err_into()
1006            .with(|message| async move { Ok(to_axum_message(message)) });
1007        let connection = Connection::new(Box::pin(socket));
1008        async move {
1009            server
1010                .handle_connection(
1011                    connection,
1012                    socket_address,
1013                    user,
1014                    version,
1015                    impersonator.0,
1016                    None,
1017                    Executor::Production,
1018                )
1019                .await;
1020        }
1021    })
1022}
1023
1024pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1025    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1026    let connections_metric = CONNECTIONS_METRIC
1027        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1028
1029    let connections = server
1030        .connection_pool
1031        .lock()
1032        .connections()
1033        .filter(|connection| !connection.admin)
1034        .count();
1035    connections_metric.set(connections as _);
1036
1037    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1038    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1039        register_int_gauge!(
1040            "shared_projects",
1041            "number of open projects with one or more guests"
1042        )
1043        .unwrap()
1044    });
1045
1046    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1047    shared_projects_metric.set(shared_projects as _);
1048
1049    let encoder = prometheus::TextEncoder::new();
1050    let metric_families = prometheus::gather();
1051    let encoded_metrics = encoder
1052        .encode_to_string(&metric_families)
1053        .map_err(|err| anyhow!("{}", err))?;
1054    Ok(encoded_metrics)
1055}
1056
1057#[instrument(err, skip(executor))]
1058async fn connection_lost(
1059    session: Session,
1060    mut teardown: watch::Receiver<bool>,
1061    executor: Executor,
1062) -> Result<()> {
1063    session.peer.disconnect(session.connection_id);
1064    session
1065        .connection_pool()
1066        .await
1067        .remove_connection(session.connection_id)?;
1068
1069    session
1070        .db()
1071        .await
1072        .connection_lost(session.connection_id)
1073        .await
1074        .trace_err();
1075
1076    futures::select_biased! {
1077        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1078            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
1079            leave_room_for_session(&session).await.trace_err();
1080            leave_channel_buffers_for_session(&session)
1081                .await
1082                .trace_err();
1083
1084            if !session
1085                .connection_pool()
1086                .await
1087                .is_user_online(session.user_id)
1088            {
1089                let db = session.db().await;
1090                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
1091                    room_updated(&room, &session.peer);
1092                }
1093            }
1094
1095            update_user_contacts(session.user_id, &session).await?;
1096        }
1097        _ = teardown.changed().fuse() => {}
1098    }
1099
1100    Ok(())
1101}
1102
1103/// Acknowledges a ping from a client, used to keep the connection alive.
1104async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1105    response.send(proto::Ack {})?;
1106    Ok(())
1107}
1108
1109/// Creates a new room for calling (outside of channels)
1110async fn create_room(
1111    _request: proto::CreateRoom,
1112    response: Response<proto::CreateRoom>,
1113    session: Session,
1114) -> Result<()> {
1115    let live_kit_room = nanoid::nanoid!(30);
1116
1117    let live_kit_connection_info = {
1118        let live_kit_room = live_kit_room.clone();
1119        let live_kit = session.live_kit_client.as_ref();
1120
1121        util::async_maybe!({
1122            let live_kit = live_kit?;
1123
1124            let token = live_kit
1125                .room_token(&live_kit_room, &session.user_id.to_string())
1126                .trace_err()?;
1127
1128            Some(proto::LiveKitConnectionInfo {
1129                server_url: live_kit.url().into(),
1130                token,
1131                can_publish: true,
1132            })
1133        })
1134    }
1135    .await;
1136
1137    let room = session
1138        .db()
1139        .await
1140        .create_room(session.user_id, session.connection_id, &live_kit_room)
1141        .await?;
1142
1143    response.send(proto::CreateRoomResponse {
1144        room: Some(room.clone()),
1145        live_kit_connection_info,
1146    })?;
1147
1148    update_user_contacts(session.user_id, &session).await?;
1149    Ok(())
1150}
1151
1152/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1153async fn join_room(
1154    request: proto::JoinRoom,
1155    response: Response<proto::JoinRoom>,
1156    session: Session,
1157) -> Result<()> {
1158    let room_id = RoomId::from_proto(request.id);
1159
1160    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1161
1162    if let Some(channel_id) = channel_id {
1163        return join_channel_internal(channel_id, Box::new(response), session).await;
1164    }
1165
1166    let joined_room = {
1167        let room = session
1168            .db()
1169            .await
1170            .join_room(room_id, session.user_id, session.connection_id)
1171            .await?;
1172        room_updated(&room.room, &session.peer);
1173        room.into_inner()
1174    };
1175
1176    for connection_id in session
1177        .connection_pool()
1178        .await
1179        .user_connection_ids(session.user_id)
1180    {
1181        session
1182            .peer
1183            .send(
1184                connection_id,
1185                proto::CallCanceled {
1186                    room_id: room_id.to_proto(),
1187                },
1188            )
1189            .trace_err();
1190    }
1191
1192    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1193        if let Some(token) = live_kit
1194            .room_token(
1195                &joined_room.room.live_kit_room,
1196                &session.user_id.to_string(),
1197            )
1198            .trace_err()
1199        {
1200            Some(proto::LiveKitConnectionInfo {
1201                server_url: live_kit.url().into(),
1202                token,
1203                can_publish: true,
1204            })
1205        } else {
1206            None
1207        }
1208    } else {
1209        None
1210    };
1211
1212    response.send(proto::JoinRoomResponse {
1213        room: Some(joined_room.room),
1214        channel_id: None,
1215        live_kit_connection_info,
1216    })?;
1217
1218    update_user_contacts(session.user_id, &session).await?;
1219    Ok(())
1220}
1221
1222/// Rejoin room is used to reconnect to a room after connection errors.
1223async fn rejoin_room(
1224    request: proto::RejoinRoom,
1225    response: Response<proto::RejoinRoom>,
1226    session: Session,
1227) -> Result<()> {
1228    let room;
1229    let channel;
1230    {
1231        let mut rejoined_room = session
1232            .db()
1233            .await
1234            .rejoin_room(request, session.user_id, session.connection_id)
1235            .await?;
1236
1237        response.send(proto::RejoinRoomResponse {
1238            room: Some(rejoined_room.room.clone()),
1239            reshared_projects: rejoined_room
1240                .reshared_projects
1241                .iter()
1242                .map(|project| proto::ResharedProject {
1243                    id: project.id.to_proto(),
1244                    collaborators: project
1245                        .collaborators
1246                        .iter()
1247                        .map(|collaborator| collaborator.to_proto())
1248                        .collect(),
1249                })
1250                .collect(),
1251            rejoined_projects: rejoined_room
1252                .rejoined_projects
1253                .iter()
1254                .map(|rejoined_project| proto::RejoinedProject {
1255                    id: rejoined_project.id.to_proto(),
1256                    worktrees: rejoined_project
1257                        .worktrees
1258                        .iter()
1259                        .map(|worktree| proto::WorktreeMetadata {
1260                            id: worktree.id,
1261                            root_name: worktree.root_name.clone(),
1262                            visible: worktree.visible,
1263                            abs_path: worktree.abs_path.clone(),
1264                        })
1265                        .collect(),
1266                    collaborators: rejoined_project
1267                        .collaborators
1268                        .iter()
1269                        .map(|collaborator| collaborator.to_proto())
1270                        .collect(),
1271                    language_servers: rejoined_project.language_servers.clone(),
1272                })
1273                .collect(),
1274        })?;
1275        room_updated(&rejoined_room.room, &session.peer);
1276
1277        for project in &rejoined_room.reshared_projects {
1278            for collaborator in &project.collaborators {
1279                session
1280                    .peer
1281                    .send(
1282                        collaborator.connection_id,
1283                        proto::UpdateProjectCollaborator {
1284                            project_id: project.id.to_proto(),
1285                            old_peer_id: Some(project.old_connection_id.into()),
1286                            new_peer_id: Some(session.connection_id.into()),
1287                        },
1288                    )
1289                    .trace_err();
1290            }
1291
1292            broadcast(
1293                Some(session.connection_id),
1294                project
1295                    .collaborators
1296                    .iter()
1297                    .map(|collaborator| collaborator.connection_id),
1298                |connection_id| {
1299                    session.peer.forward_send(
1300                        session.connection_id,
1301                        connection_id,
1302                        proto::UpdateProject {
1303                            project_id: project.id.to_proto(),
1304                            worktrees: project.worktrees.clone(),
1305                        },
1306                    )
1307                },
1308            );
1309        }
1310
1311        for project in &rejoined_room.rejoined_projects {
1312            for collaborator in &project.collaborators {
1313                session
1314                    .peer
1315                    .send(
1316                        collaborator.connection_id,
1317                        proto::UpdateProjectCollaborator {
1318                            project_id: project.id.to_proto(),
1319                            old_peer_id: Some(project.old_connection_id.into()),
1320                            new_peer_id: Some(session.connection_id.into()),
1321                        },
1322                    )
1323                    .trace_err();
1324            }
1325        }
1326
1327        for project in &mut rejoined_room.rejoined_projects {
1328            for worktree in mem::take(&mut project.worktrees) {
1329                #[cfg(any(test, feature = "test-support"))]
1330                const MAX_CHUNK_SIZE: usize = 2;
1331                #[cfg(not(any(test, feature = "test-support")))]
1332                const MAX_CHUNK_SIZE: usize = 256;
1333
1334                // Stream this worktree's entries.
1335                let message = proto::UpdateWorktree {
1336                    project_id: project.id.to_proto(),
1337                    worktree_id: worktree.id,
1338                    abs_path: worktree.abs_path.clone(),
1339                    root_name: worktree.root_name,
1340                    updated_entries: worktree.updated_entries,
1341                    removed_entries: worktree.removed_entries,
1342                    scan_id: worktree.scan_id,
1343                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1344                    updated_repositories: worktree.updated_repositories,
1345                    removed_repositories: worktree.removed_repositories,
1346                };
1347                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1348                    session.peer.send(session.connection_id, update.clone())?;
1349                }
1350
1351                // Stream this worktree's diagnostics.
1352                for summary in worktree.diagnostic_summaries {
1353                    session.peer.send(
1354                        session.connection_id,
1355                        proto::UpdateDiagnosticSummary {
1356                            project_id: project.id.to_proto(),
1357                            worktree_id: worktree.id,
1358                            summary: Some(summary),
1359                        },
1360                    )?;
1361                }
1362
1363                for settings_file in worktree.settings_files {
1364                    session.peer.send(
1365                        session.connection_id,
1366                        proto::UpdateWorktreeSettings {
1367                            project_id: project.id.to_proto(),
1368                            worktree_id: worktree.id,
1369                            path: settings_file.path,
1370                            content: Some(settings_file.content),
1371                        },
1372                    )?;
1373                }
1374            }
1375
1376            for language_server in &project.language_servers {
1377                session.peer.send(
1378                    session.connection_id,
1379                    proto::UpdateLanguageServer {
1380                        project_id: project.id.to_proto(),
1381                        language_server_id: language_server.id,
1382                        variant: Some(
1383                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1384                                proto::LspDiskBasedDiagnosticsUpdated {},
1385                            ),
1386                        ),
1387                    },
1388                )?;
1389            }
1390        }
1391
1392        let rejoined_room = rejoined_room.into_inner();
1393
1394        room = rejoined_room.room;
1395        channel = rejoined_room.channel;
1396    }
1397
1398    if let Some(channel) = channel {
1399        channel_updated(
1400            &channel,
1401            &room,
1402            &session.peer,
1403            &*session.connection_pool().await,
1404        );
1405    }
1406
1407    update_user_contacts(session.user_id, &session).await?;
1408    Ok(())
1409}
1410
1411/// leave room disconnects from the room.
1412async fn leave_room(
1413    _: proto::LeaveRoom,
1414    response: Response<proto::LeaveRoom>,
1415    session: Session,
1416) -> Result<()> {
1417    leave_room_for_session(&session).await?;
1418    response.send(proto::Ack {})?;
1419    Ok(())
1420}
1421
1422/// Updates the permissions of someone else in the room.
1423async fn set_room_participant_role(
1424    request: proto::SetRoomParticipantRole,
1425    response: Response<proto::SetRoomParticipantRole>,
1426    session: Session,
1427) -> Result<()> {
1428    let user_id = UserId::from_proto(request.user_id);
1429    let role = ChannelRole::from(request.role());
1430
1431    let (live_kit_room, can_publish) = {
1432        let room = session
1433            .db()
1434            .await
1435            .set_room_participant_role(
1436                session.user_id,
1437                RoomId::from_proto(request.room_id),
1438                user_id,
1439                role,
1440            )
1441            .await?;
1442
1443        let live_kit_room = room.live_kit_room.clone();
1444        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1445        room_updated(&room, &session.peer);
1446        (live_kit_room, can_publish)
1447    };
1448
1449    if let Some(live_kit) = session.live_kit_client.as_ref() {
1450        live_kit
1451            .update_participant(
1452                live_kit_room.clone(),
1453                request.user_id.to_string(),
1454                live_kit_server::proto::ParticipantPermission {
1455                    can_subscribe: true,
1456                    can_publish,
1457                    can_publish_data: can_publish,
1458                    hidden: false,
1459                    recorder: false,
1460                },
1461            )
1462            .await
1463            .trace_err();
1464    }
1465
1466    response.send(proto::Ack {})?;
1467    Ok(())
1468}
1469
1470/// Call someone else into the current room
1471async fn call(
1472    request: proto::Call,
1473    response: Response<proto::Call>,
1474    session: Session,
1475) -> Result<()> {
1476    let room_id = RoomId::from_proto(request.room_id);
1477    let calling_user_id = session.user_id;
1478    let calling_connection_id = session.connection_id;
1479    let called_user_id = UserId::from_proto(request.called_user_id);
1480    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1481    if !session
1482        .db()
1483        .await
1484        .has_contact(calling_user_id, called_user_id)
1485        .await?
1486    {
1487        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1488    }
1489
1490    let incoming_call = {
1491        let (room, incoming_call) = &mut *session
1492            .db()
1493            .await
1494            .call(
1495                room_id,
1496                calling_user_id,
1497                calling_connection_id,
1498                called_user_id,
1499                initial_project_id,
1500            )
1501            .await?;
1502        room_updated(&room, &session.peer);
1503        mem::take(incoming_call)
1504    };
1505    update_user_contacts(called_user_id, &session).await?;
1506
1507    let mut calls = session
1508        .connection_pool()
1509        .await
1510        .user_connection_ids(called_user_id)
1511        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1512        .collect::<FuturesUnordered<_>>();
1513
1514    while let Some(call_response) = calls.next().await {
1515        match call_response.as_ref() {
1516            Ok(_) => {
1517                response.send(proto::Ack {})?;
1518                return Ok(());
1519            }
1520            Err(_) => {
1521                call_response.trace_err();
1522            }
1523        }
1524    }
1525
1526    {
1527        let room = session
1528            .db()
1529            .await
1530            .call_failed(room_id, called_user_id)
1531            .await?;
1532        room_updated(&room, &session.peer);
1533    }
1534    update_user_contacts(called_user_id, &session).await?;
1535
1536    Err(anyhow!("failed to ring user"))?
1537}
1538
1539/// Cancel an outgoing call.
1540async fn cancel_call(
1541    request: proto::CancelCall,
1542    response: Response<proto::CancelCall>,
1543    session: Session,
1544) -> Result<()> {
1545    let called_user_id = UserId::from_proto(request.called_user_id);
1546    let room_id = RoomId::from_proto(request.room_id);
1547    {
1548        let room = session
1549            .db()
1550            .await
1551            .cancel_call(room_id, session.connection_id, called_user_id)
1552            .await?;
1553        room_updated(&room, &session.peer);
1554    }
1555
1556    for connection_id in session
1557        .connection_pool()
1558        .await
1559        .user_connection_ids(called_user_id)
1560    {
1561        session
1562            .peer
1563            .send(
1564                connection_id,
1565                proto::CallCanceled {
1566                    room_id: room_id.to_proto(),
1567                },
1568            )
1569            .trace_err();
1570    }
1571    response.send(proto::Ack {})?;
1572
1573    update_user_contacts(called_user_id, &session).await?;
1574    Ok(())
1575}
1576
1577/// Decline an incoming call.
1578async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1579    let room_id = RoomId::from_proto(message.room_id);
1580    {
1581        let room = session
1582            .db()
1583            .await
1584            .decline_call(Some(room_id), session.user_id)
1585            .await?
1586            .ok_or_else(|| anyhow!("failed to decline call"))?;
1587        room_updated(&room, &session.peer);
1588    }
1589
1590    for connection_id in session
1591        .connection_pool()
1592        .await
1593        .user_connection_ids(session.user_id)
1594    {
1595        session
1596            .peer
1597            .send(
1598                connection_id,
1599                proto::CallCanceled {
1600                    room_id: room_id.to_proto(),
1601                },
1602            )
1603            .trace_err();
1604    }
1605    update_user_contacts(session.user_id, &session).await?;
1606    Ok(())
1607}
1608
1609/// Updates other participants in the room with your current location.
1610async fn update_participant_location(
1611    request: proto::UpdateParticipantLocation,
1612    response: Response<proto::UpdateParticipantLocation>,
1613    session: Session,
1614) -> Result<()> {
1615    let room_id = RoomId::from_proto(request.room_id);
1616    let location = request
1617        .location
1618        .ok_or_else(|| anyhow!("invalid location"))?;
1619
1620    let db = session.db().await;
1621    let room = db
1622        .update_room_participant_location(room_id, session.connection_id, location)
1623        .await?;
1624
1625    room_updated(&room, &session.peer);
1626    response.send(proto::Ack {})?;
1627    Ok(())
1628}
1629
1630/// Share a project into the room.
1631async fn share_project(
1632    request: proto::ShareProject,
1633    response: Response<proto::ShareProject>,
1634    session: Session,
1635) -> Result<()> {
1636    let (project_id, room) = &*session
1637        .db()
1638        .await
1639        .share_project(
1640            RoomId::from_proto(request.room_id),
1641            session.connection_id,
1642            &request.worktrees,
1643        )
1644        .await?;
1645    response.send(proto::ShareProjectResponse {
1646        project_id: project_id.to_proto(),
1647    })?;
1648    room_updated(&room, &session.peer);
1649
1650    Ok(())
1651}
1652
1653/// Unshare a project from the room.
1654async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1655    let project_id = ProjectId::from_proto(message.project_id);
1656
1657    let (room, guest_connection_ids) = &*session
1658        .db()
1659        .await
1660        .unshare_project(project_id, session.connection_id)
1661        .await?;
1662
1663    broadcast(
1664        Some(session.connection_id),
1665        guest_connection_ids.iter().copied(),
1666        |conn_id| session.peer.send(conn_id, message.clone()),
1667    );
1668    room_updated(&room, &session.peer);
1669
1670    Ok(())
1671}
1672
1673/// Join someone elses shared project.
1674async fn join_project(
1675    request: proto::JoinProject,
1676    response: Response<proto::JoinProject>,
1677    session: Session,
1678) -> Result<()> {
1679    let project_id = ProjectId::from_proto(request.project_id);
1680
1681    tracing::info!(%project_id, "join project");
1682
1683    let (project, replica_id) = &mut *session
1684        .db()
1685        .await
1686        .join_project_in_room(project_id, session.connection_id)
1687        .await?;
1688
1689    join_project_internal(response, session, project, replica_id)
1690}
1691
1692trait JoinProjectInternalResponse {
1693    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1694}
1695impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1696    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1697        Response::<proto::JoinProject>::send(self, result)
1698    }
1699}
1700impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1701    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1702        Response::<proto::JoinHostedProject>::send(self, result)
1703    }
1704}
1705
1706fn join_project_internal(
1707    response: impl JoinProjectInternalResponse,
1708    session: Session,
1709    project: &mut Project,
1710    replica_id: &ReplicaId,
1711) -> Result<()> {
1712    let collaborators = project
1713        .collaborators
1714        .iter()
1715        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1716        .map(|collaborator| collaborator.to_proto())
1717        .collect::<Vec<_>>();
1718    let project_id = project.id;
1719    let guest_user_id = session.user_id;
1720
1721    let worktrees = project
1722        .worktrees
1723        .iter()
1724        .map(|(id, worktree)| proto::WorktreeMetadata {
1725            id: *id,
1726            root_name: worktree.root_name.clone(),
1727            visible: worktree.visible,
1728            abs_path: worktree.abs_path.clone(),
1729        })
1730        .collect::<Vec<_>>();
1731
1732    for collaborator in &collaborators {
1733        session
1734            .peer
1735            .send(
1736                collaborator.peer_id.unwrap().into(),
1737                proto::AddProjectCollaborator {
1738                    project_id: project_id.to_proto(),
1739                    collaborator: Some(proto::Collaborator {
1740                        peer_id: Some(session.connection_id.into()),
1741                        replica_id: replica_id.0 as u32,
1742                        user_id: guest_user_id.to_proto(),
1743                    }),
1744                },
1745            )
1746            .trace_err();
1747    }
1748
1749    // First, we send the metadata associated with each worktree.
1750    response.send(proto::JoinProjectResponse {
1751        project_id: project.id.0 as u64,
1752        worktrees: worktrees.clone(),
1753        replica_id: replica_id.0 as u32,
1754        collaborators: collaborators.clone(),
1755        language_servers: project.language_servers.clone(),
1756        role: project.role.into(), // todo
1757    })?;
1758
1759    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1760        #[cfg(any(test, feature = "test-support"))]
1761        const MAX_CHUNK_SIZE: usize = 2;
1762        #[cfg(not(any(test, feature = "test-support")))]
1763        const MAX_CHUNK_SIZE: usize = 256;
1764
1765        // Stream this worktree's entries.
1766        let message = proto::UpdateWorktree {
1767            project_id: project_id.to_proto(),
1768            worktree_id,
1769            abs_path: worktree.abs_path.clone(),
1770            root_name: worktree.root_name,
1771            updated_entries: worktree.entries,
1772            removed_entries: Default::default(),
1773            scan_id: worktree.scan_id,
1774            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1775            updated_repositories: worktree.repository_entries.into_values().collect(),
1776            removed_repositories: Default::default(),
1777        };
1778        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1779            session.peer.send(session.connection_id, update.clone())?;
1780        }
1781
1782        // Stream this worktree's diagnostics.
1783        for summary in worktree.diagnostic_summaries {
1784            session.peer.send(
1785                session.connection_id,
1786                proto::UpdateDiagnosticSummary {
1787                    project_id: project_id.to_proto(),
1788                    worktree_id: worktree.id,
1789                    summary: Some(summary),
1790                },
1791            )?;
1792        }
1793
1794        for settings_file in worktree.settings_files {
1795            session.peer.send(
1796                session.connection_id,
1797                proto::UpdateWorktreeSettings {
1798                    project_id: project_id.to_proto(),
1799                    worktree_id: worktree.id,
1800                    path: settings_file.path,
1801                    content: Some(settings_file.content),
1802                },
1803            )?;
1804        }
1805    }
1806
1807    for language_server in &project.language_servers {
1808        session.peer.send(
1809            session.connection_id,
1810            proto::UpdateLanguageServer {
1811                project_id: project_id.to_proto(),
1812                language_server_id: language_server.id,
1813                variant: Some(
1814                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1815                        proto::LspDiskBasedDiagnosticsUpdated {},
1816                    ),
1817                ),
1818            },
1819        )?;
1820    }
1821
1822    Ok(())
1823}
1824
1825/// Leave someone elses shared project.
1826async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1827    let sender_id = session.connection_id;
1828    let project_id = ProjectId::from_proto(request.project_id);
1829    let db = session.db().await;
1830    if db.is_hosted_project(project_id).await? {
1831        let project = db.leave_hosted_project(project_id, sender_id).await?;
1832        project_left(&project, &session);
1833        return Ok(());
1834    }
1835
1836    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1837    tracing::info!(
1838        %project_id,
1839        host_user_id = ?project.host_user_id,
1840        host_connection_id = ?project.host_connection_id,
1841        "leave project"
1842    );
1843
1844    project_left(&project, &session);
1845    room_updated(&room, &session.peer);
1846
1847    Ok(())
1848}
1849
1850async fn join_hosted_project(
1851    request: proto::JoinHostedProject,
1852    response: Response<proto::JoinHostedProject>,
1853    session: Session,
1854) -> Result<()> {
1855    let (mut project, replica_id) = session
1856        .db()
1857        .await
1858        .join_hosted_project(
1859            ProjectId(request.project_id as i32),
1860            session.user_id,
1861            session.connection_id,
1862        )
1863        .await?;
1864
1865    join_project_internal(response, session, &mut project, &replica_id)
1866}
1867
1868/// Updates other participants with changes to the project
1869async fn update_project(
1870    request: proto::UpdateProject,
1871    response: Response<proto::UpdateProject>,
1872    session: Session,
1873) -> Result<()> {
1874    let project_id = ProjectId::from_proto(request.project_id);
1875    let (room, guest_connection_ids) = &*session
1876        .db()
1877        .await
1878        .update_project(project_id, session.connection_id, &request.worktrees)
1879        .await?;
1880    broadcast(
1881        Some(session.connection_id),
1882        guest_connection_ids.iter().copied(),
1883        |connection_id| {
1884            session
1885                .peer
1886                .forward_send(session.connection_id, connection_id, request.clone())
1887        },
1888    );
1889    room_updated(&room, &session.peer);
1890    response.send(proto::Ack {})?;
1891
1892    Ok(())
1893}
1894
1895/// Updates other participants with changes to the worktree
1896async fn update_worktree(
1897    request: proto::UpdateWorktree,
1898    response: Response<proto::UpdateWorktree>,
1899    session: Session,
1900) -> Result<()> {
1901    let guest_connection_ids = session
1902        .db()
1903        .await
1904        .update_worktree(&request, session.connection_id)
1905        .await?;
1906
1907    broadcast(
1908        Some(session.connection_id),
1909        guest_connection_ids.iter().copied(),
1910        |connection_id| {
1911            session
1912                .peer
1913                .forward_send(session.connection_id, connection_id, request.clone())
1914        },
1915    );
1916    response.send(proto::Ack {})?;
1917    Ok(())
1918}
1919
1920/// Updates other participants with changes to the diagnostics
1921async fn update_diagnostic_summary(
1922    message: proto::UpdateDiagnosticSummary,
1923    session: Session,
1924) -> Result<()> {
1925    let guest_connection_ids = session
1926        .db()
1927        .await
1928        .update_diagnostic_summary(&message, session.connection_id)
1929        .await?;
1930
1931    broadcast(
1932        Some(session.connection_id),
1933        guest_connection_ids.iter().copied(),
1934        |connection_id| {
1935            session
1936                .peer
1937                .forward_send(session.connection_id, connection_id, message.clone())
1938        },
1939    );
1940
1941    Ok(())
1942}
1943
1944/// Updates other participants with changes to the worktree settings
1945async fn update_worktree_settings(
1946    message: proto::UpdateWorktreeSettings,
1947    session: Session,
1948) -> Result<()> {
1949    let guest_connection_ids = session
1950        .db()
1951        .await
1952        .update_worktree_settings(&message, session.connection_id)
1953        .await?;
1954
1955    broadcast(
1956        Some(session.connection_id),
1957        guest_connection_ids.iter().copied(),
1958        |connection_id| {
1959            session
1960                .peer
1961                .forward_send(session.connection_id, connection_id, message.clone())
1962        },
1963    );
1964
1965    Ok(())
1966}
1967
1968/// Notify other participants that a  language server has started.
1969async fn start_language_server(
1970    request: proto::StartLanguageServer,
1971    session: Session,
1972) -> Result<()> {
1973    let guest_connection_ids = session
1974        .db()
1975        .await
1976        .start_language_server(&request, session.connection_id)
1977        .await?;
1978
1979    broadcast(
1980        Some(session.connection_id),
1981        guest_connection_ids.iter().copied(),
1982        |connection_id| {
1983            session
1984                .peer
1985                .forward_send(session.connection_id, connection_id, request.clone())
1986        },
1987    );
1988    Ok(())
1989}
1990
1991/// Notify other participants that a language server has changed.
1992async fn update_language_server(
1993    request: proto::UpdateLanguageServer,
1994    session: Session,
1995) -> Result<()> {
1996    let project_id = ProjectId::from_proto(request.project_id);
1997    let project_connection_ids = session
1998        .db()
1999        .await
2000        .project_connection_ids(project_id, session.connection_id)
2001        .await?;
2002    broadcast(
2003        Some(session.connection_id),
2004        project_connection_ids.iter().copied(),
2005        |connection_id| {
2006            session
2007                .peer
2008                .forward_send(session.connection_id, connection_id, request.clone())
2009        },
2010    );
2011    Ok(())
2012}
2013
2014/// forward a project request to the host. These requests should be read only
2015/// as guests are allowed to send them.
2016async fn forward_read_only_project_request<T>(
2017    request: T,
2018    response: Response<T>,
2019    session: Session,
2020) -> Result<()>
2021where
2022    T: EntityMessage + RequestMessage,
2023{
2024    let project_id = ProjectId::from_proto(request.remote_entity_id());
2025    let host_connection_id = session
2026        .db()
2027        .await
2028        .host_for_read_only_project_request(project_id, session.connection_id)
2029        .await?;
2030    let payload = session
2031        .peer
2032        .forward_request(session.connection_id, host_connection_id, request)
2033        .await?;
2034    response.send(payload)?;
2035    Ok(())
2036}
2037
2038/// forward a project request to the host. These requests are disallowed
2039/// for guests.
2040async fn forward_mutating_project_request<T>(
2041    request: T,
2042    response: Response<T>,
2043    session: Session,
2044) -> Result<()>
2045where
2046    T: EntityMessage + RequestMessage,
2047{
2048    let project_id = ProjectId::from_proto(request.remote_entity_id());
2049    let host_connection_id = session
2050        .db()
2051        .await
2052        .host_for_mutating_project_request(project_id, session.connection_id)
2053        .await?;
2054    let payload = session
2055        .peer
2056        .forward_request(session.connection_id, host_connection_id, request)
2057        .await?;
2058    response.send(payload)?;
2059    Ok(())
2060}
2061
2062/// Notify other participants that a new buffer has been created
2063async fn create_buffer_for_peer(
2064    request: proto::CreateBufferForPeer,
2065    session: Session,
2066) -> Result<()> {
2067    session
2068        .db()
2069        .await
2070        .check_user_is_project_host(
2071            ProjectId::from_proto(request.project_id),
2072            session.connection_id,
2073        )
2074        .await?;
2075    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2076    session
2077        .peer
2078        .forward_send(session.connection_id, peer_id.into(), request)?;
2079    Ok(())
2080}
2081
2082/// Notify other participants that a buffer has been updated. This is
2083/// allowed for guests as long as the update is limited to selections.
2084async fn update_buffer(
2085    request: proto::UpdateBuffer,
2086    response: Response<proto::UpdateBuffer>,
2087    session: Session,
2088) -> Result<()> {
2089    let project_id = ProjectId::from_proto(request.project_id);
2090    let mut guest_connection_ids;
2091    let mut host_connection_id = None;
2092
2093    let mut requires_write_permission = false;
2094
2095    for op in request.operations.iter() {
2096        match op.variant {
2097            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2098            Some(_) => requires_write_permission = true,
2099        }
2100    }
2101
2102    {
2103        let collaborators = session
2104            .db()
2105            .await
2106            .project_collaborators_for_buffer_update(
2107                project_id,
2108                session.connection_id,
2109                requires_write_permission,
2110            )
2111            .await?;
2112        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2113        for collaborator in collaborators.iter() {
2114            if collaborator.is_host {
2115                host_connection_id = Some(collaborator.connection_id);
2116            } else {
2117                guest_connection_ids.push(collaborator.connection_id);
2118            }
2119        }
2120    }
2121    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2122
2123    broadcast(
2124        Some(session.connection_id),
2125        guest_connection_ids,
2126        |connection_id| {
2127            session
2128                .peer
2129                .forward_send(session.connection_id, connection_id, request.clone())
2130        },
2131    );
2132    if host_connection_id != session.connection_id {
2133        session
2134            .peer
2135            .forward_request(session.connection_id, host_connection_id, request.clone())
2136            .await?;
2137    }
2138
2139    response.send(proto::Ack {})?;
2140    Ok(())
2141}
2142
2143/// Notify other participants that a project has been updated.
2144async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2145    request: T,
2146    session: Session,
2147) -> Result<()> {
2148    let project_id = ProjectId::from_proto(request.remote_entity_id());
2149    let project_connection_ids = session
2150        .db()
2151        .await
2152        .project_connection_ids(project_id, session.connection_id)
2153        .await?;
2154
2155    broadcast(
2156        Some(session.connection_id),
2157        project_connection_ids.iter().copied(),
2158        |connection_id| {
2159            session
2160                .peer
2161                .forward_send(session.connection_id, connection_id, request.clone())
2162        },
2163    );
2164    Ok(())
2165}
2166
2167/// Start following another user in a call.
2168async fn follow(
2169    request: proto::Follow,
2170    response: Response<proto::Follow>,
2171    session: Session,
2172) -> Result<()> {
2173    let room_id = RoomId::from_proto(request.room_id);
2174    let project_id = request.project_id.map(ProjectId::from_proto);
2175    let leader_id = request
2176        .leader_id
2177        .ok_or_else(|| anyhow!("invalid leader id"))?
2178        .into();
2179    let follower_id = session.connection_id;
2180
2181    session
2182        .db()
2183        .await
2184        .check_room_participants(room_id, leader_id, session.connection_id)
2185        .await?;
2186
2187    let response_payload = session
2188        .peer
2189        .forward_request(session.connection_id, leader_id, request)
2190        .await?;
2191    response.send(response_payload)?;
2192
2193    if let Some(project_id) = project_id {
2194        let room = session
2195            .db()
2196            .await
2197            .follow(room_id, project_id, leader_id, follower_id)
2198            .await?;
2199        room_updated(&room, &session.peer);
2200    }
2201
2202    Ok(())
2203}
2204
2205/// Stop following another user in a call.
2206async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2207    let room_id = RoomId::from_proto(request.room_id);
2208    let project_id = request.project_id.map(ProjectId::from_proto);
2209    let leader_id = request
2210        .leader_id
2211        .ok_or_else(|| anyhow!("invalid leader id"))?
2212        .into();
2213    let follower_id = session.connection_id;
2214
2215    session
2216        .db()
2217        .await
2218        .check_room_participants(room_id, leader_id, session.connection_id)
2219        .await?;
2220
2221    session
2222        .peer
2223        .forward_send(session.connection_id, leader_id, request)?;
2224
2225    if let Some(project_id) = project_id {
2226        let room = session
2227            .db()
2228            .await
2229            .unfollow(room_id, project_id, leader_id, follower_id)
2230            .await?;
2231        room_updated(&room, &session.peer);
2232    }
2233
2234    Ok(())
2235}
2236
2237/// Notify everyone following you of your current location.
2238async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2239    let room_id = RoomId::from_proto(request.room_id);
2240    let database = session.db.lock().await;
2241
2242    let connection_ids = if let Some(project_id) = request.project_id {
2243        let project_id = ProjectId::from_proto(project_id);
2244        database
2245            .project_connection_ids(project_id, session.connection_id)
2246            .await?
2247    } else {
2248        database
2249            .room_connection_ids(room_id, session.connection_id)
2250            .await?
2251    };
2252
2253    // For now, don't send view update messages back to that view's current leader.
2254    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2255        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2256        _ => None,
2257    });
2258
2259    for connection_id in connection_ids.iter().cloned() {
2260        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2261            session
2262                .peer
2263                .forward_send(session.connection_id, connection_id, request.clone())?;
2264        }
2265    }
2266    Ok(())
2267}
2268
2269/// Get public data about users.
2270async fn get_users(
2271    request: proto::GetUsers,
2272    response: Response<proto::GetUsers>,
2273    session: Session,
2274) -> Result<()> {
2275    let user_ids = request
2276        .user_ids
2277        .into_iter()
2278        .map(UserId::from_proto)
2279        .collect();
2280    let users = session
2281        .db()
2282        .await
2283        .get_users_by_ids(user_ids)
2284        .await?
2285        .into_iter()
2286        .map(|user| proto::User {
2287            id: user.id.to_proto(),
2288            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2289            github_login: user.github_login,
2290        })
2291        .collect();
2292    response.send(proto::UsersResponse { users })?;
2293    Ok(())
2294}
2295
2296/// Search for users (to invite) buy Github login
2297async fn fuzzy_search_users(
2298    request: proto::FuzzySearchUsers,
2299    response: Response<proto::FuzzySearchUsers>,
2300    session: Session,
2301) -> Result<()> {
2302    let query = request.query;
2303    let users = match query.len() {
2304        0 => vec![],
2305        1 | 2 => session
2306            .db()
2307            .await
2308            .get_user_by_github_login(&query)
2309            .await?
2310            .into_iter()
2311            .collect(),
2312        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2313    };
2314    let users = users
2315        .into_iter()
2316        .filter(|user| user.id != session.user_id)
2317        .map(|user| proto::User {
2318            id: user.id.to_proto(),
2319            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2320            github_login: user.github_login,
2321        })
2322        .collect();
2323    response.send(proto::UsersResponse { users })?;
2324    Ok(())
2325}
2326
2327/// Send a contact request to another user.
2328async fn request_contact(
2329    request: proto::RequestContact,
2330    response: Response<proto::RequestContact>,
2331    session: Session,
2332) -> Result<()> {
2333    let requester_id = session.user_id;
2334    let responder_id = UserId::from_proto(request.responder_id);
2335    if requester_id == responder_id {
2336        return Err(anyhow!("cannot add yourself as a contact"))?;
2337    }
2338
2339    let notifications = session
2340        .db()
2341        .await
2342        .send_contact_request(requester_id, responder_id)
2343        .await?;
2344
2345    // Update outgoing contact requests of requester
2346    let mut update = proto::UpdateContacts::default();
2347    update.outgoing_requests.push(responder_id.to_proto());
2348    for connection_id in session
2349        .connection_pool()
2350        .await
2351        .user_connection_ids(requester_id)
2352    {
2353        session.peer.send(connection_id, update.clone())?;
2354    }
2355
2356    // Update incoming contact requests of responder
2357    let mut update = proto::UpdateContacts::default();
2358    update
2359        .incoming_requests
2360        .push(proto::IncomingContactRequest {
2361            requester_id: requester_id.to_proto(),
2362        });
2363    let connection_pool = session.connection_pool().await;
2364    for connection_id in connection_pool.user_connection_ids(responder_id) {
2365        session.peer.send(connection_id, update.clone())?;
2366    }
2367
2368    send_notifications(&connection_pool, &session.peer, notifications);
2369
2370    response.send(proto::Ack {})?;
2371    Ok(())
2372}
2373
2374/// Accept or decline a contact request
2375async fn respond_to_contact_request(
2376    request: proto::RespondToContactRequest,
2377    response: Response<proto::RespondToContactRequest>,
2378    session: Session,
2379) -> Result<()> {
2380    let responder_id = session.user_id;
2381    let requester_id = UserId::from_proto(request.requester_id);
2382    let db = session.db().await;
2383    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2384        db.dismiss_contact_notification(responder_id, requester_id)
2385            .await?;
2386    } else {
2387        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2388
2389        let notifications = db
2390            .respond_to_contact_request(responder_id, requester_id, accept)
2391            .await?;
2392        let requester_busy = db.is_user_busy(requester_id).await?;
2393        let responder_busy = db.is_user_busy(responder_id).await?;
2394
2395        let pool = session.connection_pool().await;
2396        // Update responder with new contact
2397        let mut update = proto::UpdateContacts::default();
2398        if accept {
2399            update
2400                .contacts
2401                .push(contact_for_user(requester_id, requester_busy, &pool));
2402        }
2403        update
2404            .remove_incoming_requests
2405            .push(requester_id.to_proto());
2406        for connection_id in pool.user_connection_ids(responder_id) {
2407            session.peer.send(connection_id, update.clone())?;
2408        }
2409
2410        // Update requester with new contact
2411        let mut update = proto::UpdateContacts::default();
2412        if accept {
2413            update
2414                .contacts
2415                .push(contact_for_user(responder_id, responder_busy, &pool));
2416        }
2417        update
2418            .remove_outgoing_requests
2419            .push(responder_id.to_proto());
2420
2421        for connection_id in pool.user_connection_ids(requester_id) {
2422            session.peer.send(connection_id, update.clone())?;
2423        }
2424
2425        send_notifications(&pool, &session.peer, notifications);
2426    }
2427
2428    response.send(proto::Ack {})?;
2429    Ok(())
2430}
2431
2432/// Remove a contact.
2433async fn remove_contact(
2434    request: proto::RemoveContact,
2435    response: Response<proto::RemoveContact>,
2436    session: Session,
2437) -> Result<()> {
2438    let requester_id = session.user_id;
2439    let responder_id = UserId::from_proto(request.user_id);
2440    let db = session.db().await;
2441    let (contact_accepted, deleted_notification_id) =
2442        db.remove_contact(requester_id, responder_id).await?;
2443
2444    let pool = session.connection_pool().await;
2445    // Update outgoing contact requests of requester
2446    let mut update = proto::UpdateContacts::default();
2447    if contact_accepted {
2448        update.remove_contacts.push(responder_id.to_proto());
2449    } else {
2450        update
2451            .remove_outgoing_requests
2452            .push(responder_id.to_proto());
2453    }
2454    for connection_id in pool.user_connection_ids(requester_id) {
2455        session.peer.send(connection_id, update.clone())?;
2456    }
2457
2458    // Update incoming contact requests of responder
2459    let mut update = proto::UpdateContacts::default();
2460    if contact_accepted {
2461        update.remove_contacts.push(requester_id.to_proto());
2462    } else {
2463        update
2464            .remove_incoming_requests
2465            .push(requester_id.to_proto());
2466    }
2467    for connection_id in pool.user_connection_ids(responder_id) {
2468        session.peer.send(connection_id, update.clone())?;
2469        if let Some(notification_id) = deleted_notification_id {
2470            session.peer.send(
2471                connection_id,
2472                proto::DeleteNotification {
2473                    notification_id: notification_id.to_proto(),
2474                },
2475            )?;
2476        }
2477    }
2478
2479    response.send(proto::Ack {})?;
2480    Ok(())
2481}
2482
2483/// Creates a new channel.
2484async fn create_channel(
2485    request: proto::CreateChannel,
2486    response: Response<proto::CreateChannel>,
2487    session: Session,
2488) -> Result<()> {
2489    let db = session.db().await;
2490
2491    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2492    let (channel, membership) = db
2493        .create_channel(&request.name, parent_id, session.user_id)
2494        .await?;
2495
2496    let root_id = channel.root_id();
2497    let channel = Channel::from_model(channel);
2498
2499    response.send(proto::CreateChannelResponse {
2500        channel: Some(channel.to_proto()),
2501        parent_id: request.parent_id,
2502    })?;
2503
2504    let mut connection_pool = session.connection_pool().await;
2505    if let Some(membership) = membership {
2506        connection_pool.subscribe_to_channel(
2507            membership.user_id,
2508            membership.channel_id,
2509            membership.role,
2510        );
2511        let update = proto::UpdateUserChannels {
2512            channel_memberships: vec![proto::ChannelMembership {
2513                channel_id: membership.channel_id.to_proto(),
2514                role: membership.role.into(),
2515            }],
2516            ..Default::default()
2517        };
2518        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2519            session.peer.send(connection_id, update.clone())?;
2520        }
2521    }
2522
2523    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2524        if !role.can_see_channel(channel.visibility) {
2525            continue;
2526        }
2527
2528        let update = proto::UpdateChannels {
2529            channels: vec![channel.to_proto()],
2530            ..Default::default()
2531        };
2532        session.peer.send(connection_id, update.clone())?;
2533    }
2534
2535    Ok(())
2536}
2537
2538/// Delete a channel
2539async fn delete_channel(
2540    request: proto::DeleteChannel,
2541    response: Response<proto::DeleteChannel>,
2542    session: Session,
2543) -> Result<()> {
2544    let db = session.db().await;
2545
2546    let channel_id = request.channel_id;
2547    let (root_channel, removed_channels) = db
2548        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2549        .await?;
2550    response.send(proto::Ack {})?;
2551
2552    // Notify members of removed channels
2553    let mut update = proto::UpdateChannels::default();
2554    update
2555        .delete_channels
2556        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2557
2558    let connection_pool = session.connection_pool().await;
2559    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2560        session.peer.send(connection_id, update.clone())?;
2561    }
2562
2563    Ok(())
2564}
2565
2566/// Invite someone to join a channel.
2567async fn invite_channel_member(
2568    request: proto::InviteChannelMember,
2569    response: Response<proto::InviteChannelMember>,
2570    session: Session,
2571) -> Result<()> {
2572    let db = session.db().await;
2573    let channel_id = ChannelId::from_proto(request.channel_id);
2574    let invitee_id = UserId::from_proto(request.user_id);
2575    let InviteMemberResult {
2576        channel,
2577        notifications,
2578    } = db
2579        .invite_channel_member(
2580            channel_id,
2581            invitee_id,
2582            session.user_id,
2583            request.role().into(),
2584        )
2585        .await?;
2586
2587    let update = proto::UpdateChannels {
2588        channel_invitations: vec![channel.to_proto()],
2589        ..Default::default()
2590    };
2591
2592    let connection_pool = session.connection_pool().await;
2593    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2594        session.peer.send(connection_id, update.clone())?;
2595    }
2596
2597    send_notifications(&connection_pool, &session.peer, notifications);
2598
2599    response.send(proto::Ack {})?;
2600    Ok(())
2601}
2602
2603/// remove someone from a channel
2604async fn remove_channel_member(
2605    request: proto::RemoveChannelMember,
2606    response: Response<proto::RemoveChannelMember>,
2607    session: Session,
2608) -> Result<()> {
2609    let db = session.db().await;
2610    let channel_id = ChannelId::from_proto(request.channel_id);
2611    let member_id = UserId::from_proto(request.user_id);
2612
2613    let RemoveChannelMemberResult {
2614        membership_update,
2615        notification_id,
2616    } = db
2617        .remove_channel_member(channel_id, member_id, session.user_id)
2618        .await?;
2619
2620    let mut connection_pool = session.connection_pool().await;
2621    notify_membership_updated(
2622        &mut connection_pool,
2623        membership_update,
2624        member_id,
2625        &session.peer,
2626    );
2627    for connection_id in connection_pool.user_connection_ids(member_id) {
2628        if let Some(notification_id) = notification_id {
2629            session
2630                .peer
2631                .send(
2632                    connection_id,
2633                    proto::DeleteNotification {
2634                        notification_id: notification_id.to_proto(),
2635                    },
2636                )
2637                .trace_err();
2638        }
2639    }
2640
2641    response.send(proto::Ack {})?;
2642    Ok(())
2643}
2644
2645/// Toggle the channel between public and private.
2646/// Care is taken to maintain the invariant that public channels only descend from public channels,
2647/// (though members-only channels can appear at any point in the hierarchy).
2648async fn set_channel_visibility(
2649    request: proto::SetChannelVisibility,
2650    response: Response<proto::SetChannelVisibility>,
2651    session: Session,
2652) -> Result<()> {
2653    let db = session.db().await;
2654    let channel_id = ChannelId::from_proto(request.channel_id);
2655    let visibility = request.visibility().into();
2656
2657    let channel_model = db
2658        .set_channel_visibility(channel_id, visibility, session.user_id)
2659        .await?;
2660    let root_id = channel_model.root_id();
2661    let channel = Channel::from_model(channel_model);
2662
2663    let mut connection_pool = session.connection_pool().await;
2664    for (user_id, role) in connection_pool
2665        .channel_user_ids(root_id)
2666        .collect::<Vec<_>>()
2667        .into_iter()
2668    {
2669        let update = if role.can_see_channel(channel.visibility) {
2670            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2671            proto::UpdateChannels {
2672                channels: vec![channel.to_proto()],
2673                ..Default::default()
2674            }
2675        } else {
2676            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2677            proto::UpdateChannels {
2678                delete_channels: vec![channel.id.to_proto()],
2679                ..Default::default()
2680            }
2681        };
2682
2683        for connection_id in connection_pool.user_connection_ids(user_id) {
2684            session.peer.send(connection_id, update.clone())?;
2685        }
2686    }
2687
2688    response.send(proto::Ack {})?;
2689    Ok(())
2690}
2691
2692/// Alter the role for a user in the channel.
2693async fn set_channel_member_role(
2694    request: proto::SetChannelMemberRole,
2695    response: Response<proto::SetChannelMemberRole>,
2696    session: Session,
2697) -> Result<()> {
2698    let db = session.db().await;
2699    let channel_id = ChannelId::from_proto(request.channel_id);
2700    let member_id = UserId::from_proto(request.user_id);
2701    let result = db
2702        .set_channel_member_role(
2703            channel_id,
2704            session.user_id,
2705            member_id,
2706            request.role().into(),
2707        )
2708        .await?;
2709
2710    match result {
2711        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2712            let mut connection_pool = session.connection_pool().await;
2713            notify_membership_updated(
2714                &mut connection_pool,
2715                membership_update,
2716                member_id,
2717                &session.peer,
2718            )
2719        }
2720        db::SetMemberRoleResult::InviteUpdated(channel) => {
2721            let update = proto::UpdateChannels {
2722                channel_invitations: vec![channel.to_proto()],
2723                ..Default::default()
2724            };
2725
2726            for connection_id in session
2727                .connection_pool()
2728                .await
2729                .user_connection_ids(member_id)
2730            {
2731                session.peer.send(connection_id, update.clone())?;
2732            }
2733        }
2734    }
2735
2736    response.send(proto::Ack {})?;
2737    Ok(())
2738}
2739
2740/// Change the name of a channel
2741async fn rename_channel(
2742    request: proto::RenameChannel,
2743    response: Response<proto::RenameChannel>,
2744    session: Session,
2745) -> Result<()> {
2746    let db = session.db().await;
2747    let channel_id = ChannelId::from_proto(request.channel_id);
2748    let channel_model = db
2749        .rename_channel(channel_id, session.user_id, &request.name)
2750        .await?;
2751    let root_id = channel_model.root_id();
2752    let channel = Channel::from_model(channel_model);
2753
2754    response.send(proto::RenameChannelResponse {
2755        channel: Some(channel.to_proto()),
2756    })?;
2757
2758    let connection_pool = session.connection_pool().await;
2759    let update = proto::UpdateChannels {
2760        channels: vec![channel.to_proto()],
2761        ..Default::default()
2762    };
2763    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2764        if role.can_see_channel(channel.visibility) {
2765            session.peer.send(connection_id, update.clone())?;
2766        }
2767    }
2768
2769    Ok(())
2770}
2771
2772/// Move a channel to a new parent.
2773async fn move_channel(
2774    request: proto::MoveChannel,
2775    response: Response<proto::MoveChannel>,
2776    session: Session,
2777) -> Result<()> {
2778    let channel_id = ChannelId::from_proto(request.channel_id);
2779    let to = ChannelId::from_proto(request.to);
2780
2781    let (root_id, channels) = session
2782        .db()
2783        .await
2784        .move_channel(channel_id, to, session.user_id)
2785        .await?;
2786
2787    let connection_pool = session.connection_pool().await;
2788    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2789        let channels = channels
2790            .iter()
2791            .filter_map(|channel| {
2792                if role.can_see_channel(channel.visibility) {
2793                    Some(channel.to_proto())
2794                } else {
2795                    None
2796                }
2797            })
2798            .collect::<Vec<_>>();
2799        if channels.is_empty() {
2800            continue;
2801        }
2802
2803        let update = proto::UpdateChannels {
2804            channels,
2805            ..Default::default()
2806        };
2807
2808        session.peer.send(connection_id, update.clone())?;
2809    }
2810
2811    response.send(Ack {})?;
2812    Ok(())
2813}
2814
2815/// Get the list of channel members
2816async fn get_channel_members(
2817    request: proto::GetChannelMembers,
2818    response: Response<proto::GetChannelMembers>,
2819    session: Session,
2820) -> Result<()> {
2821    let db = session.db().await;
2822    let channel_id = ChannelId::from_proto(request.channel_id);
2823    let members = db
2824        .get_channel_participant_details(channel_id, session.user_id)
2825        .await?;
2826    response.send(proto::GetChannelMembersResponse { members })?;
2827    Ok(())
2828}
2829
2830/// Accept or decline a channel invitation.
2831async fn respond_to_channel_invite(
2832    request: proto::RespondToChannelInvite,
2833    response: Response<proto::RespondToChannelInvite>,
2834    session: Session,
2835) -> Result<()> {
2836    let db = session.db().await;
2837    let channel_id = ChannelId::from_proto(request.channel_id);
2838    let RespondToChannelInvite {
2839        membership_update,
2840        notifications,
2841    } = db
2842        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2843        .await?;
2844
2845    let mut connection_pool = session.connection_pool().await;
2846    if let Some(membership_update) = membership_update {
2847        notify_membership_updated(
2848            &mut connection_pool,
2849            membership_update,
2850            session.user_id,
2851            &session.peer,
2852        );
2853    } else {
2854        let update = proto::UpdateChannels {
2855            remove_channel_invitations: vec![channel_id.to_proto()],
2856            ..Default::default()
2857        };
2858
2859        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2860            session.peer.send(connection_id, update.clone())?;
2861        }
2862    };
2863
2864    send_notifications(&connection_pool, &session.peer, notifications);
2865
2866    response.send(proto::Ack {})?;
2867
2868    Ok(())
2869}
2870
2871/// Join the channels' room
2872async fn join_channel(
2873    request: proto::JoinChannel,
2874    response: Response<proto::JoinChannel>,
2875    session: Session,
2876) -> Result<()> {
2877    let channel_id = ChannelId::from_proto(request.channel_id);
2878    join_channel_internal(channel_id, Box::new(response), session).await
2879}
2880
2881trait JoinChannelInternalResponse {
2882    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2883}
2884impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2885    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2886        Response::<proto::JoinChannel>::send(self, result)
2887    }
2888}
2889impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2890    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2891        Response::<proto::JoinRoom>::send(self, result)
2892    }
2893}
2894
2895async fn join_channel_internal(
2896    channel_id: ChannelId,
2897    response: Box<impl JoinChannelInternalResponse>,
2898    session: Session,
2899) -> Result<()> {
2900    let joined_room = {
2901        leave_room_for_session(&session).await?;
2902        let db = session.db().await;
2903
2904        let (joined_room, membership_updated, role) = db
2905            .join_channel(channel_id, session.user_id, session.connection_id)
2906            .await?;
2907
2908        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2909            let (can_publish, token) = if role == ChannelRole::Guest {
2910                (
2911                    false,
2912                    live_kit
2913                        .guest_token(
2914                            &joined_room.room.live_kit_room,
2915                            &session.user_id.to_string(),
2916                        )
2917                        .trace_err()?,
2918                )
2919            } else {
2920                (
2921                    true,
2922                    live_kit
2923                        .room_token(
2924                            &joined_room.room.live_kit_room,
2925                            &session.user_id.to_string(),
2926                        )
2927                        .trace_err()?,
2928                )
2929            };
2930
2931            Some(LiveKitConnectionInfo {
2932                server_url: live_kit.url().into(),
2933                token,
2934                can_publish,
2935            })
2936        });
2937
2938        response.send(proto::JoinRoomResponse {
2939            room: Some(joined_room.room.clone()),
2940            channel_id: joined_room
2941                .channel
2942                .as_ref()
2943                .map(|channel| channel.id.to_proto()),
2944            live_kit_connection_info,
2945        })?;
2946
2947        let mut connection_pool = session.connection_pool().await;
2948        if let Some(membership_updated) = membership_updated {
2949            notify_membership_updated(
2950                &mut connection_pool,
2951                membership_updated,
2952                session.user_id,
2953                &session.peer,
2954            );
2955        }
2956
2957        room_updated(&joined_room.room, &session.peer);
2958
2959        joined_room
2960    };
2961
2962    channel_updated(
2963        &joined_room
2964            .channel
2965            .ok_or_else(|| anyhow!("channel not returned"))?,
2966        &joined_room.room,
2967        &session.peer,
2968        &*session.connection_pool().await,
2969    );
2970
2971    update_user_contacts(session.user_id, &session).await?;
2972    Ok(())
2973}
2974
2975/// Start editing the channel notes
2976async fn join_channel_buffer(
2977    request: proto::JoinChannelBuffer,
2978    response: Response<proto::JoinChannelBuffer>,
2979    session: Session,
2980) -> Result<()> {
2981    let db = session.db().await;
2982    let channel_id = ChannelId::from_proto(request.channel_id);
2983
2984    let open_response = db
2985        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2986        .await?;
2987
2988    let collaborators = open_response.collaborators.clone();
2989    response.send(open_response)?;
2990
2991    let update = UpdateChannelBufferCollaborators {
2992        channel_id: channel_id.to_proto(),
2993        collaborators: collaborators.clone(),
2994    };
2995    channel_buffer_updated(
2996        session.connection_id,
2997        collaborators
2998            .iter()
2999            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3000        &update,
3001        &session.peer,
3002    );
3003
3004    Ok(())
3005}
3006
3007/// Edit the channel notes
3008async fn update_channel_buffer(
3009    request: proto::UpdateChannelBuffer,
3010    session: Session,
3011) -> Result<()> {
3012    let db = session.db().await;
3013    let channel_id = ChannelId::from_proto(request.channel_id);
3014
3015    let (collaborators, non_collaborators, epoch, version) = db
3016        .update_channel_buffer(channel_id, session.user_id, &request.operations)
3017        .await?;
3018
3019    channel_buffer_updated(
3020        session.connection_id,
3021        collaborators,
3022        &proto::UpdateChannelBuffer {
3023            channel_id: channel_id.to_proto(),
3024            operations: request.operations,
3025        },
3026        &session.peer,
3027    );
3028
3029    let pool = &*session.connection_pool().await;
3030
3031    broadcast(
3032        None,
3033        non_collaborators
3034            .iter()
3035            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3036        |peer_id| {
3037            session.peer.send(
3038                peer_id,
3039                proto::UpdateChannels {
3040                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3041                        channel_id: channel_id.to_proto(),
3042                        epoch: epoch as u64,
3043                        version: version.clone(),
3044                    }],
3045                    ..Default::default()
3046                },
3047            )
3048        },
3049    );
3050
3051    Ok(())
3052}
3053
3054/// Rejoin the channel notes after a connection blip
3055async fn rejoin_channel_buffers(
3056    request: proto::RejoinChannelBuffers,
3057    response: Response<proto::RejoinChannelBuffers>,
3058    session: Session,
3059) -> Result<()> {
3060    let db = session.db().await;
3061    let buffers = db
3062        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
3063        .await?;
3064
3065    for rejoined_buffer in &buffers {
3066        let collaborators_to_notify = rejoined_buffer
3067            .buffer
3068            .collaborators
3069            .iter()
3070            .filter_map(|c| Some(c.peer_id?.into()));
3071        channel_buffer_updated(
3072            session.connection_id,
3073            collaborators_to_notify,
3074            &proto::UpdateChannelBufferCollaborators {
3075                channel_id: rejoined_buffer.buffer.channel_id,
3076                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3077            },
3078            &session.peer,
3079        );
3080    }
3081
3082    response.send(proto::RejoinChannelBuffersResponse {
3083        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3084    })?;
3085
3086    Ok(())
3087}
3088
3089/// Stop editing the channel notes
3090async fn leave_channel_buffer(
3091    request: proto::LeaveChannelBuffer,
3092    response: Response<proto::LeaveChannelBuffer>,
3093    session: Session,
3094) -> Result<()> {
3095    let db = session.db().await;
3096    let channel_id = ChannelId::from_proto(request.channel_id);
3097
3098    let left_buffer = db
3099        .leave_channel_buffer(channel_id, session.connection_id)
3100        .await?;
3101
3102    response.send(Ack {})?;
3103
3104    channel_buffer_updated(
3105        session.connection_id,
3106        left_buffer.connections,
3107        &proto::UpdateChannelBufferCollaborators {
3108            channel_id: channel_id.to_proto(),
3109            collaborators: left_buffer.collaborators,
3110        },
3111        &session.peer,
3112    );
3113
3114    Ok(())
3115}
3116
3117fn channel_buffer_updated<T: EnvelopedMessage>(
3118    sender_id: ConnectionId,
3119    collaborators: impl IntoIterator<Item = ConnectionId>,
3120    message: &T,
3121    peer: &Peer,
3122) {
3123    broadcast(Some(sender_id), collaborators, |peer_id| {
3124        peer.send(peer_id, message.clone())
3125    });
3126}
3127
3128fn send_notifications(
3129    connection_pool: &ConnectionPool,
3130    peer: &Peer,
3131    notifications: db::NotificationBatch,
3132) {
3133    for (user_id, notification) in notifications {
3134        for connection_id in connection_pool.user_connection_ids(user_id) {
3135            if let Err(error) = peer.send(
3136                connection_id,
3137                proto::AddNotification {
3138                    notification: Some(notification.clone()),
3139                },
3140            ) {
3141                tracing::error!(
3142                    "failed to send notification to {:?} {}",
3143                    connection_id,
3144                    error
3145                );
3146            }
3147        }
3148    }
3149}
3150
3151/// Send a message to the channel
3152async fn send_channel_message(
3153    request: proto::SendChannelMessage,
3154    response: Response<proto::SendChannelMessage>,
3155    session: Session,
3156) -> Result<()> {
3157    // Validate the message body.
3158    let body = request.body.trim().to_string();
3159    if body.len() > MAX_MESSAGE_LEN {
3160        return Err(anyhow!("message is too long"))?;
3161    }
3162    if body.is_empty() {
3163        return Err(anyhow!("message can't be blank"))?;
3164    }
3165
3166    // TODO: adjust mentions if body is trimmed
3167
3168    let timestamp = OffsetDateTime::now_utc();
3169    let nonce = request
3170        .nonce
3171        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3172
3173    let channel_id = ChannelId::from_proto(request.channel_id);
3174    let CreatedChannelMessage {
3175        message_id,
3176        participant_connection_ids,
3177        channel_members,
3178        notifications,
3179    } = session
3180        .db()
3181        .await
3182        .create_channel_message(
3183            channel_id,
3184            session.user_id,
3185            &body,
3186            &request.mentions,
3187            timestamp,
3188            nonce.clone().into(),
3189            match request.reply_to_message_id {
3190                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3191                None => None,
3192            },
3193        )
3194        .await?;
3195
3196    let message = proto::ChannelMessage {
3197        sender_id: session.user_id.to_proto(),
3198        id: message_id.to_proto(),
3199        body,
3200        mentions: request.mentions,
3201        timestamp: timestamp.unix_timestamp() as u64,
3202        nonce: Some(nonce),
3203        reply_to_message_id: request.reply_to_message_id,
3204        edited_at: None,
3205    };
3206    broadcast(
3207        Some(session.connection_id),
3208        participant_connection_ids,
3209        |connection| {
3210            session.peer.send(
3211                connection,
3212                proto::ChannelMessageSent {
3213                    channel_id: channel_id.to_proto(),
3214                    message: Some(message.clone()),
3215                },
3216            )
3217        },
3218    );
3219    response.send(proto::SendChannelMessageResponse {
3220        message: Some(message),
3221    })?;
3222
3223    let pool = &*session.connection_pool().await;
3224    broadcast(
3225        None,
3226        channel_members
3227            .iter()
3228            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3229        |peer_id| {
3230            session.peer.send(
3231                peer_id,
3232                proto::UpdateChannels {
3233                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3234                        channel_id: channel_id.to_proto(),
3235                        message_id: message_id.to_proto(),
3236                    }],
3237                    ..Default::default()
3238                },
3239            )
3240        },
3241    );
3242    send_notifications(pool, &session.peer, notifications);
3243
3244    Ok(())
3245}
3246
3247/// Delete a channel message
3248async fn remove_channel_message(
3249    request: proto::RemoveChannelMessage,
3250    response: Response<proto::RemoveChannelMessage>,
3251    session: Session,
3252) -> Result<()> {
3253    let channel_id = ChannelId::from_proto(request.channel_id);
3254    let message_id = MessageId::from_proto(request.message_id);
3255    let connection_ids = session
3256        .db()
3257        .await
3258        .remove_channel_message(channel_id, message_id, session.user_id)
3259        .await?;
3260    broadcast(Some(session.connection_id), connection_ids, |connection| {
3261        session.peer.send(connection, request.clone())
3262    });
3263    response.send(proto::Ack {})?;
3264    Ok(())
3265}
3266
3267async fn update_channel_message(
3268    request: proto::UpdateChannelMessage,
3269    response: Response<proto::UpdateChannelMessage>,
3270    session: Session,
3271) -> Result<()> {
3272    let channel_id = ChannelId::from_proto(request.channel_id);
3273    let message_id = MessageId::from_proto(request.message_id);
3274    let updated_at = OffsetDateTime::now_utc();
3275    let UpdatedChannelMessage {
3276        message_id,
3277        participant_connection_ids,
3278        notifications,
3279        reply_to_message_id,
3280        timestamp,
3281    } = session
3282        .db()
3283        .await
3284        .update_channel_message(
3285            channel_id,
3286            message_id,
3287            session.user_id,
3288            request.body.as_str(),
3289            &request.mentions,
3290            updated_at,
3291        )
3292        .await?;
3293
3294    let nonce = request
3295        .nonce
3296        .clone()
3297        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3298
3299    let message = proto::ChannelMessage {
3300        sender_id: session.user_id.to_proto(),
3301        id: message_id.to_proto(),
3302        body: request.body.clone(),
3303        mentions: request.mentions.clone(),
3304        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3305        nonce: Some(nonce),
3306        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3307        edited_at: Some(updated_at.unix_timestamp() as u64),
3308    };
3309
3310    response.send(proto::Ack {})?;
3311
3312    let pool = &*session.connection_pool().await;
3313    broadcast(
3314        Some(session.connection_id),
3315        participant_connection_ids,
3316        |connection| {
3317            session.peer.send(
3318                connection,
3319                proto::ChannelMessageUpdate {
3320                    channel_id: channel_id.to_proto(),
3321                    message: Some(message.clone()),
3322                },
3323            )
3324        },
3325    );
3326
3327    send_notifications(pool, &session.peer, notifications);
3328
3329    Ok(())
3330}
3331
3332/// Mark a channel message as read
3333async fn acknowledge_channel_message(
3334    request: proto::AckChannelMessage,
3335    session: Session,
3336) -> Result<()> {
3337    let channel_id = ChannelId::from_proto(request.channel_id);
3338    let message_id = MessageId::from_proto(request.message_id);
3339    let notifications = session
3340        .db()
3341        .await
3342        .observe_channel_message(channel_id, session.user_id, message_id)
3343        .await?;
3344    send_notifications(
3345        &*session.connection_pool().await,
3346        &session.peer,
3347        notifications,
3348    );
3349    Ok(())
3350}
3351
3352/// Mark a buffer version as synced
3353async fn acknowledge_buffer_version(
3354    request: proto::AckBufferOperation,
3355    session: Session,
3356) -> Result<()> {
3357    let buffer_id = BufferId::from_proto(request.buffer_id);
3358    session
3359        .db()
3360        .await
3361        .observe_buffer_version(
3362            buffer_id,
3363            session.user_id,
3364            request.epoch as i32,
3365            &request.version,
3366        )
3367        .await?;
3368    Ok(())
3369}
3370
3371struct CompleteWithLanguageModelRateLimit;
3372
3373impl RateLimit for CompleteWithLanguageModelRateLimit {
3374    fn capacity() -> usize {
3375        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3376            .ok()
3377            .and_then(|v| v.parse().ok())
3378            .unwrap_or(120) // Picked arbitrarily
3379    }
3380
3381    fn refill_duration() -> chrono::Duration {
3382        chrono::Duration::hours(1)
3383    }
3384
3385    fn db_name() -> &'static str {
3386        "complete-with-language-model"
3387    }
3388}
3389
3390async fn complete_with_language_model(
3391    request: proto::CompleteWithLanguageModel,
3392    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3393    session: Session,
3394    open_ai_api_key: Option<Arc<str>>,
3395    google_ai_api_key: Option<Arc<str>>,
3396) -> Result<()> {
3397    authorize_access_to_language_models(&session).await?;
3398    session
3399        .rate_limiter
3400        .check::<CompleteWithLanguageModelRateLimit>(session.user_id)
3401        .await?;
3402
3403    if request.model.starts_with("gpt") {
3404        let api_key =
3405            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3406        complete_with_open_ai(request, response, session, api_key).await?;
3407    } else if request.model.starts_with("gemini") {
3408        let api_key = google_ai_api_key
3409            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3410        complete_with_google_ai(request, response, session, api_key).await?;
3411    }
3412
3413    Ok(())
3414}
3415
3416async fn complete_with_open_ai(
3417    request: proto::CompleteWithLanguageModel,
3418    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3419    session: Session,
3420    api_key: Arc<str>,
3421) -> Result<()> {
3422    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3423
3424    let mut completion_stream = open_ai::stream_completion(
3425        &session.http_client,
3426        OPEN_AI_API_URL,
3427        &api_key,
3428        crate::ai::language_model_request_to_open_ai(request)?,
3429    )
3430    .await
3431    .context("open_ai::stream_completion request failed")?;
3432
3433    while let Some(event) = completion_stream.next().await {
3434        let event = event?;
3435        response.send(proto::LanguageModelResponse {
3436            choices: event
3437                .choices
3438                .into_iter()
3439                .map(|choice| proto::LanguageModelChoiceDelta {
3440                    index: choice.index,
3441                    delta: Some(proto::LanguageModelResponseMessage {
3442                        role: choice.delta.role.map(|role| match role {
3443                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3444                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3445                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3446                        } as i32),
3447                        content: choice.delta.content,
3448                    }),
3449                    finish_reason: choice.finish_reason,
3450                })
3451                .collect(),
3452        })?;
3453    }
3454
3455    Ok(())
3456}
3457
3458async fn complete_with_google_ai(
3459    request: proto::CompleteWithLanguageModel,
3460    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3461    session: Session,
3462    api_key: Arc<str>,
3463) -> Result<()> {
3464    let mut stream = google_ai::stream_generate_content(
3465        &session.http_client,
3466        google_ai::API_URL,
3467        api_key.as_ref(),
3468        crate::ai::language_model_request_to_google_ai(request)?,
3469    )
3470    .await
3471    .context("google_ai::stream_generate_content request failed")?;
3472
3473    while let Some(event) = stream.next().await {
3474        let event = event?;
3475        response.send(proto::LanguageModelResponse {
3476            choices: event
3477                .candidates
3478                .unwrap_or_default()
3479                .into_iter()
3480                .map(|candidate| proto::LanguageModelChoiceDelta {
3481                    index: candidate.index as u32,
3482                    delta: Some(proto::LanguageModelResponseMessage {
3483                        role: Some(match candidate.content.role {
3484                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3485                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3486                        } as i32),
3487                        content: Some(
3488                            candidate
3489                                .content
3490                                .parts
3491                                .into_iter()
3492                                .filter_map(|part| match part {
3493                                    google_ai::Part::TextPart(part) => Some(part.text),
3494                                    google_ai::Part::InlineDataPart(_) => None,
3495                                })
3496                                .collect(),
3497                        ),
3498                    }),
3499                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3500                })
3501                .collect(),
3502        })?;
3503    }
3504
3505    Ok(())
3506}
3507
3508struct CountTokensWithLanguageModelRateLimit;
3509
3510impl RateLimit for CountTokensWithLanguageModelRateLimit {
3511    fn capacity() -> usize {
3512        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3513            .ok()
3514            .and_then(|v| v.parse().ok())
3515            .unwrap_or(600) // Picked arbitrarily
3516    }
3517
3518    fn refill_duration() -> chrono::Duration {
3519        chrono::Duration::hours(1)
3520    }
3521
3522    fn db_name() -> &'static str {
3523        "count-tokens-with-language-model"
3524    }
3525}
3526
3527async fn count_tokens_with_language_model(
3528    request: proto::CountTokensWithLanguageModel,
3529    response: Response<proto::CountTokensWithLanguageModel>,
3530    session: Session,
3531    google_ai_api_key: Option<Arc<str>>,
3532) -> Result<()> {
3533    authorize_access_to_language_models(&session).await?;
3534
3535    if !request.model.starts_with("gemini") {
3536        return Err(anyhow!(
3537            "counting tokens for model: {:?} is not supported",
3538            request.model
3539        ))?;
3540    }
3541
3542    session
3543        .rate_limiter
3544        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
3545        .await?;
3546
3547    let api_key = google_ai_api_key
3548        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3549    let tokens_response = google_ai::count_tokens(
3550        &session.http_client,
3551        google_ai::API_URL,
3552        &api_key,
3553        crate::ai::count_tokens_request_to_google_ai(request)?,
3554    )
3555    .await?;
3556    response.send(proto::CountTokensResponse {
3557        token_count: tokens_response.total_tokens as u32,
3558    })?;
3559    Ok(())
3560}
3561
3562async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
3563    let db = session.db().await;
3564    let flags = db.get_user_flags(session.user_id).await?;
3565    if flags.iter().any(|flag| flag == "language-models") {
3566        Ok(())
3567    } else {
3568        Err(anyhow!("permission denied"))?
3569    }
3570}
3571
3572/// Start receiving chat updates for a channel
3573async fn join_channel_chat(
3574    request: proto::JoinChannelChat,
3575    response: Response<proto::JoinChannelChat>,
3576    session: Session,
3577) -> Result<()> {
3578    let channel_id = ChannelId::from_proto(request.channel_id);
3579
3580    let db = session.db().await;
3581    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3582        .await?;
3583    let messages = db
3584        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3585        .await?;
3586    response.send(proto::JoinChannelChatResponse {
3587        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3588        messages,
3589    })?;
3590    Ok(())
3591}
3592
3593/// Stop receiving chat updates for a channel
3594async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3595    let channel_id = ChannelId::from_proto(request.channel_id);
3596    session
3597        .db()
3598        .await
3599        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3600        .await?;
3601    Ok(())
3602}
3603
3604/// Retrieve the chat history for a channel
3605async fn get_channel_messages(
3606    request: proto::GetChannelMessages,
3607    response: Response<proto::GetChannelMessages>,
3608    session: Session,
3609) -> Result<()> {
3610    let channel_id = ChannelId::from_proto(request.channel_id);
3611    let messages = session
3612        .db()
3613        .await
3614        .get_channel_messages(
3615            channel_id,
3616            session.user_id,
3617            MESSAGE_COUNT_PER_PAGE,
3618            Some(MessageId::from_proto(request.before_message_id)),
3619        )
3620        .await?;
3621    response.send(proto::GetChannelMessagesResponse {
3622        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3623        messages,
3624    })?;
3625    Ok(())
3626}
3627
3628/// Retrieve specific chat messages
3629async fn get_channel_messages_by_id(
3630    request: proto::GetChannelMessagesById,
3631    response: Response<proto::GetChannelMessagesById>,
3632    session: Session,
3633) -> Result<()> {
3634    let message_ids = request
3635        .message_ids
3636        .iter()
3637        .map(|id| MessageId::from_proto(*id))
3638        .collect::<Vec<_>>();
3639    let messages = session
3640        .db()
3641        .await
3642        .get_channel_messages_by_id(session.user_id, &message_ids)
3643        .await?;
3644    response.send(proto::GetChannelMessagesResponse {
3645        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3646        messages,
3647    })?;
3648    Ok(())
3649}
3650
3651/// Retrieve the current users notifications
3652async fn get_notifications(
3653    request: proto::GetNotifications,
3654    response: Response<proto::GetNotifications>,
3655    session: Session,
3656) -> Result<()> {
3657    let notifications = session
3658        .db()
3659        .await
3660        .get_notifications(
3661            session.user_id,
3662            NOTIFICATION_COUNT_PER_PAGE,
3663            request
3664                .before_id
3665                .map(|id| db::NotificationId::from_proto(id)),
3666        )
3667        .await?;
3668    response.send(proto::GetNotificationsResponse {
3669        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3670        notifications,
3671    })?;
3672    Ok(())
3673}
3674
3675/// Mark notifications as read
3676async fn mark_notification_as_read(
3677    request: proto::MarkNotificationRead,
3678    response: Response<proto::MarkNotificationRead>,
3679    session: Session,
3680) -> Result<()> {
3681    let database = &session.db().await;
3682    let notifications = database
3683        .mark_notification_as_read_by_id(
3684            session.user_id,
3685            NotificationId::from_proto(request.notification_id),
3686        )
3687        .await?;
3688    send_notifications(
3689        &*session.connection_pool().await,
3690        &session.peer,
3691        notifications,
3692    );
3693    response.send(proto::Ack {})?;
3694    Ok(())
3695}
3696
3697/// Get the current users information
3698async fn get_private_user_info(
3699    _request: proto::GetPrivateUserInfo,
3700    response: Response<proto::GetPrivateUserInfo>,
3701    session: Session,
3702) -> Result<()> {
3703    let db = session.db().await;
3704
3705    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3706    let user = db
3707        .get_user_by_id(session.user_id)
3708        .await?
3709        .ok_or_else(|| anyhow!("user not found"))?;
3710    let flags = db.get_user_flags(session.user_id).await?;
3711
3712    response.send(proto::GetPrivateUserInfoResponse {
3713        metrics_id,
3714        staff: user.admin,
3715        flags,
3716    })?;
3717    Ok(())
3718}
3719
3720fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3721    match message {
3722        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3723        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3724        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3725        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3726        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3727            code: frame.code.into(),
3728            reason: frame.reason,
3729        })),
3730    }
3731}
3732
3733fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3734    match message {
3735        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3736        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3737        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3738        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3739        AxumMessage::Close(frame) => {
3740            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3741                code: frame.code.into(),
3742                reason: frame.reason,
3743            }))
3744        }
3745    }
3746}
3747
3748fn notify_membership_updated(
3749    connection_pool: &mut ConnectionPool,
3750    result: MembershipUpdated,
3751    user_id: UserId,
3752    peer: &Peer,
3753) {
3754    for membership in &result.new_channels.channel_memberships {
3755        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3756    }
3757    for channel_id in &result.removed_channels {
3758        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3759    }
3760
3761    let user_channels_update = proto::UpdateUserChannels {
3762        channel_memberships: result
3763            .new_channels
3764            .channel_memberships
3765            .iter()
3766            .map(|cm| proto::ChannelMembership {
3767                channel_id: cm.channel_id.to_proto(),
3768                role: cm.role.into(),
3769            })
3770            .collect(),
3771        ..Default::default()
3772    };
3773
3774    let mut update = build_channels_update(result.new_channels, vec![]);
3775    update.delete_channels = result
3776        .removed_channels
3777        .into_iter()
3778        .map(|id| id.to_proto())
3779        .collect();
3780    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3781
3782    for connection_id in connection_pool.user_connection_ids(user_id) {
3783        peer.send(connection_id, user_channels_update.clone())
3784            .trace_err();
3785        peer.send(connection_id, update.clone()).trace_err();
3786    }
3787}
3788
3789fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3790    proto::UpdateUserChannels {
3791        channel_memberships: channels
3792            .channel_memberships
3793            .iter()
3794            .map(|m| proto::ChannelMembership {
3795                channel_id: m.channel_id.to_proto(),
3796                role: m.role.into(),
3797            })
3798            .collect(),
3799        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3800        observed_channel_message_id: channels.observed_channel_messages.clone(),
3801    }
3802}
3803
3804fn build_channels_update(
3805    channels: ChannelsForUser,
3806    channel_invites: Vec<db::Channel>,
3807) -> proto::UpdateChannels {
3808    let mut update = proto::UpdateChannels::default();
3809
3810    for channel in channels.channels {
3811        update.channels.push(channel.to_proto());
3812    }
3813
3814    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3815    update.latest_channel_message_ids = channels.latest_channel_messages;
3816
3817    for (channel_id, participants) in channels.channel_participants {
3818        update
3819            .channel_participants
3820            .push(proto::ChannelParticipants {
3821                channel_id: channel_id.to_proto(),
3822                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3823            });
3824    }
3825
3826    for channel in channel_invites {
3827        update.channel_invitations.push(channel.to_proto());
3828    }
3829    for project in channels.hosted_projects {
3830        update.hosted_projects.push(project);
3831    }
3832
3833    update
3834}
3835
3836fn build_initial_contacts_update(
3837    contacts: Vec<db::Contact>,
3838    pool: &ConnectionPool,
3839) -> proto::UpdateContacts {
3840    let mut update = proto::UpdateContacts::default();
3841
3842    for contact in contacts {
3843        match contact {
3844            db::Contact::Accepted { user_id, busy } => {
3845                update.contacts.push(contact_for_user(user_id, busy, &pool));
3846            }
3847            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3848            db::Contact::Incoming { user_id } => {
3849                update
3850                    .incoming_requests
3851                    .push(proto::IncomingContactRequest {
3852                        requester_id: user_id.to_proto(),
3853                    })
3854            }
3855        }
3856    }
3857
3858    update
3859}
3860
3861fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3862    proto::Contact {
3863        user_id: user_id.to_proto(),
3864        online: pool.is_user_online(user_id),
3865        busy,
3866    }
3867}
3868
3869fn room_updated(room: &proto::Room, peer: &Peer) {
3870    broadcast(
3871        None,
3872        room.participants
3873            .iter()
3874            .filter_map(|participant| Some(participant.peer_id?.into())),
3875        |peer_id| {
3876            peer.send(
3877                peer_id,
3878                proto::RoomUpdated {
3879                    room: Some(room.clone()),
3880                },
3881            )
3882        },
3883    );
3884}
3885
3886fn channel_updated(
3887    channel: &db::channel::Model,
3888    room: &proto::Room,
3889    peer: &Peer,
3890    pool: &ConnectionPool,
3891) {
3892    let participants = room
3893        .participants
3894        .iter()
3895        .map(|p| p.user_id)
3896        .collect::<Vec<_>>();
3897
3898    broadcast(
3899        None,
3900        pool.channel_connection_ids(channel.root_id())
3901            .filter_map(|(channel_id, role)| {
3902                role.can_see_channel(channel.visibility).then(|| channel_id)
3903            }),
3904        |peer_id| {
3905            peer.send(
3906                peer_id,
3907                proto::UpdateChannels {
3908                    channel_participants: vec![proto::ChannelParticipants {
3909                        channel_id: channel.id.to_proto(),
3910                        participant_user_ids: participants.clone(),
3911                    }],
3912                    ..Default::default()
3913                },
3914            )
3915        },
3916    );
3917}
3918
3919async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3920    let db = session.db().await;
3921
3922    let contacts = db.get_contacts(user_id).await?;
3923    let busy = db.is_user_busy(user_id).await?;
3924
3925    let pool = session.connection_pool().await;
3926    let updated_contact = contact_for_user(user_id, busy, &pool);
3927    for contact in contacts {
3928        if let db::Contact::Accepted {
3929            user_id: contact_user_id,
3930            ..
3931        } = contact
3932        {
3933            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3934                session
3935                    .peer
3936                    .send(
3937                        contact_conn_id,
3938                        proto::UpdateContacts {
3939                            contacts: vec![updated_contact.clone()],
3940                            remove_contacts: Default::default(),
3941                            incoming_requests: Default::default(),
3942                            remove_incoming_requests: Default::default(),
3943                            outgoing_requests: Default::default(),
3944                            remove_outgoing_requests: Default::default(),
3945                        },
3946                    )
3947                    .trace_err();
3948            }
3949        }
3950    }
3951    Ok(())
3952}
3953
3954async fn leave_room_for_session(session: &Session) -> Result<()> {
3955    let mut contacts_to_update = HashSet::default();
3956
3957    let room_id;
3958    let canceled_calls_to_user_ids;
3959    let live_kit_room;
3960    let delete_live_kit_room;
3961    let room;
3962    let channel;
3963
3964    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3965        contacts_to_update.insert(session.user_id);
3966
3967        for project in left_room.left_projects.values() {
3968            project_left(project, session);
3969        }
3970
3971        room_id = RoomId::from_proto(left_room.room.id);
3972        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3973        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3974        delete_live_kit_room = left_room.deleted;
3975        room = mem::take(&mut left_room.room);
3976        channel = mem::take(&mut left_room.channel);
3977
3978        room_updated(&room, &session.peer);
3979    } else {
3980        return Ok(());
3981    }
3982
3983    if let Some(channel) = channel {
3984        channel_updated(
3985            &channel,
3986            &room,
3987            &session.peer,
3988            &*session.connection_pool().await,
3989        );
3990    }
3991
3992    {
3993        let pool = session.connection_pool().await;
3994        for canceled_user_id in canceled_calls_to_user_ids {
3995            for connection_id in pool.user_connection_ids(canceled_user_id) {
3996                session
3997                    .peer
3998                    .send(
3999                        connection_id,
4000                        proto::CallCanceled {
4001                            room_id: room_id.to_proto(),
4002                        },
4003                    )
4004                    .trace_err();
4005            }
4006            contacts_to_update.insert(canceled_user_id);
4007        }
4008    }
4009
4010    for contact_user_id in contacts_to_update {
4011        update_user_contacts(contact_user_id, &session).await?;
4012    }
4013
4014    if let Some(live_kit) = session.live_kit_client.as_ref() {
4015        live_kit
4016            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
4017            .await
4018            .trace_err();
4019
4020        if delete_live_kit_room {
4021            live_kit.delete_room(live_kit_room).await.trace_err();
4022        }
4023    }
4024
4025    Ok(())
4026}
4027
4028async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4029    let left_channel_buffers = session
4030        .db()
4031        .await
4032        .leave_channel_buffers(session.connection_id)
4033        .await?;
4034
4035    for left_buffer in left_channel_buffers {
4036        channel_buffer_updated(
4037            session.connection_id,
4038            left_buffer.connections,
4039            &proto::UpdateChannelBufferCollaborators {
4040                channel_id: left_buffer.channel_id.to_proto(),
4041                collaborators: left_buffer.collaborators,
4042            },
4043            &session.peer,
4044        );
4045    }
4046
4047    Ok(())
4048}
4049
4050fn project_left(project: &db::LeftProject, session: &Session) {
4051    for connection_id in &project.connection_ids {
4052        if project.host_user_id == Some(session.user_id) {
4053            session
4054                .peer
4055                .send(
4056                    *connection_id,
4057                    proto::UnshareProject {
4058                        project_id: project.id.to_proto(),
4059                    },
4060                )
4061                .trace_err();
4062        } else {
4063            session
4064                .peer
4065                .send(
4066                    *connection_id,
4067                    proto::RemoveProjectCollaborator {
4068                        project_id: project.id.to_proto(),
4069                        peer_id: Some(session.connection_id.into()),
4070                    },
4071                )
4072                .trace_err();
4073        }
4074    }
4075}
4076
4077pub trait ResultExt {
4078    type Ok;
4079
4080    fn trace_err(self) -> Option<Self::Ok>;
4081}
4082
4083impl<T, E> ResultExt for Result<T, E>
4084where
4085    E: std::fmt::Debug,
4086{
4087    type Ok = T;
4088
4089    fn trace_err(self) -> Option<T> {
4090        match self {
4091            Ok(value) => Some(value),
4092            Err(error) => {
4093                tracing::error!("{:?}", error);
4094                None
4095            }
4096        }
4097    }
4098}