rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RenameChannelResult, RespondToChannelInvite, RoomId, ServerId,
   9        SetChannelVisibilityResult, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, Result,
  13};
  14use anyhow::anyhow;
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::ConnectionPool;
  33use futures::{
  34    channel::oneshot,
  35    future::{self, BoxFuture},
  36    stream::FuturesUnordered,
  37    FutureExt, SinkExt, StreamExt, TryStreamExt,
  38};
  39use lazy_static::lazy_static;
  40use prometheus::{register_int_gauge, IntGauge};
  41use rpc::{
  42    proto::{
  43        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  44        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  45    },
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47};
  48use serde::{Serialize, Serializer};
  49use std::{
  50    any::TypeId,
  51    fmt,
  52    future::Future,
  53    marker::PhantomData,
  54    mem,
  55    net::SocketAddr,
  56    ops::{Deref, DerefMut},
  57    rc::Rc,
  58    sync::{
  59        atomic::{AtomicBool, Ordering::SeqCst},
  60        Arc,
  61    },
  62    time::{Duration, Instant},
  63};
  64use time::OffsetDateTime;
  65use tokio::sync::{watch, Semaphore};
  66use tower::ServiceBuilder;
  67use tracing::{field, info_span, instrument, Instrument};
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  75
  76lazy_static! {
  77    static ref METRIC_CONNECTIONS: IntGauge =
  78        register_int_gauge!("connections", "number of connections").unwrap();
  79    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  80        "shared_projects",
  81        "number of open projects with one or more guests"
  82    )
  83    .unwrap();
  84}
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103#[derive(Clone)]
 104struct Session {
 105    zed_environment: Arc<str>,
 106    user_id: UserId,
 107    connection_id: ConnectionId,
 108    db: Arc<tokio::sync::Mutex<DbHandle>>,
 109    peer: Arc<Peer>,
 110    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 111    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 112    _executor: Executor,
 113}
 114
 115impl Session {
 116    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        let guard = self.db.lock().await;
 120        #[cfg(test)]
 121        tokio::task::yield_now().await;
 122        guard
 123    }
 124
 125    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 126        #[cfg(test)]
 127        tokio::task::yield_now().await;
 128        let guard = self.connection_pool.lock();
 129        ConnectionPoolGuard {
 130            guard,
 131            _not_send: PhantomData,
 132        }
 133    }
 134}
 135
 136impl fmt::Debug for Session {
 137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 138        f.debug_struct("Session")
 139            .field("user_id", &self.user_id)
 140            .field("connection_id", &self.connection_id)
 141            .finish()
 142    }
 143}
 144
 145struct DbHandle(Arc<Database>);
 146
 147impl Deref for DbHandle {
 148    type Target = Database;
 149
 150    fn deref(&self) -> &Self::Target {
 151        self.0.as_ref()
 152    }
 153}
 154
 155pub struct Server {
 156    id: parking_lot::Mutex<ServerId>,
 157    peer: Arc<Peer>,
 158    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 159    app_state: Arc<AppState>,
 160    executor: Executor,
 161    handlers: HashMap<TypeId, MessageHandler>,
 162    teardown: watch::Sender<()>,
 163}
 164
 165pub(crate) struct ConnectionPoolGuard<'a> {
 166    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 167    _not_send: PhantomData<Rc<()>>,
 168}
 169
 170#[derive(Serialize)]
 171pub struct ServerSnapshot<'a> {
 172    peer: &'a Peer,
 173    #[serde(serialize_with = "serialize_deref")]
 174    connection_pool: ConnectionPoolGuard<'a>,
 175}
 176
 177pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 178where
 179    S: Serializer,
 180    T: Deref<Target = U>,
 181    U: Serialize,
 182{
 183    Serialize::serialize(value.deref(), serializer)
 184}
 185
 186impl Server {
 187    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 188        let mut server = Self {
 189            id: parking_lot::Mutex::new(id),
 190            peer: Peer::new(id.0 as u32),
 191            app_state,
 192            executor,
 193            connection_pool: Default::default(),
 194            handlers: Default::default(),
 195            teardown: watch::channel(()).0,
 196        };
 197
 198        server
 199            .add_request_handler(ping)
 200            .add_request_handler(create_room)
 201            .add_request_handler(join_room)
 202            .add_request_handler(rejoin_room)
 203            .add_request_handler(leave_room)
 204            .add_request_handler(set_room_participant_role)
 205            .add_request_handler(call)
 206            .add_request_handler(cancel_call)
 207            .add_message_handler(decline_call)
 208            .add_request_handler(update_participant_location)
 209            .add_request_handler(share_project)
 210            .add_message_handler(unshare_project)
 211            .add_request_handler(join_project)
 212            .add_message_handler(leave_project)
 213            .add_request_handler(update_project)
 214            .add_request_handler(update_worktree)
 215            .add_message_handler(start_language_server)
 216            .add_message_handler(update_language_server)
 217            .add_message_handler(update_diagnostic_summary)
 218            .add_message_handler(update_worktree_settings)
 219            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 223            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 226            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 228            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 229            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 230            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 231            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 232            .add_request_handler(
 233                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 234            )
 235            .add_request_handler(
 236                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 237            )
 238            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 239            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 240            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 243            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 245            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 250            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 251            .add_message_handler(create_buffer_for_peer)
 252            .add_request_handler(update_buffer)
 253            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 258            .add_request_handler(get_users)
 259            .add_request_handler(fuzzy_search_users)
 260            .add_request_handler(request_contact)
 261            .add_request_handler(remove_contact)
 262            .add_request_handler(respond_to_contact_request)
 263            .add_request_handler(create_channel)
 264            .add_request_handler(delete_channel)
 265            .add_request_handler(invite_channel_member)
 266            .add_request_handler(remove_channel_member)
 267            .add_request_handler(set_channel_member_role)
 268            .add_request_handler(set_channel_visibility)
 269            .add_request_handler(rename_channel)
 270            .add_request_handler(join_channel_buffer)
 271            .add_request_handler(leave_channel_buffer)
 272            .add_message_handler(update_channel_buffer)
 273            .add_request_handler(rejoin_channel_buffers)
 274            .add_request_handler(get_channel_members)
 275            .add_request_handler(respond_to_channel_invite)
 276            .add_request_handler(join_channel)
 277            .add_request_handler(join_channel_chat)
 278            .add_message_handler(leave_channel_chat)
 279            .add_request_handler(send_channel_message)
 280            .add_request_handler(remove_channel_message)
 281            .add_request_handler(get_channel_messages)
 282            .add_request_handler(get_channel_messages_by_id)
 283            .add_request_handler(get_notifications)
 284            .add_request_handler(mark_notification_as_read)
 285            .add_request_handler(move_channel)
 286            .add_request_handler(follow)
 287            .add_message_handler(unfollow)
 288            .add_message_handler(update_followers)
 289            .add_request_handler(get_private_user_info)
 290            .add_message_handler(acknowledge_channel_message)
 291            .add_message_handler(acknowledge_buffer_version);
 292
 293        Arc::new(server)
 294    }
 295
 296    pub async fn start(&self) -> Result<()> {
 297        let server_id = *self.id.lock();
 298        let app_state = self.app_state.clone();
 299        let peer = self.peer.clone();
 300        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 301        let pool = self.connection_pool.clone();
 302        let live_kit_client = self.app_state.live_kit_client.clone();
 303
 304        let span = info_span!("start server");
 305        self.executor.spawn_detached(
 306            async move {
 307                tracing::info!("waiting for cleanup timeout");
 308                timeout.await;
 309                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 310                if let Some((room_ids, channel_ids)) = app_state
 311                    .db
 312                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 313                    .await
 314                    .trace_err()
 315                {
 316                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 317                    tracing::info!(
 318                        stale_channel_buffer_count = channel_ids.len(),
 319                        "retrieved stale channel buffers"
 320                    );
 321
 322                    for channel_id in channel_ids {
 323                        if let Some(refreshed_channel_buffer) = app_state
 324                            .db
 325                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 326                            .await
 327                            .trace_err()
 328                        {
 329                            for connection_id in refreshed_channel_buffer.connection_ids {
 330                                peer.send(
 331                                    connection_id,
 332                                    proto::UpdateChannelBufferCollaborators {
 333                                        channel_id: channel_id.to_proto(),
 334                                        collaborators: refreshed_channel_buffer
 335                                            .collaborators
 336                                            .clone(),
 337                                    },
 338                                )
 339                                .trace_err();
 340                            }
 341                        }
 342                    }
 343
 344                    for room_id in room_ids {
 345                        let mut contacts_to_update = HashSet::default();
 346                        let mut canceled_calls_to_user_ids = Vec::new();
 347                        let mut live_kit_room = String::new();
 348                        let mut delete_live_kit_room = false;
 349
 350                        if let Some(mut refreshed_room) = app_state
 351                            .db
 352                            .clear_stale_room_participants(room_id, server_id)
 353                            .await
 354                            .trace_err()
 355                        {
 356                            tracing::info!(
 357                                room_id = room_id.0,
 358                                new_participant_count = refreshed_room.room.participants.len(),
 359                                "refreshed room"
 360                            );
 361                            room_updated(&refreshed_room.room, &peer);
 362                            if let Some(channel_id) = refreshed_room.channel_id {
 363                                channel_updated(
 364                                    channel_id,
 365                                    &refreshed_room.room,
 366                                    &refreshed_room.channel_members,
 367                                    &peer,
 368                                    &*pool.lock(),
 369                                );
 370                            }
 371                            contacts_to_update
 372                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 373                            contacts_to_update
 374                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 375                            canceled_calls_to_user_ids =
 376                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 377                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 378                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 379                        }
 380
 381                        {
 382                            let pool = pool.lock();
 383                            for canceled_user_id in canceled_calls_to_user_ids {
 384                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 385                                    peer.send(
 386                                        connection_id,
 387                                        proto::CallCanceled {
 388                                            room_id: room_id.to_proto(),
 389                                        },
 390                                    )
 391                                    .trace_err();
 392                                }
 393                            }
 394                        }
 395
 396                        for user_id in contacts_to_update {
 397                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 398                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 399                            if let Some((busy, contacts)) = busy.zip(contacts) {
 400                                let pool = pool.lock();
 401                                let updated_contact = contact_for_user(user_id, busy, &pool);
 402                                for contact in contacts {
 403                                    if let db::Contact::Accepted {
 404                                        user_id: contact_user_id,
 405                                        ..
 406                                    } = contact
 407                                    {
 408                                        for contact_conn_id in
 409                                            pool.user_connection_ids(contact_user_id)
 410                                        {
 411                                            peer.send(
 412                                                contact_conn_id,
 413                                                proto::UpdateContacts {
 414                                                    contacts: vec![updated_contact.clone()],
 415                                                    remove_contacts: Default::default(),
 416                                                    incoming_requests: Default::default(),
 417                                                    remove_incoming_requests: Default::default(),
 418                                                    outgoing_requests: Default::default(),
 419                                                    remove_outgoing_requests: Default::default(),
 420                                                },
 421                                            )
 422                                            .trace_err();
 423                                        }
 424                                    }
 425                                }
 426                            }
 427                        }
 428
 429                        if let Some(live_kit) = live_kit_client.as_ref() {
 430                            if delete_live_kit_room {
 431                                live_kit.delete_room(live_kit_room).await.trace_err();
 432                            }
 433                        }
 434                    }
 435                }
 436
 437                app_state
 438                    .db
 439                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 440                    .await
 441                    .trace_err();
 442            }
 443            .instrument(span),
 444        );
 445        Ok(())
 446    }
 447
 448    pub fn teardown(&self) {
 449        self.peer.teardown();
 450        self.connection_pool.lock().reset();
 451        let _ = self.teardown.send(());
 452    }
 453
 454    #[cfg(test)]
 455    pub fn reset(&self, id: ServerId) {
 456        self.teardown();
 457        *self.id.lock() = id;
 458        self.peer.reset(id.0 as u32);
 459    }
 460
 461    #[cfg(test)]
 462    pub fn id(&self) -> ServerId {
 463        *self.id.lock()
 464    }
 465
 466    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 467    where
 468        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 469        Fut: 'static + Send + Future<Output = Result<()>>,
 470        M: EnvelopedMessage,
 471    {
 472        let prev_handler = self.handlers.insert(
 473            TypeId::of::<M>(),
 474            Box::new(move |envelope, session| {
 475                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 476                let span = info_span!(
 477                    "handle message",
 478                    payload_type = envelope.payload_type_name()
 479                );
 480                span.in_scope(|| {
 481                    tracing::info!(
 482                        payload_type = envelope.payload_type_name(),
 483                        "message received"
 484                    );
 485                });
 486                let start_time = Instant::now();
 487                let future = (handler)(*envelope, session);
 488                async move {
 489                    let result = future.await;
 490                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 491                    match result {
 492                        Err(error) => {
 493                            tracing::error!(%error, ?duration_ms, "error handling message")
 494                        }
 495                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 496                    }
 497                }
 498                .instrument(span)
 499                .boxed()
 500            }),
 501        );
 502        if prev_handler.is_some() {
 503            panic!("registered a handler for the same message twice");
 504        }
 505        self
 506    }
 507
 508    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 509    where
 510        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 511        Fut: 'static + Send + Future<Output = Result<()>>,
 512        M: EnvelopedMessage,
 513    {
 514        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 515        self
 516    }
 517
 518    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 519    where
 520        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 521        Fut: Send + Future<Output = Result<()>>,
 522        M: RequestMessage,
 523    {
 524        let handler = Arc::new(handler);
 525        self.add_handler(move |envelope, session| {
 526            let receipt = envelope.receipt();
 527            let handler = handler.clone();
 528            async move {
 529                let peer = session.peer.clone();
 530                let responded = Arc::new(AtomicBool::default());
 531                let response = Response {
 532                    peer: peer.clone(),
 533                    responded: responded.clone(),
 534                    receipt,
 535                };
 536                match (handler)(envelope.payload, response, session).await {
 537                    Ok(()) => {
 538                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 539                            Ok(())
 540                        } else {
 541                            Err(anyhow!("handler did not send a response"))?
 542                        }
 543                    }
 544                    Err(error) => {
 545                        let proto_err = match &error {
 546                            Error::Internal(err) => err.to_proto(),
 547                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 548                        };
 549                        peer.respond_with_error(receipt, proto_err)?;
 550                        Err(error)
 551                    }
 552                }
 553            }
 554        })
 555    }
 556
 557    pub fn handle_connection(
 558        self: &Arc<Self>,
 559        connection: Connection,
 560        address: String,
 561        user: User,
 562        impersonator: Option<User>,
 563        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 564        executor: Executor,
 565    ) -> impl Future<Output = Result<()>> {
 566        let this = self.clone();
 567        let user_id = user.id;
 568        let login = user.github_login;
 569        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 570        if let Some(impersonator) = impersonator {
 571            span.record("impersonator", &impersonator.github_login);
 572        }
 573        let mut teardown = self.teardown.subscribe();
 574        async move {
 575            let (connection_id, handle_io, mut incoming_rx) = this
 576                .peer
 577                .add_connection(connection, {
 578                    let executor = executor.clone();
 579                    move |duration| executor.sleep(duration)
 580                });
 581
 582            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 583            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 584            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 585
 586            if let Some(send_connection_id) = send_connection_id.take() {
 587                let _ = send_connection_id.send(connection_id);
 588            }
 589
 590            if !user.connected_once {
 591                this.peer.send(connection_id, proto::ShowContacts {})?;
 592                this.app_state.db.set_user_connected_once(user_id, true).await?;
 593            }
 594
 595            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 596                this.app_state.db.get_contacts(user_id),
 597                this.app_state.db.get_channels_for_user(user_id),
 598                this.app_state.db.get_channel_invites_for_user(user_id),
 599            ).await?;
 600
 601            {
 602                let mut pool = this.connection_pool.lock();
 603                pool.add_connection(connection_id, user_id, user.admin);
 604                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 605                this.peer.send(connection_id, build_channels_update(
 606                    channels_for_user,
 607                    channel_invites
 608                ))?;
 609            }
 610
 611            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 612                this.peer.send(connection_id, incoming_call)?;
 613            }
 614
 615            let session = Session {
 616                user_id,
 617                connection_id,
 618                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 619                zed_environment: this.app_state.config.zed_environment.clone(),
 620                peer: this.peer.clone(),
 621                connection_pool: this.connection_pool.clone(),
 622                live_kit_client: this.app_state.live_kit_client.clone(),
 623                _executor: executor.clone()
 624            };
 625            update_user_contacts(user_id, &session).await?;
 626
 627            let handle_io = handle_io.fuse();
 628            futures::pin_mut!(handle_io);
 629
 630            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 631            // This prevents deadlocks when e.g., client A performs a request to client B and
 632            // client B performs a request to client A. If both clients stop processing further
 633            // messages until their respective request completes, they won't have a chance to
 634            // respond to the other client's request and cause a deadlock.
 635            //
 636            // This arrangement ensures we will attempt to process earlier messages first, but fall
 637            // back to processing messages arrived later in the spirit of making progress.
 638            let mut foreground_message_handlers = FuturesUnordered::new();
 639            let concurrent_handlers = Arc::new(Semaphore::new(256));
 640            loop {
 641                let next_message = async {
 642                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 643                    let message = incoming_rx.next().await;
 644                    (permit, message)
 645                }.fuse();
 646                futures::pin_mut!(next_message);
 647                futures::select_biased! {
 648                    _ = teardown.changed().fuse() => return Ok(()),
 649                    result = handle_io => {
 650                        if let Err(error) = result {
 651                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 652                        }
 653                        break;
 654                    }
 655                    _ = foreground_message_handlers.next() => {}
 656                    next_message = next_message => {
 657                        let (permit, message) = next_message;
 658                        if let Some(message) = message {
 659                            let type_name = message.payload_type_name();
 660                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 661                            let span_enter = span.enter();
 662                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 663                                let is_background = message.is_background();
 664                                let handle_message = (handler)(message, session.clone());
 665                                drop(span_enter);
 666
 667                                let handle_message = async move {
 668                                    handle_message.await;
 669                                    drop(permit);
 670                                }.instrument(span);
 671                                if is_background {
 672                                    executor.spawn_detached(handle_message);
 673                                } else {
 674                                    foreground_message_handlers.push(handle_message);
 675                                }
 676                            } else {
 677                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 678                            }
 679                        } else {
 680                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 681                            break;
 682                        }
 683                    }
 684                }
 685            }
 686
 687            drop(foreground_message_handlers);
 688            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 689            if let Err(error) = connection_lost(session, teardown, executor).await {
 690                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 691            }
 692
 693            Ok(())
 694        }.instrument(span)
 695    }
 696
 697    pub async fn invite_code_redeemed(
 698        self: &Arc<Self>,
 699        inviter_id: UserId,
 700        invitee_id: UserId,
 701    ) -> Result<()> {
 702        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 703            if let Some(code) = &user.invite_code {
 704                let pool = self.connection_pool.lock();
 705                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 706                for connection_id in pool.user_connection_ids(inviter_id) {
 707                    self.peer.send(
 708                        connection_id,
 709                        proto::UpdateContacts {
 710                            contacts: vec![invitee_contact.clone()],
 711                            ..Default::default()
 712                        },
 713                    )?;
 714                    self.peer.send(
 715                        connection_id,
 716                        proto::UpdateInviteInfo {
 717                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 718                            count: user.invite_count as u32,
 719                        },
 720                    )?;
 721                }
 722            }
 723        }
 724        Ok(())
 725    }
 726
 727    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 728        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 729            if let Some(invite_code) = &user.invite_code {
 730                let pool = self.connection_pool.lock();
 731                for connection_id in pool.user_connection_ids(user_id) {
 732                    self.peer.send(
 733                        connection_id,
 734                        proto::UpdateInviteInfo {
 735                            url: format!(
 736                                "{}{}",
 737                                self.app_state.config.invite_link_prefix, invite_code
 738                            ),
 739                            count: user.invite_count as u32,
 740                        },
 741                    )?;
 742                }
 743            }
 744        }
 745        Ok(())
 746    }
 747
 748    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 749        ServerSnapshot {
 750            connection_pool: ConnectionPoolGuard {
 751                guard: self.connection_pool.lock(),
 752                _not_send: PhantomData,
 753            },
 754            peer: &self.peer,
 755        }
 756    }
 757}
 758
 759impl<'a> Deref for ConnectionPoolGuard<'a> {
 760    type Target = ConnectionPool;
 761
 762    fn deref(&self) -> &Self::Target {
 763        &*self.guard
 764    }
 765}
 766
 767impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 768    fn deref_mut(&mut self) -> &mut Self::Target {
 769        &mut *self.guard
 770    }
 771}
 772
 773impl<'a> Drop for ConnectionPoolGuard<'a> {
 774    fn drop(&mut self) {
 775        #[cfg(test)]
 776        self.check_invariants();
 777    }
 778}
 779
 780fn broadcast<F>(
 781    sender_id: Option<ConnectionId>,
 782    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 783    mut f: F,
 784) where
 785    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 786{
 787    for receiver_id in receiver_ids {
 788        if Some(receiver_id) != sender_id {
 789            if let Err(error) = f(receiver_id) {
 790                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 791            }
 792        }
 793    }
 794}
 795
 796lazy_static! {
 797    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 798}
 799
 800pub struct ProtocolVersion(u32);
 801
 802impl Header for ProtocolVersion {
 803    fn name() -> &'static HeaderName {
 804        &ZED_PROTOCOL_VERSION
 805    }
 806
 807    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 808    where
 809        Self: Sized,
 810        I: Iterator<Item = &'i axum::http::HeaderValue>,
 811    {
 812        let version = values
 813            .next()
 814            .ok_or_else(axum::headers::Error::invalid)?
 815            .to_str()
 816            .map_err(|_| axum::headers::Error::invalid())?
 817            .parse()
 818            .map_err(|_| axum::headers::Error::invalid())?;
 819        Ok(Self(version))
 820    }
 821
 822    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 823        values.extend([self.0.to_string().parse().unwrap()]);
 824    }
 825}
 826
 827pub fn routes(server: Arc<Server>) -> Router<Body> {
 828    Router::new()
 829        .route("/rpc", get(handle_websocket_request))
 830        .layer(
 831            ServiceBuilder::new()
 832                .layer(Extension(server.app_state.clone()))
 833                .layer(middleware::from_fn(auth::validate_header)),
 834        )
 835        .route("/metrics", get(handle_metrics))
 836        .layer(Extension(server))
 837}
 838
 839pub async fn handle_websocket_request(
 840    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 841    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 842    Extension(server): Extension<Arc<Server>>,
 843    Extension(user): Extension<User>,
 844    Extension(impersonator): Extension<Impersonator>,
 845    ws: WebSocketUpgrade,
 846) -> axum::response::Response {
 847    if protocol_version != rpc::PROTOCOL_VERSION {
 848        return (
 849            StatusCode::UPGRADE_REQUIRED,
 850            "client must be upgraded".to_string(),
 851        )
 852            .into_response();
 853    }
 854    let socket_address = socket_address.to_string();
 855    ws.on_upgrade(move |socket| {
 856        use util::ResultExt;
 857        let socket = socket
 858            .map_ok(to_tungstenite_message)
 859            .err_into()
 860            .with(|message| async move { Ok(to_axum_message(message)) });
 861        let connection = Connection::new(Box::pin(socket));
 862        async move {
 863            server
 864                .handle_connection(
 865                    connection,
 866                    socket_address,
 867                    user,
 868                    impersonator.0,
 869                    None,
 870                    Executor::Production,
 871                )
 872                .await
 873                .log_err();
 874        }
 875    })
 876}
 877
 878pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 879    let connections = server
 880        .connection_pool
 881        .lock()
 882        .connections()
 883        .filter(|connection| !connection.admin)
 884        .count();
 885
 886    METRIC_CONNECTIONS.set(connections as _);
 887
 888    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 889    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 890
 891    let encoder = prometheus::TextEncoder::new();
 892    let metric_families = prometheus::gather();
 893    let encoded_metrics = encoder
 894        .encode_to_string(&metric_families)
 895        .map_err(|err| anyhow!("{}", err))?;
 896    Ok(encoded_metrics)
 897}
 898
 899#[instrument(err, skip(executor))]
 900async fn connection_lost(
 901    session: Session,
 902    mut teardown: watch::Receiver<()>,
 903    executor: Executor,
 904) -> Result<()> {
 905    session.peer.disconnect(session.connection_id);
 906    session
 907        .connection_pool()
 908        .await
 909        .remove_connection(session.connection_id)?;
 910
 911    session
 912        .db()
 913        .await
 914        .connection_lost(session.connection_id)
 915        .await
 916        .trace_err();
 917
 918    futures::select_biased! {
 919        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 920            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 921            leave_room_for_session(&session).await.trace_err();
 922            leave_channel_buffers_for_session(&session)
 923                .await
 924                .trace_err();
 925
 926            if !session
 927                .connection_pool()
 928                .await
 929                .is_user_online(session.user_id)
 930            {
 931                let db = session.db().await;
 932                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 933                    room_updated(&room, &session.peer);
 934                }
 935            }
 936
 937            update_user_contacts(session.user_id, &session).await?;
 938        }
 939        _ = teardown.changed().fuse() => {}
 940    }
 941
 942    Ok(())
 943}
 944
 945/// Acknowledges a ping from a client, used to keep the connection alive.
 946async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 947    response.send(proto::Ack {})?;
 948    Ok(())
 949}
 950
 951/// Creates a new room for calling (outside of channels)
 952async fn create_room(
 953    _request: proto::CreateRoom,
 954    response: Response<proto::CreateRoom>,
 955    session: Session,
 956) -> Result<()> {
 957    let live_kit_room = nanoid::nanoid!(30);
 958
 959    let live_kit_connection_info = {
 960        let live_kit_room = live_kit_room.clone();
 961        let live_kit = session.live_kit_client.as_ref();
 962
 963        util::async_maybe!({
 964            let live_kit = live_kit?;
 965
 966            let token = live_kit
 967                .room_token(&live_kit_room, &session.user_id.to_string())
 968                .trace_err()?;
 969
 970            Some(proto::LiveKitConnectionInfo {
 971                server_url: live_kit.url().into(),
 972                token,
 973                can_publish: true,
 974            })
 975        })
 976    }
 977    .await;
 978
 979    let room = session
 980        .db()
 981        .await
 982        .create_room(
 983            session.user_id,
 984            session.connection_id,
 985            &live_kit_room,
 986            &session.zed_environment,
 987        )
 988        .await?;
 989
 990    response.send(proto::CreateRoomResponse {
 991        room: Some(room.clone()),
 992        live_kit_connection_info,
 993    })?;
 994
 995    update_user_contacts(session.user_id, &session).await?;
 996    Ok(())
 997}
 998
 999/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1000async fn join_room(
1001    request: proto::JoinRoom,
1002    response: Response<proto::JoinRoom>,
1003    session: Session,
1004) -> Result<()> {
1005    let room_id = RoomId::from_proto(request.id);
1006
1007    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1008
1009    if let Some(channel_id) = channel_id {
1010        return join_channel_internal(channel_id, Box::new(response), session).await;
1011    }
1012
1013    let joined_room = {
1014        let room = session
1015            .db()
1016            .await
1017            .join_room(
1018                room_id,
1019                session.user_id,
1020                session.connection_id,
1021                session.zed_environment.as_ref(),
1022            )
1023            .await?;
1024        room_updated(&room.room, &session.peer);
1025        room.into_inner()
1026    };
1027
1028    for connection_id in session
1029        .connection_pool()
1030        .await
1031        .user_connection_ids(session.user_id)
1032    {
1033        session
1034            .peer
1035            .send(
1036                connection_id,
1037                proto::CallCanceled {
1038                    room_id: room_id.to_proto(),
1039                },
1040            )
1041            .trace_err();
1042    }
1043
1044    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1045        if let Some(token) = live_kit
1046            .room_token(
1047                &joined_room.room.live_kit_room,
1048                &session.user_id.to_string(),
1049            )
1050            .trace_err()
1051        {
1052            Some(proto::LiveKitConnectionInfo {
1053                server_url: live_kit.url().into(),
1054                token,
1055                can_publish: true,
1056            })
1057        } else {
1058            None
1059        }
1060    } else {
1061        None
1062    };
1063
1064    response.send(proto::JoinRoomResponse {
1065        room: Some(joined_room.room),
1066        channel_id: None,
1067        live_kit_connection_info,
1068    })?;
1069
1070    update_user_contacts(session.user_id, &session).await?;
1071    Ok(())
1072}
1073
1074/// Rejoin room is used to reconnect to a room after connection errors.
1075async fn rejoin_room(
1076    request: proto::RejoinRoom,
1077    response: Response<proto::RejoinRoom>,
1078    session: Session,
1079) -> Result<()> {
1080    let room;
1081    let channel_id;
1082    let channel_members;
1083    {
1084        let mut rejoined_room = session
1085            .db()
1086            .await
1087            .rejoin_room(request, session.user_id, session.connection_id)
1088            .await?;
1089
1090        response.send(proto::RejoinRoomResponse {
1091            room: Some(rejoined_room.room.clone()),
1092            reshared_projects: rejoined_room
1093                .reshared_projects
1094                .iter()
1095                .map(|project| proto::ResharedProject {
1096                    id: project.id.to_proto(),
1097                    collaborators: project
1098                        .collaborators
1099                        .iter()
1100                        .map(|collaborator| collaborator.to_proto())
1101                        .collect(),
1102                })
1103                .collect(),
1104            rejoined_projects: rejoined_room
1105                .rejoined_projects
1106                .iter()
1107                .map(|rejoined_project| proto::RejoinedProject {
1108                    id: rejoined_project.id.to_proto(),
1109                    worktrees: rejoined_project
1110                        .worktrees
1111                        .iter()
1112                        .map(|worktree| proto::WorktreeMetadata {
1113                            id: worktree.id,
1114                            root_name: worktree.root_name.clone(),
1115                            visible: worktree.visible,
1116                            abs_path: worktree.abs_path.clone(),
1117                        })
1118                        .collect(),
1119                    collaborators: rejoined_project
1120                        .collaborators
1121                        .iter()
1122                        .map(|collaborator| collaborator.to_proto())
1123                        .collect(),
1124                    language_servers: rejoined_project.language_servers.clone(),
1125                })
1126                .collect(),
1127        })?;
1128        room_updated(&rejoined_room.room, &session.peer);
1129
1130        for project in &rejoined_room.reshared_projects {
1131            for collaborator in &project.collaborators {
1132                session
1133                    .peer
1134                    .send(
1135                        collaborator.connection_id,
1136                        proto::UpdateProjectCollaborator {
1137                            project_id: project.id.to_proto(),
1138                            old_peer_id: Some(project.old_connection_id.into()),
1139                            new_peer_id: Some(session.connection_id.into()),
1140                        },
1141                    )
1142                    .trace_err();
1143            }
1144
1145            broadcast(
1146                Some(session.connection_id),
1147                project
1148                    .collaborators
1149                    .iter()
1150                    .map(|collaborator| collaborator.connection_id),
1151                |connection_id| {
1152                    session.peer.forward_send(
1153                        session.connection_id,
1154                        connection_id,
1155                        proto::UpdateProject {
1156                            project_id: project.id.to_proto(),
1157                            worktrees: project.worktrees.clone(),
1158                        },
1159                    )
1160                },
1161            );
1162        }
1163
1164        for project in &rejoined_room.rejoined_projects {
1165            for collaborator in &project.collaborators {
1166                session
1167                    .peer
1168                    .send(
1169                        collaborator.connection_id,
1170                        proto::UpdateProjectCollaborator {
1171                            project_id: project.id.to_proto(),
1172                            old_peer_id: Some(project.old_connection_id.into()),
1173                            new_peer_id: Some(session.connection_id.into()),
1174                        },
1175                    )
1176                    .trace_err();
1177            }
1178        }
1179
1180        for project in &mut rejoined_room.rejoined_projects {
1181            for worktree in mem::take(&mut project.worktrees) {
1182                #[cfg(any(test, feature = "test-support"))]
1183                const MAX_CHUNK_SIZE: usize = 2;
1184                #[cfg(not(any(test, feature = "test-support")))]
1185                const MAX_CHUNK_SIZE: usize = 256;
1186
1187                // Stream this worktree's entries.
1188                let message = proto::UpdateWorktree {
1189                    project_id: project.id.to_proto(),
1190                    worktree_id: worktree.id,
1191                    abs_path: worktree.abs_path.clone(),
1192                    root_name: worktree.root_name,
1193                    updated_entries: worktree.updated_entries,
1194                    removed_entries: worktree.removed_entries,
1195                    scan_id: worktree.scan_id,
1196                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1197                    updated_repositories: worktree.updated_repositories,
1198                    removed_repositories: worktree.removed_repositories,
1199                };
1200                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1201                    session.peer.send(session.connection_id, update.clone())?;
1202                }
1203
1204                // Stream this worktree's diagnostics.
1205                for summary in worktree.diagnostic_summaries {
1206                    session.peer.send(
1207                        session.connection_id,
1208                        proto::UpdateDiagnosticSummary {
1209                            project_id: project.id.to_proto(),
1210                            worktree_id: worktree.id,
1211                            summary: Some(summary),
1212                        },
1213                    )?;
1214                }
1215
1216                for settings_file in worktree.settings_files {
1217                    session.peer.send(
1218                        session.connection_id,
1219                        proto::UpdateWorktreeSettings {
1220                            project_id: project.id.to_proto(),
1221                            worktree_id: worktree.id,
1222                            path: settings_file.path,
1223                            content: Some(settings_file.content),
1224                        },
1225                    )?;
1226                }
1227            }
1228
1229            for language_server in &project.language_servers {
1230                session.peer.send(
1231                    session.connection_id,
1232                    proto::UpdateLanguageServer {
1233                        project_id: project.id.to_proto(),
1234                        language_server_id: language_server.id,
1235                        variant: Some(
1236                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1237                                proto::LspDiskBasedDiagnosticsUpdated {},
1238                            ),
1239                        ),
1240                    },
1241                )?;
1242            }
1243        }
1244
1245        let rejoined_room = rejoined_room.into_inner();
1246
1247        room = rejoined_room.room;
1248        channel_id = rejoined_room.channel_id;
1249        channel_members = rejoined_room.channel_members;
1250    }
1251
1252    if let Some(channel_id) = channel_id {
1253        channel_updated(
1254            channel_id,
1255            &room,
1256            &channel_members,
1257            &session.peer,
1258            &*session.connection_pool().await,
1259        );
1260    }
1261
1262    update_user_contacts(session.user_id, &session).await?;
1263    Ok(())
1264}
1265
1266/// leave room disonnects from the room.
1267async fn leave_room(
1268    _: proto::LeaveRoom,
1269    response: Response<proto::LeaveRoom>,
1270    session: Session,
1271) -> Result<()> {
1272    leave_room_for_session(&session).await?;
1273    response.send(proto::Ack {})?;
1274    Ok(())
1275}
1276
1277/// Updates the permissions of someone else in the room.
1278async fn set_room_participant_role(
1279    request: proto::SetRoomParticipantRole,
1280    response: Response<proto::SetRoomParticipantRole>,
1281    session: Session,
1282) -> Result<()> {
1283    let (live_kit_room, can_publish) = {
1284        let room = session
1285            .db()
1286            .await
1287            .set_room_participant_role(
1288                session.user_id,
1289                RoomId::from_proto(request.room_id),
1290                UserId::from_proto(request.user_id),
1291                ChannelRole::from(request.role()),
1292            )
1293            .await?;
1294
1295        let live_kit_room = room.live_kit_room.clone();
1296        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1297        room_updated(&room, &session.peer);
1298        (live_kit_room, can_publish)
1299    };
1300
1301    if let Some(live_kit) = session.live_kit_client.as_ref() {
1302        live_kit
1303            .update_participant(
1304                live_kit_room.clone(),
1305                request.user_id.to_string(),
1306                live_kit_server::proto::ParticipantPermission {
1307                    can_subscribe: true,
1308                    can_publish,
1309                    can_publish_data: can_publish,
1310                    hidden: false,
1311                    recorder: false,
1312                },
1313            )
1314            .await
1315            .trace_err();
1316    }
1317
1318    response.send(proto::Ack {})?;
1319    Ok(())
1320}
1321
1322/// Call someone else into the current room
1323async fn call(
1324    request: proto::Call,
1325    response: Response<proto::Call>,
1326    session: Session,
1327) -> Result<()> {
1328    let room_id = RoomId::from_proto(request.room_id);
1329    let calling_user_id = session.user_id;
1330    let calling_connection_id = session.connection_id;
1331    let called_user_id = UserId::from_proto(request.called_user_id);
1332    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1333    if !session
1334        .db()
1335        .await
1336        .has_contact(calling_user_id, called_user_id)
1337        .await?
1338    {
1339        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1340    }
1341
1342    let incoming_call = {
1343        let (room, incoming_call) = &mut *session
1344            .db()
1345            .await
1346            .call(
1347                room_id,
1348                calling_user_id,
1349                calling_connection_id,
1350                called_user_id,
1351                initial_project_id,
1352            )
1353            .await?;
1354        room_updated(&room, &session.peer);
1355        mem::take(incoming_call)
1356    };
1357    update_user_contacts(called_user_id, &session).await?;
1358
1359    let mut calls = session
1360        .connection_pool()
1361        .await
1362        .user_connection_ids(called_user_id)
1363        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1364        .collect::<FuturesUnordered<_>>();
1365
1366    while let Some(call_response) = calls.next().await {
1367        match call_response.as_ref() {
1368            Ok(_) => {
1369                response.send(proto::Ack {})?;
1370                return Ok(());
1371            }
1372            Err(_) => {
1373                call_response.trace_err();
1374            }
1375        }
1376    }
1377
1378    {
1379        let room = session
1380            .db()
1381            .await
1382            .call_failed(room_id, called_user_id)
1383            .await?;
1384        room_updated(&room, &session.peer);
1385    }
1386    update_user_contacts(called_user_id, &session).await?;
1387
1388    Err(anyhow!("failed to ring user"))?
1389}
1390
1391/// Cancel an outgoing call.
1392async fn cancel_call(
1393    request: proto::CancelCall,
1394    response: Response<proto::CancelCall>,
1395    session: Session,
1396) -> Result<()> {
1397    let called_user_id = UserId::from_proto(request.called_user_id);
1398    let room_id = RoomId::from_proto(request.room_id);
1399    {
1400        let room = session
1401            .db()
1402            .await
1403            .cancel_call(room_id, session.connection_id, called_user_id)
1404            .await?;
1405        room_updated(&room, &session.peer);
1406    }
1407
1408    for connection_id in session
1409        .connection_pool()
1410        .await
1411        .user_connection_ids(called_user_id)
1412    {
1413        session
1414            .peer
1415            .send(
1416                connection_id,
1417                proto::CallCanceled {
1418                    room_id: room_id.to_proto(),
1419                },
1420            )
1421            .trace_err();
1422    }
1423    response.send(proto::Ack {})?;
1424
1425    update_user_contacts(called_user_id, &session).await?;
1426    Ok(())
1427}
1428
1429/// Decline an incoming call.
1430async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1431    let room_id = RoomId::from_proto(message.room_id);
1432    {
1433        let room = session
1434            .db()
1435            .await
1436            .decline_call(Some(room_id), session.user_id)
1437            .await?
1438            .ok_or_else(|| anyhow!("failed to decline call"))?;
1439        room_updated(&room, &session.peer);
1440    }
1441
1442    for connection_id in session
1443        .connection_pool()
1444        .await
1445        .user_connection_ids(session.user_id)
1446    {
1447        session
1448            .peer
1449            .send(
1450                connection_id,
1451                proto::CallCanceled {
1452                    room_id: room_id.to_proto(),
1453                },
1454            )
1455            .trace_err();
1456    }
1457    update_user_contacts(session.user_id, &session).await?;
1458    Ok(())
1459}
1460
1461/// Updates other participants in the room with your current location.
1462async fn update_participant_location(
1463    request: proto::UpdateParticipantLocation,
1464    response: Response<proto::UpdateParticipantLocation>,
1465    session: Session,
1466) -> Result<()> {
1467    let room_id = RoomId::from_proto(request.room_id);
1468    let location = request
1469        .location
1470        .ok_or_else(|| anyhow!("invalid location"))?;
1471
1472    let db = session.db().await;
1473    let room = db
1474        .update_room_participant_location(room_id, session.connection_id, location)
1475        .await?;
1476
1477    room_updated(&room, &session.peer);
1478    response.send(proto::Ack {})?;
1479    Ok(())
1480}
1481
1482/// Share a project into the room.
1483async fn share_project(
1484    request: proto::ShareProject,
1485    response: Response<proto::ShareProject>,
1486    session: Session,
1487) -> Result<()> {
1488    let (project_id, room) = &*session
1489        .db()
1490        .await
1491        .share_project(
1492            RoomId::from_proto(request.room_id),
1493            session.connection_id,
1494            &request.worktrees,
1495        )
1496        .await?;
1497    response.send(proto::ShareProjectResponse {
1498        project_id: project_id.to_proto(),
1499    })?;
1500    room_updated(&room, &session.peer);
1501
1502    Ok(())
1503}
1504
1505/// Unshare a project from the room.
1506async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1507    let project_id = ProjectId::from_proto(message.project_id);
1508
1509    let (room, guest_connection_ids) = &*session
1510        .db()
1511        .await
1512        .unshare_project(project_id, session.connection_id)
1513        .await?;
1514
1515    broadcast(
1516        Some(session.connection_id),
1517        guest_connection_ids.iter().copied(),
1518        |conn_id| session.peer.send(conn_id, message.clone()),
1519    );
1520    room_updated(&room, &session.peer);
1521
1522    Ok(())
1523}
1524
1525/// Join someone elses shared project.
1526async fn join_project(
1527    request: proto::JoinProject,
1528    response: Response<proto::JoinProject>,
1529    session: Session,
1530) -> Result<()> {
1531    let project_id = ProjectId::from_proto(request.project_id);
1532    let guest_user_id = session.user_id;
1533
1534    tracing::info!(%project_id, "join project");
1535
1536    let (project, replica_id) = &mut *session
1537        .db()
1538        .await
1539        .join_project(project_id, session.connection_id)
1540        .await?;
1541
1542    let collaborators = project
1543        .collaborators
1544        .iter()
1545        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1546        .map(|collaborator| collaborator.to_proto())
1547        .collect::<Vec<_>>();
1548
1549    let worktrees = project
1550        .worktrees
1551        .iter()
1552        .map(|(id, worktree)| proto::WorktreeMetadata {
1553            id: *id,
1554            root_name: worktree.root_name.clone(),
1555            visible: worktree.visible,
1556            abs_path: worktree.abs_path.clone(),
1557        })
1558        .collect::<Vec<_>>();
1559
1560    for collaborator in &collaborators {
1561        session
1562            .peer
1563            .send(
1564                collaborator.peer_id.unwrap().into(),
1565                proto::AddProjectCollaborator {
1566                    project_id: project_id.to_proto(),
1567                    collaborator: Some(proto::Collaborator {
1568                        peer_id: Some(session.connection_id.into()),
1569                        replica_id: replica_id.0 as u32,
1570                        user_id: guest_user_id.to_proto(),
1571                    }),
1572                },
1573            )
1574            .trace_err();
1575    }
1576
1577    // First, we send the metadata associated with each worktree.
1578    response.send(proto::JoinProjectResponse {
1579        worktrees: worktrees.clone(),
1580        replica_id: replica_id.0 as u32,
1581        collaborators: collaborators.clone(),
1582        language_servers: project.language_servers.clone(),
1583    })?;
1584
1585    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1586        #[cfg(any(test, feature = "test-support"))]
1587        const MAX_CHUNK_SIZE: usize = 2;
1588        #[cfg(not(any(test, feature = "test-support")))]
1589        const MAX_CHUNK_SIZE: usize = 256;
1590
1591        // Stream this worktree's entries.
1592        let message = proto::UpdateWorktree {
1593            project_id: project_id.to_proto(),
1594            worktree_id,
1595            abs_path: worktree.abs_path.clone(),
1596            root_name: worktree.root_name,
1597            updated_entries: worktree.entries,
1598            removed_entries: Default::default(),
1599            scan_id: worktree.scan_id,
1600            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1601            updated_repositories: worktree.repository_entries.into_values().collect(),
1602            removed_repositories: Default::default(),
1603        };
1604        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1605            session.peer.send(session.connection_id, update.clone())?;
1606        }
1607
1608        // Stream this worktree's diagnostics.
1609        for summary in worktree.diagnostic_summaries {
1610            session.peer.send(
1611                session.connection_id,
1612                proto::UpdateDiagnosticSummary {
1613                    project_id: project_id.to_proto(),
1614                    worktree_id: worktree.id,
1615                    summary: Some(summary),
1616                },
1617            )?;
1618        }
1619
1620        for settings_file in worktree.settings_files {
1621            session.peer.send(
1622                session.connection_id,
1623                proto::UpdateWorktreeSettings {
1624                    project_id: project_id.to_proto(),
1625                    worktree_id: worktree.id,
1626                    path: settings_file.path,
1627                    content: Some(settings_file.content),
1628                },
1629            )?;
1630        }
1631    }
1632
1633    for language_server in &project.language_servers {
1634        session.peer.send(
1635            session.connection_id,
1636            proto::UpdateLanguageServer {
1637                project_id: project_id.to_proto(),
1638                language_server_id: language_server.id,
1639                variant: Some(
1640                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1641                        proto::LspDiskBasedDiagnosticsUpdated {},
1642                    ),
1643                ),
1644            },
1645        )?;
1646    }
1647
1648    Ok(())
1649}
1650
1651/// Leave someone elses shared project.
1652async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1653    let sender_id = session.connection_id;
1654    let project_id = ProjectId::from_proto(request.project_id);
1655
1656    let (room, project) = &*session
1657        .db()
1658        .await
1659        .leave_project(project_id, sender_id)
1660        .await?;
1661    tracing::info!(
1662        %project_id,
1663        host_user_id = %project.host_user_id,
1664        host_connection_id = %project.host_connection_id,
1665        "leave project"
1666    );
1667
1668    project_left(&project, &session);
1669    room_updated(&room, &session.peer);
1670
1671    Ok(())
1672}
1673
1674/// Updates other participants with changes to the project
1675async fn update_project(
1676    request: proto::UpdateProject,
1677    response: Response<proto::UpdateProject>,
1678    session: Session,
1679) -> Result<()> {
1680    let project_id = ProjectId::from_proto(request.project_id);
1681    let (room, guest_connection_ids) = &*session
1682        .db()
1683        .await
1684        .update_project(project_id, session.connection_id, &request.worktrees)
1685        .await?;
1686    broadcast(
1687        Some(session.connection_id),
1688        guest_connection_ids.iter().copied(),
1689        |connection_id| {
1690            session
1691                .peer
1692                .forward_send(session.connection_id, connection_id, request.clone())
1693        },
1694    );
1695    room_updated(&room, &session.peer);
1696    response.send(proto::Ack {})?;
1697
1698    Ok(())
1699}
1700
1701/// Updates other participants with changes to the worktree
1702async fn update_worktree(
1703    request: proto::UpdateWorktree,
1704    response: Response<proto::UpdateWorktree>,
1705    session: Session,
1706) -> Result<()> {
1707    let guest_connection_ids = session
1708        .db()
1709        .await
1710        .update_worktree(&request, session.connection_id)
1711        .await?;
1712
1713    broadcast(
1714        Some(session.connection_id),
1715        guest_connection_ids.iter().copied(),
1716        |connection_id| {
1717            session
1718                .peer
1719                .forward_send(session.connection_id, connection_id, request.clone())
1720        },
1721    );
1722    response.send(proto::Ack {})?;
1723    Ok(())
1724}
1725
1726/// Updates other participants with changes to the diagnostics
1727async fn update_diagnostic_summary(
1728    message: proto::UpdateDiagnosticSummary,
1729    session: Session,
1730) -> Result<()> {
1731    let guest_connection_ids = session
1732        .db()
1733        .await
1734        .update_diagnostic_summary(&message, session.connection_id)
1735        .await?;
1736
1737    broadcast(
1738        Some(session.connection_id),
1739        guest_connection_ids.iter().copied(),
1740        |connection_id| {
1741            session
1742                .peer
1743                .forward_send(session.connection_id, connection_id, message.clone())
1744        },
1745    );
1746
1747    Ok(())
1748}
1749
1750/// Updates other participants with changes to the worktree settings
1751async fn update_worktree_settings(
1752    message: proto::UpdateWorktreeSettings,
1753    session: Session,
1754) -> Result<()> {
1755    let guest_connection_ids = session
1756        .db()
1757        .await
1758        .update_worktree_settings(&message, session.connection_id)
1759        .await?;
1760
1761    broadcast(
1762        Some(session.connection_id),
1763        guest_connection_ids.iter().copied(),
1764        |connection_id| {
1765            session
1766                .peer
1767                .forward_send(session.connection_id, connection_id, message.clone())
1768        },
1769    );
1770
1771    Ok(())
1772}
1773
1774/// Notify other participants that a  language server has started.
1775async fn start_language_server(
1776    request: proto::StartLanguageServer,
1777    session: Session,
1778) -> Result<()> {
1779    let guest_connection_ids = session
1780        .db()
1781        .await
1782        .start_language_server(&request, session.connection_id)
1783        .await?;
1784
1785    broadcast(
1786        Some(session.connection_id),
1787        guest_connection_ids.iter().copied(),
1788        |connection_id| {
1789            session
1790                .peer
1791                .forward_send(session.connection_id, connection_id, request.clone())
1792        },
1793    );
1794    Ok(())
1795}
1796
1797/// Notify other participants that a language server has changed.
1798async fn update_language_server(
1799    request: proto::UpdateLanguageServer,
1800    session: Session,
1801) -> Result<()> {
1802    let project_id = ProjectId::from_proto(request.project_id);
1803    let project_connection_ids = session
1804        .db()
1805        .await
1806        .project_connection_ids(project_id, session.connection_id)
1807        .await?;
1808    broadcast(
1809        Some(session.connection_id),
1810        project_connection_ids.iter().copied(),
1811        |connection_id| {
1812            session
1813                .peer
1814                .forward_send(session.connection_id, connection_id, request.clone())
1815        },
1816    );
1817    Ok(())
1818}
1819
1820/// forward a project request to the host. These requests should be read only
1821/// as guests are allowed to send them.
1822async fn forward_read_only_project_request<T>(
1823    request: T,
1824    response: Response<T>,
1825    session: Session,
1826) -> Result<()>
1827where
1828    T: EntityMessage + RequestMessage,
1829{
1830    let project_id = ProjectId::from_proto(request.remote_entity_id());
1831    let host_connection_id = session
1832        .db()
1833        .await
1834        .host_for_read_only_project_request(project_id, session.connection_id)
1835        .await?;
1836    let payload = session
1837        .peer
1838        .forward_request(session.connection_id, host_connection_id, request)
1839        .await?;
1840    response.send(payload)?;
1841    Ok(())
1842}
1843
1844/// forward a project request to the host. These requests are disallowed
1845/// for guests.
1846async fn forward_mutating_project_request<T>(
1847    request: T,
1848    response: Response<T>,
1849    session: Session,
1850) -> Result<()>
1851where
1852    T: EntityMessage + RequestMessage,
1853{
1854    let project_id = ProjectId::from_proto(request.remote_entity_id());
1855    let host_connection_id = session
1856        .db()
1857        .await
1858        .host_for_mutating_project_request(project_id, session.connection_id)
1859        .await?;
1860    let payload = session
1861        .peer
1862        .forward_request(session.connection_id, host_connection_id, request)
1863        .await?;
1864    response.send(payload)?;
1865    Ok(())
1866}
1867
1868/// Notify other participants that a new buffer has been created
1869async fn create_buffer_for_peer(
1870    request: proto::CreateBufferForPeer,
1871    session: Session,
1872) -> Result<()> {
1873    session
1874        .db()
1875        .await
1876        .check_user_is_project_host(
1877            ProjectId::from_proto(request.project_id),
1878            session.connection_id,
1879        )
1880        .await?;
1881    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1882    session
1883        .peer
1884        .forward_send(session.connection_id, peer_id.into(), request)?;
1885    Ok(())
1886}
1887
1888/// Notify other participants that a buffer has been updated. This is
1889/// allowed for guests as long as the update is limited to selections.
1890async fn update_buffer(
1891    request: proto::UpdateBuffer,
1892    response: Response<proto::UpdateBuffer>,
1893    session: Session,
1894) -> Result<()> {
1895    let project_id = ProjectId::from_proto(request.project_id);
1896    let mut guest_connection_ids;
1897    let mut host_connection_id = None;
1898
1899    let mut requires_write_permission = false;
1900
1901    for op in request.operations.iter() {
1902        match op.variant {
1903            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1904            Some(_) => requires_write_permission = true,
1905        }
1906    }
1907
1908    {
1909        let collaborators = session
1910            .db()
1911            .await
1912            .project_collaborators_for_buffer_update(
1913                project_id,
1914                session.connection_id,
1915                requires_write_permission,
1916            )
1917            .await?;
1918        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1919        for collaborator in collaborators.iter() {
1920            if collaborator.is_host {
1921                host_connection_id = Some(collaborator.connection_id);
1922            } else {
1923                guest_connection_ids.push(collaborator.connection_id);
1924            }
1925        }
1926    }
1927    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1928
1929    broadcast(
1930        Some(session.connection_id),
1931        guest_connection_ids,
1932        |connection_id| {
1933            session
1934                .peer
1935                .forward_send(session.connection_id, connection_id, request.clone())
1936        },
1937    );
1938    if host_connection_id != session.connection_id {
1939        session
1940            .peer
1941            .forward_request(session.connection_id, host_connection_id, request.clone())
1942            .await?;
1943    }
1944
1945    response.send(proto::Ack {})?;
1946    Ok(())
1947}
1948
1949/// Notify other participants that a project has been updated.
1950async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1951    request: T,
1952    session: Session,
1953) -> Result<()> {
1954    let project_id = ProjectId::from_proto(request.remote_entity_id());
1955    let project_connection_ids = session
1956        .db()
1957        .await
1958        .project_connection_ids(project_id, session.connection_id)
1959        .await?;
1960
1961    broadcast(
1962        Some(session.connection_id),
1963        project_connection_ids.iter().copied(),
1964        |connection_id| {
1965            session
1966                .peer
1967                .forward_send(session.connection_id, connection_id, request.clone())
1968        },
1969    );
1970    Ok(())
1971}
1972
1973/// Start following another user in a call.
1974async fn follow(
1975    request: proto::Follow,
1976    response: Response<proto::Follow>,
1977    session: Session,
1978) -> Result<()> {
1979    let room_id = RoomId::from_proto(request.room_id);
1980    let project_id = request.project_id.map(ProjectId::from_proto);
1981    let leader_id = request
1982        .leader_id
1983        .ok_or_else(|| anyhow!("invalid leader id"))?
1984        .into();
1985    let follower_id = session.connection_id;
1986
1987    session
1988        .db()
1989        .await
1990        .check_room_participants(room_id, leader_id, session.connection_id)
1991        .await?;
1992
1993    let response_payload = session
1994        .peer
1995        .forward_request(session.connection_id, leader_id, request)
1996        .await?;
1997    response.send(response_payload)?;
1998
1999    if let Some(project_id) = project_id {
2000        let room = session
2001            .db()
2002            .await
2003            .follow(room_id, project_id, leader_id, follower_id)
2004            .await?;
2005        room_updated(&room, &session.peer);
2006    }
2007
2008    Ok(())
2009}
2010
2011/// Stop following another user in a call.
2012async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2013    let room_id = RoomId::from_proto(request.room_id);
2014    let project_id = request.project_id.map(ProjectId::from_proto);
2015    let leader_id = request
2016        .leader_id
2017        .ok_or_else(|| anyhow!("invalid leader id"))?
2018        .into();
2019    let follower_id = session.connection_id;
2020
2021    session
2022        .db()
2023        .await
2024        .check_room_participants(room_id, leader_id, session.connection_id)
2025        .await?;
2026
2027    session
2028        .peer
2029        .forward_send(session.connection_id, leader_id, request)?;
2030
2031    if let Some(project_id) = project_id {
2032        let room = session
2033            .db()
2034            .await
2035            .unfollow(room_id, project_id, leader_id, follower_id)
2036            .await?;
2037        room_updated(&room, &session.peer);
2038    }
2039
2040    Ok(())
2041}
2042
2043/// Notify everyone following you of your current location.
2044async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2045    let room_id = RoomId::from_proto(request.room_id);
2046    let database = session.db.lock().await;
2047
2048    let connection_ids = if let Some(project_id) = request.project_id {
2049        let project_id = ProjectId::from_proto(project_id);
2050        database
2051            .project_connection_ids(project_id, session.connection_id)
2052            .await?
2053    } else {
2054        database
2055            .room_connection_ids(room_id, session.connection_id)
2056            .await?
2057    };
2058
2059    // For now, don't send view update messages back to that view's current leader.
2060    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2061        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2062        _ => None,
2063    });
2064
2065    for follower_peer_id in request.follower_ids.iter().copied() {
2066        let follower_connection_id = follower_peer_id.into();
2067        if Some(follower_peer_id) != connection_id_to_omit
2068            && connection_ids.contains(&follower_connection_id)
2069        {
2070            session.peer.forward_send(
2071                session.connection_id,
2072                follower_connection_id,
2073                request.clone(),
2074            )?;
2075        }
2076    }
2077    Ok(())
2078}
2079
2080/// Get public data about users.
2081async fn get_users(
2082    request: proto::GetUsers,
2083    response: Response<proto::GetUsers>,
2084    session: Session,
2085) -> Result<()> {
2086    let user_ids = request
2087        .user_ids
2088        .into_iter()
2089        .map(UserId::from_proto)
2090        .collect();
2091    let users = session
2092        .db()
2093        .await
2094        .get_users_by_ids(user_ids)
2095        .await?
2096        .into_iter()
2097        .map(|user| proto::User {
2098            id: user.id.to_proto(),
2099            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2100            github_login: user.github_login,
2101        })
2102        .collect();
2103    response.send(proto::UsersResponse { users })?;
2104    Ok(())
2105}
2106
2107/// Search for users (to invite) buy Github login
2108async fn fuzzy_search_users(
2109    request: proto::FuzzySearchUsers,
2110    response: Response<proto::FuzzySearchUsers>,
2111    session: Session,
2112) -> Result<()> {
2113    let query = request.query;
2114    let users = match query.len() {
2115        0 => vec![],
2116        1 | 2 => session
2117            .db()
2118            .await
2119            .get_user_by_github_login(&query)
2120            .await?
2121            .into_iter()
2122            .collect(),
2123        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2124    };
2125    let users = users
2126        .into_iter()
2127        .filter(|user| user.id != session.user_id)
2128        .map(|user| proto::User {
2129            id: user.id.to_proto(),
2130            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2131            github_login: user.github_login,
2132        })
2133        .collect();
2134    response.send(proto::UsersResponse { users })?;
2135    Ok(())
2136}
2137
2138/// Send a contact request to another user.
2139async fn request_contact(
2140    request: proto::RequestContact,
2141    response: Response<proto::RequestContact>,
2142    session: Session,
2143) -> Result<()> {
2144    let requester_id = session.user_id;
2145    let responder_id = UserId::from_proto(request.responder_id);
2146    if requester_id == responder_id {
2147        return Err(anyhow!("cannot add yourself as a contact"))?;
2148    }
2149
2150    let notifications = session
2151        .db()
2152        .await
2153        .send_contact_request(requester_id, responder_id)
2154        .await?;
2155
2156    // Update outgoing contact requests of requester
2157    let mut update = proto::UpdateContacts::default();
2158    update.outgoing_requests.push(responder_id.to_proto());
2159    for connection_id in session
2160        .connection_pool()
2161        .await
2162        .user_connection_ids(requester_id)
2163    {
2164        session.peer.send(connection_id, update.clone())?;
2165    }
2166
2167    // Update incoming contact requests of responder
2168    let mut update = proto::UpdateContacts::default();
2169    update
2170        .incoming_requests
2171        .push(proto::IncomingContactRequest {
2172            requester_id: requester_id.to_proto(),
2173        });
2174    let connection_pool = session.connection_pool().await;
2175    for connection_id in connection_pool.user_connection_ids(responder_id) {
2176        session.peer.send(connection_id, update.clone())?;
2177    }
2178
2179    send_notifications(&*connection_pool, &session.peer, notifications);
2180
2181    response.send(proto::Ack {})?;
2182    Ok(())
2183}
2184
2185/// Accept or decline a contact request
2186async fn respond_to_contact_request(
2187    request: proto::RespondToContactRequest,
2188    response: Response<proto::RespondToContactRequest>,
2189    session: Session,
2190) -> Result<()> {
2191    let responder_id = session.user_id;
2192    let requester_id = UserId::from_proto(request.requester_id);
2193    let db = session.db().await;
2194    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2195        db.dismiss_contact_notification(responder_id, requester_id)
2196            .await?;
2197    } else {
2198        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2199
2200        let notifications = db
2201            .respond_to_contact_request(responder_id, requester_id, accept)
2202            .await?;
2203        let requester_busy = db.is_user_busy(requester_id).await?;
2204        let responder_busy = db.is_user_busy(responder_id).await?;
2205
2206        let pool = session.connection_pool().await;
2207        // Update responder with new contact
2208        let mut update = proto::UpdateContacts::default();
2209        if accept {
2210            update
2211                .contacts
2212                .push(contact_for_user(requester_id, requester_busy, &pool));
2213        }
2214        update
2215            .remove_incoming_requests
2216            .push(requester_id.to_proto());
2217        for connection_id in pool.user_connection_ids(responder_id) {
2218            session.peer.send(connection_id, update.clone())?;
2219        }
2220
2221        // Update requester with new contact
2222        let mut update = proto::UpdateContacts::default();
2223        if accept {
2224            update
2225                .contacts
2226                .push(contact_for_user(responder_id, responder_busy, &pool));
2227        }
2228        update
2229            .remove_outgoing_requests
2230            .push(responder_id.to_proto());
2231
2232        for connection_id in pool.user_connection_ids(requester_id) {
2233            session.peer.send(connection_id, update.clone())?;
2234        }
2235
2236        send_notifications(&*pool, &session.peer, notifications);
2237    }
2238
2239    response.send(proto::Ack {})?;
2240    Ok(())
2241}
2242
2243/// Remove a contact.
2244async fn remove_contact(
2245    request: proto::RemoveContact,
2246    response: Response<proto::RemoveContact>,
2247    session: Session,
2248) -> Result<()> {
2249    let requester_id = session.user_id;
2250    let responder_id = UserId::from_proto(request.user_id);
2251    let db = session.db().await;
2252    let (contact_accepted, deleted_notification_id) =
2253        db.remove_contact(requester_id, responder_id).await?;
2254
2255    let pool = session.connection_pool().await;
2256    // Update outgoing contact requests of requester
2257    let mut update = proto::UpdateContacts::default();
2258    if contact_accepted {
2259        update.remove_contacts.push(responder_id.to_proto());
2260    } else {
2261        update
2262            .remove_outgoing_requests
2263            .push(responder_id.to_proto());
2264    }
2265    for connection_id in pool.user_connection_ids(requester_id) {
2266        session.peer.send(connection_id, update.clone())?;
2267    }
2268
2269    // Update incoming contact requests of responder
2270    let mut update = proto::UpdateContacts::default();
2271    if contact_accepted {
2272        update.remove_contacts.push(requester_id.to_proto());
2273    } else {
2274        update
2275            .remove_incoming_requests
2276            .push(requester_id.to_proto());
2277    }
2278    for connection_id in pool.user_connection_ids(responder_id) {
2279        session.peer.send(connection_id, update.clone())?;
2280        if let Some(notification_id) = deleted_notification_id {
2281            session.peer.send(
2282                connection_id,
2283                proto::DeleteNotification {
2284                    notification_id: notification_id.to_proto(),
2285                },
2286            )?;
2287        }
2288    }
2289
2290    response.send(proto::Ack {})?;
2291    Ok(())
2292}
2293
2294/// Creates a new channel.
2295async fn create_channel(
2296    request: proto::CreateChannel,
2297    response: Response<proto::CreateChannel>,
2298    session: Session,
2299) -> Result<()> {
2300    let db = session.db().await;
2301
2302    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2303    let channel = db
2304        .create_channel(&request.name, parent_id, session.user_id)
2305        .await?;
2306
2307    response.send(proto::CreateChannelResponse {
2308        channel: Some(channel.to_proto()),
2309        parent_id: request.parent_id,
2310    })?;
2311
2312    let participants_to_update;
2313    if let Some(parent) = parent_id {
2314        participants_to_update = db.new_participants_to_notify(parent).await?;
2315    } else {
2316        participants_to_update = vec![];
2317    }
2318
2319    let connection_pool = session.connection_pool().await;
2320    for (user_id, channels) in participants_to_update {
2321        let update = build_channels_update(channels, vec![]);
2322        for connection_id in connection_pool.user_connection_ids(user_id) {
2323            if user_id == session.user_id {
2324                continue;
2325            }
2326            session.peer.send(connection_id, update.clone())?;
2327        }
2328    }
2329
2330    Ok(())
2331}
2332
2333/// Delete a channel
2334async fn delete_channel(
2335    request: proto::DeleteChannel,
2336    response: Response<proto::DeleteChannel>,
2337    session: Session,
2338) -> Result<()> {
2339    let db = session.db().await;
2340
2341    let channel_id = request.channel_id;
2342    let (removed_channels, member_ids) = db
2343        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2344        .await?;
2345    response.send(proto::Ack {})?;
2346
2347    // Notify members of removed channels
2348    let mut update = proto::UpdateChannels::default();
2349    update
2350        .delete_channels
2351        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2352
2353    let connection_pool = session.connection_pool().await;
2354    for member_id in member_ids {
2355        for connection_id in connection_pool.user_connection_ids(member_id) {
2356            session.peer.send(connection_id, update.clone())?;
2357        }
2358    }
2359
2360    Ok(())
2361}
2362
2363/// Invite someone to join a channel.
2364async fn invite_channel_member(
2365    request: proto::InviteChannelMember,
2366    response: Response<proto::InviteChannelMember>,
2367    session: Session,
2368) -> Result<()> {
2369    let db = session.db().await;
2370    let channel_id = ChannelId::from_proto(request.channel_id);
2371    let invitee_id = UserId::from_proto(request.user_id);
2372    let InviteMemberResult {
2373        channel,
2374        notifications,
2375    } = db
2376        .invite_channel_member(
2377            channel_id,
2378            invitee_id,
2379            session.user_id,
2380            request.role().into(),
2381        )
2382        .await?;
2383
2384    let update = proto::UpdateChannels {
2385        channel_invitations: vec![channel.to_proto()],
2386        ..Default::default()
2387    };
2388
2389    let connection_pool = session.connection_pool().await;
2390    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2391        session.peer.send(connection_id, update.clone())?;
2392    }
2393
2394    send_notifications(&*connection_pool, &session.peer, notifications);
2395
2396    response.send(proto::Ack {})?;
2397    Ok(())
2398}
2399
2400/// remove someone from a channel
2401async fn remove_channel_member(
2402    request: proto::RemoveChannelMember,
2403    response: Response<proto::RemoveChannelMember>,
2404    session: Session,
2405) -> Result<()> {
2406    let db = session.db().await;
2407    let channel_id = ChannelId::from_proto(request.channel_id);
2408    let member_id = UserId::from_proto(request.user_id);
2409
2410    let RemoveChannelMemberResult {
2411        membership_update,
2412        notification_id,
2413    } = db
2414        .remove_channel_member(channel_id, member_id, session.user_id)
2415        .await?;
2416
2417    let connection_pool = &session.connection_pool().await;
2418    notify_membership_updated(
2419        &connection_pool,
2420        membership_update,
2421        member_id,
2422        &session.peer,
2423    );
2424    for connection_id in connection_pool.user_connection_ids(member_id) {
2425        if let Some(notification_id) = notification_id {
2426            session
2427                .peer
2428                .send(
2429                    connection_id,
2430                    proto::DeleteNotification {
2431                        notification_id: notification_id.to_proto(),
2432                    },
2433                )
2434                .trace_err();
2435        }
2436    }
2437
2438    response.send(proto::Ack {})?;
2439    Ok(())
2440}
2441
2442/// Toggle the channel between public and private
2443async fn set_channel_visibility(
2444    request: proto::SetChannelVisibility,
2445    response: Response<proto::SetChannelVisibility>,
2446    session: Session,
2447) -> Result<()> {
2448    let db = session.db().await;
2449    let channel_id = ChannelId::from_proto(request.channel_id);
2450    let visibility = request.visibility().into();
2451
2452    let SetChannelVisibilityResult {
2453        participants_to_update,
2454        participants_to_remove,
2455        channels_to_remove,
2456    } = db
2457        .set_channel_visibility(channel_id, visibility, session.user_id)
2458        .await?;
2459
2460    let connection_pool = session.connection_pool().await;
2461    for (user_id, channels) in participants_to_update {
2462        let update = build_channels_update(channels, vec![]);
2463        for connection_id in connection_pool.user_connection_ids(user_id) {
2464            session.peer.send(connection_id, update.clone())?;
2465        }
2466    }
2467    for user_id in participants_to_remove {
2468        let update = proto::UpdateChannels {
2469            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2470            ..Default::default()
2471        };
2472        for connection_id in connection_pool.user_connection_ids(user_id) {
2473            session.peer.send(connection_id, update.clone())?;
2474        }
2475    }
2476
2477    response.send(proto::Ack {})?;
2478    Ok(())
2479}
2480
2481/// Alter the role for a user in the channel
2482async fn set_channel_member_role(
2483    request: proto::SetChannelMemberRole,
2484    response: Response<proto::SetChannelMemberRole>,
2485    session: Session,
2486) -> Result<()> {
2487    let db = session.db().await;
2488    let channel_id = ChannelId::from_proto(request.channel_id);
2489    let member_id = UserId::from_proto(request.user_id);
2490    let result = db
2491        .set_channel_member_role(
2492            channel_id,
2493            session.user_id,
2494            member_id,
2495            request.role().into(),
2496        )
2497        .await?;
2498
2499    match result {
2500        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2501            let connection_pool = session.connection_pool().await;
2502            notify_membership_updated(
2503                &connection_pool,
2504                membership_update,
2505                member_id,
2506                &session.peer,
2507            )
2508        }
2509        db::SetMemberRoleResult::InviteUpdated(channel) => {
2510            let update = proto::UpdateChannels {
2511                channel_invitations: vec![channel.to_proto()],
2512                ..Default::default()
2513            };
2514
2515            for connection_id in session
2516                .connection_pool()
2517                .await
2518                .user_connection_ids(member_id)
2519            {
2520                session.peer.send(connection_id, update.clone())?;
2521            }
2522        }
2523    }
2524
2525    response.send(proto::Ack {})?;
2526    Ok(())
2527}
2528
2529/// Change the name of a channel
2530async fn rename_channel(
2531    request: proto::RenameChannel,
2532    response: Response<proto::RenameChannel>,
2533    session: Session,
2534) -> Result<()> {
2535    let db = session.db().await;
2536    let channel_id = ChannelId::from_proto(request.channel_id);
2537    let RenameChannelResult {
2538        channel,
2539        participants_to_update,
2540    } = db
2541        .rename_channel(channel_id, session.user_id, &request.name)
2542        .await?;
2543
2544    response.send(proto::RenameChannelResponse {
2545        channel: Some(channel.to_proto()),
2546    })?;
2547
2548    let connection_pool = session.connection_pool().await;
2549    for (user_id, channel) in participants_to_update {
2550        for connection_id in connection_pool.user_connection_ids(user_id) {
2551            let update = proto::UpdateChannels {
2552                channels: vec![channel.to_proto()],
2553                ..Default::default()
2554            };
2555
2556            session.peer.send(connection_id, update.clone())?;
2557        }
2558    }
2559
2560    Ok(())
2561}
2562
2563/// Move a channel to a new parent.
2564async fn move_channel(
2565    request: proto::MoveChannel,
2566    response: Response<proto::MoveChannel>,
2567    session: Session,
2568) -> Result<()> {
2569    let channel_id = ChannelId::from_proto(request.channel_id);
2570    let to = request.to.map(ChannelId::from_proto);
2571
2572    let result = session
2573        .db()
2574        .await
2575        .move_channel(channel_id, to, session.user_id)
2576        .await?;
2577
2578    if let Some(result) = result {
2579        let participants_to_update: HashMap<_, _> = session
2580            .db()
2581            .await
2582            .new_participants_to_notify(to.unwrap_or(channel_id))
2583            .await?
2584            .into_iter()
2585            .collect();
2586
2587        let mut moved_channels: HashSet<ChannelId> = HashSet::default();
2588        for id in result.descendent_ids {
2589            moved_channels.insert(id);
2590        }
2591        moved_channels.insert(channel_id);
2592
2593        let mut participants_to_remove: HashSet<UserId> = HashSet::default();
2594        for participant in result.previous_participants {
2595            if participant.kind == proto::channel_member::Kind::AncestorMember {
2596                if !participants_to_update.contains_key(&participant.user_id) {
2597                    participants_to_remove.insert(participant.user_id);
2598                }
2599            }
2600        }
2601
2602        let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2603
2604        let connection_pool = session.connection_pool().await;
2605        for (user_id, channels) in participants_to_update {
2606            let mut update = build_channels_update(channels, vec![]);
2607            update.delete_channels = moved_channels.clone();
2608            for connection_id in connection_pool.user_connection_ids(user_id) {
2609                session.peer.send(connection_id, update.clone())?;
2610            }
2611        }
2612
2613        for user_id in participants_to_remove {
2614            let update = proto::UpdateChannels {
2615                delete_channels: moved_channels.clone(),
2616                ..Default::default()
2617            };
2618            for connection_id in connection_pool.user_connection_ids(user_id) {
2619                session.peer.send(connection_id, update.clone())?;
2620            }
2621        }
2622    }
2623
2624    response.send(Ack {})?;
2625    Ok(())
2626}
2627
2628/// Get the list of channel members
2629async fn get_channel_members(
2630    request: proto::GetChannelMembers,
2631    response: Response<proto::GetChannelMembers>,
2632    session: Session,
2633) -> Result<()> {
2634    let db = session.db().await;
2635    let channel_id = ChannelId::from_proto(request.channel_id);
2636    let members = db
2637        .get_channel_participant_details(channel_id, session.user_id)
2638        .await?;
2639    response.send(proto::GetChannelMembersResponse { members })?;
2640    Ok(())
2641}
2642
2643/// Accept or decline a channel invitation.
2644async fn respond_to_channel_invite(
2645    request: proto::RespondToChannelInvite,
2646    response: Response<proto::RespondToChannelInvite>,
2647    session: Session,
2648) -> Result<()> {
2649    let db = session.db().await;
2650    let channel_id = ChannelId::from_proto(request.channel_id);
2651    let RespondToChannelInvite {
2652        membership_update,
2653        notifications,
2654    } = db
2655        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2656        .await?;
2657
2658    let connection_pool = session.connection_pool().await;
2659    if let Some(membership_update) = membership_update {
2660        notify_membership_updated(
2661            &connection_pool,
2662            membership_update,
2663            session.user_id,
2664            &session.peer,
2665        );
2666    } else {
2667        let update = proto::UpdateChannels {
2668            remove_channel_invitations: vec![channel_id.to_proto()],
2669            ..Default::default()
2670        };
2671
2672        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2673            session.peer.send(connection_id, update.clone())?;
2674        }
2675    };
2676
2677    send_notifications(&*connection_pool, &session.peer, notifications);
2678
2679    response.send(proto::Ack {})?;
2680
2681    Ok(())
2682}
2683
2684/// Join the channels' room
2685async fn join_channel(
2686    request: proto::JoinChannel,
2687    response: Response<proto::JoinChannel>,
2688    session: Session,
2689) -> Result<()> {
2690    let channel_id = ChannelId::from_proto(request.channel_id);
2691    join_channel_internal(channel_id, Box::new(response), session).await
2692}
2693
2694trait JoinChannelInternalResponse {
2695    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2696}
2697impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2698    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2699        Response::<proto::JoinChannel>::send(self, result)
2700    }
2701}
2702impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2703    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2704        Response::<proto::JoinRoom>::send(self, result)
2705    }
2706}
2707
2708async fn join_channel_internal(
2709    channel_id: ChannelId,
2710    response: Box<impl JoinChannelInternalResponse>,
2711    session: Session,
2712) -> Result<()> {
2713    let joined_room = {
2714        leave_room_for_session(&session).await?;
2715        let db = session.db().await;
2716
2717        let (joined_room, membership_updated, role) = db
2718            .join_channel(
2719                channel_id,
2720                session.user_id,
2721                session.connection_id,
2722                session.zed_environment.as_ref(),
2723            )
2724            .await?;
2725
2726        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2727            let (can_publish, token) = if role == ChannelRole::Guest {
2728                (
2729                    false,
2730                    live_kit
2731                        .guest_token(
2732                            &joined_room.room.live_kit_room,
2733                            &session.user_id.to_string(),
2734                        )
2735                        .trace_err()?,
2736                )
2737            } else {
2738                (
2739                    true,
2740                    live_kit
2741                        .room_token(
2742                            &joined_room.room.live_kit_room,
2743                            &session.user_id.to_string(),
2744                        )
2745                        .trace_err()?,
2746                )
2747            };
2748
2749            Some(LiveKitConnectionInfo {
2750                server_url: live_kit.url().into(),
2751                token,
2752                can_publish,
2753            })
2754        });
2755
2756        response.send(proto::JoinRoomResponse {
2757            room: Some(joined_room.room.clone()),
2758            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2759            live_kit_connection_info,
2760        })?;
2761
2762        let connection_pool = session.connection_pool().await;
2763        if let Some(membership_updated) = membership_updated {
2764            notify_membership_updated(
2765                &connection_pool,
2766                membership_updated,
2767                session.user_id,
2768                &session.peer,
2769            );
2770        }
2771
2772        room_updated(&joined_room.room, &session.peer);
2773
2774        joined_room
2775    };
2776
2777    channel_updated(
2778        channel_id,
2779        &joined_room.room,
2780        &joined_room.channel_members,
2781        &session.peer,
2782        &*session.connection_pool().await,
2783    );
2784
2785    update_user_contacts(session.user_id, &session).await?;
2786    Ok(())
2787}
2788
2789/// Start editing the channel notes
2790async fn join_channel_buffer(
2791    request: proto::JoinChannelBuffer,
2792    response: Response<proto::JoinChannelBuffer>,
2793    session: Session,
2794) -> Result<()> {
2795    let db = session.db().await;
2796    let channel_id = ChannelId::from_proto(request.channel_id);
2797
2798    let open_response = db
2799        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2800        .await?;
2801
2802    let collaborators = open_response.collaborators.clone();
2803    response.send(open_response)?;
2804
2805    let update = UpdateChannelBufferCollaborators {
2806        channel_id: channel_id.to_proto(),
2807        collaborators: collaborators.clone(),
2808    };
2809    channel_buffer_updated(
2810        session.connection_id,
2811        collaborators
2812            .iter()
2813            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2814        &update,
2815        &session.peer,
2816    );
2817
2818    Ok(())
2819}
2820
2821/// Edit the channel notes
2822async fn update_channel_buffer(
2823    request: proto::UpdateChannelBuffer,
2824    session: Session,
2825) -> Result<()> {
2826    let db = session.db().await;
2827    let channel_id = ChannelId::from_proto(request.channel_id);
2828
2829    let (collaborators, non_collaborators, epoch, version) = db
2830        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2831        .await?;
2832
2833    channel_buffer_updated(
2834        session.connection_id,
2835        collaborators,
2836        &proto::UpdateChannelBuffer {
2837            channel_id: channel_id.to_proto(),
2838            operations: request.operations,
2839        },
2840        &session.peer,
2841    );
2842
2843    let pool = &*session.connection_pool().await;
2844
2845    broadcast(
2846        None,
2847        non_collaborators
2848            .iter()
2849            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2850        |peer_id| {
2851            session.peer.send(
2852                peer_id.into(),
2853                proto::UpdateChannels {
2854                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2855                        channel_id: channel_id.to_proto(),
2856                        epoch: epoch as u64,
2857                        version: version.clone(),
2858                    }],
2859                    ..Default::default()
2860                },
2861            )
2862        },
2863    );
2864
2865    Ok(())
2866}
2867
2868/// Rejoin the channel notes after a connection blip
2869async fn rejoin_channel_buffers(
2870    request: proto::RejoinChannelBuffers,
2871    response: Response<proto::RejoinChannelBuffers>,
2872    session: Session,
2873) -> Result<()> {
2874    let db = session.db().await;
2875    let buffers = db
2876        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2877        .await?;
2878
2879    for rejoined_buffer in &buffers {
2880        let collaborators_to_notify = rejoined_buffer
2881            .buffer
2882            .collaborators
2883            .iter()
2884            .filter_map(|c| Some(c.peer_id?.into()));
2885        channel_buffer_updated(
2886            session.connection_id,
2887            collaborators_to_notify,
2888            &proto::UpdateChannelBufferCollaborators {
2889                channel_id: rejoined_buffer.buffer.channel_id,
2890                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2891            },
2892            &session.peer,
2893        );
2894    }
2895
2896    response.send(proto::RejoinChannelBuffersResponse {
2897        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2898    })?;
2899
2900    Ok(())
2901}
2902
2903/// Stop editing the channel notes
2904async fn leave_channel_buffer(
2905    request: proto::LeaveChannelBuffer,
2906    response: Response<proto::LeaveChannelBuffer>,
2907    session: Session,
2908) -> Result<()> {
2909    let db = session.db().await;
2910    let channel_id = ChannelId::from_proto(request.channel_id);
2911
2912    let left_buffer = db
2913        .leave_channel_buffer(channel_id, session.connection_id)
2914        .await?;
2915
2916    response.send(Ack {})?;
2917
2918    channel_buffer_updated(
2919        session.connection_id,
2920        left_buffer.connections,
2921        &proto::UpdateChannelBufferCollaborators {
2922            channel_id: channel_id.to_proto(),
2923            collaborators: left_buffer.collaborators,
2924        },
2925        &session.peer,
2926    );
2927
2928    Ok(())
2929}
2930
2931fn channel_buffer_updated<T: EnvelopedMessage>(
2932    sender_id: ConnectionId,
2933    collaborators: impl IntoIterator<Item = ConnectionId>,
2934    message: &T,
2935    peer: &Peer,
2936) {
2937    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2938        peer.send(peer_id.into(), message.clone())
2939    });
2940}
2941
2942fn send_notifications(
2943    connection_pool: &ConnectionPool,
2944    peer: &Peer,
2945    notifications: db::NotificationBatch,
2946) {
2947    for (user_id, notification) in notifications {
2948        for connection_id in connection_pool.user_connection_ids(user_id) {
2949            if let Err(error) = peer.send(
2950                connection_id,
2951                proto::AddNotification {
2952                    notification: Some(notification.clone()),
2953                },
2954            ) {
2955                tracing::error!(
2956                    "failed to send notification to {:?} {}",
2957                    connection_id,
2958                    error
2959                );
2960            }
2961        }
2962    }
2963}
2964
2965/// Send a message to the channel
2966async fn send_channel_message(
2967    request: proto::SendChannelMessage,
2968    response: Response<proto::SendChannelMessage>,
2969    session: Session,
2970) -> Result<()> {
2971    // Validate the message body.
2972    let body = request.body.trim().to_string();
2973    if body.len() > MAX_MESSAGE_LEN {
2974        return Err(anyhow!("message is too long"))?;
2975    }
2976    if body.is_empty() {
2977        return Err(anyhow!("message can't be blank"))?;
2978    }
2979
2980    // TODO: adjust mentions if body is trimmed
2981
2982    let timestamp = OffsetDateTime::now_utc();
2983    let nonce = request
2984        .nonce
2985        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2986
2987    let channel_id = ChannelId::from_proto(request.channel_id);
2988    let CreatedChannelMessage {
2989        message_id,
2990        participant_connection_ids,
2991        channel_members,
2992        notifications,
2993    } = session
2994        .db()
2995        .await
2996        .create_channel_message(
2997            channel_id,
2998            session.user_id,
2999            &body,
3000            &request.mentions,
3001            timestamp,
3002            nonce.clone().into(),
3003        )
3004        .await?;
3005    let message = proto::ChannelMessage {
3006        sender_id: session.user_id.to_proto(),
3007        id: message_id.to_proto(),
3008        body,
3009        mentions: request.mentions,
3010        timestamp: timestamp.unix_timestamp() as u64,
3011        nonce: Some(nonce),
3012    };
3013    broadcast(
3014        Some(session.connection_id),
3015        participant_connection_ids,
3016        |connection| {
3017            session.peer.send(
3018                connection,
3019                proto::ChannelMessageSent {
3020                    channel_id: channel_id.to_proto(),
3021                    message: Some(message.clone()),
3022                },
3023            )
3024        },
3025    );
3026    response.send(proto::SendChannelMessageResponse {
3027        message: Some(message),
3028    })?;
3029
3030    let pool = &*session.connection_pool().await;
3031    broadcast(
3032        None,
3033        channel_members
3034            .iter()
3035            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3036        |peer_id| {
3037            session.peer.send(
3038                peer_id.into(),
3039                proto::UpdateChannels {
3040                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
3041                        channel_id: channel_id.to_proto(),
3042                        message_id: message_id.to_proto(),
3043                    }],
3044                    ..Default::default()
3045                },
3046            )
3047        },
3048    );
3049    send_notifications(pool, &session.peer, notifications);
3050
3051    Ok(())
3052}
3053
3054/// Delete a channel message
3055async fn remove_channel_message(
3056    request: proto::RemoveChannelMessage,
3057    response: Response<proto::RemoveChannelMessage>,
3058    session: Session,
3059) -> Result<()> {
3060    let channel_id = ChannelId::from_proto(request.channel_id);
3061    let message_id = MessageId::from_proto(request.message_id);
3062    let connection_ids = session
3063        .db()
3064        .await
3065        .remove_channel_message(channel_id, message_id, session.user_id)
3066        .await?;
3067    broadcast(Some(session.connection_id), connection_ids, |connection| {
3068        session.peer.send(connection, request.clone())
3069    });
3070    response.send(proto::Ack {})?;
3071    Ok(())
3072}
3073
3074/// Mark a channel message as read
3075async fn acknowledge_channel_message(
3076    request: proto::AckChannelMessage,
3077    session: Session,
3078) -> Result<()> {
3079    let channel_id = ChannelId::from_proto(request.channel_id);
3080    let message_id = MessageId::from_proto(request.message_id);
3081    let notifications = session
3082        .db()
3083        .await
3084        .observe_channel_message(channel_id, session.user_id, message_id)
3085        .await?;
3086    send_notifications(
3087        &*session.connection_pool().await,
3088        &session.peer,
3089        notifications,
3090    );
3091    Ok(())
3092}
3093
3094/// Mark a buffer version as synced
3095async fn acknowledge_buffer_version(
3096    request: proto::AckBufferOperation,
3097    session: Session,
3098) -> Result<()> {
3099    let buffer_id = BufferId::from_proto(request.buffer_id);
3100    session
3101        .db()
3102        .await
3103        .observe_buffer_version(
3104            buffer_id,
3105            session.user_id,
3106            request.epoch as i32,
3107            &request.version,
3108        )
3109        .await?;
3110    Ok(())
3111}
3112
3113/// Start receiving chat updates for a channel
3114async fn join_channel_chat(
3115    request: proto::JoinChannelChat,
3116    response: Response<proto::JoinChannelChat>,
3117    session: Session,
3118) -> Result<()> {
3119    let channel_id = ChannelId::from_proto(request.channel_id);
3120
3121    let db = session.db().await;
3122    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3123        .await?;
3124    let messages = db
3125        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3126        .await?;
3127    response.send(proto::JoinChannelChatResponse {
3128        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3129        messages,
3130    })?;
3131    Ok(())
3132}
3133
3134/// Stop receiving chat updates for a channel
3135async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3136    let channel_id = ChannelId::from_proto(request.channel_id);
3137    session
3138        .db()
3139        .await
3140        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3141        .await?;
3142    Ok(())
3143}
3144
3145/// Retrieve the chat history for a channel
3146async fn get_channel_messages(
3147    request: proto::GetChannelMessages,
3148    response: Response<proto::GetChannelMessages>,
3149    session: Session,
3150) -> Result<()> {
3151    let channel_id = ChannelId::from_proto(request.channel_id);
3152    let messages = session
3153        .db()
3154        .await
3155        .get_channel_messages(
3156            channel_id,
3157            session.user_id,
3158            MESSAGE_COUNT_PER_PAGE,
3159            Some(MessageId::from_proto(request.before_message_id)),
3160        )
3161        .await?;
3162    response.send(proto::GetChannelMessagesResponse {
3163        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3164        messages,
3165    })?;
3166    Ok(())
3167}
3168
3169/// Retrieve specific chat messages
3170async fn get_channel_messages_by_id(
3171    request: proto::GetChannelMessagesById,
3172    response: Response<proto::GetChannelMessagesById>,
3173    session: Session,
3174) -> Result<()> {
3175    let message_ids = request
3176        .message_ids
3177        .iter()
3178        .map(|id| MessageId::from_proto(*id))
3179        .collect::<Vec<_>>();
3180    let messages = session
3181        .db()
3182        .await
3183        .get_channel_messages_by_id(session.user_id, &message_ids)
3184        .await?;
3185    response.send(proto::GetChannelMessagesResponse {
3186        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3187        messages,
3188    })?;
3189    Ok(())
3190}
3191
3192/// Retrieve the current users notifications
3193async fn get_notifications(
3194    request: proto::GetNotifications,
3195    response: Response<proto::GetNotifications>,
3196    session: Session,
3197) -> Result<()> {
3198    let notifications = session
3199        .db()
3200        .await
3201        .get_notifications(
3202            session.user_id,
3203            NOTIFICATION_COUNT_PER_PAGE,
3204            request
3205                .before_id
3206                .map(|id| db::NotificationId::from_proto(id)),
3207        )
3208        .await?;
3209    response.send(proto::GetNotificationsResponse {
3210        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3211        notifications,
3212    })?;
3213    Ok(())
3214}
3215
3216/// Mark notifications as read
3217async fn mark_notification_as_read(
3218    request: proto::MarkNotificationRead,
3219    response: Response<proto::MarkNotificationRead>,
3220    session: Session,
3221) -> Result<()> {
3222    let database = &session.db().await;
3223    let notifications = database
3224        .mark_notification_as_read_by_id(
3225            session.user_id,
3226            NotificationId::from_proto(request.notification_id),
3227        )
3228        .await?;
3229    send_notifications(
3230        &*session.connection_pool().await,
3231        &session.peer,
3232        notifications,
3233    );
3234    response.send(proto::Ack {})?;
3235    Ok(())
3236}
3237
3238/// Get the current users information
3239async fn get_private_user_info(
3240    _request: proto::GetPrivateUserInfo,
3241    response: Response<proto::GetPrivateUserInfo>,
3242    session: Session,
3243) -> Result<()> {
3244    let db = session.db().await;
3245
3246    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3247    let user = db
3248        .get_user_by_id(session.user_id)
3249        .await?
3250        .ok_or_else(|| anyhow!("user not found"))?;
3251    let flags = db.get_user_flags(session.user_id).await?;
3252
3253    response.send(proto::GetPrivateUserInfoResponse {
3254        metrics_id,
3255        staff: user.admin,
3256        flags,
3257    })?;
3258    Ok(())
3259}
3260
3261fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3262    match message {
3263        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3264        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3265        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3266        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3267        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3268            code: frame.code.into(),
3269            reason: frame.reason,
3270        })),
3271    }
3272}
3273
3274fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3275    match message {
3276        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3277        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3278        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3279        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3280        AxumMessage::Close(frame) => {
3281            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3282                code: frame.code.into(),
3283                reason: frame.reason,
3284            }))
3285        }
3286    }
3287}
3288
3289fn notify_membership_updated(
3290    connection_pool: &ConnectionPool,
3291    result: MembershipUpdated,
3292    user_id: UserId,
3293    peer: &Peer,
3294) {
3295    let mut update = build_channels_update(result.new_channels, vec![]);
3296    update.delete_channels = result
3297        .removed_channels
3298        .into_iter()
3299        .map(|id| id.to_proto())
3300        .collect();
3301    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3302
3303    for connection_id in connection_pool.user_connection_ids(user_id) {
3304        peer.send(connection_id, update.clone()).trace_err();
3305    }
3306}
3307
3308fn build_channels_update(
3309    channels: ChannelsForUser,
3310    channel_invites: Vec<db::Channel>,
3311) -> proto::UpdateChannels {
3312    let mut update = proto::UpdateChannels::default();
3313
3314    for channel in channels.channels {
3315        update.channels.push(channel.to_proto());
3316    }
3317
3318    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3319    update.unseen_channel_messages = channels.channel_messages;
3320
3321    for (channel_id, participants) in channels.channel_participants {
3322        update
3323            .channel_participants
3324            .push(proto::ChannelParticipants {
3325                channel_id: channel_id.to_proto(),
3326                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3327            });
3328    }
3329
3330    for channel in channel_invites {
3331        update.channel_invitations.push(channel.to_proto());
3332    }
3333
3334    update
3335}
3336
3337fn build_initial_contacts_update(
3338    contacts: Vec<db::Contact>,
3339    pool: &ConnectionPool,
3340) -> proto::UpdateContacts {
3341    let mut update = proto::UpdateContacts::default();
3342
3343    for contact in contacts {
3344        match contact {
3345            db::Contact::Accepted { user_id, busy } => {
3346                update.contacts.push(contact_for_user(user_id, busy, &pool));
3347            }
3348            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3349            db::Contact::Incoming { user_id } => {
3350                update
3351                    .incoming_requests
3352                    .push(proto::IncomingContactRequest {
3353                        requester_id: user_id.to_proto(),
3354                    })
3355            }
3356        }
3357    }
3358
3359    update
3360}
3361
3362fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3363    proto::Contact {
3364        user_id: user_id.to_proto(),
3365        online: pool.is_user_online(user_id),
3366        busy,
3367    }
3368}
3369
3370fn room_updated(room: &proto::Room, peer: &Peer) {
3371    broadcast(
3372        None,
3373        room.participants
3374            .iter()
3375            .filter_map(|participant| Some(participant.peer_id?.into())),
3376        |peer_id| {
3377            peer.send(
3378                peer_id.into(),
3379                proto::RoomUpdated {
3380                    room: Some(room.clone()),
3381                },
3382            )
3383        },
3384    );
3385}
3386
3387fn channel_updated(
3388    channel_id: ChannelId,
3389    room: &proto::Room,
3390    channel_members: &[UserId],
3391    peer: &Peer,
3392    pool: &ConnectionPool,
3393) {
3394    let participants = room
3395        .participants
3396        .iter()
3397        .map(|p| p.user_id)
3398        .collect::<Vec<_>>();
3399
3400    broadcast(
3401        None,
3402        channel_members
3403            .iter()
3404            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3405        |peer_id| {
3406            peer.send(
3407                peer_id.into(),
3408                proto::UpdateChannels {
3409                    channel_participants: vec![proto::ChannelParticipants {
3410                        channel_id: channel_id.to_proto(),
3411                        participant_user_ids: participants.clone(),
3412                    }],
3413                    ..Default::default()
3414                },
3415            )
3416        },
3417    );
3418}
3419
3420async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3421    let db = session.db().await;
3422
3423    let contacts = db.get_contacts(user_id).await?;
3424    let busy = db.is_user_busy(user_id).await?;
3425
3426    let pool = session.connection_pool().await;
3427    let updated_contact = contact_for_user(user_id, busy, &pool);
3428    for contact in contacts {
3429        if let db::Contact::Accepted {
3430            user_id: contact_user_id,
3431            ..
3432        } = contact
3433        {
3434            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3435                session
3436                    .peer
3437                    .send(
3438                        contact_conn_id,
3439                        proto::UpdateContacts {
3440                            contacts: vec![updated_contact.clone()],
3441                            remove_contacts: Default::default(),
3442                            incoming_requests: Default::default(),
3443                            remove_incoming_requests: Default::default(),
3444                            outgoing_requests: Default::default(),
3445                            remove_outgoing_requests: Default::default(),
3446                        },
3447                    )
3448                    .trace_err();
3449            }
3450        }
3451    }
3452    Ok(())
3453}
3454
3455async fn leave_room_for_session(session: &Session) -> Result<()> {
3456    let mut contacts_to_update = HashSet::default();
3457
3458    let room_id;
3459    let canceled_calls_to_user_ids;
3460    let live_kit_room;
3461    let delete_live_kit_room;
3462    let room;
3463    let channel_members;
3464    let channel_id;
3465
3466    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3467        contacts_to_update.insert(session.user_id);
3468
3469        for project in left_room.left_projects.values() {
3470            project_left(project, session);
3471        }
3472
3473        room_id = RoomId::from_proto(left_room.room.id);
3474        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3475        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3476        delete_live_kit_room = left_room.deleted;
3477        room = mem::take(&mut left_room.room);
3478        channel_members = mem::take(&mut left_room.channel_members);
3479        channel_id = left_room.channel_id;
3480
3481        room_updated(&room, &session.peer);
3482    } else {
3483        return Ok(());
3484    }
3485
3486    if let Some(channel_id) = channel_id {
3487        channel_updated(
3488            channel_id,
3489            &room,
3490            &channel_members,
3491            &session.peer,
3492            &*session.connection_pool().await,
3493        );
3494    }
3495
3496    {
3497        let pool = session.connection_pool().await;
3498        for canceled_user_id in canceled_calls_to_user_ids {
3499            for connection_id in pool.user_connection_ids(canceled_user_id) {
3500                session
3501                    .peer
3502                    .send(
3503                        connection_id,
3504                        proto::CallCanceled {
3505                            room_id: room_id.to_proto(),
3506                        },
3507                    )
3508                    .trace_err();
3509            }
3510            contacts_to_update.insert(canceled_user_id);
3511        }
3512    }
3513
3514    for contact_user_id in contacts_to_update {
3515        update_user_contacts(contact_user_id, &session).await?;
3516    }
3517
3518    if let Some(live_kit) = session.live_kit_client.as_ref() {
3519        live_kit
3520            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3521            .await
3522            .trace_err();
3523
3524        if delete_live_kit_room {
3525            live_kit.delete_room(live_kit_room).await.trace_err();
3526        }
3527    }
3528
3529    Ok(())
3530}
3531
3532async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3533    let left_channel_buffers = session
3534        .db()
3535        .await
3536        .leave_channel_buffers(session.connection_id)
3537        .await?;
3538
3539    for left_buffer in left_channel_buffers {
3540        channel_buffer_updated(
3541            session.connection_id,
3542            left_buffer.connections,
3543            &proto::UpdateChannelBufferCollaborators {
3544                channel_id: left_buffer.channel_id.to_proto(),
3545                collaborators: left_buffer.collaborators,
3546            },
3547            &session.peer,
3548        );
3549    }
3550
3551    Ok(())
3552}
3553
3554fn project_left(project: &db::LeftProject, session: &Session) {
3555    for connection_id in &project.connection_ids {
3556        if project.host_user_id == session.user_id {
3557            session
3558                .peer
3559                .send(
3560                    *connection_id,
3561                    proto::UnshareProject {
3562                        project_id: project.id.to_proto(),
3563                    },
3564                )
3565                .trace_err();
3566        } else {
3567            session
3568                .peer
3569                .send(
3570                    *connection_id,
3571                    proto::RemoveProjectCollaborator {
3572                        project_id: project.id.to_proto(),
3573                        peer_id: Some(session.connection_id.into()),
3574                    },
3575                )
3576                .trace_err();
3577        }
3578    }
3579}
3580
3581pub trait ResultExt {
3582    type Ok;
3583
3584    fn trace_err(self) -> Option<Self::Ok>;
3585}
3586
3587impl<T, E> ResultExt for Result<T, E>
3588where
3589    E: std::fmt::Debug,
3590{
3591    type Ok = T;
3592
3593    fn trace_err(self) -> Option<T> {
3594        match self {
3595            Ok(value) => Some(value),
3596            Err(error) => {
3597                tracing::error!("{:?}", error);
3598                None
3599            }
3600        }
3601    }
3602}