rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::ConnectionPool;
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use lazy_static::lazy_static;
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67
  68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  70
  71const MESSAGE_COUNT_PER_PAGE: usize = 100;
  72const MAX_MESSAGE_LEN: usize = 1024;
  73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  74
  75lazy_static! {
  76    static ref METRIC_CONNECTIONS: IntGauge =
  77        register_int_gauge!("connections", "number of connections").unwrap();
  78    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  79        "shared_projects",
  80        "number of open projects with one or more guests"
  81    )
  82    .unwrap();
  83}
  84
  85type MessageHandler =
  86    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  87
  88struct Response<R> {
  89    peer: Arc<Peer>,
  90    receipt: Receipt<R>,
  91    responded: Arc<AtomicBool>,
  92}
  93
  94impl<R: RequestMessage> Response<R> {
  95    fn send(self, payload: R::Response) -> Result<()> {
  96        self.responded.store(true, SeqCst);
  97        self.peer.respond(self.receipt, payload)?;
  98        Ok(())
  99    }
 100}
 101
 102#[derive(Clone)]
 103struct Session {
 104    zed_environment: Arc<str>,
 105    user_id: UserId,
 106    connection_id: ConnectionId,
 107    db: Arc<tokio::sync::Mutex<DbHandle>>,
 108    peer: Arc<Peer>,
 109    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 110    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 111    _executor: Executor,
 112}
 113
 114impl Session {
 115    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        let guard = self.db.lock().await;
 119        #[cfg(test)]
 120        tokio::task::yield_now().await;
 121        guard
 122    }
 123
 124    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 125        #[cfg(test)]
 126        tokio::task::yield_now().await;
 127        let guard = self.connection_pool.lock();
 128        ConnectionPoolGuard {
 129            guard,
 130            _not_send: PhantomData,
 131        }
 132    }
 133}
 134
 135impl fmt::Debug for Session {
 136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 137        f.debug_struct("Session")
 138            .field("user_id", &self.user_id)
 139            .field("connection_id", &self.connection_id)
 140            .finish()
 141    }
 142}
 143
 144struct DbHandle(Arc<Database>);
 145
 146impl Deref for DbHandle {
 147    type Target = Database;
 148
 149    fn deref(&self) -> &Self::Target {
 150        self.0.as_ref()
 151    }
 152}
 153
 154pub struct Server {
 155    id: parking_lot::Mutex<ServerId>,
 156    peer: Arc<Peer>,
 157    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 158    app_state: Arc<AppState>,
 159    executor: Executor,
 160    handlers: HashMap<TypeId, MessageHandler>,
 161    teardown: watch::Sender<()>,
 162}
 163
 164pub(crate) struct ConnectionPoolGuard<'a> {
 165    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 166    _not_send: PhantomData<Rc<()>>,
 167}
 168
 169#[derive(Serialize)]
 170pub struct ServerSnapshot<'a> {
 171    peer: &'a Peer,
 172    #[serde(serialize_with = "serialize_deref")]
 173    connection_pool: ConnectionPoolGuard<'a>,
 174}
 175
 176pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 177where
 178    S: Serializer,
 179    T: Deref<Target = U>,
 180    U: Serialize,
 181{
 182    Serialize::serialize(value.deref(), serializer)
 183}
 184
 185impl Server {
 186    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 187        let mut server = Self {
 188            id: parking_lot::Mutex::new(id),
 189            peer: Peer::new(id.0 as u32),
 190            app_state,
 191            executor,
 192            connection_pool: Default::default(),
 193            handlers: Default::default(),
 194            teardown: watch::channel(()).0,
 195        };
 196
 197        server
 198            .add_request_handler(ping)
 199            .add_request_handler(create_room)
 200            .add_request_handler(join_room)
 201            .add_request_handler(rejoin_room)
 202            .add_request_handler(leave_room)
 203            .add_request_handler(set_room_participant_role)
 204            .add_request_handler(call)
 205            .add_request_handler(cancel_call)
 206            .add_message_handler(decline_call)
 207            .add_request_handler(update_participant_location)
 208            .add_request_handler(share_project)
 209            .add_message_handler(unshare_project)
 210            .add_request_handler(join_project)
 211            .add_message_handler(leave_project)
 212            .add_request_handler(update_project)
 213            .add_request_handler(update_worktree)
 214            .add_message_handler(start_language_server)
 215            .add_message_handler(update_language_server)
 216            .add_message_handler(update_diagnostic_summary)
 217            .add_message_handler(update_worktree_settings)
 218            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 219            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 222            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 225            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 226            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 227            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 228            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 229            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 230            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 231            .add_request_handler(
 232                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 233            )
 234            .add_request_handler(
 235                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 236            )
 237            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 238            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 239            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 240            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 241            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 242            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 243            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 244            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 249            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 250            .add_message_handler(create_buffer_for_peer)
 251            .add_request_handler(update_buffer)
 252            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 253            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 257            .add_request_handler(get_users)
 258            .add_request_handler(fuzzy_search_users)
 259            .add_request_handler(request_contact)
 260            .add_request_handler(remove_contact)
 261            .add_request_handler(respond_to_contact_request)
 262            .add_request_handler(create_channel)
 263            .add_request_handler(delete_channel)
 264            .add_request_handler(invite_channel_member)
 265            .add_request_handler(remove_channel_member)
 266            .add_request_handler(set_channel_member_role)
 267            .add_request_handler(set_channel_visibility)
 268            .add_request_handler(rename_channel)
 269            .add_request_handler(join_channel_buffer)
 270            .add_request_handler(leave_channel_buffer)
 271            .add_message_handler(update_channel_buffer)
 272            .add_request_handler(rejoin_channel_buffers)
 273            .add_request_handler(get_channel_members)
 274            .add_request_handler(respond_to_channel_invite)
 275            .add_request_handler(join_channel)
 276            .add_request_handler(join_channel_chat)
 277            .add_message_handler(leave_channel_chat)
 278            .add_request_handler(send_channel_message)
 279            .add_request_handler(remove_channel_message)
 280            .add_request_handler(get_channel_messages)
 281            .add_request_handler(get_channel_messages_by_id)
 282            .add_request_handler(get_notifications)
 283            .add_request_handler(mark_notification_as_read)
 284            .add_request_handler(move_channel)
 285            .add_request_handler(follow)
 286            .add_message_handler(unfollow)
 287            .add_message_handler(update_followers)
 288            .add_request_handler(get_private_user_info)
 289            .add_message_handler(acknowledge_channel_message)
 290            .add_message_handler(acknowledge_buffer_version);
 291
 292        Arc::new(server)
 293    }
 294
 295    pub async fn start(&self) -> Result<()> {
 296        let server_id = *self.id.lock();
 297        let app_state = self.app_state.clone();
 298        let peer = self.peer.clone();
 299        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 300        let pool = self.connection_pool.clone();
 301        let live_kit_client = self.app_state.live_kit_client.clone();
 302
 303        let span = info_span!("start server");
 304        self.executor.spawn_detached(
 305            async move {
 306                tracing::info!("waiting for cleanup timeout");
 307                timeout.await;
 308                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 309                if let Some((room_ids, channel_ids)) = app_state
 310                    .db
 311                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 312                    .await
 313                    .trace_err()
 314                {
 315                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 316                    tracing::info!(
 317                        stale_channel_buffer_count = channel_ids.len(),
 318                        "retrieved stale channel buffers"
 319                    );
 320
 321                    for channel_id in channel_ids {
 322                        if let Some(refreshed_channel_buffer) = app_state
 323                            .db
 324                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 325                            .await
 326                            .trace_err()
 327                        {
 328                            for connection_id in refreshed_channel_buffer.connection_ids {
 329                                peer.send(
 330                                    connection_id,
 331                                    proto::UpdateChannelBufferCollaborators {
 332                                        channel_id: channel_id.to_proto(),
 333                                        collaborators: refreshed_channel_buffer
 334                                            .collaborators
 335                                            .clone(),
 336                                    },
 337                                )
 338                                .trace_err();
 339                            }
 340                        }
 341                    }
 342
 343                    for room_id in room_ids {
 344                        let mut contacts_to_update = HashSet::default();
 345                        let mut canceled_calls_to_user_ids = Vec::new();
 346                        let mut live_kit_room = String::new();
 347                        let mut delete_live_kit_room = false;
 348
 349                        if let Some(mut refreshed_room) = app_state
 350                            .db
 351                            .clear_stale_room_participants(room_id, server_id)
 352                            .await
 353                            .trace_err()
 354                        {
 355                            tracing::info!(
 356                                room_id = room_id.0,
 357                                new_participant_count = refreshed_room.room.participants.len(),
 358                                "refreshed room"
 359                            );
 360                            room_updated(&refreshed_room.room, &peer);
 361                            if let Some(channel_id) = refreshed_room.channel_id {
 362                                channel_updated(
 363                                    channel_id,
 364                                    &refreshed_room.room,
 365                                    &refreshed_room.channel_members,
 366                                    &peer,
 367                                    &*pool.lock(),
 368                                );
 369                            }
 370                            contacts_to_update
 371                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 372                            contacts_to_update
 373                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 374                            canceled_calls_to_user_ids =
 375                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 376                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 377                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 378                        }
 379
 380                        {
 381                            let pool = pool.lock();
 382                            for canceled_user_id in canceled_calls_to_user_ids {
 383                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 384                                    peer.send(
 385                                        connection_id,
 386                                        proto::CallCanceled {
 387                                            room_id: room_id.to_proto(),
 388                                        },
 389                                    )
 390                                    .trace_err();
 391                                }
 392                            }
 393                        }
 394
 395                        for user_id in contacts_to_update {
 396                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 397                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 398                            if let Some((busy, contacts)) = busy.zip(contacts) {
 399                                let pool = pool.lock();
 400                                let updated_contact = contact_for_user(user_id, busy, &pool);
 401                                for contact in contacts {
 402                                    if let db::Contact::Accepted {
 403                                        user_id: contact_user_id,
 404                                        ..
 405                                    } = contact
 406                                    {
 407                                        for contact_conn_id in
 408                                            pool.user_connection_ids(contact_user_id)
 409                                        {
 410                                            peer.send(
 411                                                contact_conn_id,
 412                                                proto::UpdateContacts {
 413                                                    contacts: vec![updated_contact.clone()],
 414                                                    remove_contacts: Default::default(),
 415                                                    incoming_requests: Default::default(),
 416                                                    remove_incoming_requests: Default::default(),
 417                                                    outgoing_requests: Default::default(),
 418                                                    remove_outgoing_requests: Default::default(),
 419                                                },
 420                                            )
 421                                            .trace_err();
 422                                        }
 423                                    }
 424                                }
 425                            }
 426                        }
 427
 428                        if let Some(live_kit) = live_kit_client.as_ref() {
 429                            if delete_live_kit_room {
 430                                live_kit.delete_room(live_kit_room).await.trace_err();
 431                            }
 432                        }
 433                    }
 434                }
 435
 436                app_state
 437                    .db
 438                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 439                    .await
 440                    .trace_err();
 441            }
 442            .instrument(span),
 443        );
 444        Ok(())
 445    }
 446
 447    pub fn teardown(&self) {
 448        self.peer.teardown();
 449        self.connection_pool.lock().reset();
 450        let _ = self.teardown.send(());
 451    }
 452
 453    #[cfg(test)]
 454    pub fn reset(&self, id: ServerId) {
 455        self.teardown();
 456        *self.id.lock() = id;
 457        self.peer.reset(id.0 as u32);
 458    }
 459
 460    #[cfg(test)]
 461    pub fn id(&self) -> ServerId {
 462        *self.id.lock()
 463    }
 464
 465    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 466    where
 467        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 468        Fut: 'static + Send + Future<Output = Result<()>>,
 469        M: EnvelopedMessage,
 470    {
 471        let prev_handler = self.handlers.insert(
 472            TypeId::of::<M>(),
 473            Box::new(move |envelope, session| {
 474                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 475                let span = info_span!(
 476                    "handle message",
 477                    payload_type = envelope.payload_type_name()
 478                );
 479                span.in_scope(|| {
 480                    tracing::info!(
 481                        payload_type = envelope.payload_type_name(),
 482                        "message received"
 483                    );
 484                });
 485                let start_time = Instant::now();
 486                let future = (handler)(*envelope, session);
 487                async move {
 488                    let result = future.await;
 489                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 490                    match result {
 491                        Err(error) => {
 492                            tracing::error!(%error, ?duration_ms, "error handling message")
 493                        }
 494                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 495                    }
 496                }
 497                .instrument(span)
 498                .boxed()
 499            }),
 500        );
 501        if prev_handler.is_some() {
 502            panic!("registered a handler for the same message twice");
 503        }
 504        self
 505    }
 506
 507    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 508    where
 509        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 510        Fut: 'static + Send + Future<Output = Result<()>>,
 511        M: EnvelopedMessage,
 512    {
 513        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 514        self
 515    }
 516
 517    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 518    where
 519        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 520        Fut: Send + Future<Output = Result<()>>,
 521        M: RequestMessage,
 522    {
 523        let handler = Arc::new(handler);
 524        self.add_handler(move |envelope, session| {
 525            let receipt = envelope.receipt();
 526            let handler = handler.clone();
 527            async move {
 528                let peer = session.peer.clone();
 529                let responded = Arc::new(AtomicBool::default());
 530                let response = Response {
 531                    peer: peer.clone(),
 532                    responded: responded.clone(),
 533                    receipt,
 534                };
 535                match (handler)(envelope.payload, response, session).await {
 536                    Ok(()) => {
 537                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 538                            Ok(())
 539                        } else {
 540                            Err(anyhow!("handler did not send a response"))?
 541                        }
 542                    }
 543                    Err(error) => {
 544                        let proto_err = match &error {
 545                            Error::Internal(err) => err.to_proto(),
 546                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 547                        };
 548                        peer.respond_with_error(receipt, proto_err)?;
 549                        Err(error)
 550                    }
 551                }
 552            }
 553        })
 554    }
 555
 556    pub fn handle_connection(
 557        self: &Arc<Self>,
 558        connection: Connection,
 559        address: String,
 560        user: User,
 561        impersonator: Option<User>,
 562        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 563        executor: Executor,
 564    ) -> impl Future<Output = Result<()>> {
 565        let this = self.clone();
 566        let user_id = user.id;
 567        let login = user.github_login;
 568        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 569        if let Some(impersonator) = impersonator {
 570            span.record("impersonator", &impersonator.github_login);
 571        }
 572        let mut teardown = self.teardown.subscribe();
 573        async move {
 574            let (connection_id, handle_io, mut incoming_rx) = this
 575                .peer
 576                .add_connection(connection, {
 577                    let executor = executor.clone();
 578                    move |duration| executor.sleep(duration)
 579                });
 580
 581            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 582            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 583            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 584
 585            if let Some(send_connection_id) = send_connection_id.take() {
 586                let _ = send_connection_id.send(connection_id);
 587            }
 588
 589            if !user.connected_once {
 590                this.peer.send(connection_id, proto::ShowContacts {})?;
 591                this.app_state.db.set_user_connected_once(user_id, true).await?;
 592            }
 593
 594            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 595                this.app_state.db.get_contacts(user_id),
 596                this.app_state.db.get_channels_for_user(user_id),
 597                this.app_state.db.get_channel_invites_for_user(user_id),
 598            ).await?;
 599
 600            {
 601                let mut pool = this.connection_pool.lock();
 602                pool.add_connection(connection_id, user_id, user.admin);
 603                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 604                this.peer.send(connection_id, build_update_user_channels(&channels_for_user.channel_memberships))?;
 605                this.peer.send(connection_id, build_channels_update(
 606                    channels_for_user,
 607                    channel_invites
 608                ))?;
 609            }
 610
 611            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 612                this.peer.send(connection_id, incoming_call)?;
 613            }
 614
 615            let session = Session {
 616                user_id,
 617                connection_id,
 618                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 619                zed_environment: this.app_state.config.zed_environment.clone(),
 620                peer: this.peer.clone(),
 621                connection_pool: this.connection_pool.clone(),
 622                live_kit_client: this.app_state.live_kit_client.clone(),
 623                _executor: executor.clone()
 624            };
 625            update_user_contacts(user_id, &session).await?;
 626
 627            let handle_io = handle_io.fuse();
 628            futures::pin_mut!(handle_io);
 629
 630            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 631            // This prevents deadlocks when e.g., client A performs a request to client B and
 632            // client B performs a request to client A. If both clients stop processing further
 633            // messages until their respective request completes, they won't have a chance to
 634            // respond to the other client's request and cause a deadlock.
 635            //
 636            // This arrangement ensures we will attempt to process earlier messages first, but fall
 637            // back to processing messages arrived later in the spirit of making progress.
 638            let mut foreground_message_handlers = FuturesUnordered::new();
 639            let concurrent_handlers = Arc::new(Semaphore::new(256));
 640            loop {
 641                let next_message = async {
 642                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 643                    let message = incoming_rx.next().await;
 644                    (permit, message)
 645                }.fuse();
 646                futures::pin_mut!(next_message);
 647                futures::select_biased! {
 648                    _ = teardown.changed().fuse() => return Ok(()),
 649                    result = handle_io => {
 650                        if let Err(error) = result {
 651                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 652                        }
 653                        break;
 654                    }
 655                    _ = foreground_message_handlers.next() => {}
 656                    next_message = next_message => {
 657                        let (permit, message) = next_message;
 658                        if let Some(message) = message {
 659                            let type_name = message.payload_type_name();
 660                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 661                            let span_enter = span.enter();
 662                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 663                                let is_background = message.is_background();
 664                                let handle_message = (handler)(message, session.clone());
 665                                drop(span_enter);
 666
 667                                let handle_message = async move {
 668                                    handle_message.await;
 669                                    drop(permit);
 670                                }.instrument(span);
 671                                if is_background {
 672                                    executor.spawn_detached(handle_message);
 673                                } else {
 674                                    foreground_message_handlers.push(handle_message);
 675                                }
 676                            } else {
 677                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 678                            }
 679                        } else {
 680                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 681                            break;
 682                        }
 683                    }
 684                }
 685            }
 686
 687            drop(foreground_message_handlers);
 688            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 689            if let Err(error) = connection_lost(session, teardown, executor).await {
 690                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 691            }
 692
 693            Ok(())
 694        }.instrument(span)
 695    }
 696
 697    pub async fn invite_code_redeemed(
 698        self: &Arc<Self>,
 699        inviter_id: UserId,
 700        invitee_id: UserId,
 701    ) -> Result<()> {
 702        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 703            if let Some(code) = &user.invite_code {
 704                let pool = self.connection_pool.lock();
 705                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 706                for connection_id in pool.user_connection_ids(inviter_id) {
 707                    self.peer.send(
 708                        connection_id,
 709                        proto::UpdateContacts {
 710                            contacts: vec![invitee_contact.clone()],
 711                            ..Default::default()
 712                        },
 713                    )?;
 714                    self.peer.send(
 715                        connection_id,
 716                        proto::UpdateInviteInfo {
 717                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 718                            count: user.invite_count as u32,
 719                        },
 720                    )?;
 721                }
 722            }
 723        }
 724        Ok(())
 725    }
 726
 727    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 728        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 729            if let Some(invite_code) = &user.invite_code {
 730                let pool = self.connection_pool.lock();
 731                for connection_id in pool.user_connection_ids(user_id) {
 732                    self.peer.send(
 733                        connection_id,
 734                        proto::UpdateInviteInfo {
 735                            url: format!(
 736                                "{}{}",
 737                                self.app_state.config.invite_link_prefix, invite_code
 738                            ),
 739                            count: user.invite_count as u32,
 740                        },
 741                    )?;
 742                }
 743            }
 744        }
 745        Ok(())
 746    }
 747
 748    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 749        ServerSnapshot {
 750            connection_pool: ConnectionPoolGuard {
 751                guard: self.connection_pool.lock(),
 752                _not_send: PhantomData,
 753            },
 754            peer: &self.peer,
 755        }
 756    }
 757}
 758
 759impl<'a> Deref for ConnectionPoolGuard<'a> {
 760    type Target = ConnectionPool;
 761
 762    fn deref(&self) -> &Self::Target {
 763        &*self.guard
 764    }
 765}
 766
 767impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 768    fn deref_mut(&mut self) -> &mut Self::Target {
 769        &mut *self.guard
 770    }
 771}
 772
 773impl<'a> Drop for ConnectionPoolGuard<'a> {
 774    fn drop(&mut self) {
 775        #[cfg(test)]
 776        self.check_invariants();
 777    }
 778}
 779
 780fn broadcast<F>(
 781    sender_id: Option<ConnectionId>,
 782    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 783    mut f: F,
 784) where
 785    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 786{
 787    for receiver_id in receiver_ids {
 788        if Some(receiver_id) != sender_id {
 789            if let Err(error) = f(receiver_id) {
 790                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 791            }
 792        }
 793    }
 794}
 795
 796lazy_static! {
 797    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 798}
 799
 800pub struct ProtocolVersion(u32);
 801
 802impl Header for ProtocolVersion {
 803    fn name() -> &'static HeaderName {
 804        &ZED_PROTOCOL_VERSION
 805    }
 806
 807    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 808    where
 809        Self: Sized,
 810        I: Iterator<Item = &'i axum::http::HeaderValue>,
 811    {
 812        let version = values
 813            .next()
 814            .ok_or_else(axum::headers::Error::invalid)?
 815            .to_str()
 816            .map_err(|_| axum::headers::Error::invalid())?
 817            .parse()
 818            .map_err(|_| axum::headers::Error::invalid())?;
 819        Ok(Self(version))
 820    }
 821
 822    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 823        values.extend([self.0.to_string().parse().unwrap()]);
 824    }
 825}
 826
 827pub fn routes(server: Arc<Server>) -> Router<Body> {
 828    Router::new()
 829        .route("/rpc", get(handle_websocket_request))
 830        .layer(
 831            ServiceBuilder::new()
 832                .layer(Extension(server.app_state.clone()))
 833                .layer(middleware::from_fn(auth::validate_header)),
 834        )
 835        .route("/metrics", get(handle_metrics))
 836        .layer(Extension(server))
 837}
 838
 839pub async fn handle_websocket_request(
 840    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 841    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 842    Extension(server): Extension<Arc<Server>>,
 843    Extension(user): Extension<User>,
 844    Extension(impersonator): Extension<Impersonator>,
 845    ws: WebSocketUpgrade,
 846) -> axum::response::Response {
 847    if protocol_version != rpc::PROTOCOL_VERSION {
 848        return (
 849            StatusCode::UPGRADE_REQUIRED,
 850            "client must be upgraded".to_string(),
 851        )
 852            .into_response();
 853    }
 854    let socket_address = socket_address.to_string();
 855    ws.on_upgrade(move |socket| {
 856        use util::ResultExt;
 857        let socket = socket
 858            .map_ok(to_tungstenite_message)
 859            .err_into()
 860            .with(|message| async move { Ok(to_axum_message(message)) });
 861        let connection = Connection::new(Box::pin(socket));
 862        async move {
 863            server
 864                .handle_connection(
 865                    connection,
 866                    socket_address,
 867                    user,
 868                    impersonator.0,
 869                    None,
 870                    Executor::Production,
 871                )
 872                .await
 873                .log_err();
 874        }
 875    })
 876}
 877
 878pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 879    let connections = server
 880        .connection_pool
 881        .lock()
 882        .connections()
 883        .filter(|connection| !connection.admin)
 884        .count();
 885
 886    METRIC_CONNECTIONS.set(connections as _);
 887
 888    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 889    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 890
 891    let encoder = prometheus::TextEncoder::new();
 892    let metric_families = prometheus::gather();
 893    let encoded_metrics = encoder
 894        .encode_to_string(&metric_families)
 895        .map_err(|err| anyhow!("{}", err))?;
 896    Ok(encoded_metrics)
 897}
 898
 899#[instrument(err, skip(executor))]
 900async fn connection_lost(
 901    session: Session,
 902    mut teardown: watch::Receiver<()>,
 903    executor: Executor,
 904) -> Result<()> {
 905    session.peer.disconnect(session.connection_id);
 906    session
 907        .connection_pool()
 908        .await
 909        .remove_connection(session.connection_id)?;
 910
 911    session
 912        .db()
 913        .await
 914        .connection_lost(session.connection_id)
 915        .await
 916        .trace_err();
 917
 918    futures::select_biased! {
 919        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 920            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 921            leave_room_for_session(&session).await.trace_err();
 922            leave_channel_buffers_for_session(&session)
 923                .await
 924                .trace_err();
 925
 926            if !session
 927                .connection_pool()
 928                .await
 929                .is_user_online(session.user_id)
 930            {
 931                let db = session.db().await;
 932                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 933                    room_updated(&room, &session.peer);
 934                }
 935            }
 936
 937            update_user_contacts(session.user_id, &session).await?;
 938        }
 939        _ = teardown.changed().fuse() => {}
 940    }
 941
 942    Ok(())
 943}
 944
 945/// Acknowledges a ping from a client, used to keep the connection alive.
 946async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 947    response.send(proto::Ack {})?;
 948    Ok(())
 949}
 950
 951/// Creates a new room for calling (outside of channels)
 952async fn create_room(
 953    _request: proto::CreateRoom,
 954    response: Response<proto::CreateRoom>,
 955    session: Session,
 956) -> Result<()> {
 957    let live_kit_room = nanoid::nanoid!(30);
 958
 959    let live_kit_connection_info = {
 960        let live_kit_room = live_kit_room.clone();
 961        let live_kit = session.live_kit_client.as_ref();
 962
 963        util::async_maybe!({
 964            let live_kit = live_kit?;
 965
 966            let token = live_kit
 967                .room_token(&live_kit_room, &session.user_id.to_string())
 968                .trace_err()?;
 969
 970            Some(proto::LiveKitConnectionInfo {
 971                server_url: live_kit.url().into(),
 972                token,
 973                can_publish: true,
 974            })
 975        })
 976    }
 977    .await;
 978
 979    let room = session
 980        .db()
 981        .await
 982        .create_room(
 983            session.user_id,
 984            session.connection_id,
 985            &live_kit_room,
 986            &session.zed_environment,
 987        )
 988        .await?;
 989
 990    response.send(proto::CreateRoomResponse {
 991        room: Some(room.clone()),
 992        live_kit_connection_info,
 993    })?;
 994
 995    update_user_contacts(session.user_id, &session).await?;
 996    Ok(())
 997}
 998
 999/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1000async fn join_room(
1001    request: proto::JoinRoom,
1002    response: Response<proto::JoinRoom>,
1003    session: Session,
1004) -> Result<()> {
1005    let room_id = RoomId::from_proto(request.id);
1006
1007    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1008
1009    if let Some(channel_id) = channel_id {
1010        return join_channel_internal(channel_id, Box::new(response), session).await;
1011    }
1012
1013    let joined_room = {
1014        let room = session
1015            .db()
1016            .await
1017            .join_room(
1018                room_id,
1019                session.user_id,
1020                session.connection_id,
1021                session.zed_environment.as_ref(),
1022            )
1023            .await?;
1024        room_updated(&room.room, &session.peer);
1025        room.into_inner()
1026    };
1027
1028    for connection_id in session
1029        .connection_pool()
1030        .await
1031        .user_connection_ids(session.user_id)
1032    {
1033        session
1034            .peer
1035            .send(
1036                connection_id,
1037                proto::CallCanceled {
1038                    room_id: room_id.to_proto(),
1039                },
1040            )
1041            .trace_err();
1042    }
1043
1044    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1045        if let Some(token) = live_kit
1046            .room_token(
1047                &joined_room.room.live_kit_room,
1048                &session.user_id.to_string(),
1049            )
1050            .trace_err()
1051        {
1052            Some(proto::LiveKitConnectionInfo {
1053                server_url: live_kit.url().into(),
1054                token,
1055                can_publish: true,
1056            })
1057        } else {
1058            None
1059        }
1060    } else {
1061        None
1062    };
1063
1064    response.send(proto::JoinRoomResponse {
1065        room: Some(joined_room.room),
1066        channel_id: None,
1067        live_kit_connection_info,
1068    })?;
1069
1070    update_user_contacts(session.user_id, &session).await?;
1071    Ok(())
1072}
1073
1074/// Rejoin room is used to reconnect to a room after connection errors.
1075async fn rejoin_room(
1076    request: proto::RejoinRoom,
1077    response: Response<proto::RejoinRoom>,
1078    session: Session,
1079) -> Result<()> {
1080    let room;
1081    let channel_id;
1082    let channel_members;
1083    {
1084        let mut rejoined_room = session
1085            .db()
1086            .await
1087            .rejoin_room(request, session.user_id, session.connection_id)
1088            .await?;
1089
1090        response.send(proto::RejoinRoomResponse {
1091            room: Some(rejoined_room.room.clone()),
1092            reshared_projects: rejoined_room
1093                .reshared_projects
1094                .iter()
1095                .map(|project| proto::ResharedProject {
1096                    id: project.id.to_proto(),
1097                    collaborators: project
1098                        .collaborators
1099                        .iter()
1100                        .map(|collaborator| collaborator.to_proto())
1101                        .collect(),
1102                })
1103                .collect(),
1104            rejoined_projects: rejoined_room
1105                .rejoined_projects
1106                .iter()
1107                .map(|rejoined_project| proto::RejoinedProject {
1108                    id: rejoined_project.id.to_proto(),
1109                    worktrees: rejoined_project
1110                        .worktrees
1111                        .iter()
1112                        .map(|worktree| proto::WorktreeMetadata {
1113                            id: worktree.id,
1114                            root_name: worktree.root_name.clone(),
1115                            visible: worktree.visible,
1116                            abs_path: worktree.abs_path.clone(),
1117                        })
1118                        .collect(),
1119                    collaborators: rejoined_project
1120                        .collaborators
1121                        .iter()
1122                        .map(|collaborator| collaborator.to_proto())
1123                        .collect(),
1124                    language_servers: rejoined_project.language_servers.clone(),
1125                })
1126                .collect(),
1127        })?;
1128        room_updated(&rejoined_room.room, &session.peer);
1129
1130        for project in &rejoined_room.reshared_projects {
1131            for collaborator in &project.collaborators {
1132                session
1133                    .peer
1134                    .send(
1135                        collaborator.connection_id,
1136                        proto::UpdateProjectCollaborator {
1137                            project_id: project.id.to_proto(),
1138                            old_peer_id: Some(project.old_connection_id.into()),
1139                            new_peer_id: Some(session.connection_id.into()),
1140                        },
1141                    )
1142                    .trace_err();
1143            }
1144
1145            broadcast(
1146                Some(session.connection_id),
1147                project
1148                    .collaborators
1149                    .iter()
1150                    .map(|collaborator| collaborator.connection_id),
1151                |connection_id| {
1152                    session.peer.forward_send(
1153                        session.connection_id,
1154                        connection_id,
1155                        proto::UpdateProject {
1156                            project_id: project.id.to_proto(),
1157                            worktrees: project.worktrees.clone(),
1158                        },
1159                    )
1160                },
1161            );
1162        }
1163
1164        for project in &rejoined_room.rejoined_projects {
1165            for collaborator in &project.collaborators {
1166                session
1167                    .peer
1168                    .send(
1169                        collaborator.connection_id,
1170                        proto::UpdateProjectCollaborator {
1171                            project_id: project.id.to_proto(),
1172                            old_peer_id: Some(project.old_connection_id.into()),
1173                            new_peer_id: Some(session.connection_id.into()),
1174                        },
1175                    )
1176                    .trace_err();
1177            }
1178        }
1179
1180        for project in &mut rejoined_room.rejoined_projects {
1181            for worktree in mem::take(&mut project.worktrees) {
1182                #[cfg(any(test, feature = "test-support"))]
1183                const MAX_CHUNK_SIZE: usize = 2;
1184                #[cfg(not(any(test, feature = "test-support")))]
1185                const MAX_CHUNK_SIZE: usize = 256;
1186
1187                // Stream this worktree's entries.
1188                let message = proto::UpdateWorktree {
1189                    project_id: project.id.to_proto(),
1190                    worktree_id: worktree.id,
1191                    abs_path: worktree.abs_path.clone(),
1192                    root_name: worktree.root_name,
1193                    updated_entries: worktree.updated_entries,
1194                    removed_entries: worktree.removed_entries,
1195                    scan_id: worktree.scan_id,
1196                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1197                    updated_repositories: worktree.updated_repositories,
1198                    removed_repositories: worktree.removed_repositories,
1199                };
1200                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1201                    session.peer.send(session.connection_id, update.clone())?;
1202                }
1203
1204                // Stream this worktree's diagnostics.
1205                for summary in worktree.diagnostic_summaries {
1206                    session.peer.send(
1207                        session.connection_id,
1208                        proto::UpdateDiagnosticSummary {
1209                            project_id: project.id.to_proto(),
1210                            worktree_id: worktree.id,
1211                            summary: Some(summary),
1212                        },
1213                    )?;
1214                }
1215
1216                for settings_file in worktree.settings_files {
1217                    session.peer.send(
1218                        session.connection_id,
1219                        proto::UpdateWorktreeSettings {
1220                            project_id: project.id.to_proto(),
1221                            worktree_id: worktree.id,
1222                            path: settings_file.path,
1223                            content: Some(settings_file.content),
1224                        },
1225                    )?;
1226                }
1227            }
1228
1229            for language_server in &project.language_servers {
1230                session.peer.send(
1231                    session.connection_id,
1232                    proto::UpdateLanguageServer {
1233                        project_id: project.id.to_proto(),
1234                        language_server_id: language_server.id,
1235                        variant: Some(
1236                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1237                                proto::LspDiskBasedDiagnosticsUpdated {},
1238                            ),
1239                        ),
1240                    },
1241                )?;
1242            }
1243        }
1244
1245        let rejoined_room = rejoined_room.into_inner();
1246
1247        room = rejoined_room.room;
1248        channel_id = rejoined_room.channel_id;
1249        channel_members = rejoined_room.channel_members;
1250    }
1251
1252    if let Some(channel_id) = channel_id {
1253        channel_updated(
1254            channel_id,
1255            &room,
1256            &channel_members,
1257            &session.peer,
1258            &*session.connection_pool().await,
1259        );
1260    }
1261
1262    update_user_contacts(session.user_id, &session).await?;
1263    Ok(())
1264}
1265
1266/// leave room disconnects from the room.
1267async fn leave_room(
1268    _: proto::LeaveRoom,
1269    response: Response<proto::LeaveRoom>,
1270    session: Session,
1271) -> Result<()> {
1272    leave_room_for_session(&session).await?;
1273    response.send(proto::Ack {})?;
1274    Ok(())
1275}
1276
1277/// Updates the permissions of someone else in the room.
1278async fn set_room_participant_role(
1279    request: proto::SetRoomParticipantRole,
1280    response: Response<proto::SetRoomParticipantRole>,
1281    session: Session,
1282) -> Result<()> {
1283    let (live_kit_room, can_publish) = {
1284        let room = session
1285            .db()
1286            .await
1287            .set_room_participant_role(
1288                session.user_id,
1289                RoomId::from_proto(request.room_id),
1290                UserId::from_proto(request.user_id),
1291                ChannelRole::from(request.role()),
1292            )
1293            .await?;
1294
1295        let live_kit_room = room.live_kit_room.clone();
1296        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1297        room_updated(&room, &session.peer);
1298        (live_kit_room, can_publish)
1299    };
1300
1301    if let Some(live_kit) = session.live_kit_client.as_ref() {
1302        live_kit
1303            .update_participant(
1304                live_kit_room.clone(),
1305                request.user_id.to_string(),
1306                live_kit_server::proto::ParticipantPermission {
1307                    can_subscribe: true,
1308                    can_publish,
1309                    can_publish_data: can_publish,
1310                    hidden: false,
1311                    recorder: false,
1312                },
1313            )
1314            .await
1315            .trace_err();
1316    }
1317
1318    response.send(proto::Ack {})?;
1319    Ok(())
1320}
1321
1322/// Call someone else into the current room
1323async fn call(
1324    request: proto::Call,
1325    response: Response<proto::Call>,
1326    session: Session,
1327) -> Result<()> {
1328    let room_id = RoomId::from_proto(request.room_id);
1329    let calling_user_id = session.user_id;
1330    let calling_connection_id = session.connection_id;
1331    let called_user_id = UserId::from_proto(request.called_user_id);
1332    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1333    if !session
1334        .db()
1335        .await
1336        .has_contact(calling_user_id, called_user_id)
1337        .await?
1338    {
1339        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1340    }
1341
1342    let incoming_call = {
1343        let (room, incoming_call) = &mut *session
1344            .db()
1345            .await
1346            .call(
1347                room_id,
1348                calling_user_id,
1349                calling_connection_id,
1350                called_user_id,
1351                initial_project_id,
1352            )
1353            .await?;
1354        room_updated(&room, &session.peer);
1355        mem::take(incoming_call)
1356    };
1357    update_user_contacts(called_user_id, &session).await?;
1358
1359    let mut calls = session
1360        .connection_pool()
1361        .await
1362        .user_connection_ids(called_user_id)
1363        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1364        .collect::<FuturesUnordered<_>>();
1365
1366    while let Some(call_response) = calls.next().await {
1367        match call_response.as_ref() {
1368            Ok(_) => {
1369                response.send(proto::Ack {})?;
1370                return Ok(());
1371            }
1372            Err(_) => {
1373                call_response.trace_err();
1374            }
1375        }
1376    }
1377
1378    {
1379        let room = session
1380            .db()
1381            .await
1382            .call_failed(room_id, called_user_id)
1383            .await?;
1384        room_updated(&room, &session.peer);
1385    }
1386    update_user_contacts(called_user_id, &session).await?;
1387
1388    Err(anyhow!("failed to ring user"))?
1389}
1390
1391/// Cancel an outgoing call.
1392async fn cancel_call(
1393    request: proto::CancelCall,
1394    response: Response<proto::CancelCall>,
1395    session: Session,
1396) -> Result<()> {
1397    let called_user_id = UserId::from_proto(request.called_user_id);
1398    let room_id = RoomId::from_proto(request.room_id);
1399    {
1400        let room = session
1401            .db()
1402            .await
1403            .cancel_call(room_id, session.connection_id, called_user_id)
1404            .await?;
1405        room_updated(&room, &session.peer);
1406    }
1407
1408    for connection_id in session
1409        .connection_pool()
1410        .await
1411        .user_connection_ids(called_user_id)
1412    {
1413        session
1414            .peer
1415            .send(
1416                connection_id,
1417                proto::CallCanceled {
1418                    room_id: room_id.to_proto(),
1419                },
1420            )
1421            .trace_err();
1422    }
1423    response.send(proto::Ack {})?;
1424
1425    update_user_contacts(called_user_id, &session).await?;
1426    Ok(())
1427}
1428
1429/// Decline an incoming call.
1430async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1431    let room_id = RoomId::from_proto(message.room_id);
1432    {
1433        let room = session
1434            .db()
1435            .await
1436            .decline_call(Some(room_id), session.user_id)
1437            .await?
1438            .ok_or_else(|| anyhow!("failed to decline call"))?;
1439        room_updated(&room, &session.peer);
1440    }
1441
1442    for connection_id in session
1443        .connection_pool()
1444        .await
1445        .user_connection_ids(session.user_id)
1446    {
1447        session
1448            .peer
1449            .send(
1450                connection_id,
1451                proto::CallCanceled {
1452                    room_id: room_id.to_proto(),
1453                },
1454            )
1455            .trace_err();
1456    }
1457    update_user_contacts(session.user_id, &session).await?;
1458    Ok(())
1459}
1460
1461/// Updates other participants in the room with your current location.
1462async fn update_participant_location(
1463    request: proto::UpdateParticipantLocation,
1464    response: Response<proto::UpdateParticipantLocation>,
1465    session: Session,
1466) -> Result<()> {
1467    let room_id = RoomId::from_proto(request.room_id);
1468    let location = request
1469        .location
1470        .ok_or_else(|| anyhow!("invalid location"))?;
1471
1472    let db = session.db().await;
1473    let room = db
1474        .update_room_participant_location(room_id, session.connection_id, location)
1475        .await?;
1476
1477    room_updated(&room, &session.peer);
1478    response.send(proto::Ack {})?;
1479    Ok(())
1480}
1481
1482/// Share a project into the room.
1483async fn share_project(
1484    request: proto::ShareProject,
1485    response: Response<proto::ShareProject>,
1486    session: Session,
1487) -> Result<()> {
1488    let (project_id, room) = &*session
1489        .db()
1490        .await
1491        .share_project(
1492            RoomId::from_proto(request.room_id),
1493            session.connection_id,
1494            &request.worktrees,
1495        )
1496        .await?;
1497    response.send(proto::ShareProjectResponse {
1498        project_id: project_id.to_proto(),
1499    })?;
1500    room_updated(&room, &session.peer);
1501
1502    Ok(())
1503}
1504
1505/// Unshare a project from the room.
1506async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1507    let project_id = ProjectId::from_proto(message.project_id);
1508
1509    let (room, guest_connection_ids) = &*session
1510        .db()
1511        .await
1512        .unshare_project(project_id, session.connection_id)
1513        .await?;
1514
1515    broadcast(
1516        Some(session.connection_id),
1517        guest_connection_ids.iter().copied(),
1518        |conn_id| session.peer.send(conn_id, message.clone()),
1519    );
1520    room_updated(&room, &session.peer);
1521
1522    Ok(())
1523}
1524
1525/// Join someone elses shared project.
1526async fn join_project(
1527    request: proto::JoinProject,
1528    response: Response<proto::JoinProject>,
1529    session: Session,
1530) -> Result<()> {
1531    let project_id = ProjectId::from_proto(request.project_id);
1532    let guest_user_id = session.user_id;
1533
1534    tracing::info!(%project_id, "join project");
1535
1536    let (project, replica_id) = &mut *session
1537        .db()
1538        .await
1539        .join_project(project_id, session.connection_id)
1540        .await?;
1541
1542    let collaborators = project
1543        .collaborators
1544        .iter()
1545        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1546        .map(|collaborator| collaborator.to_proto())
1547        .collect::<Vec<_>>();
1548
1549    let worktrees = project
1550        .worktrees
1551        .iter()
1552        .map(|(id, worktree)| proto::WorktreeMetadata {
1553            id: *id,
1554            root_name: worktree.root_name.clone(),
1555            visible: worktree.visible,
1556            abs_path: worktree.abs_path.clone(),
1557        })
1558        .collect::<Vec<_>>();
1559
1560    for collaborator in &collaborators {
1561        session
1562            .peer
1563            .send(
1564                collaborator.peer_id.unwrap().into(),
1565                proto::AddProjectCollaborator {
1566                    project_id: project_id.to_proto(),
1567                    collaborator: Some(proto::Collaborator {
1568                        peer_id: Some(session.connection_id.into()),
1569                        replica_id: replica_id.0 as u32,
1570                        user_id: guest_user_id.to_proto(),
1571                    }),
1572                },
1573            )
1574            .trace_err();
1575    }
1576
1577    // First, we send the metadata associated with each worktree.
1578    response.send(proto::JoinProjectResponse {
1579        worktrees: worktrees.clone(),
1580        replica_id: replica_id.0 as u32,
1581        collaborators: collaborators.clone(),
1582        language_servers: project.language_servers.clone(),
1583    })?;
1584
1585    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1586        #[cfg(any(test, feature = "test-support"))]
1587        const MAX_CHUNK_SIZE: usize = 2;
1588        #[cfg(not(any(test, feature = "test-support")))]
1589        const MAX_CHUNK_SIZE: usize = 256;
1590
1591        // Stream this worktree's entries.
1592        let message = proto::UpdateWorktree {
1593            project_id: project_id.to_proto(),
1594            worktree_id,
1595            abs_path: worktree.abs_path.clone(),
1596            root_name: worktree.root_name,
1597            updated_entries: worktree.entries,
1598            removed_entries: Default::default(),
1599            scan_id: worktree.scan_id,
1600            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1601            updated_repositories: worktree.repository_entries.into_values().collect(),
1602            removed_repositories: Default::default(),
1603        };
1604        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1605            session.peer.send(session.connection_id, update.clone())?;
1606        }
1607
1608        // Stream this worktree's diagnostics.
1609        for summary in worktree.diagnostic_summaries {
1610            session.peer.send(
1611                session.connection_id,
1612                proto::UpdateDiagnosticSummary {
1613                    project_id: project_id.to_proto(),
1614                    worktree_id: worktree.id,
1615                    summary: Some(summary),
1616                },
1617            )?;
1618        }
1619
1620        for settings_file in worktree.settings_files {
1621            session.peer.send(
1622                session.connection_id,
1623                proto::UpdateWorktreeSettings {
1624                    project_id: project_id.to_proto(),
1625                    worktree_id: worktree.id,
1626                    path: settings_file.path,
1627                    content: Some(settings_file.content),
1628                },
1629            )?;
1630        }
1631    }
1632
1633    for language_server in &project.language_servers {
1634        session.peer.send(
1635            session.connection_id,
1636            proto::UpdateLanguageServer {
1637                project_id: project_id.to_proto(),
1638                language_server_id: language_server.id,
1639                variant: Some(
1640                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1641                        proto::LspDiskBasedDiagnosticsUpdated {},
1642                    ),
1643                ),
1644            },
1645        )?;
1646    }
1647
1648    Ok(())
1649}
1650
1651/// Leave someone elses shared project.
1652async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1653    let sender_id = session.connection_id;
1654    let project_id = ProjectId::from_proto(request.project_id);
1655
1656    let (room, project) = &*session
1657        .db()
1658        .await
1659        .leave_project(project_id, sender_id)
1660        .await?;
1661    tracing::info!(
1662        %project_id,
1663        host_user_id = %project.host_user_id,
1664        host_connection_id = %project.host_connection_id,
1665        "leave project"
1666    );
1667
1668    project_left(&project, &session);
1669    room_updated(&room, &session.peer);
1670
1671    Ok(())
1672}
1673
1674/// Updates other participants with changes to the project
1675async fn update_project(
1676    request: proto::UpdateProject,
1677    response: Response<proto::UpdateProject>,
1678    session: Session,
1679) -> Result<()> {
1680    let project_id = ProjectId::from_proto(request.project_id);
1681    let (room, guest_connection_ids) = &*session
1682        .db()
1683        .await
1684        .update_project(project_id, session.connection_id, &request.worktrees)
1685        .await?;
1686    broadcast(
1687        Some(session.connection_id),
1688        guest_connection_ids.iter().copied(),
1689        |connection_id| {
1690            session
1691                .peer
1692                .forward_send(session.connection_id, connection_id, request.clone())
1693        },
1694    );
1695    room_updated(&room, &session.peer);
1696    response.send(proto::Ack {})?;
1697
1698    Ok(())
1699}
1700
1701/// Updates other participants with changes to the worktree
1702async fn update_worktree(
1703    request: proto::UpdateWorktree,
1704    response: Response<proto::UpdateWorktree>,
1705    session: Session,
1706) -> Result<()> {
1707    let guest_connection_ids = session
1708        .db()
1709        .await
1710        .update_worktree(&request, session.connection_id)
1711        .await?;
1712
1713    broadcast(
1714        Some(session.connection_id),
1715        guest_connection_ids.iter().copied(),
1716        |connection_id| {
1717            session
1718                .peer
1719                .forward_send(session.connection_id, connection_id, request.clone())
1720        },
1721    );
1722    response.send(proto::Ack {})?;
1723    Ok(())
1724}
1725
1726/// Updates other participants with changes to the diagnostics
1727async fn update_diagnostic_summary(
1728    message: proto::UpdateDiagnosticSummary,
1729    session: Session,
1730) -> Result<()> {
1731    let guest_connection_ids = session
1732        .db()
1733        .await
1734        .update_diagnostic_summary(&message, session.connection_id)
1735        .await?;
1736
1737    broadcast(
1738        Some(session.connection_id),
1739        guest_connection_ids.iter().copied(),
1740        |connection_id| {
1741            session
1742                .peer
1743                .forward_send(session.connection_id, connection_id, message.clone())
1744        },
1745    );
1746
1747    Ok(())
1748}
1749
1750/// Updates other participants with changes to the worktree settings
1751async fn update_worktree_settings(
1752    message: proto::UpdateWorktreeSettings,
1753    session: Session,
1754) -> Result<()> {
1755    let guest_connection_ids = session
1756        .db()
1757        .await
1758        .update_worktree_settings(&message, session.connection_id)
1759        .await?;
1760
1761    broadcast(
1762        Some(session.connection_id),
1763        guest_connection_ids.iter().copied(),
1764        |connection_id| {
1765            session
1766                .peer
1767                .forward_send(session.connection_id, connection_id, message.clone())
1768        },
1769    );
1770
1771    Ok(())
1772}
1773
1774/// Notify other participants that a  language server has started.
1775async fn start_language_server(
1776    request: proto::StartLanguageServer,
1777    session: Session,
1778) -> Result<()> {
1779    let guest_connection_ids = session
1780        .db()
1781        .await
1782        .start_language_server(&request, session.connection_id)
1783        .await?;
1784
1785    broadcast(
1786        Some(session.connection_id),
1787        guest_connection_ids.iter().copied(),
1788        |connection_id| {
1789            session
1790                .peer
1791                .forward_send(session.connection_id, connection_id, request.clone())
1792        },
1793    );
1794    Ok(())
1795}
1796
1797/// Notify other participants that a language server has changed.
1798async fn update_language_server(
1799    request: proto::UpdateLanguageServer,
1800    session: Session,
1801) -> Result<()> {
1802    let project_id = ProjectId::from_proto(request.project_id);
1803    let project_connection_ids = session
1804        .db()
1805        .await
1806        .project_connection_ids(project_id, session.connection_id)
1807        .await?;
1808    broadcast(
1809        Some(session.connection_id),
1810        project_connection_ids.iter().copied(),
1811        |connection_id| {
1812            session
1813                .peer
1814                .forward_send(session.connection_id, connection_id, request.clone())
1815        },
1816    );
1817    Ok(())
1818}
1819
1820/// forward a project request to the host. These requests should be read only
1821/// as guests are allowed to send them.
1822async fn forward_read_only_project_request<T>(
1823    request: T,
1824    response: Response<T>,
1825    session: Session,
1826) -> Result<()>
1827where
1828    T: EntityMessage + RequestMessage,
1829{
1830    let project_id = ProjectId::from_proto(request.remote_entity_id());
1831    let host_connection_id = session
1832        .db()
1833        .await
1834        .host_for_read_only_project_request(project_id, session.connection_id)
1835        .await?;
1836    let payload = session
1837        .peer
1838        .forward_request(session.connection_id, host_connection_id, request)
1839        .await?;
1840    response.send(payload)?;
1841    Ok(())
1842}
1843
1844/// forward a project request to the host. These requests are disallowed
1845/// for guests.
1846async fn forward_mutating_project_request<T>(
1847    request: T,
1848    response: Response<T>,
1849    session: Session,
1850) -> Result<()>
1851where
1852    T: EntityMessage + RequestMessage,
1853{
1854    let project_id = ProjectId::from_proto(request.remote_entity_id());
1855    let host_connection_id = session
1856        .db()
1857        .await
1858        .host_for_mutating_project_request(project_id, session.connection_id)
1859        .await?;
1860    let payload = session
1861        .peer
1862        .forward_request(session.connection_id, host_connection_id, request)
1863        .await?;
1864    response.send(payload)?;
1865    Ok(())
1866}
1867
1868/// Notify other participants that a new buffer has been created
1869async fn create_buffer_for_peer(
1870    request: proto::CreateBufferForPeer,
1871    session: Session,
1872) -> Result<()> {
1873    session
1874        .db()
1875        .await
1876        .check_user_is_project_host(
1877            ProjectId::from_proto(request.project_id),
1878            session.connection_id,
1879        )
1880        .await?;
1881    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1882    session
1883        .peer
1884        .forward_send(session.connection_id, peer_id.into(), request)?;
1885    Ok(())
1886}
1887
1888/// Notify other participants that a buffer has been updated. This is
1889/// allowed for guests as long as the update is limited to selections.
1890async fn update_buffer(
1891    request: proto::UpdateBuffer,
1892    response: Response<proto::UpdateBuffer>,
1893    session: Session,
1894) -> Result<()> {
1895    let project_id = ProjectId::from_proto(request.project_id);
1896    let mut guest_connection_ids;
1897    let mut host_connection_id = None;
1898
1899    let mut requires_write_permission = false;
1900
1901    for op in request.operations.iter() {
1902        match op.variant {
1903            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1904            Some(_) => requires_write_permission = true,
1905        }
1906    }
1907
1908    {
1909        let collaborators = session
1910            .db()
1911            .await
1912            .project_collaborators_for_buffer_update(
1913                project_id,
1914                session.connection_id,
1915                requires_write_permission,
1916            )
1917            .await?;
1918        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1919        for collaborator in collaborators.iter() {
1920            if collaborator.is_host {
1921                host_connection_id = Some(collaborator.connection_id);
1922            } else {
1923                guest_connection_ids.push(collaborator.connection_id);
1924            }
1925        }
1926    }
1927    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1928
1929    broadcast(
1930        Some(session.connection_id),
1931        guest_connection_ids,
1932        |connection_id| {
1933            session
1934                .peer
1935                .forward_send(session.connection_id, connection_id, request.clone())
1936        },
1937    );
1938    if host_connection_id != session.connection_id {
1939        session
1940            .peer
1941            .forward_request(session.connection_id, host_connection_id, request.clone())
1942            .await?;
1943    }
1944
1945    response.send(proto::Ack {})?;
1946    Ok(())
1947}
1948
1949/// Notify other participants that a project has been updated.
1950async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1951    request: T,
1952    session: Session,
1953) -> Result<()> {
1954    let project_id = ProjectId::from_proto(request.remote_entity_id());
1955    let project_connection_ids = session
1956        .db()
1957        .await
1958        .project_connection_ids(project_id, session.connection_id)
1959        .await?;
1960
1961    broadcast(
1962        Some(session.connection_id),
1963        project_connection_ids.iter().copied(),
1964        |connection_id| {
1965            session
1966                .peer
1967                .forward_send(session.connection_id, connection_id, request.clone())
1968        },
1969    );
1970    Ok(())
1971}
1972
1973/// Start following another user in a call.
1974async fn follow(
1975    request: proto::Follow,
1976    response: Response<proto::Follow>,
1977    session: Session,
1978) -> Result<()> {
1979    let room_id = RoomId::from_proto(request.room_id);
1980    let project_id = request.project_id.map(ProjectId::from_proto);
1981    let leader_id = request
1982        .leader_id
1983        .ok_or_else(|| anyhow!("invalid leader id"))?
1984        .into();
1985    let follower_id = session.connection_id;
1986
1987    session
1988        .db()
1989        .await
1990        .check_room_participants(room_id, leader_id, session.connection_id)
1991        .await?;
1992
1993    let response_payload = session
1994        .peer
1995        .forward_request(session.connection_id, leader_id, request)
1996        .await?;
1997    response.send(response_payload)?;
1998
1999    if let Some(project_id) = project_id {
2000        let room = session
2001            .db()
2002            .await
2003            .follow(room_id, project_id, leader_id, follower_id)
2004            .await?;
2005        room_updated(&room, &session.peer);
2006    }
2007
2008    Ok(())
2009}
2010
2011/// Stop following another user in a call.
2012async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2013    let room_id = RoomId::from_proto(request.room_id);
2014    let project_id = request.project_id.map(ProjectId::from_proto);
2015    let leader_id = request
2016        .leader_id
2017        .ok_or_else(|| anyhow!("invalid leader id"))?
2018        .into();
2019    let follower_id = session.connection_id;
2020
2021    session
2022        .db()
2023        .await
2024        .check_room_participants(room_id, leader_id, session.connection_id)
2025        .await?;
2026
2027    session
2028        .peer
2029        .forward_send(session.connection_id, leader_id, request)?;
2030
2031    if let Some(project_id) = project_id {
2032        let room = session
2033            .db()
2034            .await
2035            .unfollow(room_id, project_id, leader_id, follower_id)
2036            .await?;
2037        room_updated(&room, &session.peer);
2038    }
2039
2040    Ok(())
2041}
2042
2043/// Notify everyone following you of your current location.
2044async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2045    let room_id = RoomId::from_proto(request.room_id);
2046    let database = session.db.lock().await;
2047
2048    let connection_ids = if let Some(project_id) = request.project_id {
2049        let project_id = ProjectId::from_proto(project_id);
2050        database
2051            .project_connection_ids(project_id, session.connection_id)
2052            .await?
2053    } else {
2054        database
2055            .room_connection_ids(room_id, session.connection_id)
2056            .await?
2057    };
2058
2059    // For now, don't send view update messages back to that view's current leader.
2060    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2061        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2062        _ => None,
2063    });
2064
2065    for follower_peer_id in request.follower_ids.iter().copied() {
2066        let follower_connection_id = follower_peer_id.into();
2067        if Some(follower_peer_id) != connection_id_to_omit
2068            && connection_ids.contains(&follower_connection_id)
2069        {
2070            session.peer.forward_send(
2071                session.connection_id,
2072                follower_connection_id,
2073                request.clone(),
2074            )?;
2075        }
2076    }
2077    Ok(())
2078}
2079
2080/// Get public data about users.
2081async fn get_users(
2082    request: proto::GetUsers,
2083    response: Response<proto::GetUsers>,
2084    session: Session,
2085) -> Result<()> {
2086    let user_ids = request
2087        .user_ids
2088        .into_iter()
2089        .map(UserId::from_proto)
2090        .collect();
2091    let users = session
2092        .db()
2093        .await
2094        .get_users_by_ids(user_ids)
2095        .await?
2096        .into_iter()
2097        .map(|user| proto::User {
2098            id: user.id.to_proto(),
2099            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2100            github_login: user.github_login,
2101        })
2102        .collect();
2103    response.send(proto::UsersResponse { users })?;
2104    Ok(())
2105}
2106
2107/// Search for users (to invite) buy Github login
2108async fn fuzzy_search_users(
2109    request: proto::FuzzySearchUsers,
2110    response: Response<proto::FuzzySearchUsers>,
2111    session: Session,
2112) -> Result<()> {
2113    let query = request.query;
2114    let users = match query.len() {
2115        0 => vec![],
2116        1 | 2 => session
2117            .db()
2118            .await
2119            .get_user_by_github_login(&query)
2120            .await?
2121            .into_iter()
2122            .collect(),
2123        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2124    };
2125    let users = users
2126        .into_iter()
2127        .filter(|user| user.id != session.user_id)
2128        .map(|user| proto::User {
2129            id: user.id.to_proto(),
2130            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2131            github_login: user.github_login,
2132        })
2133        .collect();
2134    response.send(proto::UsersResponse { users })?;
2135    Ok(())
2136}
2137
2138/// Send a contact request to another user.
2139async fn request_contact(
2140    request: proto::RequestContact,
2141    response: Response<proto::RequestContact>,
2142    session: Session,
2143) -> Result<()> {
2144    let requester_id = session.user_id;
2145    let responder_id = UserId::from_proto(request.responder_id);
2146    if requester_id == responder_id {
2147        return Err(anyhow!("cannot add yourself as a contact"))?;
2148    }
2149
2150    let notifications = session
2151        .db()
2152        .await
2153        .send_contact_request(requester_id, responder_id)
2154        .await?;
2155
2156    // Update outgoing contact requests of requester
2157    let mut update = proto::UpdateContacts::default();
2158    update.outgoing_requests.push(responder_id.to_proto());
2159    for connection_id in session
2160        .connection_pool()
2161        .await
2162        .user_connection_ids(requester_id)
2163    {
2164        session.peer.send(connection_id, update.clone())?;
2165    }
2166
2167    // Update incoming contact requests of responder
2168    let mut update = proto::UpdateContacts::default();
2169    update
2170        .incoming_requests
2171        .push(proto::IncomingContactRequest {
2172            requester_id: requester_id.to_proto(),
2173        });
2174    let connection_pool = session.connection_pool().await;
2175    for connection_id in connection_pool.user_connection_ids(responder_id) {
2176        session.peer.send(connection_id, update.clone())?;
2177    }
2178
2179    send_notifications(&*connection_pool, &session.peer, notifications);
2180
2181    response.send(proto::Ack {})?;
2182    Ok(())
2183}
2184
2185/// Accept or decline a contact request
2186async fn respond_to_contact_request(
2187    request: proto::RespondToContactRequest,
2188    response: Response<proto::RespondToContactRequest>,
2189    session: Session,
2190) -> Result<()> {
2191    let responder_id = session.user_id;
2192    let requester_id = UserId::from_proto(request.requester_id);
2193    let db = session.db().await;
2194    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2195        db.dismiss_contact_notification(responder_id, requester_id)
2196            .await?;
2197    } else {
2198        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2199
2200        let notifications = db
2201            .respond_to_contact_request(responder_id, requester_id, accept)
2202            .await?;
2203        let requester_busy = db.is_user_busy(requester_id).await?;
2204        let responder_busy = db.is_user_busy(responder_id).await?;
2205
2206        let pool = session.connection_pool().await;
2207        // Update responder with new contact
2208        let mut update = proto::UpdateContacts::default();
2209        if accept {
2210            update
2211                .contacts
2212                .push(contact_for_user(requester_id, requester_busy, &pool));
2213        }
2214        update
2215            .remove_incoming_requests
2216            .push(requester_id.to_proto());
2217        for connection_id in pool.user_connection_ids(responder_id) {
2218            session.peer.send(connection_id, update.clone())?;
2219        }
2220
2221        // Update requester with new contact
2222        let mut update = proto::UpdateContacts::default();
2223        if accept {
2224            update
2225                .contacts
2226                .push(contact_for_user(responder_id, responder_busy, &pool));
2227        }
2228        update
2229            .remove_outgoing_requests
2230            .push(responder_id.to_proto());
2231
2232        for connection_id in pool.user_connection_ids(requester_id) {
2233            session.peer.send(connection_id, update.clone())?;
2234        }
2235
2236        send_notifications(&*pool, &session.peer, notifications);
2237    }
2238
2239    response.send(proto::Ack {})?;
2240    Ok(())
2241}
2242
2243/// Remove a contact.
2244async fn remove_contact(
2245    request: proto::RemoveContact,
2246    response: Response<proto::RemoveContact>,
2247    session: Session,
2248) -> Result<()> {
2249    let requester_id = session.user_id;
2250    let responder_id = UserId::from_proto(request.user_id);
2251    let db = session.db().await;
2252    let (contact_accepted, deleted_notification_id) =
2253        db.remove_contact(requester_id, responder_id).await?;
2254
2255    let pool = session.connection_pool().await;
2256    // Update outgoing contact requests of requester
2257    let mut update = proto::UpdateContacts::default();
2258    if contact_accepted {
2259        update.remove_contacts.push(responder_id.to_proto());
2260    } else {
2261        update
2262            .remove_outgoing_requests
2263            .push(responder_id.to_proto());
2264    }
2265    for connection_id in pool.user_connection_ids(requester_id) {
2266        session.peer.send(connection_id, update.clone())?;
2267    }
2268
2269    // Update incoming contact requests of responder
2270    let mut update = proto::UpdateContacts::default();
2271    if contact_accepted {
2272        update.remove_contacts.push(requester_id.to_proto());
2273    } else {
2274        update
2275            .remove_incoming_requests
2276            .push(requester_id.to_proto());
2277    }
2278    for connection_id in pool.user_connection_ids(responder_id) {
2279        session.peer.send(connection_id, update.clone())?;
2280        if let Some(notification_id) = deleted_notification_id {
2281            session.peer.send(
2282                connection_id,
2283                proto::DeleteNotification {
2284                    notification_id: notification_id.to_proto(),
2285                },
2286            )?;
2287        }
2288    }
2289
2290    response.send(proto::Ack {})?;
2291    Ok(())
2292}
2293
2294/// Creates a new channel.
2295async fn create_channel(
2296    request: proto::CreateChannel,
2297    response: Response<proto::CreateChannel>,
2298    session: Session,
2299) -> Result<()> {
2300    let db = session.db().await;
2301
2302    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2303    let (channel, owner, channel_members) = db
2304        .create_channel(&request.name, parent_id, session.user_id)
2305        .await?;
2306
2307    response.send(proto::CreateChannelResponse {
2308        channel: Some(channel.to_proto()),
2309        parent_id: request.parent_id,
2310    })?;
2311
2312    let connection_pool = session.connection_pool().await;
2313    if let Some(owner) = owner {
2314        let update = proto::UpdateUserChannels {
2315            channel_memberships: vec![proto::ChannelMembership {
2316                channel_id: owner.channel_id.to_proto(),
2317                role: owner.role.into(),
2318            }],
2319            ..Default::default()
2320        };
2321        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2322            session.peer.send(connection_id, update.clone())?;
2323        }
2324    }
2325
2326    for channel_member in channel_members {
2327        if !channel_member.role.can_see_channel(channel.visibility) {
2328            continue;
2329        }
2330
2331        let update = proto::UpdateChannels {
2332            channels: vec![channel.to_proto()],
2333            ..Default::default()
2334        };
2335        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2336            session.peer.send(connection_id, update.clone())?;
2337        }
2338    }
2339
2340    Ok(())
2341}
2342
2343/// Delete a channel
2344async fn delete_channel(
2345    request: proto::DeleteChannel,
2346    response: Response<proto::DeleteChannel>,
2347    session: Session,
2348) -> Result<()> {
2349    let db = session.db().await;
2350
2351    let channel_id = request.channel_id;
2352    let (removed_channels, member_ids) = db
2353        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2354        .await?;
2355    response.send(proto::Ack {})?;
2356
2357    // Notify members of removed channels
2358    let mut update = proto::UpdateChannels::default();
2359    update
2360        .delete_channels
2361        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2362
2363    let connection_pool = session.connection_pool().await;
2364    for member_id in member_ids {
2365        for connection_id in connection_pool.user_connection_ids(member_id) {
2366            session.peer.send(connection_id, update.clone())?;
2367        }
2368    }
2369
2370    Ok(())
2371}
2372
2373/// Invite someone to join a channel.
2374async fn invite_channel_member(
2375    request: proto::InviteChannelMember,
2376    response: Response<proto::InviteChannelMember>,
2377    session: Session,
2378) -> Result<()> {
2379    let db = session.db().await;
2380    let channel_id = ChannelId::from_proto(request.channel_id);
2381    let invitee_id = UserId::from_proto(request.user_id);
2382    let InviteMemberResult {
2383        channel,
2384        notifications,
2385    } = db
2386        .invite_channel_member(
2387            channel_id,
2388            invitee_id,
2389            session.user_id,
2390            request.role().into(),
2391        )
2392        .await?;
2393
2394    let update = proto::UpdateChannels {
2395        channel_invitations: vec![channel.to_proto()],
2396        ..Default::default()
2397    };
2398
2399    let connection_pool = session.connection_pool().await;
2400    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2401        session.peer.send(connection_id, update.clone())?;
2402    }
2403
2404    send_notifications(&*connection_pool, &session.peer, notifications);
2405
2406    response.send(proto::Ack {})?;
2407    Ok(())
2408}
2409
2410/// remove someone from a channel
2411async fn remove_channel_member(
2412    request: proto::RemoveChannelMember,
2413    response: Response<proto::RemoveChannelMember>,
2414    session: Session,
2415) -> Result<()> {
2416    let db = session.db().await;
2417    let channel_id = ChannelId::from_proto(request.channel_id);
2418    let member_id = UserId::from_proto(request.user_id);
2419
2420    let RemoveChannelMemberResult {
2421        membership_update,
2422        notification_id,
2423    } = db
2424        .remove_channel_member(channel_id, member_id, session.user_id)
2425        .await?;
2426
2427    let connection_pool = &session.connection_pool().await;
2428    notify_membership_updated(
2429        &connection_pool,
2430        membership_update,
2431        member_id,
2432        &session.peer,
2433    );
2434    for connection_id in connection_pool.user_connection_ids(member_id) {
2435        if let Some(notification_id) = notification_id {
2436            session
2437                .peer
2438                .send(
2439                    connection_id,
2440                    proto::DeleteNotification {
2441                        notification_id: notification_id.to_proto(),
2442                    },
2443                )
2444                .trace_err();
2445        }
2446    }
2447
2448    response.send(proto::Ack {})?;
2449    Ok(())
2450}
2451
2452/// Toggle the channel between public and private.
2453/// Care is taken to maintain the invariant that public channels only descend from public channels,
2454/// (though members-only channels can appear at any point in the hierarchy).
2455async fn set_channel_visibility(
2456    request: proto::SetChannelVisibility,
2457    response: Response<proto::SetChannelVisibility>,
2458    session: Session,
2459) -> Result<()> {
2460    let db = session.db().await;
2461    let channel_id = ChannelId::from_proto(request.channel_id);
2462    let visibility = request.visibility().into();
2463
2464    let (channel, channel_members) = db
2465        .set_channel_visibility(channel_id, visibility, session.user_id)
2466        .await?;
2467
2468    let connection_pool = session.connection_pool().await;
2469    for member in channel_members {
2470        let update = if member.role.can_see_channel(channel.visibility) {
2471            proto::UpdateChannels {
2472                channels: vec![channel.to_proto()],
2473                ..Default::default()
2474            }
2475        } else {
2476            proto::UpdateChannels {
2477                delete_channels: vec![channel.id.to_proto()],
2478                ..Default::default()
2479            }
2480        };
2481
2482        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2483            session.peer.send(connection_id, update.clone())?;
2484        }
2485    }
2486
2487    response.send(proto::Ack {})?;
2488    Ok(())
2489}
2490
2491/// Alter the role for a user in the channel.
2492async fn set_channel_member_role(
2493    request: proto::SetChannelMemberRole,
2494    response: Response<proto::SetChannelMemberRole>,
2495    session: Session,
2496) -> Result<()> {
2497    let db = session.db().await;
2498    let channel_id = ChannelId::from_proto(request.channel_id);
2499    let member_id = UserId::from_proto(request.user_id);
2500    let result = db
2501        .set_channel_member_role(
2502            channel_id,
2503            session.user_id,
2504            member_id,
2505            request.role().into(),
2506        )
2507        .await?;
2508
2509    match result {
2510        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2511            let connection_pool = session.connection_pool().await;
2512            notify_membership_updated(
2513                &connection_pool,
2514                membership_update,
2515                member_id,
2516                &session.peer,
2517            )
2518        }
2519        db::SetMemberRoleResult::InviteUpdated(channel) => {
2520            let update = proto::UpdateChannels {
2521                channel_invitations: vec![channel.to_proto()],
2522                ..Default::default()
2523            };
2524
2525            for connection_id in session
2526                .connection_pool()
2527                .await
2528                .user_connection_ids(member_id)
2529            {
2530                session.peer.send(connection_id, update.clone())?;
2531            }
2532        }
2533    }
2534
2535    response.send(proto::Ack {})?;
2536    Ok(())
2537}
2538
2539/// Change the name of a channel
2540async fn rename_channel(
2541    request: proto::RenameChannel,
2542    response: Response<proto::RenameChannel>,
2543    session: Session,
2544) -> Result<()> {
2545    let db = session.db().await;
2546    let channel_id = ChannelId::from_proto(request.channel_id);
2547    let (channel, channel_members) = db
2548        .rename_channel(channel_id, session.user_id, &request.name)
2549        .await?;
2550
2551    response.send(proto::RenameChannelResponse {
2552        channel: Some(channel.to_proto()),
2553    })?;
2554
2555    let connection_pool = session.connection_pool().await;
2556    for channel_member in channel_members {
2557        if !channel_member.role.can_see_channel(channel.visibility) {
2558            continue;
2559        }
2560        let update = proto::UpdateChannels {
2561            channels: vec![channel.to_proto()],
2562            ..Default::default()
2563        };
2564        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2565            session.peer.send(connection_id, update.clone())?;
2566        }
2567    }
2568
2569    Ok(())
2570}
2571
2572/// Move a channel to a new parent.
2573async fn move_channel(
2574    request: proto::MoveChannel,
2575    response: Response<proto::MoveChannel>,
2576    session: Session,
2577) -> Result<()> {
2578    let channel_id = ChannelId::from_proto(request.channel_id);
2579    let to = ChannelId::from_proto(request.to);
2580
2581    let (channels, channel_members) = session
2582        .db()
2583        .await
2584        .move_channel(channel_id, to, session.user_id)
2585        .await?;
2586
2587    let connection_pool = session.connection_pool().await;
2588    for member in channel_members {
2589        let channels = channels
2590            .iter()
2591            .filter_map(|channel| {
2592                if member.role.can_see_channel(channel.visibility) {
2593                    Some(channel.to_proto())
2594                } else {
2595                    None
2596                }
2597            })
2598            .collect::<Vec<_>>();
2599        if channels.is_empty() {
2600            continue;
2601        }
2602
2603        let update = proto::UpdateChannels {
2604            channels,
2605            ..Default::default()
2606        };
2607
2608        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2609            session.peer.send(connection_id, update.clone())?;
2610        }
2611    }
2612
2613    response.send(Ack {})?;
2614    Ok(())
2615}
2616
2617/// Get the list of channel members
2618async fn get_channel_members(
2619    request: proto::GetChannelMembers,
2620    response: Response<proto::GetChannelMembers>,
2621    session: Session,
2622) -> Result<()> {
2623    let db = session.db().await;
2624    let channel_id = ChannelId::from_proto(request.channel_id);
2625    let members = db
2626        .get_channel_participant_details(channel_id, session.user_id)
2627        .await?;
2628    response.send(proto::GetChannelMembersResponse { members })?;
2629    Ok(())
2630}
2631
2632/// Accept or decline a channel invitation.
2633async fn respond_to_channel_invite(
2634    request: proto::RespondToChannelInvite,
2635    response: Response<proto::RespondToChannelInvite>,
2636    session: Session,
2637) -> Result<()> {
2638    let db = session.db().await;
2639    let channel_id = ChannelId::from_proto(request.channel_id);
2640    let RespondToChannelInvite {
2641        membership_update,
2642        notifications,
2643    } = db
2644        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2645        .await?;
2646
2647    let connection_pool = session.connection_pool().await;
2648    if let Some(membership_update) = membership_update {
2649        notify_membership_updated(
2650            &connection_pool,
2651            membership_update,
2652            session.user_id,
2653            &session.peer,
2654        );
2655    } else {
2656        let update = proto::UpdateChannels {
2657            remove_channel_invitations: vec![channel_id.to_proto()],
2658            ..Default::default()
2659        };
2660
2661        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2662            session.peer.send(connection_id, update.clone())?;
2663        }
2664    };
2665
2666    send_notifications(&*connection_pool, &session.peer, notifications);
2667
2668    response.send(proto::Ack {})?;
2669
2670    Ok(())
2671}
2672
2673/// Join the channels' room
2674async fn join_channel(
2675    request: proto::JoinChannel,
2676    response: Response<proto::JoinChannel>,
2677    session: Session,
2678) -> Result<()> {
2679    let channel_id = ChannelId::from_proto(request.channel_id);
2680    join_channel_internal(channel_id, Box::new(response), session).await
2681}
2682
2683trait JoinChannelInternalResponse {
2684    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2685}
2686impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2687    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2688        Response::<proto::JoinChannel>::send(self, result)
2689    }
2690}
2691impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2692    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2693        Response::<proto::JoinRoom>::send(self, result)
2694    }
2695}
2696
2697async fn join_channel_internal(
2698    channel_id: ChannelId,
2699    response: Box<impl JoinChannelInternalResponse>,
2700    session: Session,
2701) -> Result<()> {
2702    let joined_room = {
2703        leave_room_for_session(&session).await?;
2704        let db = session.db().await;
2705
2706        let (joined_room, membership_updated, role) = db
2707            .join_channel(
2708                channel_id,
2709                session.user_id,
2710                session.connection_id,
2711                session.zed_environment.as_ref(),
2712            )
2713            .await?;
2714
2715        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2716            let (can_publish, token) = if role == ChannelRole::Guest {
2717                (
2718                    false,
2719                    live_kit
2720                        .guest_token(
2721                            &joined_room.room.live_kit_room,
2722                            &session.user_id.to_string(),
2723                        )
2724                        .trace_err()?,
2725                )
2726            } else {
2727                (
2728                    true,
2729                    live_kit
2730                        .room_token(
2731                            &joined_room.room.live_kit_room,
2732                            &session.user_id.to_string(),
2733                        )
2734                        .trace_err()?,
2735                )
2736            };
2737
2738            Some(LiveKitConnectionInfo {
2739                server_url: live_kit.url().into(),
2740                token,
2741                can_publish,
2742            })
2743        });
2744
2745        response.send(proto::JoinRoomResponse {
2746            room: Some(joined_room.room.clone()),
2747            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2748            live_kit_connection_info,
2749        })?;
2750
2751        let connection_pool = session.connection_pool().await;
2752        if let Some(membership_updated) = membership_updated {
2753            notify_membership_updated(
2754                &connection_pool,
2755                membership_updated,
2756                session.user_id,
2757                &session.peer,
2758            );
2759        }
2760
2761        room_updated(&joined_room.room, &session.peer);
2762
2763        joined_room
2764    };
2765
2766    channel_updated(
2767        channel_id,
2768        &joined_room.room,
2769        &joined_room.channel_members,
2770        &session.peer,
2771        &*session.connection_pool().await,
2772    );
2773
2774    update_user_contacts(session.user_id, &session).await?;
2775    Ok(())
2776}
2777
2778/// Start editing the channel notes
2779async fn join_channel_buffer(
2780    request: proto::JoinChannelBuffer,
2781    response: Response<proto::JoinChannelBuffer>,
2782    session: Session,
2783) -> Result<()> {
2784    let db = session.db().await;
2785    let channel_id = ChannelId::from_proto(request.channel_id);
2786
2787    let open_response = db
2788        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2789        .await?;
2790
2791    let collaborators = open_response.collaborators.clone();
2792    response.send(open_response)?;
2793
2794    let update = UpdateChannelBufferCollaborators {
2795        channel_id: channel_id.to_proto(),
2796        collaborators: collaborators.clone(),
2797    };
2798    channel_buffer_updated(
2799        session.connection_id,
2800        collaborators
2801            .iter()
2802            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2803        &update,
2804        &session.peer,
2805    );
2806
2807    Ok(())
2808}
2809
2810/// Edit the channel notes
2811async fn update_channel_buffer(
2812    request: proto::UpdateChannelBuffer,
2813    session: Session,
2814) -> Result<()> {
2815    let db = session.db().await;
2816    let channel_id = ChannelId::from_proto(request.channel_id);
2817
2818    let (collaborators, non_collaborators, epoch, version) = db
2819        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2820        .await?;
2821
2822    channel_buffer_updated(
2823        session.connection_id,
2824        collaborators,
2825        &proto::UpdateChannelBuffer {
2826            channel_id: channel_id.to_proto(),
2827            operations: request.operations,
2828        },
2829        &session.peer,
2830    );
2831
2832    let pool = &*session.connection_pool().await;
2833
2834    broadcast(
2835        None,
2836        non_collaborators
2837            .iter()
2838            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2839        |peer_id| {
2840            session.peer.send(
2841                peer_id.into(),
2842                proto::UpdateChannels {
2843                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2844                        channel_id: channel_id.to_proto(),
2845                        epoch: epoch as u64,
2846                        version: version.clone(),
2847                    }],
2848                    ..Default::default()
2849                },
2850            )
2851        },
2852    );
2853
2854    Ok(())
2855}
2856
2857/// Rejoin the channel notes after a connection blip
2858async fn rejoin_channel_buffers(
2859    request: proto::RejoinChannelBuffers,
2860    response: Response<proto::RejoinChannelBuffers>,
2861    session: Session,
2862) -> Result<()> {
2863    let db = session.db().await;
2864    let buffers = db
2865        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2866        .await?;
2867
2868    for rejoined_buffer in &buffers {
2869        let collaborators_to_notify = rejoined_buffer
2870            .buffer
2871            .collaborators
2872            .iter()
2873            .filter_map(|c| Some(c.peer_id?.into()));
2874        channel_buffer_updated(
2875            session.connection_id,
2876            collaborators_to_notify,
2877            &proto::UpdateChannelBufferCollaborators {
2878                channel_id: rejoined_buffer.buffer.channel_id,
2879                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2880            },
2881            &session.peer,
2882        );
2883    }
2884
2885    response.send(proto::RejoinChannelBuffersResponse {
2886        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2887    })?;
2888
2889    Ok(())
2890}
2891
2892/// Stop editing the channel notes
2893async fn leave_channel_buffer(
2894    request: proto::LeaveChannelBuffer,
2895    response: Response<proto::LeaveChannelBuffer>,
2896    session: Session,
2897) -> Result<()> {
2898    let db = session.db().await;
2899    let channel_id = ChannelId::from_proto(request.channel_id);
2900
2901    let left_buffer = db
2902        .leave_channel_buffer(channel_id, session.connection_id)
2903        .await?;
2904
2905    response.send(Ack {})?;
2906
2907    channel_buffer_updated(
2908        session.connection_id,
2909        left_buffer.connections,
2910        &proto::UpdateChannelBufferCollaborators {
2911            channel_id: channel_id.to_proto(),
2912            collaborators: left_buffer.collaborators,
2913        },
2914        &session.peer,
2915    );
2916
2917    Ok(())
2918}
2919
2920fn channel_buffer_updated<T: EnvelopedMessage>(
2921    sender_id: ConnectionId,
2922    collaborators: impl IntoIterator<Item = ConnectionId>,
2923    message: &T,
2924    peer: &Peer,
2925) {
2926    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2927        peer.send(peer_id.into(), message.clone())
2928    });
2929}
2930
2931fn send_notifications(
2932    connection_pool: &ConnectionPool,
2933    peer: &Peer,
2934    notifications: db::NotificationBatch,
2935) {
2936    for (user_id, notification) in notifications {
2937        for connection_id in connection_pool.user_connection_ids(user_id) {
2938            if let Err(error) = peer.send(
2939                connection_id,
2940                proto::AddNotification {
2941                    notification: Some(notification.clone()),
2942                },
2943            ) {
2944                tracing::error!(
2945                    "failed to send notification to {:?} {}",
2946                    connection_id,
2947                    error
2948                );
2949            }
2950        }
2951    }
2952}
2953
2954/// Send a message to the channel
2955async fn send_channel_message(
2956    request: proto::SendChannelMessage,
2957    response: Response<proto::SendChannelMessage>,
2958    session: Session,
2959) -> Result<()> {
2960    // Validate the message body.
2961    let body = request.body.trim().to_string();
2962    if body.len() > MAX_MESSAGE_LEN {
2963        return Err(anyhow!("message is too long"))?;
2964    }
2965    if body.is_empty() {
2966        return Err(anyhow!("message can't be blank"))?;
2967    }
2968
2969    // TODO: adjust mentions if body is trimmed
2970
2971    let timestamp = OffsetDateTime::now_utc();
2972    let nonce = request
2973        .nonce
2974        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2975
2976    let channel_id = ChannelId::from_proto(request.channel_id);
2977    let CreatedChannelMessage {
2978        message_id,
2979        participant_connection_ids,
2980        channel_members,
2981        notifications,
2982    } = session
2983        .db()
2984        .await
2985        .create_channel_message(
2986            channel_id,
2987            session.user_id,
2988            &body,
2989            &request.mentions,
2990            timestamp,
2991            nonce.clone().into(),
2992        )
2993        .await?;
2994    let message = proto::ChannelMessage {
2995        sender_id: session.user_id.to_proto(),
2996        id: message_id.to_proto(),
2997        body,
2998        mentions: request.mentions,
2999        timestamp: timestamp.unix_timestamp() as u64,
3000        nonce: Some(nonce),
3001    };
3002    broadcast(
3003        Some(session.connection_id),
3004        participant_connection_ids,
3005        |connection| {
3006            session.peer.send(
3007                connection,
3008                proto::ChannelMessageSent {
3009                    channel_id: channel_id.to_proto(),
3010                    message: Some(message.clone()),
3011                },
3012            )
3013        },
3014    );
3015    response.send(proto::SendChannelMessageResponse {
3016        message: Some(message),
3017    })?;
3018
3019    let pool = &*session.connection_pool().await;
3020    broadcast(
3021        None,
3022        channel_members
3023            .iter()
3024            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3025        |peer_id| {
3026            session.peer.send(
3027                peer_id.into(),
3028                proto::UpdateChannels {
3029                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3030                        channel_id: channel_id.to_proto(),
3031                        message_id: message_id.to_proto(),
3032                    }],
3033                    ..Default::default()
3034                },
3035            )
3036        },
3037    );
3038    send_notifications(pool, &session.peer, notifications);
3039
3040    Ok(())
3041}
3042
3043/// Delete a channel message
3044async fn remove_channel_message(
3045    request: proto::RemoveChannelMessage,
3046    response: Response<proto::RemoveChannelMessage>,
3047    session: Session,
3048) -> Result<()> {
3049    let channel_id = ChannelId::from_proto(request.channel_id);
3050    let message_id = MessageId::from_proto(request.message_id);
3051    let connection_ids = session
3052        .db()
3053        .await
3054        .remove_channel_message(channel_id, message_id, session.user_id)
3055        .await?;
3056    broadcast(Some(session.connection_id), connection_ids, |connection| {
3057        session.peer.send(connection, request.clone())
3058    });
3059    response.send(proto::Ack {})?;
3060    Ok(())
3061}
3062
3063/// Mark a channel message as read
3064async fn acknowledge_channel_message(
3065    request: proto::AckChannelMessage,
3066    session: Session,
3067) -> Result<()> {
3068    let channel_id = ChannelId::from_proto(request.channel_id);
3069    let message_id = MessageId::from_proto(request.message_id);
3070    let notifications = session
3071        .db()
3072        .await
3073        .observe_channel_message(channel_id, session.user_id, message_id)
3074        .await?;
3075    send_notifications(
3076        &*session.connection_pool().await,
3077        &session.peer,
3078        notifications,
3079    );
3080    Ok(())
3081}
3082
3083/// Mark a buffer version as synced
3084async fn acknowledge_buffer_version(
3085    request: proto::AckBufferOperation,
3086    session: Session,
3087) -> Result<()> {
3088    let buffer_id = BufferId::from_proto(request.buffer_id);
3089    session
3090        .db()
3091        .await
3092        .observe_buffer_version(
3093            buffer_id,
3094            session.user_id,
3095            request.epoch as i32,
3096            &request.version,
3097        )
3098        .await?;
3099    Ok(())
3100}
3101
3102/// Start receiving chat updates for a channel
3103async fn join_channel_chat(
3104    request: proto::JoinChannelChat,
3105    response: Response<proto::JoinChannelChat>,
3106    session: Session,
3107) -> Result<()> {
3108    let channel_id = ChannelId::from_proto(request.channel_id);
3109
3110    let db = session.db().await;
3111    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3112        .await?;
3113    let messages = db
3114        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3115        .await?;
3116    response.send(proto::JoinChannelChatResponse {
3117        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3118        messages,
3119    })?;
3120    Ok(())
3121}
3122
3123/// Stop receiving chat updates for a channel
3124async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3125    let channel_id = ChannelId::from_proto(request.channel_id);
3126    session
3127        .db()
3128        .await
3129        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3130        .await?;
3131    Ok(())
3132}
3133
3134/// Retrieve the chat history for a channel
3135async fn get_channel_messages(
3136    request: proto::GetChannelMessages,
3137    response: Response<proto::GetChannelMessages>,
3138    session: Session,
3139) -> Result<()> {
3140    let channel_id = ChannelId::from_proto(request.channel_id);
3141    let messages = session
3142        .db()
3143        .await
3144        .get_channel_messages(
3145            channel_id,
3146            session.user_id,
3147            MESSAGE_COUNT_PER_PAGE,
3148            Some(MessageId::from_proto(request.before_message_id)),
3149        )
3150        .await?;
3151    response.send(proto::GetChannelMessagesResponse {
3152        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3153        messages,
3154    })?;
3155    Ok(())
3156}
3157
3158/// Retrieve specific chat messages
3159async fn get_channel_messages_by_id(
3160    request: proto::GetChannelMessagesById,
3161    response: Response<proto::GetChannelMessagesById>,
3162    session: Session,
3163) -> Result<()> {
3164    let message_ids = request
3165        .message_ids
3166        .iter()
3167        .map(|id| MessageId::from_proto(*id))
3168        .collect::<Vec<_>>();
3169    let messages = session
3170        .db()
3171        .await
3172        .get_channel_messages_by_id(session.user_id, &message_ids)
3173        .await?;
3174    response.send(proto::GetChannelMessagesResponse {
3175        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3176        messages,
3177    })?;
3178    Ok(())
3179}
3180
3181/// Retrieve the current users notifications
3182async fn get_notifications(
3183    request: proto::GetNotifications,
3184    response: Response<proto::GetNotifications>,
3185    session: Session,
3186) -> Result<()> {
3187    let notifications = session
3188        .db()
3189        .await
3190        .get_notifications(
3191            session.user_id,
3192            NOTIFICATION_COUNT_PER_PAGE,
3193            request
3194                .before_id
3195                .map(|id| db::NotificationId::from_proto(id)),
3196        )
3197        .await?;
3198    response.send(proto::GetNotificationsResponse {
3199        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3200        notifications,
3201    })?;
3202    Ok(())
3203}
3204
3205/// Mark notifications as read
3206async fn mark_notification_as_read(
3207    request: proto::MarkNotificationRead,
3208    response: Response<proto::MarkNotificationRead>,
3209    session: Session,
3210) -> Result<()> {
3211    let database = &session.db().await;
3212    let notifications = database
3213        .mark_notification_as_read_by_id(
3214            session.user_id,
3215            NotificationId::from_proto(request.notification_id),
3216        )
3217        .await?;
3218    send_notifications(
3219        &*session.connection_pool().await,
3220        &session.peer,
3221        notifications,
3222    );
3223    response.send(proto::Ack {})?;
3224    Ok(())
3225}
3226
3227/// Get the current users information
3228async fn get_private_user_info(
3229    _request: proto::GetPrivateUserInfo,
3230    response: Response<proto::GetPrivateUserInfo>,
3231    session: Session,
3232) -> Result<()> {
3233    let db = session.db().await;
3234
3235    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3236    let user = db
3237        .get_user_by_id(session.user_id)
3238        .await?
3239        .ok_or_else(|| anyhow!("user not found"))?;
3240    let flags = db.get_user_flags(session.user_id).await?;
3241
3242    response.send(proto::GetPrivateUserInfoResponse {
3243        metrics_id,
3244        staff: user.admin,
3245        flags,
3246    })?;
3247    Ok(())
3248}
3249
3250fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3251    match message {
3252        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3253        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3254        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3255        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3256        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3257            code: frame.code.into(),
3258            reason: frame.reason,
3259        })),
3260    }
3261}
3262
3263fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3264    match message {
3265        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3266        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3267        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3268        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3269        AxumMessage::Close(frame) => {
3270            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3271                code: frame.code.into(),
3272                reason: frame.reason,
3273            }))
3274        }
3275    }
3276}
3277
3278fn notify_membership_updated(
3279    connection_pool: &ConnectionPool,
3280    result: MembershipUpdated,
3281    user_id: UserId,
3282    peer: &Peer,
3283) {
3284    let user_channels_update = proto::UpdateUserChannels {
3285        channel_memberships: result
3286            .new_channels
3287            .channel_memberships
3288            .iter()
3289            .map(|cm| proto::ChannelMembership {
3290                channel_id: cm.channel_id.to_proto(),
3291                role: cm.role.into(),
3292            })
3293            .collect(),
3294        ..Default::default()
3295    };
3296    let mut update = build_channels_update(result.new_channels, vec![]);
3297    update.delete_channels = result
3298        .removed_channels
3299        .into_iter()
3300        .map(|id| id.to_proto())
3301        .collect();
3302    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3303
3304    for connection_id in connection_pool.user_connection_ids(user_id) {
3305        peer.send(connection_id, user_channels_update.clone())
3306            .trace_err();
3307        peer.send(connection_id, update.clone()).trace_err();
3308    }
3309}
3310
3311fn build_update_user_channels(
3312    memberships: &Vec<db::channel_member::Model>,
3313) -> proto::UpdateUserChannels {
3314    proto::UpdateUserChannels {
3315        channel_memberships: memberships
3316            .iter()
3317            .map(|m| proto::ChannelMembership {
3318                channel_id: m.channel_id.to_proto(),
3319                role: m.role.into(),
3320            })
3321            .collect(),
3322        ..Default::default()
3323    }
3324}
3325
3326fn build_channels_update(
3327    channels: ChannelsForUser,
3328    channel_invites: Vec<db::Channel>,
3329) -> proto::UpdateChannels {
3330    let mut update = proto::UpdateChannels::default();
3331
3332    for channel in channels.channels {
3333        update.channels.push(channel.to_proto());
3334    }
3335
3336    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3337    update.latest_channel_message_ids = channels.latest_channel_messages;
3338
3339    for (channel_id, participants) in channels.channel_participants {
3340        update
3341            .channel_participants
3342            .push(proto::ChannelParticipants {
3343                channel_id: channel_id.to_proto(),
3344                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3345            });
3346    }
3347
3348    for channel in channel_invites {
3349        update.channel_invitations.push(channel.to_proto());
3350    }
3351
3352    update
3353}
3354
3355fn build_initial_contacts_update(
3356    contacts: Vec<db::Contact>,
3357    pool: &ConnectionPool,
3358) -> proto::UpdateContacts {
3359    let mut update = proto::UpdateContacts::default();
3360
3361    for contact in contacts {
3362        match contact {
3363            db::Contact::Accepted { user_id, busy } => {
3364                update.contacts.push(contact_for_user(user_id, busy, &pool));
3365            }
3366            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3367            db::Contact::Incoming { user_id } => {
3368                update
3369                    .incoming_requests
3370                    .push(proto::IncomingContactRequest {
3371                        requester_id: user_id.to_proto(),
3372                    })
3373            }
3374        }
3375    }
3376
3377    update
3378}
3379
3380fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3381    proto::Contact {
3382        user_id: user_id.to_proto(),
3383        online: pool.is_user_online(user_id),
3384        busy,
3385    }
3386}
3387
3388fn room_updated(room: &proto::Room, peer: &Peer) {
3389    broadcast(
3390        None,
3391        room.participants
3392            .iter()
3393            .filter_map(|participant| Some(participant.peer_id?.into())),
3394        |peer_id| {
3395            peer.send(
3396                peer_id.into(),
3397                proto::RoomUpdated {
3398                    room: Some(room.clone()),
3399                },
3400            )
3401        },
3402    );
3403}
3404
3405fn channel_updated(
3406    channel_id: ChannelId,
3407    room: &proto::Room,
3408    channel_members: &[UserId],
3409    peer: &Peer,
3410    pool: &ConnectionPool,
3411) {
3412    let participants = room
3413        .participants
3414        .iter()
3415        .map(|p| p.user_id)
3416        .collect::<Vec<_>>();
3417
3418    broadcast(
3419        None,
3420        channel_members
3421            .iter()
3422            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3423        |peer_id| {
3424            peer.send(
3425                peer_id.into(),
3426                proto::UpdateChannels {
3427                    channel_participants: vec![proto::ChannelParticipants {
3428                        channel_id: channel_id.to_proto(),
3429                        participant_user_ids: participants.clone(),
3430                    }],
3431                    ..Default::default()
3432                },
3433            )
3434        },
3435    );
3436}
3437
3438async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3439    let db = session.db().await;
3440
3441    let contacts = db.get_contacts(user_id).await?;
3442    let busy = db.is_user_busy(user_id).await?;
3443
3444    let pool = session.connection_pool().await;
3445    let updated_contact = contact_for_user(user_id, busy, &pool);
3446    for contact in contacts {
3447        if let db::Contact::Accepted {
3448            user_id: contact_user_id,
3449            ..
3450        } = contact
3451        {
3452            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3453                session
3454                    .peer
3455                    .send(
3456                        contact_conn_id,
3457                        proto::UpdateContacts {
3458                            contacts: vec![updated_contact.clone()],
3459                            remove_contacts: Default::default(),
3460                            incoming_requests: Default::default(),
3461                            remove_incoming_requests: Default::default(),
3462                            outgoing_requests: Default::default(),
3463                            remove_outgoing_requests: Default::default(),
3464                        },
3465                    )
3466                    .trace_err();
3467            }
3468        }
3469    }
3470    Ok(())
3471}
3472
3473async fn leave_room_for_session(session: &Session) -> Result<()> {
3474    let mut contacts_to_update = HashSet::default();
3475
3476    let room_id;
3477    let canceled_calls_to_user_ids;
3478    let live_kit_room;
3479    let delete_live_kit_room;
3480    let room;
3481    let channel_members;
3482    let channel_id;
3483
3484    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3485        contacts_to_update.insert(session.user_id);
3486
3487        for project in left_room.left_projects.values() {
3488            project_left(project, session);
3489        }
3490
3491        room_id = RoomId::from_proto(left_room.room.id);
3492        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3493        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3494        delete_live_kit_room = left_room.deleted;
3495        room = mem::take(&mut left_room.room);
3496        channel_members = mem::take(&mut left_room.channel_members);
3497        channel_id = left_room.channel_id;
3498
3499        room_updated(&room, &session.peer);
3500    } else {
3501        return Ok(());
3502    }
3503
3504    if let Some(channel_id) = channel_id {
3505        channel_updated(
3506            channel_id,
3507            &room,
3508            &channel_members,
3509            &session.peer,
3510            &*session.connection_pool().await,
3511        );
3512    }
3513
3514    {
3515        let pool = session.connection_pool().await;
3516        for canceled_user_id in canceled_calls_to_user_ids {
3517            for connection_id in pool.user_connection_ids(canceled_user_id) {
3518                session
3519                    .peer
3520                    .send(
3521                        connection_id,
3522                        proto::CallCanceled {
3523                            room_id: room_id.to_proto(),
3524                        },
3525                    )
3526                    .trace_err();
3527            }
3528            contacts_to_update.insert(canceled_user_id);
3529        }
3530    }
3531
3532    for contact_user_id in contacts_to_update {
3533        update_user_contacts(contact_user_id, &session).await?;
3534    }
3535
3536    if let Some(live_kit) = session.live_kit_client.as_ref() {
3537        live_kit
3538            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3539            .await
3540            .trace_err();
3541
3542        if delete_live_kit_room {
3543            live_kit.delete_room(live_kit_room).await.trace_err();
3544        }
3545    }
3546
3547    Ok(())
3548}
3549
3550async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3551    let left_channel_buffers = session
3552        .db()
3553        .await
3554        .leave_channel_buffers(session.connection_id)
3555        .await?;
3556
3557    for left_buffer in left_channel_buffers {
3558        channel_buffer_updated(
3559            session.connection_id,
3560            left_buffer.connections,
3561            &proto::UpdateChannelBufferCollaborators {
3562                channel_id: left_buffer.channel_id.to_proto(),
3563                collaborators: left_buffer.collaborators,
3564            },
3565            &session.peer,
3566        );
3567    }
3568
3569    Ok(())
3570}
3571
3572fn project_left(project: &db::LeftProject, session: &Session) {
3573    for connection_id in &project.connection_ids {
3574        if project.host_user_id == session.user_id {
3575            session
3576                .peer
3577                .send(
3578                    *connection_id,
3579                    proto::UnshareProject {
3580                        project_id: project.id.to_proto(),
3581                    },
3582                )
3583                .trace_err();
3584        } else {
3585            session
3586                .peer
3587                .send(
3588                    *connection_id,
3589                    proto::RemoveProjectCollaborator {
3590                        project_id: project.id.to_proto(),
3591                        peer_id: Some(session.connection_id.into()),
3592                    },
3593                )
3594                .trace_err();
3595        }
3596    }
3597}
3598
3599pub trait ResultExt {
3600    type Ok;
3601
3602    fn trace_err(self) -> Option<Self::Ok>;
3603}
3604
3605impl<T, E> ResultExt for Result<T, E>
3606where
3607    E: std::fmt::Debug,
3608{
3609    type Ok = T;
3610
3611    fn trace_err(self) -> Option<T> {
3612        match self {
3613            Ok(value) => Some(value),
3614            Err(error) => {
3615                tracing::error!("{:?}", error);
3616                None
3617            }
3618        }
3619    }
3620}