rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
   7        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
   8        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
   9    },
  10    executor::Executor,
  11    AppState, Error, Result,
  12};
  13use anyhow::anyhow;
  14use async_tungstenite::tungstenite::{
  15    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  16};
  17use axum::{
  18    body::Body,
  19    extract::{
  20        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  21        ConnectInfo, WebSocketUpgrade,
  22    },
  23    headers::{Header, HeaderName},
  24    http::StatusCode,
  25    middleware,
  26    response::IntoResponse,
  27    routing::get,
  28    Extension, Router, TypedHeader,
  29};
  30use collections::{HashMap, HashSet};
  31pub use connection_pool::ConnectionPool;
  32use futures::{
  33    channel::oneshot,
  34    future::{self, BoxFuture},
  35    stream::FuturesUnordered,
  36    FutureExt, SinkExt, StreamExt, TryStreamExt,
  37};
  38use lazy_static::lazy_static;
  39use prometheus::{register_int_gauge, IntGauge};
  40use rpc::{
  41    proto::{
  42        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  43        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  44    },
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46};
  47use serde::{Serialize, Serializer};
  48use std::{
  49    any::TypeId,
  50    fmt,
  51    future::Future,
  52    marker::PhantomData,
  53    mem,
  54    net::SocketAddr,
  55    ops::{Deref, DerefMut},
  56    rc::Rc,
  57    sync::{
  58        atomic::{AtomicBool, Ordering::SeqCst},
  59        Arc,
  60    },
  61    time::{Duration, Instant},
  62};
  63use time::OffsetDateTime;
  64use tokio::sync::{watch, Semaphore};
  65use tower::ServiceBuilder;
  66use tracing::{field, info_span, instrument, Instrument};
  67use util::SemanticVersion;
  68
  69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  71
  72const MESSAGE_COUNT_PER_PAGE: usize = 100;
  73const MAX_MESSAGE_LEN: usize = 1024;
  74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  75
  76lazy_static! {
  77    static ref METRIC_CONNECTIONS: IntGauge =
  78        register_int_gauge!("connections", "number of connections").unwrap();
  79    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  80        "shared_projects",
  81        "number of open projects with one or more guests"
  82    )
  83    .unwrap();
  84}
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103#[derive(Clone)]
 104struct Session {
 105    zed_environment: Arc<str>,
 106    user_id: UserId,
 107    connection_id: ConnectionId,
 108    db: Arc<tokio::sync::Mutex<DbHandle>>,
 109    peer: Arc<Peer>,
 110    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 111    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 112    _executor: Executor,
 113}
 114
 115impl Session {
 116    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 117        #[cfg(test)]
 118        tokio::task::yield_now().await;
 119        let guard = self.db.lock().await;
 120        #[cfg(test)]
 121        tokio::task::yield_now().await;
 122        guard
 123    }
 124
 125    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 126        #[cfg(test)]
 127        tokio::task::yield_now().await;
 128        let guard = self.connection_pool.lock();
 129        ConnectionPoolGuard {
 130            guard,
 131            _not_send: PhantomData,
 132        }
 133    }
 134}
 135
 136impl fmt::Debug for Session {
 137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 138        f.debug_struct("Session")
 139            .field("user_id", &self.user_id)
 140            .field("connection_id", &self.connection_id)
 141            .finish()
 142    }
 143}
 144
 145struct DbHandle(Arc<Database>);
 146
 147impl Deref for DbHandle {
 148    type Target = Database;
 149
 150    fn deref(&self) -> &Self::Target {
 151        self.0.as_ref()
 152    }
 153}
 154
 155pub struct Server {
 156    id: parking_lot::Mutex<ServerId>,
 157    peer: Arc<Peer>,
 158    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 159    app_state: Arc<AppState>,
 160    executor: Executor,
 161    handlers: HashMap<TypeId, MessageHandler>,
 162    teardown: watch::Sender<()>,
 163}
 164
 165pub(crate) struct ConnectionPoolGuard<'a> {
 166    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 167    _not_send: PhantomData<Rc<()>>,
 168}
 169
 170#[derive(Serialize)]
 171pub struct ServerSnapshot<'a> {
 172    peer: &'a Peer,
 173    #[serde(serialize_with = "serialize_deref")]
 174    connection_pool: ConnectionPoolGuard<'a>,
 175}
 176
 177pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 178where
 179    S: Serializer,
 180    T: Deref<Target = U>,
 181    U: Serialize,
 182{
 183    Serialize::serialize(value.deref(), serializer)
 184}
 185
 186impl Server {
 187    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 188        let mut server = Self {
 189            id: parking_lot::Mutex::new(id),
 190            peer: Peer::new(id.0 as u32),
 191            app_state,
 192            executor,
 193            connection_pool: Default::default(),
 194            handlers: Default::default(),
 195            teardown: watch::channel(()).0,
 196        };
 197
 198        server
 199            .add_request_handler(ping)
 200            .add_request_handler(create_room)
 201            .add_request_handler(join_room)
 202            .add_request_handler(rejoin_room)
 203            .add_request_handler(leave_room)
 204            .add_request_handler(set_room_participant_role)
 205            .add_request_handler(call)
 206            .add_request_handler(cancel_call)
 207            .add_message_handler(decline_call)
 208            .add_request_handler(update_participant_location)
 209            .add_request_handler(share_project)
 210            .add_message_handler(unshare_project)
 211            .add_request_handler(join_project)
 212            .add_message_handler(leave_project)
 213            .add_request_handler(update_project)
 214            .add_request_handler(update_worktree)
 215            .add_message_handler(start_language_server)
 216            .add_message_handler(update_language_server)
 217            .add_message_handler(update_diagnostic_summary)
 218            .add_message_handler(update_worktree_settings)
 219            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 223            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 224            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 226            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 228            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 229            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 230            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 231            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 232            .add_request_handler(
 233                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 234            )
 235            .add_request_handler(
 236                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 237            )
 238            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 239            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 240            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 243            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 245            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 250            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 251            .add_message_handler(create_buffer_for_peer)
 252            .add_request_handler(update_buffer)
 253            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 258            .add_request_handler(get_users)
 259            .add_request_handler(fuzzy_search_users)
 260            .add_request_handler(request_contact)
 261            .add_request_handler(remove_contact)
 262            .add_request_handler(respond_to_contact_request)
 263            .add_request_handler(create_channel)
 264            .add_request_handler(delete_channel)
 265            .add_request_handler(invite_channel_member)
 266            .add_request_handler(remove_channel_member)
 267            .add_request_handler(set_channel_member_role)
 268            .add_request_handler(set_channel_visibility)
 269            .add_request_handler(rename_channel)
 270            .add_request_handler(join_channel_buffer)
 271            .add_request_handler(leave_channel_buffer)
 272            .add_message_handler(update_channel_buffer)
 273            .add_request_handler(rejoin_channel_buffers)
 274            .add_request_handler(get_channel_members)
 275            .add_request_handler(respond_to_channel_invite)
 276            .add_request_handler(join_channel)
 277            .add_request_handler(join_channel_chat)
 278            .add_message_handler(leave_channel_chat)
 279            .add_request_handler(send_channel_message)
 280            .add_request_handler(remove_channel_message)
 281            .add_request_handler(get_channel_messages)
 282            .add_request_handler(get_channel_messages_by_id)
 283            .add_request_handler(get_notifications)
 284            .add_request_handler(mark_notification_as_read)
 285            .add_request_handler(move_channel)
 286            .add_request_handler(follow)
 287            .add_message_handler(unfollow)
 288            .add_message_handler(update_followers)
 289            .add_request_handler(get_private_user_info)
 290            .add_message_handler(acknowledge_channel_message)
 291            .add_message_handler(acknowledge_buffer_version);
 292
 293        Arc::new(server)
 294    }
 295
 296    pub async fn start(&self) -> Result<()> {
 297        let server_id = *self.id.lock();
 298        let app_state = self.app_state.clone();
 299        let peer = self.peer.clone();
 300        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 301        let pool = self.connection_pool.clone();
 302        let live_kit_client = self.app_state.live_kit_client.clone();
 303
 304        let span = info_span!("start server");
 305        self.executor.spawn_detached(
 306            async move {
 307                tracing::info!("waiting for cleanup timeout");
 308                timeout.await;
 309                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 310                if let Some((room_ids, channel_ids)) = app_state
 311                    .db
 312                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 313                    .await
 314                    .trace_err()
 315                {
 316                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 317                    tracing::info!(
 318                        stale_channel_buffer_count = channel_ids.len(),
 319                        "retrieved stale channel buffers"
 320                    );
 321
 322                    for channel_id in channel_ids {
 323                        if let Some(refreshed_channel_buffer) = app_state
 324                            .db
 325                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 326                            .await
 327                            .trace_err()
 328                        {
 329                            for connection_id in refreshed_channel_buffer.connection_ids {
 330                                peer.send(
 331                                    connection_id,
 332                                    proto::UpdateChannelBufferCollaborators {
 333                                        channel_id: channel_id.to_proto(),
 334                                        collaborators: refreshed_channel_buffer
 335                                            .collaborators
 336                                            .clone(),
 337                                    },
 338                                )
 339                                .trace_err();
 340                            }
 341                        }
 342                    }
 343
 344                    for room_id in room_ids {
 345                        let mut contacts_to_update = HashSet::default();
 346                        let mut canceled_calls_to_user_ids = Vec::new();
 347                        let mut live_kit_room = String::new();
 348                        let mut delete_live_kit_room = false;
 349
 350                        if let Some(mut refreshed_room) = app_state
 351                            .db
 352                            .clear_stale_room_participants(room_id, server_id)
 353                            .await
 354                            .trace_err()
 355                        {
 356                            tracing::info!(
 357                                room_id = room_id.0,
 358                                new_participant_count = refreshed_room.room.participants.len(),
 359                                "refreshed room"
 360                            );
 361                            room_updated(&refreshed_room.room, &peer);
 362                            if let Some(channel_id) = refreshed_room.channel_id {
 363                                channel_updated(
 364                                    channel_id,
 365                                    &refreshed_room.room,
 366                                    &refreshed_room.channel_members,
 367                                    &peer,
 368                                    &*pool.lock(),
 369                                );
 370                            }
 371                            contacts_to_update
 372                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 373                            contacts_to_update
 374                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 375                            canceled_calls_to_user_ids =
 376                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 377                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 378                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 379                        }
 380
 381                        {
 382                            let pool = pool.lock();
 383                            for canceled_user_id in canceled_calls_to_user_ids {
 384                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 385                                    peer.send(
 386                                        connection_id,
 387                                        proto::CallCanceled {
 388                                            room_id: room_id.to_proto(),
 389                                        },
 390                                    )
 391                                    .trace_err();
 392                                }
 393                            }
 394                        }
 395
 396                        for user_id in contacts_to_update {
 397                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 398                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 399                            if let Some((busy, contacts)) = busy.zip(contacts) {
 400                                let pool = pool.lock();
 401                                let updated_contact = contact_for_user(user_id, busy, &pool);
 402                                for contact in contacts {
 403                                    if let db::Contact::Accepted {
 404                                        user_id: contact_user_id,
 405                                        ..
 406                                    } = contact
 407                                    {
 408                                        for contact_conn_id in
 409                                            pool.user_connection_ids(contact_user_id)
 410                                        {
 411                                            peer.send(
 412                                                contact_conn_id,
 413                                                proto::UpdateContacts {
 414                                                    contacts: vec![updated_contact.clone()],
 415                                                    remove_contacts: Default::default(),
 416                                                    incoming_requests: Default::default(),
 417                                                    remove_incoming_requests: Default::default(),
 418                                                    outgoing_requests: Default::default(),
 419                                                    remove_outgoing_requests: Default::default(),
 420                                                },
 421                                            )
 422                                            .trace_err();
 423                                        }
 424                                    }
 425                                }
 426                            }
 427                        }
 428
 429                        if let Some(live_kit) = live_kit_client.as_ref() {
 430                            if delete_live_kit_room {
 431                                live_kit.delete_room(live_kit_room).await.trace_err();
 432                            }
 433                        }
 434                    }
 435                }
 436
 437                app_state
 438                    .db
 439                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 440                    .await
 441                    .trace_err();
 442            }
 443            .instrument(span),
 444        );
 445        Ok(())
 446    }
 447
 448    pub fn teardown(&self) {
 449        self.peer.teardown();
 450        self.connection_pool.lock().reset();
 451        let _ = self.teardown.send(());
 452    }
 453
 454    #[cfg(test)]
 455    pub fn reset(&self, id: ServerId) {
 456        self.teardown();
 457        *self.id.lock() = id;
 458        self.peer.reset(id.0 as u32);
 459    }
 460
 461    #[cfg(test)]
 462    pub fn id(&self) -> ServerId {
 463        *self.id.lock()
 464    }
 465
 466    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 467    where
 468        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 469        Fut: 'static + Send + Future<Output = Result<()>>,
 470        M: EnvelopedMessage,
 471    {
 472        let prev_handler = self.handlers.insert(
 473            TypeId::of::<M>(),
 474            Box::new(move |envelope, session| {
 475                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 476                let span = info_span!(
 477                    "handle message",
 478                    payload_type = envelope.payload_type_name()
 479                );
 480                span.in_scope(|| {
 481                    tracing::info!(
 482                        payload_type = envelope.payload_type_name(),
 483                        "message received"
 484                    );
 485                });
 486                let start_time = Instant::now();
 487                let future = (handler)(*envelope, session);
 488                async move {
 489                    let result = future.await;
 490                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 491                    match result {
 492                        Err(error) => {
 493                            tracing::error!(%error, ?duration_ms, "error handling message")
 494                        }
 495                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 496                    }
 497                }
 498                .instrument(span)
 499                .boxed()
 500            }),
 501        );
 502        if prev_handler.is_some() {
 503            panic!("registered a handler for the same message twice");
 504        }
 505        self
 506    }
 507
 508    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 509    where
 510        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 511        Fut: 'static + Send + Future<Output = Result<()>>,
 512        M: EnvelopedMessage,
 513    {
 514        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 515        self
 516    }
 517
 518    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 519    where
 520        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 521        Fut: Send + Future<Output = Result<()>>,
 522        M: RequestMessage,
 523    {
 524        let handler = Arc::new(handler);
 525        self.add_handler(move |envelope, session| {
 526            let receipt = envelope.receipt();
 527            let handler = handler.clone();
 528            async move {
 529                let peer = session.peer.clone();
 530                let responded = Arc::new(AtomicBool::default());
 531                let response = Response {
 532                    peer: peer.clone(),
 533                    responded: responded.clone(),
 534                    receipt,
 535                };
 536                match (handler)(envelope.payload, response, session).await {
 537                    Ok(()) => {
 538                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 539                            Ok(())
 540                        } else {
 541                            Err(anyhow!("handler did not send a response"))?
 542                        }
 543                    }
 544                    Err(error) => {
 545                        let proto_err = match &error {
 546                            Error::Internal(err) => err.to_proto(),
 547                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 548                        };
 549                        peer.respond_with_error(receipt, proto_err)?;
 550                        Err(error)
 551                    }
 552                }
 553            }
 554        })
 555    }
 556
 557    pub fn handle_connection(
 558        self: &Arc<Self>,
 559        connection: Connection,
 560        address: String,
 561        user: User,
 562        impersonator: Option<User>,
 563        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 564        executor: Executor,
 565    ) -> impl Future<Output = Result<()>> {
 566        let this = self.clone();
 567        let user_id = user.id;
 568        let login = user.github_login;
 569        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 570        if let Some(impersonator) = impersonator {
 571            span.record("impersonator", &impersonator.github_login);
 572        }
 573        let mut teardown = self.teardown.subscribe();
 574        async move {
 575            let (connection_id, handle_io, mut incoming_rx) = this
 576                .peer
 577                .add_connection(connection, {
 578                    let executor = executor.clone();
 579                    move |duration| executor.sleep(duration)
 580                });
 581
 582            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 583            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 584            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 585
 586            if let Some(send_connection_id) = send_connection_id.take() {
 587                let _ = send_connection_id.send(connection_id);
 588            }
 589
 590            if !user.connected_once {
 591                this.peer.send(connection_id, proto::ShowContacts {})?;
 592                this.app_state.db.set_user_connected_once(user_id, true).await?;
 593            }
 594
 595            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 596                this.app_state.db.get_contacts(user_id),
 597                this.app_state.db.get_channels_for_user(user_id),
 598                this.app_state.db.get_channel_invites_for_user(user_id),
 599            ).await?;
 600
 601            {
 602                let mut pool = this.connection_pool.lock();
 603                pool.add_connection(connection_id, user_id, user.admin);
 604                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 605                this.peer.send(connection_id, build_update_user_channels(&channels_for_user.channel_memberships))?;
 606                this.peer.send(connection_id, build_channels_update(
 607                    channels_for_user,
 608                    channel_invites
 609                ))?;
 610            }
 611
 612            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 613                this.peer.send(connection_id, incoming_call)?;
 614            }
 615
 616            let session = Session {
 617                user_id,
 618                connection_id,
 619                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 620                zed_environment: this.app_state.config.zed_environment.clone(),
 621                peer: this.peer.clone(),
 622                connection_pool: this.connection_pool.clone(),
 623                live_kit_client: this.app_state.live_kit_client.clone(),
 624                _executor: executor.clone()
 625            };
 626            update_user_contacts(user_id, &session).await?;
 627
 628            let handle_io = handle_io.fuse();
 629            futures::pin_mut!(handle_io);
 630
 631            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 632            // This prevents deadlocks when e.g., client A performs a request to client B and
 633            // client B performs a request to client A. If both clients stop processing further
 634            // messages until their respective request completes, they won't have a chance to
 635            // respond to the other client's request and cause a deadlock.
 636            //
 637            // This arrangement ensures we will attempt to process earlier messages first, but fall
 638            // back to processing messages arrived later in the spirit of making progress.
 639            let mut foreground_message_handlers = FuturesUnordered::new();
 640            let concurrent_handlers = Arc::new(Semaphore::new(256));
 641            loop {
 642                let next_message = async {
 643                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 644                    let message = incoming_rx.next().await;
 645                    (permit, message)
 646                }.fuse();
 647                futures::pin_mut!(next_message);
 648                futures::select_biased! {
 649                    _ = teardown.changed().fuse() => return Ok(()),
 650                    result = handle_io => {
 651                        if let Err(error) = result {
 652                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 653                        }
 654                        break;
 655                    }
 656                    _ = foreground_message_handlers.next() => {}
 657                    next_message = next_message => {
 658                        let (permit, message) = next_message;
 659                        if let Some(message) = message {
 660                            let type_name = message.payload_type_name();
 661                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 662                            let span_enter = span.enter();
 663                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 664                                let is_background = message.is_background();
 665                                let handle_message = (handler)(message, session.clone());
 666                                drop(span_enter);
 667
 668                                let handle_message = async move {
 669                                    handle_message.await;
 670                                    drop(permit);
 671                                }.instrument(span);
 672                                if is_background {
 673                                    executor.spawn_detached(handle_message);
 674                                } else {
 675                                    foreground_message_handlers.push(handle_message);
 676                                }
 677                            } else {
 678                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 679                            }
 680                        } else {
 681                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 682                            break;
 683                        }
 684                    }
 685                }
 686            }
 687
 688            drop(foreground_message_handlers);
 689            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 690            if let Err(error) = connection_lost(session, teardown, executor).await {
 691                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 692            }
 693
 694            Ok(())
 695        }.instrument(span)
 696    }
 697
 698    pub async fn invite_code_redeemed(
 699        self: &Arc<Self>,
 700        inviter_id: UserId,
 701        invitee_id: UserId,
 702    ) -> Result<()> {
 703        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 704            if let Some(code) = &user.invite_code {
 705                let pool = self.connection_pool.lock();
 706                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 707                for connection_id in pool.user_connection_ids(inviter_id) {
 708                    self.peer.send(
 709                        connection_id,
 710                        proto::UpdateContacts {
 711                            contacts: vec![invitee_contact.clone()],
 712                            ..Default::default()
 713                        },
 714                    )?;
 715                    self.peer.send(
 716                        connection_id,
 717                        proto::UpdateInviteInfo {
 718                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 719                            count: user.invite_count as u32,
 720                        },
 721                    )?;
 722                }
 723            }
 724        }
 725        Ok(())
 726    }
 727
 728    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 729        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 730            if let Some(invite_code) = &user.invite_code {
 731                let pool = self.connection_pool.lock();
 732                for connection_id in pool.user_connection_ids(user_id) {
 733                    self.peer.send(
 734                        connection_id,
 735                        proto::UpdateInviteInfo {
 736                            url: format!(
 737                                "{}{}",
 738                                self.app_state.config.invite_link_prefix, invite_code
 739                            ),
 740                            count: user.invite_count as u32,
 741                        },
 742                    )?;
 743                }
 744            }
 745        }
 746        Ok(())
 747    }
 748
 749    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 750        ServerSnapshot {
 751            connection_pool: ConnectionPoolGuard {
 752                guard: self.connection_pool.lock(),
 753                _not_send: PhantomData,
 754            },
 755            peer: &self.peer,
 756        }
 757    }
 758}
 759
 760impl<'a> Deref for ConnectionPoolGuard<'a> {
 761    type Target = ConnectionPool;
 762
 763    fn deref(&self) -> &Self::Target {
 764        &*self.guard
 765    }
 766}
 767
 768impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 769    fn deref_mut(&mut self) -> &mut Self::Target {
 770        &mut *self.guard
 771    }
 772}
 773
 774impl<'a> Drop for ConnectionPoolGuard<'a> {
 775    fn drop(&mut self) {
 776        #[cfg(test)]
 777        self.check_invariants();
 778    }
 779}
 780
 781fn broadcast<F>(
 782    sender_id: Option<ConnectionId>,
 783    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 784    mut f: F,
 785) where
 786    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 787{
 788    for receiver_id in receiver_ids {
 789        if Some(receiver_id) != sender_id {
 790            if let Err(error) = f(receiver_id) {
 791                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 792            }
 793        }
 794    }
 795}
 796
 797lazy_static! {
 798    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 799    static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
 800}
 801
 802pub struct ProtocolVersion(u32);
 803
 804impl Header for ProtocolVersion {
 805    fn name() -> &'static HeaderName {
 806        &ZED_PROTOCOL_VERSION
 807    }
 808
 809    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 810    where
 811        Self: Sized,
 812        I: Iterator<Item = &'i axum::http::HeaderValue>,
 813    {
 814        let version = values
 815            .next()
 816            .ok_or_else(axum::headers::Error::invalid)?
 817            .to_str()
 818            .map_err(|_| axum::headers::Error::invalid())?
 819            .parse()
 820            .map_err(|_| axum::headers::Error::invalid())?;
 821        Ok(Self(version))
 822    }
 823
 824    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 825        values.extend([self.0.to_string().parse().unwrap()]);
 826    }
 827}
 828
 829pub struct AppVersionHeader(SemanticVersion);
 830impl Header for AppVersionHeader {
 831    fn name() -> &'static HeaderName {
 832        &ZED_APP_VERSION
 833    }
 834
 835    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 836    where
 837        Self: Sized,
 838        I: Iterator<Item = &'i axum::http::HeaderValue>,
 839    {
 840        let version = values
 841            .next()
 842            .ok_or_else(axum::headers::Error::invalid)?
 843            .to_str()
 844            .map_err(|_| axum::headers::Error::invalid())?
 845            .parse()
 846            .map_err(|_| axum::headers::Error::invalid())?;
 847        Ok(Self(version))
 848    }
 849
 850    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 851        values.extend([self.0.to_string().parse().unwrap()]);
 852    }
 853}
 854
 855pub fn routes(server: Arc<Server>) -> Router<Body> {
 856    Router::new()
 857        .route("/rpc", get(handle_websocket_request))
 858        .layer(
 859            ServiceBuilder::new()
 860                .layer(Extension(server.app_state.clone()))
 861                .layer(middleware::from_fn(auth::validate_header)),
 862        )
 863        .route("/metrics", get(handle_metrics))
 864        .layer(Extension(server))
 865}
 866
 867pub async fn handle_websocket_request(
 868    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 869    _app_version_header: Option<TypedHeader<AppVersionHeader>>,
 870    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 871    Extension(server): Extension<Arc<Server>>,
 872    Extension(user): Extension<User>,
 873    Extension(impersonator): Extension<Impersonator>,
 874    ws: WebSocketUpgrade,
 875) -> axum::response::Response {
 876    if protocol_version != rpc::PROTOCOL_VERSION {
 877        return (
 878            StatusCode::UPGRADE_REQUIRED,
 879            "client must be upgraded".to_string(),
 880        )
 881            .into_response();
 882    }
 883
 884    let socket_address = socket_address.to_string();
 885    ws.on_upgrade(move |socket| {
 886        use util::ResultExt;
 887        let socket = socket
 888            .map_ok(to_tungstenite_message)
 889            .err_into()
 890            .with(|message| async move { Ok(to_axum_message(message)) });
 891        let connection = Connection::new(Box::pin(socket));
 892        async move {
 893            server
 894                .handle_connection(
 895                    connection,
 896                    socket_address,
 897                    user,
 898                    impersonator.0,
 899                    None,
 900                    Executor::Production,
 901                )
 902                .await
 903                .log_err();
 904        }
 905    })
 906}
 907
 908pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 909    let connections = server
 910        .connection_pool
 911        .lock()
 912        .connections()
 913        .filter(|connection| !connection.admin)
 914        .count();
 915
 916    METRIC_CONNECTIONS.set(connections as _);
 917
 918    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 919    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 920
 921    let encoder = prometheus::TextEncoder::new();
 922    let metric_families = prometheus::gather();
 923    let encoded_metrics = encoder
 924        .encode_to_string(&metric_families)
 925        .map_err(|err| anyhow!("{}", err))?;
 926    Ok(encoded_metrics)
 927}
 928
 929#[instrument(err, skip(executor))]
 930async fn connection_lost(
 931    session: Session,
 932    mut teardown: watch::Receiver<()>,
 933    executor: Executor,
 934) -> Result<()> {
 935    session.peer.disconnect(session.connection_id);
 936    session
 937        .connection_pool()
 938        .await
 939        .remove_connection(session.connection_id)?;
 940
 941    session
 942        .db()
 943        .await
 944        .connection_lost(session.connection_id)
 945        .await
 946        .trace_err();
 947
 948    futures::select_biased! {
 949        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 950            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 951            leave_room_for_session(&session).await.trace_err();
 952            leave_channel_buffers_for_session(&session)
 953                .await
 954                .trace_err();
 955
 956            if !session
 957                .connection_pool()
 958                .await
 959                .is_user_online(session.user_id)
 960            {
 961                let db = session.db().await;
 962                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 963                    room_updated(&room, &session.peer);
 964                }
 965            }
 966
 967            update_user_contacts(session.user_id, &session).await?;
 968        }
 969        _ = teardown.changed().fuse() => {}
 970    }
 971
 972    Ok(())
 973}
 974
 975/// Acknowledges a ping from a client, used to keep the connection alive.
 976async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 977    response.send(proto::Ack {})?;
 978    Ok(())
 979}
 980
 981/// Creates a new room for calling (outside of channels)
 982async fn create_room(
 983    _request: proto::CreateRoom,
 984    response: Response<proto::CreateRoom>,
 985    session: Session,
 986) -> Result<()> {
 987    let live_kit_room = nanoid::nanoid!(30);
 988
 989    let live_kit_connection_info = {
 990        let live_kit_room = live_kit_room.clone();
 991        let live_kit = session.live_kit_client.as_ref();
 992
 993        util::async_maybe!({
 994            let live_kit = live_kit?;
 995
 996            let token = live_kit
 997                .room_token(&live_kit_room, &session.user_id.to_string())
 998                .trace_err()?;
 999
1000            Some(proto::LiveKitConnectionInfo {
1001                server_url: live_kit.url().into(),
1002                token,
1003                can_publish: true,
1004            })
1005        })
1006    }
1007    .await;
1008
1009    let room = session
1010        .db()
1011        .await
1012        .create_room(
1013            session.user_id,
1014            session.connection_id,
1015            &live_kit_room,
1016            &session.zed_environment,
1017        )
1018        .await?;
1019
1020    response.send(proto::CreateRoomResponse {
1021        room: Some(room.clone()),
1022        live_kit_connection_info,
1023    })?;
1024
1025    update_user_contacts(session.user_id, &session).await?;
1026    Ok(())
1027}
1028
1029/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1030async fn join_room(
1031    request: proto::JoinRoom,
1032    response: Response<proto::JoinRoom>,
1033    session: Session,
1034) -> Result<()> {
1035    let room_id = RoomId::from_proto(request.id);
1036
1037    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1038
1039    if let Some(channel_id) = channel_id {
1040        return join_channel_internal(channel_id, Box::new(response), session).await;
1041    }
1042
1043    let joined_room = {
1044        let room = session
1045            .db()
1046            .await
1047            .join_room(
1048                room_id,
1049                session.user_id,
1050                session.connection_id,
1051                session.zed_environment.as_ref(),
1052            )
1053            .await?;
1054        room_updated(&room.room, &session.peer);
1055        room.into_inner()
1056    };
1057
1058    for connection_id in session
1059        .connection_pool()
1060        .await
1061        .user_connection_ids(session.user_id)
1062    {
1063        session
1064            .peer
1065            .send(
1066                connection_id,
1067                proto::CallCanceled {
1068                    room_id: room_id.to_proto(),
1069                },
1070            )
1071            .trace_err();
1072    }
1073
1074    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1075        if let Some(token) = live_kit
1076            .room_token(
1077                &joined_room.room.live_kit_room,
1078                &session.user_id.to_string(),
1079            )
1080            .trace_err()
1081        {
1082            Some(proto::LiveKitConnectionInfo {
1083                server_url: live_kit.url().into(),
1084                token,
1085                can_publish: true,
1086            })
1087        } else {
1088            None
1089        }
1090    } else {
1091        None
1092    };
1093
1094    response.send(proto::JoinRoomResponse {
1095        room: Some(joined_room.room),
1096        channel_id: None,
1097        live_kit_connection_info,
1098    })?;
1099
1100    update_user_contacts(session.user_id, &session).await?;
1101    Ok(())
1102}
1103
1104/// Rejoin room is used to reconnect to a room after connection errors.
1105async fn rejoin_room(
1106    request: proto::RejoinRoom,
1107    response: Response<proto::RejoinRoom>,
1108    session: Session,
1109) -> Result<()> {
1110    let room;
1111    let channel_id;
1112    let channel_members;
1113    {
1114        let mut rejoined_room = session
1115            .db()
1116            .await
1117            .rejoin_room(request, session.user_id, session.connection_id)
1118            .await?;
1119
1120        response.send(proto::RejoinRoomResponse {
1121            room: Some(rejoined_room.room.clone()),
1122            reshared_projects: rejoined_room
1123                .reshared_projects
1124                .iter()
1125                .map(|project| proto::ResharedProject {
1126                    id: project.id.to_proto(),
1127                    collaborators: project
1128                        .collaborators
1129                        .iter()
1130                        .map(|collaborator| collaborator.to_proto())
1131                        .collect(),
1132                })
1133                .collect(),
1134            rejoined_projects: rejoined_room
1135                .rejoined_projects
1136                .iter()
1137                .map(|rejoined_project| proto::RejoinedProject {
1138                    id: rejoined_project.id.to_proto(),
1139                    worktrees: rejoined_project
1140                        .worktrees
1141                        .iter()
1142                        .map(|worktree| proto::WorktreeMetadata {
1143                            id: worktree.id,
1144                            root_name: worktree.root_name.clone(),
1145                            visible: worktree.visible,
1146                            abs_path: worktree.abs_path.clone(),
1147                        })
1148                        .collect(),
1149                    collaborators: rejoined_project
1150                        .collaborators
1151                        .iter()
1152                        .map(|collaborator| collaborator.to_proto())
1153                        .collect(),
1154                    language_servers: rejoined_project.language_servers.clone(),
1155                })
1156                .collect(),
1157        })?;
1158        room_updated(&rejoined_room.room, &session.peer);
1159
1160        for project in &rejoined_room.reshared_projects {
1161            for collaborator in &project.collaborators {
1162                session
1163                    .peer
1164                    .send(
1165                        collaborator.connection_id,
1166                        proto::UpdateProjectCollaborator {
1167                            project_id: project.id.to_proto(),
1168                            old_peer_id: Some(project.old_connection_id.into()),
1169                            new_peer_id: Some(session.connection_id.into()),
1170                        },
1171                    )
1172                    .trace_err();
1173            }
1174
1175            broadcast(
1176                Some(session.connection_id),
1177                project
1178                    .collaborators
1179                    .iter()
1180                    .map(|collaborator| collaborator.connection_id),
1181                |connection_id| {
1182                    session.peer.forward_send(
1183                        session.connection_id,
1184                        connection_id,
1185                        proto::UpdateProject {
1186                            project_id: project.id.to_proto(),
1187                            worktrees: project.worktrees.clone(),
1188                        },
1189                    )
1190                },
1191            );
1192        }
1193
1194        for project in &rejoined_room.rejoined_projects {
1195            for collaborator in &project.collaborators {
1196                session
1197                    .peer
1198                    .send(
1199                        collaborator.connection_id,
1200                        proto::UpdateProjectCollaborator {
1201                            project_id: project.id.to_proto(),
1202                            old_peer_id: Some(project.old_connection_id.into()),
1203                            new_peer_id: Some(session.connection_id.into()),
1204                        },
1205                    )
1206                    .trace_err();
1207            }
1208        }
1209
1210        for project in &mut rejoined_room.rejoined_projects {
1211            for worktree in mem::take(&mut project.worktrees) {
1212                #[cfg(any(test, feature = "test-support"))]
1213                const MAX_CHUNK_SIZE: usize = 2;
1214                #[cfg(not(any(test, feature = "test-support")))]
1215                const MAX_CHUNK_SIZE: usize = 256;
1216
1217                // Stream this worktree's entries.
1218                let message = proto::UpdateWorktree {
1219                    project_id: project.id.to_proto(),
1220                    worktree_id: worktree.id,
1221                    abs_path: worktree.abs_path.clone(),
1222                    root_name: worktree.root_name,
1223                    updated_entries: worktree.updated_entries,
1224                    removed_entries: worktree.removed_entries,
1225                    scan_id: worktree.scan_id,
1226                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1227                    updated_repositories: worktree.updated_repositories,
1228                    removed_repositories: worktree.removed_repositories,
1229                };
1230                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1231                    session.peer.send(session.connection_id, update.clone())?;
1232                }
1233
1234                // Stream this worktree's diagnostics.
1235                for summary in worktree.diagnostic_summaries {
1236                    session.peer.send(
1237                        session.connection_id,
1238                        proto::UpdateDiagnosticSummary {
1239                            project_id: project.id.to_proto(),
1240                            worktree_id: worktree.id,
1241                            summary: Some(summary),
1242                        },
1243                    )?;
1244                }
1245
1246                for settings_file in worktree.settings_files {
1247                    session.peer.send(
1248                        session.connection_id,
1249                        proto::UpdateWorktreeSettings {
1250                            project_id: project.id.to_proto(),
1251                            worktree_id: worktree.id,
1252                            path: settings_file.path,
1253                            content: Some(settings_file.content),
1254                        },
1255                    )?;
1256                }
1257            }
1258
1259            for language_server in &project.language_servers {
1260                session.peer.send(
1261                    session.connection_id,
1262                    proto::UpdateLanguageServer {
1263                        project_id: project.id.to_proto(),
1264                        language_server_id: language_server.id,
1265                        variant: Some(
1266                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1267                                proto::LspDiskBasedDiagnosticsUpdated {},
1268                            ),
1269                        ),
1270                    },
1271                )?;
1272            }
1273        }
1274
1275        let rejoined_room = rejoined_room.into_inner();
1276
1277        room = rejoined_room.room;
1278        channel_id = rejoined_room.channel_id;
1279        channel_members = rejoined_room.channel_members;
1280    }
1281
1282    if let Some(channel_id) = channel_id {
1283        channel_updated(
1284            channel_id,
1285            &room,
1286            &channel_members,
1287            &session.peer,
1288            &*session.connection_pool().await,
1289        );
1290    }
1291
1292    update_user_contacts(session.user_id, &session).await?;
1293    Ok(())
1294}
1295
1296/// leave room disconnects from the room.
1297async fn leave_room(
1298    _: proto::LeaveRoom,
1299    response: Response<proto::LeaveRoom>,
1300    session: Session,
1301) -> Result<()> {
1302    leave_room_for_session(&session).await?;
1303    response.send(proto::Ack {})?;
1304    Ok(())
1305}
1306
1307/// Updates the permissions of someone else in the room.
1308async fn set_room_participant_role(
1309    request: proto::SetRoomParticipantRole,
1310    response: Response<proto::SetRoomParticipantRole>,
1311    session: Session,
1312) -> Result<()> {
1313    let (live_kit_room, can_publish) = {
1314        let room = session
1315            .db()
1316            .await
1317            .set_room_participant_role(
1318                session.user_id,
1319                RoomId::from_proto(request.room_id),
1320                UserId::from_proto(request.user_id),
1321                ChannelRole::from(request.role()),
1322            )
1323            .await?;
1324
1325        let live_kit_room = room.live_kit_room.clone();
1326        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1327        room_updated(&room, &session.peer);
1328        (live_kit_room, can_publish)
1329    };
1330
1331    if let Some(live_kit) = session.live_kit_client.as_ref() {
1332        live_kit
1333            .update_participant(
1334                live_kit_room.clone(),
1335                request.user_id.to_string(),
1336                live_kit_server::proto::ParticipantPermission {
1337                    can_subscribe: true,
1338                    can_publish,
1339                    can_publish_data: can_publish,
1340                    hidden: false,
1341                    recorder: false,
1342                },
1343            )
1344            .await
1345            .trace_err();
1346    }
1347
1348    response.send(proto::Ack {})?;
1349    Ok(())
1350}
1351
1352/// Call someone else into the current room
1353async fn call(
1354    request: proto::Call,
1355    response: Response<proto::Call>,
1356    session: Session,
1357) -> Result<()> {
1358    let room_id = RoomId::from_proto(request.room_id);
1359    let calling_user_id = session.user_id;
1360    let calling_connection_id = session.connection_id;
1361    let called_user_id = UserId::from_proto(request.called_user_id);
1362    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1363    if !session
1364        .db()
1365        .await
1366        .has_contact(calling_user_id, called_user_id)
1367        .await?
1368    {
1369        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1370    }
1371
1372    let incoming_call = {
1373        let (room, incoming_call) = &mut *session
1374            .db()
1375            .await
1376            .call(
1377                room_id,
1378                calling_user_id,
1379                calling_connection_id,
1380                called_user_id,
1381                initial_project_id,
1382            )
1383            .await?;
1384        room_updated(&room, &session.peer);
1385        mem::take(incoming_call)
1386    };
1387    update_user_contacts(called_user_id, &session).await?;
1388
1389    let mut calls = session
1390        .connection_pool()
1391        .await
1392        .user_connection_ids(called_user_id)
1393        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1394        .collect::<FuturesUnordered<_>>();
1395
1396    while let Some(call_response) = calls.next().await {
1397        match call_response.as_ref() {
1398            Ok(_) => {
1399                response.send(proto::Ack {})?;
1400                return Ok(());
1401            }
1402            Err(_) => {
1403                call_response.trace_err();
1404            }
1405        }
1406    }
1407
1408    {
1409        let room = session
1410            .db()
1411            .await
1412            .call_failed(room_id, called_user_id)
1413            .await?;
1414        room_updated(&room, &session.peer);
1415    }
1416    update_user_contacts(called_user_id, &session).await?;
1417
1418    Err(anyhow!("failed to ring user"))?
1419}
1420
1421/// Cancel an outgoing call.
1422async fn cancel_call(
1423    request: proto::CancelCall,
1424    response: Response<proto::CancelCall>,
1425    session: Session,
1426) -> Result<()> {
1427    let called_user_id = UserId::from_proto(request.called_user_id);
1428    let room_id = RoomId::from_proto(request.room_id);
1429    {
1430        let room = session
1431            .db()
1432            .await
1433            .cancel_call(room_id, session.connection_id, called_user_id)
1434            .await?;
1435        room_updated(&room, &session.peer);
1436    }
1437
1438    for connection_id in session
1439        .connection_pool()
1440        .await
1441        .user_connection_ids(called_user_id)
1442    {
1443        session
1444            .peer
1445            .send(
1446                connection_id,
1447                proto::CallCanceled {
1448                    room_id: room_id.to_proto(),
1449                },
1450            )
1451            .trace_err();
1452    }
1453    response.send(proto::Ack {})?;
1454
1455    update_user_contacts(called_user_id, &session).await?;
1456    Ok(())
1457}
1458
1459/// Decline an incoming call.
1460async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1461    let room_id = RoomId::from_proto(message.room_id);
1462    {
1463        let room = session
1464            .db()
1465            .await
1466            .decline_call(Some(room_id), session.user_id)
1467            .await?
1468            .ok_or_else(|| anyhow!("failed to decline call"))?;
1469        room_updated(&room, &session.peer);
1470    }
1471
1472    for connection_id in session
1473        .connection_pool()
1474        .await
1475        .user_connection_ids(session.user_id)
1476    {
1477        session
1478            .peer
1479            .send(
1480                connection_id,
1481                proto::CallCanceled {
1482                    room_id: room_id.to_proto(),
1483                },
1484            )
1485            .trace_err();
1486    }
1487    update_user_contacts(session.user_id, &session).await?;
1488    Ok(())
1489}
1490
1491/// Updates other participants in the room with your current location.
1492async fn update_participant_location(
1493    request: proto::UpdateParticipantLocation,
1494    response: Response<proto::UpdateParticipantLocation>,
1495    session: Session,
1496) -> Result<()> {
1497    let room_id = RoomId::from_proto(request.room_id);
1498    let location = request
1499        .location
1500        .ok_or_else(|| anyhow!("invalid location"))?;
1501
1502    let db = session.db().await;
1503    let room = db
1504        .update_room_participant_location(room_id, session.connection_id, location)
1505        .await?;
1506
1507    room_updated(&room, &session.peer);
1508    response.send(proto::Ack {})?;
1509    Ok(())
1510}
1511
1512/// Share a project into the room.
1513async fn share_project(
1514    request: proto::ShareProject,
1515    response: Response<proto::ShareProject>,
1516    session: Session,
1517) -> Result<()> {
1518    let (project_id, room) = &*session
1519        .db()
1520        .await
1521        .share_project(
1522            RoomId::from_proto(request.room_id),
1523            session.connection_id,
1524            &request.worktrees,
1525        )
1526        .await?;
1527    response.send(proto::ShareProjectResponse {
1528        project_id: project_id.to_proto(),
1529    })?;
1530    room_updated(&room, &session.peer);
1531
1532    Ok(())
1533}
1534
1535/// Unshare a project from the room.
1536async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1537    let project_id = ProjectId::from_proto(message.project_id);
1538
1539    let (room, guest_connection_ids) = &*session
1540        .db()
1541        .await
1542        .unshare_project(project_id, session.connection_id)
1543        .await?;
1544
1545    broadcast(
1546        Some(session.connection_id),
1547        guest_connection_ids.iter().copied(),
1548        |conn_id| session.peer.send(conn_id, message.clone()),
1549    );
1550    room_updated(&room, &session.peer);
1551
1552    Ok(())
1553}
1554
1555/// Join someone elses shared project.
1556async fn join_project(
1557    request: proto::JoinProject,
1558    response: Response<proto::JoinProject>,
1559    session: Session,
1560) -> Result<()> {
1561    let project_id = ProjectId::from_proto(request.project_id);
1562    let guest_user_id = session.user_id;
1563
1564    tracing::info!(%project_id, "join project");
1565
1566    let (project, replica_id) = &mut *session
1567        .db()
1568        .await
1569        .join_project(project_id, session.connection_id)
1570        .await?;
1571
1572    let collaborators = project
1573        .collaborators
1574        .iter()
1575        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1576        .map(|collaborator| collaborator.to_proto())
1577        .collect::<Vec<_>>();
1578
1579    let worktrees = project
1580        .worktrees
1581        .iter()
1582        .map(|(id, worktree)| proto::WorktreeMetadata {
1583            id: *id,
1584            root_name: worktree.root_name.clone(),
1585            visible: worktree.visible,
1586            abs_path: worktree.abs_path.clone(),
1587        })
1588        .collect::<Vec<_>>();
1589
1590    for collaborator in &collaborators {
1591        session
1592            .peer
1593            .send(
1594                collaborator.peer_id.unwrap().into(),
1595                proto::AddProjectCollaborator {
1596                    project_id: project_id.to_proto(),
1597                    collaborator: Some(proto::Collaborator {
1598                        peer_id: Some(session.connection_id.into()),
1599                        replica_id: replica_id.0 as u32,
1600                        user_id: guest_user_id.to_proto(),
1601                    }),
1602                },
1603            )
1604            .trace_err();
1605    }
1606
1607    // First, we send the metadata associated with each worktree.
1608    response.send(proto::JoinProjectResponse {
1609        worktrees: worktrees.clone(),
1610        replica_id: replica_id.0 as u32,
1611        collaborators: collaborators.clone(),
1612        language_servers: project.language_servers.clone(),
1613    })?;
1614
1615    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1616        #[cfg(any(test, feature = "test-support"))]
1617        const MAX_CHUNK_SIZE: usize = 2;
1618        #[cfg(not(any(test, feature = "test-support")))]
1619        const MAX_CHUNK_SIZE: usize = 256;
1620
1621        // Stream this worktree's entries.
1622        let message = proto::UpdateWorktree {
1623            project_id: project_id.to_proto(),
1624            worktree_id,
1625            abs_path: worktree.abs_path.clone(),
1626            root_name: worktree.root_name,
1627            updated_entries: worktree.entries,
1628            removed_entries: Default::default(),
1629            scan_id: worktree.scan_id,
1630            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1631            updated_repositories: worktree.repository_entries.into_values().collect(),
1632            removed_repositories: Default::default(),
1633        };
1634        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1635            session.peer.send(session.connection_id, update.clone())?;
1636        }
1637
1638        // Stream this worktree's diagnostics.
1639        for summary in worktree.diagnostic_summaries {
1640            session.peer.send(
1641                session.connection_id,
1642                proto::UpdateDiagnosticSummary {
1643                    project_id: project_id.to_proto(),
1644                    worktree_id: worktree.id,
1645                    summary: Some(summary),
1646                },
1647            )?;
1648        }
1649
1650        for settings_file in worktree.settings_files {
1651            session.peer.send(
1652                session.connection_id,
1653                proto::UpdateWorktreeSettings {
1654                    project_id: project_id.to_proto(),
1655                    worktree_id: worktree.id,
1656                    path: settings_file.path,
1657                    content: Some(settings_file.content),
1658                },
1659            )?;
1660        }
1661    }
1662
1663    for language_server in &project.language_servers {
1664        session.peer.send(
1665            session.connection_id,
1666            proto::UpdateLanguageServer {
1667                project_id: project_id.to_proto(),
1668                language_server_id: language_server.id,
1669                variant: Some(
1670                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1671                        proto::LspDiskBasedDiagnosticsUpdated {},
1672                    ),
1673                ),
1674            },
1675        )?;
1676    }
1677
1678    Ok(())
1679}
1680
1681/// Leave someone elses shared project.
1682async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1683    let sender_id = session.connection_id;
1684    let project_id = ProjectId::from_proto(request.project_id);
1685
1686    let (room, project) = &*session
1687        .db()
1688        .await
1689        .leave_project(project_id, sender_id)
1690        .await?;
1691    tracing::info!(
1692        %project_id,
1693        host_user_id = %project.host_user_id,
1694        host_connection_id = ?project.host_connection_id,
1695        "leave project"
1696    );
1697
1698    project_left(&project, &session);
1699    room_updated(&room, &session.peer);
1700
1701    Ok(())
1702}
1703
1704/// Updates other participants with changes to the project
1705async fn update_project(
1706    request: proto::UpdateProject,
1707    response: Response<proto::UpdateProject>,
1708    session: Session,
1709) -> Result<()> {
1710    let project_id = ProjectId::from_proto(request.project_id);
1711    let (room, guest_connection_ids) = &*session
1712        .db()
1713        .await
1714        .update_project(project_id, session.connection_id, &request.worktrees)
1715        .await?;
1716    broadcast(
1717        Some(session.connection_id),
1718        guest_connection_ids.iter().copied(),
1719        |connection_id| {
1720            session
1721                .peer
1722                .forward_send(session.connection_id, connection_id, request.clone())
1723        },
1724    );
1725    room_updated(&room, &session.peer);
1726    response.send(proto::Ack {})?;
1727
1728    Ok(())
1729}
1730
1731/// Updates other participants with changes to the worktree
1732async fn update_worktree(
1733    request: proto::UpdateWorktree,
1734    response: Response<proto::UpdateWorktree>,
1735    session: Session,
1736) -> Result<()> {
1737    let guest_connection_ids = session
1738        .db()
1739        .await
1740        .update_worktree(&request, session.connection_id)
1741        .await?;
1742
1743    broadcast(
1744        Some(session.connection_id),
1745        guest_connection_ids.iter().copied(),
1746        |connection_id| {
1747            session
1748                .peer
1749                .forward_send(session.connection_id, connection_id, request.clone())
1750        },
1751    );
1752    response.send(proto::Ack {})?;
1753    Ok(())
1754}
1755
1756/// Updates other participants with changes to the diagnostics
1757async fn update_diagnostic_summary(
1758    message: proto::UpdateDiagnosticSummary,
1759    session: Session,
1760) -> Result<()> {
1761    let guest_connection_ids = session
1762        .db()
1763        .await
1764        .update_diagnostic_summary(&message, session.connection_id)
1765        .await?;
1766
1767    broadcast(
1768        Some(session.connection_id),
1769        guest_connection_ids.iter().copied(),
1770        |connection_id| {
1771            session
1772                .peer
1773                .forward_send(session.connection_id, connection_id, message.clone())
1774        },
1775    );
1776
1777    Ok(())
1778}
1779
1780/// Updates other participants with changes to the worktree settings
1781async fn update_worktree_settings(
1782    message: proto::UpdateWorktreeSettings,
1783    session: Session,
1784) -> Result<()> {
1785    let guest_connection_ids = session
1786        .db()
1787        .await
1788        .update_worktree_settings(&message, session.connection_id)
1789        .await?;
1790
1791    broadcast(
1792        Some(session.connection_id),
1793        guest_connection_ids.iter().copied(),
1794        |connection_id| {
1795            session
1796                .peer
1797                .forward_send(session.connection_id, connection_id, message.clone())
1798        },
1799    );
1800
1801    Ok(())
1802}
1803
1804/// Notify other participants that a  language server has started.
1805async fn start_language_server(
1806    request: proto::StartLanguageServer,
1807    session: Session,
1808) -> Result<()> {
1809    let guest_connection_ids = session
1810        .db()
1811        .await
1812        .start_language_server(&request, session.connection_id)
1813        .await?;
1814
1815    broadcast(
1816        Some(session.connection_id),
1817        guest_connection_ids.iter().copied(),
1818        |connection_id| {
1819            session
1820                .peer
1821                .forward_send(session.connection_id, connection_id, request.clone())
1822        },
1823    );
1824    Ok(())
1825}
1826
1827/// Notify other participants that a language server has changed.
1828async fn update_language_server(
1829    request: proto::UpdateLanguageServer,
1830    session: Session,
1831) -> Result<()> {
1832    let project_id = ProjectId::from_proto(request.project_id);
1833    let project_connection_ids = session
1834        .db()
1835        .await
1836        .project_connection_ids(project_id, session.connection_id)
1837        .await?;
1838    broadcast(
1839        Some(session.connection_id),
1840        project_connection_ids.iter().copied(),
1841        |connection_id| {
1842            session
1843                .peer
1844                .forward_send(session.connection_id, connection_id, request.clone())
1845        },
1846    );
1847    Ok(())
1848}
1849
1850/// forward a project request to the host. These requests should be read only
1851/// as guests are allowed to send them.
1852async fn forward_read_only_project_request<T>(
1853    request: T,
1854    response: Response<T>,
1855    session: Session,
1856) -> Result<()>
1857where
1858    T: EntityMessage + RequestMessage,
1859{
1860    let project_id = ProjectId::from_proto(request.remote_entity_id());
1861    let host_connection_id = session
1862        .db()
1863        .await
1864        .host_for_read_only_project_request(project_id, session.connection_id)
1865        .await?;
1866    let payload = session
1867        .peer
1868        .forward_request(session.connection_id, host_connection_id, request)
1869        .await?;
1870    response.send(payload)?;
1871    Ok(())
1872}
1873
1874/// forward a project request to the host. These requests are disallowed
1875/// for guests.
1876async fn forward_mutating_project_request<T>(
1877    request: T,
1878    response: Response<T>,
1879    session: Session,
1880) -> Result<()>
1881where
1882    T: EntityMessage + RequestMessage,
1883{
1884    let project_id = ProjectId::from_proto(request.remote_entity_id());
1885    let host_connection_id = session
1886        .db()
1887        .await
1888        .host_for_mutating_project_request(project_id, session.connection_id)
1889        .await?;
1890    let payload = session
1891        .peer
1892        .forward_request(session.connection_id, host_connection_id, request)
1893        .await?;
1894    response.send(payload)?;
1895    Ok(())
1896}
1897
1898/// Notify other participants that a new buffer has been created
1899async fn create_buffer_for_peer(
1900    request: proto::CreateBufferForPeer,
1901    session: Session,
1902) -> Result<()> {
1903    session
1904        .db()
1905        .await
1906        .check_user_is_project_host(
1907            ProjectId::from_proto(request.project_id),
1908            session.connection_id,
1909        )
1910        .await?;
1911    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1912    session
1913        .peer
1914        .forward_send(session.connection_id, peer_id.into(), request)?;
1915    Ok(())
1916}
1917
1918/// Notify other participants that a buffer has been updated. This is
1919/// allowed for guests as long as the update is limited to selections.
1920async fn update_buffer(
1921    request: proto::UpdateBuffer,
1922    response: Response<proto::UpdateBuffer>,
1923    session: Session,
1924) -> Result<()> {
1925    let project_id = ProjectId::from_proto(request.project_id);
1926    let mut guest_connection_ids;
1927    let mut host_connection_id = None;
1928
1929    let mut requires_write_permission = false;
1930
1931    for op in request.operations.iter() {
1932        match op.variant {
1933            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1934            Some(_) => requires_write_permission = true,
1935        }
1936    }
1937
1938    {
1939        let collaborators = session
1940            .db()
1941            .await
1942            .project_collaborators_for_buffer_update(
1943                project_id,
1944                session.connection_id,
1945                requires_write_permission,
1946            )
1947            .await?;
1948        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1949        for collaborator in collaborators.iter() {
1950            if collaborator.is_host {
1951                host_connection_id = Some(collaborator.connection_id);
1952            } else {
1953                guest_connection_ids.push(collaborator.connection_id);
1954            }
1955        }
1956    }
1957    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1958
1959    broadcast(
1960        Some(session.connection_id),
1961        guest_connection_ids,
1962        |connection_id| {
1963            session
1964                .peer
1965                .forward_send(session.connection_id, connection_id, request.clone())
1966        },
1967    );
1968    if host_connection_id != session.connection_id {
1969        session
1970            .peer
1971            .forward_request(session.connection_id, host_connection_id, request.clone())
1972            .await?;
1973    }
1974
1975    response.send(proto::Ack {})?;
1976    Ok(())
1977}
1978
1979/// Notify other participants that a project has been updated.
1980async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1981    request: T,
1982    session: Session,
1983) -> Result<()> {
1984    let project_id = ProjectId::from_proto(request.remote_entity_id());
1985    let project_connection_ids = session
1986        .db()
1987        .await
1988        .project_connection_ids(project_id, session.connection_id)
1989        .await?;
1990
1991    broadcast(
1992        Some(session.connection_id),
1993        project_connection_ids.iter().copied(),
1994        |connection_id| {
1995            session
1996                .peer
1997                .forward_send(session.connection_id, connection_id, request.clone())
1998        },
1999    );
2000    Ok(())
2001}
2002
2003/// Start following another user in a call.
2004async fn follow(
2005    request: proto::Follow,
2006    response: Response<proto::Follow>,
2007    session: Session,
2008) -> Result<()> {
2009    let room_id = RoomId::from_proto(request.room_id);
2010    let project_id = request.project_id.map(ProjectId::from_proto);
2011    let leader_id = request
2012        .leader_id
2013        .ok_or_else(|| anyhow!("invalid leader id"))?
2014        .into();
2015    let follower_id = session.connection_id;
2016
2017    session
2018        .db()
2019        .await
2020        .check_room_participants(room_id, leader_id, session.connection_id)
2021        .await?;
2022
2023    let response_payload = session
2024        .peer
2025        .forward_request(session.connection_id, leader_id, request)
2026        .await?;
2027    response.send(response_payload)?;
2028
2029    if let Some(project_id) = project_id {
2030        let room = session
2031            .db()
2032            .await
2033            .follow(room_id, project_id, leader_id, follower_id)
2034            .await?;
2035        room_updated(&room, &session.peer);
2036    }
2037
2038    Ok(())
2039}
2040
2041/// Stop following another user in a call.
2042async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2043    let room_id = RoomId::from_proto(request.room_id);
2044    let project_id = request.project_id.map(ProjectId::from_proto);
2045    let leader_id = request
2046        .leader_id
2047        .ok_or_else(|| anyhow!("invalid leader id"))?
2048        .into();
2049    let follower_id = session.connection_id;
2050
2051    session
2052        .db()
2053        .await
2054        .check_room_participants(room_id, leader_id, session.connection_id)
2055        .await?;
2056
2057    session
2058        .peer
2059        .forward_send(session.connection_id, leader_id, request)?;
2060
2061    if let Some(project_id) = project_id {
2062        let room = session
2063            .db()
2064            .await
2065            .unfollow(room_id, project_id, leader_id, follower_id)
2066            .await?;
2067        room_updated(&room, &session.peer);
2068    }
2069
2070    Ok(())
2071}
2072
2073/// Notify everyone following you of your current location.
2074async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2075    let room_id = RoomId::from_proto(request.room_id);
2076    let database = session.db.lock().await;
2077
2078    let connection_ids = if let Some(project_id) = request.project_id {
2079        let project_id = ProjectId::from_proto(project_id);
2080        database
2081            .project_connection_ids(project_id, session.connection_id)
2082            .await?
2083    } else {
2084        database
2085            .room_connection_ids(room_id, session.connection_id)
2086            .await?
2087    };
2088
2089    // For now, don't send view update messages back to that view's current leader.
2090    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2091        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2092        _ => None,
2093    });
2094
2095    for follower_peer_id in request.follower_ids.iter().copied() {
2096        let follower_connection_id = follower_peer_id.into();
2097        if Some(follower_peer_id) != connection_id_to_omit
2098            && connection_ids.contains(&follower_connection_id)
2099        {
2100            session.peer.forward_send(
2101                session.connection_id,
2102                follower_connection_id,
2103                request.clone(),
2104            )?;
2105        }
2106    }
2107    Ok(())
2108}
2109
2110/// Get public data about users.
2111async fn get_users(
2112    request: proto::GetUsers,
2113    response: Response<proto::GetUsers>,
2114    session: Session,
2115) -> Result<()> {
2116    let user_ids = request
2117        .user_ids
2118        .into_iter()
2119        .map(UserId::from_proto)
2120        .collect();
2121    let users = session
2122        .db()
2123        .await
2124        .get_users_by_ids(user_ids)
2125        .await?
2126        .into_iter()
2127        .map(|user| proto::User {
2128            id: user.id.to_proto(),
2129            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2130            github_login: user.github_login,
2131        })
2132        .collect();
2133    response.send(proto::UsersResponse { users })?;
2134    Ok(())
2135}
2136
2137/// Search for users (to invite) buy Github login
2138async fn fuzzy_search_users(
2139    request: proto::FuzzySearchUsers,
2140    response: Response<proto::FuzzySearchUsers>,
2141    session: Session,
2142) -> Result<()> {
2143    let query = request.query;
2144    let users = match query.len() {
2145        0 => vec![],
2146        1 | 2 => session
2147            .db()
2148            .await
2149            .get_user_by_github_login(&query)
2150            .await?
2151            .into_iter()
2152            .collect(),
2153        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2154    };
2155    let users = users
2156        .into_iter()
2157        .filter(|user| user.id != session.user_id)
2158        .map(|user| proto::User {
2159            id: user.id.to_proto(),
2160            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2161            github_login: user.github_login,
2162        })
2163        .collect();
2164    response.send(proto::UsersResponse { users })?;
2165    Ok(())
2166}
2167
2168/// Send a contact request to another user.
2169async fn request_contact(
2170    request: proto::RequestContact,
2171    response: Response<proto::RequestContact>,
2172    session: Session,
2173) -> Result<()> {
2174    let requester_id = session.user_id;
2175    let responder_id = UserId::from_proto(request.responder_id);
2176    if requester_id == responder_id {
2177        return Err(anyhow!("cannot add yourself as a contact"))?;
2178    }
2179
2180    let notifications = session
2181        .db()
2182        .await
2183        .send_contact_request(requester_id, responder_id)
2184        .await?;
2185
2186    // Update outgoing contact requests of requester
2187    let mut update = proto::UpdateContacts::default();
2188    update.outgoing_requests.push(responder_id.to_proto());
2189    for connection_id in session
2190        .connection_pool()
2191        .await
2192        .user_connection_ids(requester_id)
2193    {
2194        session.peer.send(connection_id, update.clone())?;
2195    }
2196
2197    // Update incoming contact requests of responder
2198    let mut update = proto::UpdateContacts::default();
2199    update
2200        .incoming_requests
2201        .push(proto::IncomingContactRequest {
2202            requester_id: requester_id.to_proto(),
2203        });
2204    let connection_pool = session.connection_pool().await;
2205    for connection_id in connection_pool.user_connection_ids(responder_id) {
2206        session.peer.send(connection_id, update.clone())?;
2207    }
2208
2209    send_notifications(&*connection_pool, &session.peer, notifications);
2210
2211    response.send(proto::Ack {})?;
2212    Ok(())
2213}
2214
2215/// Accept or decline a contact request
2216async fn respond_to_contact_request(
2217    request: proto::RespondToContactRequest,
2218    response: Response<proto::RespondToContactRequest>,
2219    session: Session,
2220) -> Result<()> {
2221    let responder_id = session.user_id;
2222    let requester_id = UserId::from_proto(request.requester_id);
2223    let db = session.db().await;
2224    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2225        db.dismiss_contact_notification(responder_id, requester_id)
2226            .await?;
2227    } else {
2228        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2229
2230        let notifications = db
2231            .respond_to_contact_request(responder_id, requester_id, accept)
2232            .await?;
2233        let requester_busy = db.is_user_busy(requester_id).await?;
2234        let responder_busy = db.is_user_busy(responder_id).await?;
2235
2236        let pool = session.connection_pool().await;
2237        // Update responder with new contact
2238        let mut update = proto::UpdateContacts::default();
2239        if accept {
2240            update
2241                .contacts
2242                .push(contact_for_user(requester_id, requester_busy, &pool));
2243        }
2244        update
2245            .remove_incoming_requests
2246            .push(requester_id.to_proto());
2247        for connection_id in pool.user_connection_ids(responder_id) {
2248            session.peer.send(connection_id, update.clone())?;
2249        }
2250
2251        // Update requester with new contact
2252        let mut update = proto::UpdateContacts::default();
2253        if accept {
2254            update
2255                .contacts
2256                .push(contact_for_user(responder_id, responder_busy, &pool));
2257        }
2258        update
2259            .remove_outgoing_requests
2260            .push(responder_id.to_proto());
2261
2262        for connection_id in pool.user_connection_ids(requester_id) {
2263            session.peer.send(connection_id, update.clone())?;
2264        }
2265
2266        send_notifications(&*pool, &session.peer, notifications);
2267    }
2268
2269    response.send(proto::Ack {})?;
2270    Ok(())
2271}
2272
2273/// Remove a contact.
2274async fn remove_contact(
2275    request: proto::RemoveContact,
2276    response: Response<proto::RemoveContact>,
2277    session: Session,
2278) -> Result<()> {
2279    let requester_id = session.user_id;
2280    let responder_id = UserId::from_proto(request.user_id);
2281    let db = session.db().await;
2282    let (contact_accepted, deleted_notification_id) =
2283        db.remove_contact(requester_id, responder_id).await?;
2284
2285    let pool = session.connection_pool().await;
2286    // Update outgoing contact requests of requester
2287    let mut update = proto::UpdateContacts::default();
2288    if contact_accepted {
2289        update.remove_contacts.push(responder_id.to_proto());
2290    } else {
2291        update
2292            .remove_outgoing_requests
2293            .push(responder_id.to_proto());
2294    }
2295    for connection_id in pool.user_connection_ids(requester_id) {
2296        session.peer.send(connection_id, update.clone())?;
2297    }
2298
2299    // Update incoming contact requests of responder
2300    let mut update = proto::UpdateContacts::default();
2301    if contact_accepted {
2302        update.remove_contacts.push(requester_id.to_proto());
2303    } else {
2304        update
2305            .remove_incoming_requests
2306            .push(requester_id.to_proto());
2307    }
2308    for connection_id in pool.user_connection_ids(responder_id) {
2309        session.peer.send(connection_id, update.clone())?;
2310        if let Some(notification_id) = deleted_notification_id {
2311            session.peer.send(
2312                connection_id,
2313                proto::DeleteNotification {
2314                    notification_id: notification_id.to_proto(),
2315                },
2316            )?;
2317        }
2318    }
2319
2320    response.send(proto::Ack {})?;
2321    Ok(())
2322}
2323
2324/// Creates a new channel.
2325async fn create_channel(
2326    request: proto::CreateChannel,
2327    response: Response<proto::CreateChannel>,
2328    session: Session,
2329) -> Result<()> {
2330    let db = session.db().await;
2331
2332    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2333    let (channel, owner, channel_members) = db
2334        .create_channel(&request.name, parent_id, session.user_id)
2335        .await?;
2336
2337    response.send(proto::CreateChannelResponse {
2338        channel: Some(channel.to_proto()),
2339        parent_id: request.parent_id,
2340    })?;
2341
2342    let connection_pool = session.connection_pool().await;
2343    if let Some(owner) = owner {
2344        let update = proto::UpdateUserChannels {
2345            channel_memberships: vec![proto::ChannelMembership {
2346                channel_id: owner.channel_id.to_proto(),
2347                role: owner.role.into(),
2348            }],
2349            ..Default::default()
2350        };
2351        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2352            session.peer.send(connection_id, update.clone())?;
2353        }
2354    }
2355
2356    for channel_member in channel_members {
2357        if !channel_member.role.can_see_channel(channel.visibility) {
2358            continue;
2359        }
2360
2361        let update = proto::UpdateChannels {
2362            channels: vec![channel.to_proto()],
2363            ..Default::default()
2364        };
2365        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2366            session.peer.send(connection_id, update.clone())?;
2367        }
2368    }
2369
2370    Ok(())
2371}
2372
2373/// Delete a channel
2374async fn delete_channel(
2375    request: proto::DeleteChannel,
2376    response: Response<proto::DeleteChannel>,
2377    session: Session,
2378) -> Result<()> {
2379    let db = session.db().await;
2380
2381    let channel_id = request.channel_id;
2382    let (removed_channels, member_ids) = db
2383        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2384        .await?;
2385    response.send(proto::Ack {})?;
2386
2387    // Notify members of removed channels
2388    let mut update = proto::UpdateChannels::default();
2389    update
2390        .delete_channels
2391        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2392
2393    let connection_pool = session.connection_pool().await;
2394    for member_id in member_ids {
2395        for connection_id in connection_pool.user_connection_ids(member_id) {
2396            session.peer.send(connection_id, update.clone())?;
2397        }
2398    }
2399
2400    Ok(())
2401}
2402
2403/// Invite someone to join a channel.
2404async fn invite_channel_member(
2405    request: proto::InviteChannelMember,
2406    response: Response<proto::InviteChannelMember>,
2407    session: Session,
2408) -> Result<()> {
2409    let db = session.db().await;
2410    let channel_id = ChannelId::from_proto(request.channel_id);
2411    let invitee_id = UserId::from_proto(request.user_id);
2412    let InviteMemberResult {
2413        channel,
2414        notifications,
2415    } = db
2416        .invite_channel_member(
2417            channel_id,
2418            invitee_id,
2419            session.user_id,
2420            request.role().into(),
2421        )
2422        .await?;
2423
2424    let update = proto::UpdateChannels {
2425        channel_invitations: vec![channel.to_proto()],
2426        ..Default::default()
2427    };
2428
2429    let connection_pool = session.connection_pool().await;
2430    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2431        session.peer.send(connection_id, update.clone())?;
2432    }
2433
2434    send_notifications(&*connection_pool, &session.peer, notifications);
2435
2436    response.send(proto::Ack {})?;
2437    Ok(())
2438}
2439
2440/// remove someone from a channel
2441async fn remove_channel_member(
2442    request: proto::RemoveChannelMember,
2443    response: Response<proto::RemoveChannelMember>,
2444    session: Session,
2445) -> Result<()> {
2446    let db = session.db().await;
2447    let channel_id = ChannelId::from_proto(request.channel_id);
2448    let member_id = UserId::from_proto(request.user_id);
2449
2450    let RemoveChannelMemberResult {
2451        membership_update,
2452        notification_id,
2453    } = db
2454        .remove_channel_member(channel_id, member_id, session.user_id)
2455        .await?;
2456
2457    let connection_pool = &session.connection_pool().await;
2458    notify_membership_updated(
2459        &connection_pool,
2460        membership_update,
2461        member_id,
2462        &session.peer,
2463    );
2464    for connection_id in connection_pool.user_connection_ids(member_id) {
2465        if let Some(notification_id) = notification_id {
2466            session
2467                .peer
2468                .send(
2469                    connection_id,
2470                    proto::DeleteNotification {
2471                        notification_id: notification_id.to_proto(),
2472                    },
2473                )
2474                .trace_err();
2475        }
2476    }
2477
2478    response.send(proto::Ack {})?;
2479    Ok(())
2480}
2481
2482/// Toggle the channel between public and private.
2483/// Care is taken to maintain the invariant that public channels only descend from public channels,
2484/// (though members-only channels can appear at any point in the hierarchy).
2485async fn set_channel_visibility(
2486    request: proto::SetChannelVisibility,
2487    response: Response<proto::SetChannelVisibility>,
2488    session: Session,
2489) -> Result<()> {
2490    let db = session.db().await;
2491    let channel_id = ChannelId::from_proto(request.channel_id);
2492    let visibility = request.visibility().into();
2493
2494    let (channel, channel_members) = db
2495        .set_channel_visibility(channel_id, visibility, session.user_id)
2496        .await?;
2497
2498    let connection_pool = session.connection_pool().await;
2499    for member in channel_members {
2500        let update = if member.role.can_see_channel(channel.visibility) {
2501            proto::UpdateChannels {
2502                channels: vec![channel.to_proto()],
2503                ..Default::default()
2504            }
2505        } else {
2506            proto::UpdateChannels {
2507                delete_channels: vec![channel.id.to_proto()],
2508                ..Default::default()
2509            }
2510        };
2511
2512        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2513            session.peer.send(connection_id, update.clone())?;
2514        }
2515    }
2516
2517    response.send(proto::Ack {})?;
2518    Ok(())
2519}
2520
2521/// Alter the role for a user in the channel.
2522async fn set_channel_member_role(
2523    request: proto::SetChannelMemberRole,
2524    response: Response<proto::SetChannelMemberRole>,
2525    session: Session,
2526) -> Result<()> {
2527    let db = session.db().await;
2528    let channel_id = ChannelId::from_proto(request.channel_id);
2529    let member_id = UserId::from_proto(request.user_id);
2530    let result = db
2531        .set_channel_member_role(
2532            channel_id,
2533            session.user_id,
2534            member_id,
2535            request.role().into(),
2536        )
2537        .await?;
2538
2539    match result {
2540        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2541            let connection_pool = session.connection_pool().await;
2542            notify_membership_updated(
2543                &connection_pool,
2544                membership_update,
2545                member_id,
2546                &session.peer,
2547            )
2548        }
2549        db::SetMemberRoleResult::InviteUpdated(channel) => {
2550            let update = proto::UpdateChannels {
2551                channel_invitations: vec![channel.to_proto()],
2552                ..Default::default()
2553            };
2554
2555            for connection_id in session
2556                .connection_pool()
2557                .await
2558                .user_connection_ids(member_id)
2559            {
2560                session.peer.send(connection_id, update.clone())?;
2561            }
2562        }
2563    }
2564
2565    response.send(proto::Ack {})?;
2566    Ok(())
2567}
2568
2569/// Change the name of a channel
2570async fn rename_channel(
2571    request: proto::RenameChannel,
2572    response: Response<proto::RenameChannel>,
2573    session: Session,
2574) -> Result<()> {
2575    let db = session.db().await;
2576    let channel_id = ChannelId::from_proto(request.channel_id);
2577    let (channel, channel_members) = db
2578        .rename_channel(channel_id, session.user_id, &request.name)
2579        .await?;
2580
2581    response.send(proto::RenameChannelResponse {
2582        channel: Some(channel.to_proto()),
2583    })?;
2584
2585    let connection_pool = session.connection_pool().await;
2586    for channel_member in channel_members {
2587        if !channel_member.role.can_see_channel(channel.visibility) {
2588            continue;
2589        }
2590        let update = proto::UpdateChannels {
2591            channels: vec![channel.to_proto()],
2592            ..Default::default()
2593        };
2594        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2595            session.peer.send(connection_id, update.clone())?;
2596        }
2597    }
2598
2599    Ok(())
2600}
2601
2602/// Move a channel to a new parent.
2603async fn move_channel(
2604    request: proto::MoveChannel,
2605    response: Response<proto::MoveChannel>,
2606    session: Session,
2607) -> Result<()> {
2608    let channel_id = ChannelId::from_proto(request.channel_id);
2609    let to = ChannelId::from_proto(request.to);
2610
2611    let (channels, channel_members) = session
2612        .db()
2613        .await
2614        .move_channel(channel_id, to, session.user_id)
2615        .await?;
2616
2617    let connection_pool = session.connection_pool().await;
2618    for member in channel_members {
2619        let channels = channels
2620            .iter()
2621            .filter_map(|channel| {
2622                if member.role.can_see_channel(channel.visibility) {
2623                    Some(channel.to_proto())
2624                } else {
2625                    None
2626                }
2627            })
2628            .collect::<Vec<_>>();
2629        if channels.is_empty() {
2630            continue;
2631        }
2632
2633        let update = proto::UpdateChannels {
2634            channels,
2635            ..Default::default()
2636        };
2637
2638        for connection_id in connection_pool.user_connection_ids(member.user_id) {
2639            session.peer.send(connection_id, update.clone())?;
2640        }
2641    }
2642
2643    response.send(Ack {})?;
2644    Ok(())
2645}
2646
2647/// Get the list of channel members
2648async fn get_channel_members(
2649    request: proto::GetChannelMembers,
2650    response: Response<proto::GetChannelMembers>,
2651    session: Session,
2652) -> Result<()> {
2653    let db = session.db().await;
2654    let channel_id = ChannelId::from_proto(request.channel_id);
2655    let members = db
2656        .get_channel_participant_details(channel_id, session.user_id)
2657        .await?;
2658    response.send(proto::GetChannelMembersResponse { members })?;
2659    Ok(())
2660}
2661
2662/// Accept or decline a channel invitation.
2663async fn respond_to_channel_invite(
2664    request: proto::RespondToChannelInvite,
2665    response: Response<proto::RespondToChannelInvite>,
2666    session: Session,
2667) -> Result<()> {
2668    let db = session.db().await;
2669    let channel_id = ChannelId::from_proto(request.channel_id);
2670    let RespondToChannelInvite {
2671        membership_update,
2672        notifications,
2673    } = db
2674        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2675        .await?;
2676
2677    let connection_pool = session.connection_pool().await;
2678    if let Some(membership_update) = membership_update {
2679        notify_membership_updated(
2680            &connection_pool,
2681            membership_update,
2682            session.user_id,
2683            &session.peer,
2684        );
2685    } else {
2686        let update = proto::UpdateChannels {
2687            remove_channel_invitations: vec![channel_id.to_proto()],
2688            ..Default::default()
2689        };
2690
2691        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2692            session.peer.send(connection_id, update.clone())?;
2693        }
2694    };
2695
2696    send_notifications(&*connection_pool, &session.peer, notifications);
2697
2698    response.send(proto::Ack {})?;
2699
2700    Ok(())
2701}
2702
2703/// Join the channels' room
2704async fn join_channel(
2705    request: proto::JoinChannel,
2706    response: Response<proto::JoinChannel>,
2707    session: Session,
2708) -> Result<()> {
2709    let channel_id = ChannelId::from_proto(request.channel_id);
2710    join_channel_internal(channel_id, Box::new(response), session).await
2711}
2712
2713trait JoinChannelInternalResponse {
2714    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2715}
2716impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2717    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2718        Response::<proto::JoinChannel>::send(self, result)
2719    }
2720}
2721impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2722    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2723        Response::<proto::JoinRoom>::send(self, result)
2724    }
2725}
2726
2727async fn join_channel_internal(
2728    channel_id: ChannelId,
2729    response: Box<impl JoinChannelInternalResponse>,
2730    session: Session,
2731) -> Result<()> {
2732    let joined_room = {
2733        leave_room_for_session(&session).await?;
2734        let db = session.db().await;
2735
2736        let (joined_room, membership_updated, role) = db
2737            .join_channel(
2738                channel_id,
2739                session.user_id,
2740                session.connection_id,
2741                session.zed_environment.as_ref(),
2742            )
2743            .await?;
2744
2745        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2746            let (can_publish, token) = if role == ChannelRole::Guest {
2747                (
2748                    false,
2749                    live_kit
2750                        .guest_token(
2751                            &joined_room.room.live_kit_room,
2752                            &session.user_id.to_string(),
2753                        )
2754                        .trace_err()?,
2755                )
2756            } else {
2757                (
2758                    true,
2759                    live_kit
2760                        .room_token(
2761                            &joined_room.room.live_kit_room,
2762                            &session.user_id.to_string(),
2763                        )
2764                        .trace_err()?,
2765                )
2766            };
2767
2768            Some(LiveKitConnectionInfo {
2769                server_url: live_kit.url().into(),
2770                token,
2771                can_publish,
2772            })
2773        });
2774
2775        response.send(proto::JoinRoomResponse {
2776            room: Some(joined_room.room.clone()),
2777            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2778            live_kit_connection_info,
2779        })?;
2780
2781        let connection_pool = session.connection_pool().await;
2782        if let Some(membership_updated) = membership_updated {
2783            notify_membership_updated(
2784                &connection_pool,
2785                membership_updated,
2786                session.user_id,
2787                &session.peer,
2788            );
2789        }
2790
2791        room_updated(&joined_room.room, &session.peer);
2792
2793        joined_room
2794    };
2795
2796    channel_updated(
2797        channel_id,
2798        &joined_room.room,
2799        &joined_room.channel_members,
2800        &session.peer,
2801        &*session.connection_pool().await,
2802    );
2803
2804    update_user_contacts(session.user_id, &session).await?;
2805    Ok(())
2806}
2807
2808/// Start editing the channel notes
2809async fn join_channel_buffer(
2810    request: proto::JoinChannelBuffer,
2811    response: Response<proto::JoinChannelBuffer>,
2812    session: Session,
2813) -> Result<()> {
2814    let db = session.db().await;
2815    let channel_id = ChannelId::from_proto(request.channel_id);
2816
2817    let open_response = db
2818        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2819        .await?;
2820
2821    let collaborators = open_response.collaborators.clone();
2822    response.send(open_response)?;
2823
2824    let update = UpdateChannelBufferCollaborators {
2825        channel_id: channel_id.to_proto(),
2826        collaborators: collaborators.clone(),
2827    };
2828    channel_buffer_updated(
2829        session.connection_id,
2830        collaborators
2831            .iter()
2832            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2833        &update,
2834        &session.peer,
2835    );
2836
2837    Ok(())
2838}
2839
2840/// Edit the channel notes
2841async fn update_channel_buffer(
2842    request: proto::UpdateChannelBuffer,
2843    session: Session,
2844) -> Result<()> {
2845    let db = session.db().await;
2846    let channel_id = ChannelId::from_proto(request.channel_id);
2847
2848    let (collaborators, non_collaborators, epoch, version) = db
2849        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2850        .await?;
2851
2852    channel_buffer_updated(
2853        session.connection_id,
2854        collaborators,
2855        &proto::UpdateChannelBuffer {
2856            channel_id: channel_id.to_proto(),
2857            operations: request.operations,
2858        },
2859        &session.peer,
2860    );
2861
2862    let pool = &*session.connection_pool().await;
2863
2864    broadcast(
2865        None,
2866        non_collaborators
2867            .iter()
2868            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2869        |peer_id| {
2870            session.peer.send(
2871                peer_id.into(),
2872                proto::UpdateChannels {
2873                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2874                        channel_id: channel_id.to_proto(),
2875                        epoch: epoch as u64,
2876                        version: version.clone(),
2877                    }],
2878                    ..Default::default()
2879                },
2880            )
2881        },
2882    );
2883
2884    Ok(())
2885}
2886
2887/// Rejoin the channel notes after a connection blip
2888async fn rejoin_channel_buffers(
2889    request: proto::RejoinChannelBuffers,
2890    response: Response<proto::RejoinChannelBuffers>,
2891    session: Session,
2892) -> Result<()> {
2893    let db = session.db().await;
2894    let buffers = db
2895        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2896        .await?;
2897
2898    for rejoined_buffer in &buffers {
2899        let collaborators_to_notify = rejoined_buffer
2900            .buffer
2901            .collaborators
2902            .iter()
2903            .filter_map(|c| Some(c.peer_id?.into()));
2904        channel_buffer_updated(
2905            session.connection_id,
2906            collaborators_to_notify,
2907            &proto::UpdateChannelBufferCollaborators {
2908                channel_id: rejoined_buffer.buffer.channel_id,
2909                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2910            },
2911            &session.peer,
2912        );
2913    }
2914
2915    response.send(proto::RejoinChannelBuffersResponse {
2916        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2917    })?;
2918
2919    Ok(())
2920}
2921
2922/// Stop editing the channel notes
2923async fn leave_channel_buffer(
2924    request: proto::LeaveChannelBuffer,
2925    response: Response<proto::LeaveChannelBuffer>,
2926    session: Session,
2927) -> Result<()> {
2928    let db = session.db().await;
2929    let channel_id = ChannelId::from_proto(request.channel_id);
2930
2931    let left_buffer = db
2932        .leave_channel_buffer(channel_id, session.connection_id)
2933        .await?;
2934
2935    response.send(Ack {})?;
2936
2937    channel_buffer_updated(
2938        session.connection_id,
2939        left_buffer.connections,
2940        &proto::UpdateChannelBufferCollaborators {
2941            channel_id: channel_id.to_proto(),
2942            collaborators: left_buffer.collaborators,
2943        },
2944        &session.peer,
2945    );
2946
2947    Ok(())
2948}
2949
2950fn channel_buffer_updated<T: EnvelopedMessage>(
2951    sender_id: ConnectionId,
2952    collaborators: impl IntoIterator<Item = ConnectionId>,
2953    message: &T,
2954    peer: &Peer,
2955) {
2956    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2957        peer.send(peer_id.into(), message.clone())
2958    });
2959}
2960
2961fn send_notifications(
2962    connection_pool: &ConnectionPool,
2963    peer: &Peer,
2964    notifications: db::NotificationBatch,
2965) {
2966    for (user_id, notification) in notifications {
2967        for connection_id in connection_pool.user_connection_ids(user_id) {
2968            if let Err(error) = peer.send(
2969                connection_id,
2970                proto::AddNotification {
2971                    notification: Some(notification.clone()),
2972                },
2973            ) {
2974                tracing::error!(
2975                    "failed to send notification to {:?} {}",
2976                    connection_id,
2977                    error
2978                );
2979            }
2980        }
2981    }
2982}
2983
2984/// Send a message to the channel
2985async fn send_channel_message(
2986    request: proto::SendChannelMessage,
2987    response: Response<proto::SendChannelMessage>,
2988    session: Session,
2989) -> Result<()> {
2990    // Validate the message body.
2991    let body = request.body.trim().to_string();
2992    if body.len() > MAX_MESSAGE_LEN {
2993        return Err(anyhow!("message is too long"))?;
2994    }
2995    if body.is_empty() {
2996        return Err(anyhow!("message can't be blank"))?;
2997    }
2998
2999    // TODO: adjust mentions if body is trimmed
3000
3001    let timestamp = OffsetDateTime::now_utc();
3002    let nonce = request
3003        .nonce
3004        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3005
3006    let channel_id = ChannelId::from_proto(request.channel_id);
3007    let CreatedChannelMessage {
3008        message_id,
3009        participant_connection_ids,
3010        channel_members,
3011        notifications,
3012    } = session
3013        .db()
3014        .await
3015        .create_channel_message(
3016            channel_id,
3017            session.user_id,
3018            &body,
3019            &request.mentions,
3020            timestamp,
3021            nonce.clone().into(),
3022            match request.reply_to_message_id {
3023                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3024                None => None,
3025            },
3026        )
3027        .await?;
3028    let message = proto::ChannelMessage {
3029        sender_id: session.user_id.to_proto(),
3030        id: message_id.to_proto(),
3031        body,
3032        mentions: request.mentions,
3033        timestamp: timestamp.unix_timestamp() as u64,
3034        nonce: Some(nonce),
3035        reply_to_message_id: request.reply_to_message_id,
3036    };
3037    broadcast(
3038        Some(session.connection_id),
3039        participant_connection_ids,
3040        |connection| {
3041            session.peer.send(
3042                connection,
3043                proto::ChannelMessageSent {
3044                    channel_id: channel_id.to_proto(),
3045                    message: Some(message.clone()),
3046                },
3047            )
3048        },
3049    );
3050    response.send(proto::SendChannelMessageResponse {
3051        message: Some(message),
3052    })?;
3053
3054    let pool = &*session.connection_pool().await;
3055    broadcast(
3056        None,
3057        channel_members
3058            .iter()
3059            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3060        |peer_id| {
3061            session.peer.send(
3062                peer_id.into(),
3063                proto::UpdateChannels {
3064                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3065                        channel_id: channel_id.to_proto(),
3066                        message_id: message_id.to_proto(),
3067                    }],
3068                    ..Default::default()
3069                },
3070            )
3071        },
3072    );
3073    send_notifications(pool, &session.peer, notifications);
3074
3075    Ok(())
3076}
3077
3078/// Delete a channel message
3079async fn remove_channel_message(
3080    request: proto::RemoveChannelMessage,
3081    response: Response<proto::RemoveChannelMessage>,
3082    session: Session,
3083) -> Result<()> {
3084    let channel_id = ChannelId::from_proto(request.channel_id);
3085    let message_id = MessageId::from_proto(request.message_id);
3086    let connection_ids = session
3087        .db()
3088        .await
3089        .remove_channel_message(channel_id, message_id, session.user_id)
3090        .await?;
3091    broadcast(Some(session.connection_id), connection_ids, |connection| {
3092        session.peer.send(connection, request.clone())
3093    });
3094    response.send(proto::Ack {})?;
3095    Ok(())
3096}
3097
3098/// Mark a channel message as read
3099async fn acknowledge_channel_message(
3100    request: proto::AckChannelMessage,
3101    session: Session,
3102) -> Result<()> {
3103    let channel_id = ChannelId::from_proto(request.channel_id);
3104    let message_id = MessageId::from_proto(request.message_id);
3105    let notifications = session
3106        .db()
3107        .await
3108        .observe_channel_message(channel_id, session.user_id, message_id)
3109        .await?;
3110    send_notifications(
3111        &*session.connection_pool().await,
3112        &session.peer,
3113        notifications,
3114    );
3115    Ok(())
3116}
3117
3118/// Mark a buffer version as synced
3119async fn acknowledge_buffer_version(
3120    request: proto::AckBufferOperation,
3121    session: Session,
3122) -> Result<()> {
3123    let buffer_id = BufferId::from_proto(request.buffer_id);
3124    session
3125        .db()
3126        .await
3127        .observe_buffer_version(
3128            buffer_id,
3129            session.user_id,
3130            request.epoch as i32,
3131            &request.version,
3132        )
3133        .await?;
3134    Ok(())
3135}
3136
3137/// Start receiving chat updates for a channel
3138async fn join_channel_chat(
3139    request: proto::JoinChannelChat,
3140    response: Response<proto::JoinChannelChat>,
3141    session: Session,
3142) -> Result<()> {
3143    let channel_id = ChannelId::from_proto(request.channel_id);
3144
3145    let db = session.db().await;
3146    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3147        .await?;
3148    let messages = db
3149        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3150        .await?;
3151    response.send(proto::JoinChannelChatResponse {
3152        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3153        messages,
3154    })?;
3155    Ok(())
3156}
3157
3158/// Stop receiving chat updates for a channel
3159async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3160    let channel_id = ChannelId::from_proto(request.channel_id);
3161    session
3162        .db()
3163        .await
3164        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3165        .await?;
3166    Ok(())
3167}
3168
3169/// Retrieve the chat history for a channel
3170async fn get_channel_messages(
3171    request: proto::GetChannelMessages,
3172    response: Response<proto::GetChannelMessages>,
3173    session: Session,
3174) -> Result<()> {
3175    let channel_id = ChannelId::from_proto(request.channel_id);
3176    let messages = session
3177        .db()
3178        .await
3179        .get_channel_messages(
3180            channel_id,
3181            session.user_id,
3182            MESSAGE_COUNT_PER_PAGE,
3183            Some(MessageId::from_proto(request.before_message_id)),
3184        )
3185        .await?;
3186    response.send(proto::GetChannelMessagesResponse {
3187        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3188        messages,
3189    })?;
3190    Ok(())
3191}
3192
3193/// Retrieve specific chat messages
3194async fn get_channel_messages_by_id(
3195    request: proto::GetChannelMessagesById,
3196    response: Response<proto::GetChannelMessagesById>,
3197    session: Session,
3198) -> Result<()> {
3199    let message_ids = request
3200        .message_ids
3201        .iter()
3202        .map(|id| MessageId::from_proto(*id))
3203        .collect::<Vec<_>>();
3204    let messages = session
3205        .db()
3206        .await
3207        .get_channel_messages_by_id(session.user_id, &message_ids)
3208        .await?;
3209    response.send(proto::GetChannelMessagesResponse {
3210        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3211        messages,
3212    })?;
3213    Ok(())
3214}
3215
3216/// Retrieve the current users notifications
3217async fn get_notifications(
3218    request: proto::GetNotifications,
3219    response: Response<proto::GetNotifications>,
3220    session: Session,
3221) -> Result<()> {
3222    let notifications = session
3223        .db()
3224        .await
3225        .get_notifications(
3226            session.user_id,
3227            NOTIFICATION_COUNT_PER_PAGE,
3228            request
3229                .before_id
3230                .map(|id| db::NotificationId::from_proto(id)),
3231        )
3232        .await?;
3233    response.send(proto::GetNotificationsResponse {
3234        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3235        notifications,
3236    })?;
3237    Ok(())
3238}
3239
3240/// Mark notifications as read
3241async fn mark_notification_as_read(
3242    request: proto::MarkNotificationRead,
3243    response: Response<proto::MarkNotificationRead>,
3244    session: Session,
3245) -> Result<()> {
3246    let database = &session.db().await;
3247    let notifications = database
3248        .mark_notification_as_read_by_id(
3249            session.user_id,
3250            NotificationId::from_proto(request.notification_id),
3251        )
3252        .await?;
3253    send_notifications(
3254        &*session.connection_pool().await,
3255        &session.peer,
3256        notifications,
3257    );
3258    response.send(proto::Ack {})?;
3259    Ok(())
3260}
3261
3262/// Get the current users information
3263async fn get_private_user_info(
3264    _request: proto::GetPrivateUserInfo,
3265    response: Response<proto::GetPrivateUserInfo>,
3266    session: Session,
3267) -> Result<()> {
3268    let db = session.db().await;
3269
3270    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3271    let user = db
3272        .get_user_by_id(session.user_id)
3273        .await?
3274        .ok_or_else(|| anyhow!("user not found"))?;
3275    let flags = db.get_user_flags(session.user_id).await?;
3276
3277    response.send(proto::GetPrivateUserInfoResponse {
3278        metrics_id,
3279        staff: user.admin,
3280        flags,
3281    })?;
3282    Ok(())
3283}
3284
3285fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3286    match message {
3287        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3288        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3289        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3290        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3291        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3292            code: frame.code.into(),
3293            reason: frame.reason,
3294        })),
3295    }
3296}
3297
3298fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3299    match message {
3300        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3301        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3302        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3303        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3304        AxumMessage::Close(frame) => {
3305            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3306                code: frame.code.into(),
3307                reason: frame.reason,
3308            }))
3309        }
3310    }
3311}
3312
3313fn notify_membership_updated(
3314    connection_pool: &ConnectionPool,
3315    result: MembershipUpdated,
3316    user_id: UserId,
3317    peer: &Peer,
3318) {
3319    let user_channels_update = proto::UpdateUserChannels {
3320        channel_memberships: result
3321            .new_channels
3322            .channel_memberships
3323            .iter()
3324            .map(|cm| proto::ChannelMembership {
3325                channel_id: cm.channel_id.to_proto(),
3326                role: cm.role.into(),
3327            })
3328            .collect(),
3329        ..Default::default()
3330    };
3331    let mut update = build_channels_update(result.new_channels, vec![]);
3332    update.delete_channels = result
3333        .removed_channels
3334        .into_iter()
3335        .map(|id| id.to_proto())
3336        .collect();
3337    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3338
3339    for connection_id in connection_pool.user_connection_ids(user_id) {
3340        peer.send(connection_id, user_channels_update.clone())
3341            .trace_err();
3342        peer.send(connection_id, update.clone()).trace_err();
3343    }
3344}
3345
3346fn build_update_user_channels(
3347    memberships: &Vec<db::channel_member::Model>,
3348) -> proto::UpdateUserChannels {
3349    proto::UpdateUserChannels {
3350        channel_memberships: memberships
3351            .iter()
3352            .map(|m| proto::ChannelMembership {
3353                channel_id: m.channel_id.to_proto(),
3354                role: m.role.into(),
3355            })
3356            .collect(),
3357        ..Default::default()
3358    }
3359}
3360
3361fn build_channels_update(
3362    channels: ChannelsForUser,
3363    channel_invites: Vec<db::Channel>,
3364) -> proto::UpdateChannels {
3365    let mut update = proto::UpdateChannels::default();
3366
3367    for channel in channels.channels {
3368        update.channels.push(channel.to_proto());
3369    }
3370
3371    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3372    update.latest_channel_message_ids = channels.latest_channel_messages;
3373
3374    for (channel_id, participants) in channels.channel_participants {
3375        update
3376            .channel_participants
3377            .push(proto::ChannelParticipants {
3378                channel_id: channel_id.to_proto(),
3379                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3380            });
3381    }
3382
3383    for channel in channel_invites {
3384        update.channel_invitations.push(channel.to_proto());
3385    }
3386
3387    update
3388}
3389
3390fn build_initial_contacts_update(
3391    contacts: Vec<db::Contact>,
3392    pool: &ConnectionPool,
3393) -> proto::UpdateContacts {
3394    let mut update = proto::UpdateContacts::default();
3395
3396    for contact in contacts {
3397        match contact {
3398            db::Contact::Accepted { user_id, busy } => {
3399                update.contacts.push(contact_for_user(user_id, busy, &pool));
3400            }
3401            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3402            db::Contact::Incoming { user_id } => {
3403                update
3404                    .incoming_requests
3405                    .push(proto::IncomingContactRequest {
3406                        requester_id: user_id.to_proto(),
3407                    })
3408            }
3409        }
3410    }
3411
3412    update
3413}
3414
3415fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3416    proto::Contact {
3417        user_id: user_id.to_proto(),
3418        online: pool.is_user_online(user_id),
3419        busy,
3420    }
3421}
3422
3423fn room_updated(room: &proto::Room, peer: &Peer) {
3424    broadcast(
3425        None,
3426        room.participants
3427            .iter()
3428            .filter_map(|participant| Some(participant.peer_id?.into())),
3429        |peer_id| {
3430            peer.send(
3431                peer_id.into(),
3432                proto::RoomUpdated {
3433                    room: Some(room.clone()),
3434                },
3435            )
3436        },
3437    );
3438}
3439
3440fn channel_updated(
3441    channel_id: ChannelId,
3442    room: &proto::Room,
3443    channel_members: &[UserId],
3444    peer: &Peer,
3445    pool: &ConnectionPool,
3446) {
3447    let participants = room
3448        .participants
3449        .iter()
3450        .map(|p| p.user_id)
3451        .collect::<Vec<_>>();
3452
3453    broadcast(
3454        None,
3455        channel_members
3456            .iter()
3457            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3458        |peer_id| {
3459            peer.send(
3460                peer_id.into(),
3461                proto::UpdateChannels {
3462                    channel_participants: vec![proto::ChannelParticipants {
3463                        channel_id: channel_id.to_proto(),
3464                        participant_user_ids: participants.clone(),
3465                    }],
3466                    ..Default::default()
3467                },
3468            )
3469        },
3470    );
3471}
3472
3473async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3474    let db = session.db().await;
3475
3476    let contacts = db.get_contacts(user_id).await?;
3477    let busy = db.is_user_busy(user_id).await?;
3478
3479    let pool = session.connection_pool().await;
3480    let updated_contact = contact_for_user(user_id, busy, &pool);
3481    for contact in contacts {
3482        if let db::Contact::Accepted {
3483            user_id: contact_user_id,
3484            ..
3485        } = contact
3486        {
3487            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3488                session
3489                    .peer
3490                    .send(
3491                        contact_conn_id,
3492                        proto::UpdateContacts {
3493                            contacts: vec![updated_contact.clone()],
3494                            remove_contacts: Default::default(),
3495                            incoming_requests: Default::default(),
3496                            remove_incoming_requests: Default::default(),
3497                            outgoing_requests: Default::default(),
3498                            remove_outgoing_requests: Default::default(),
3499                        },
3500                    )
3501                    .trace_err();
3502            }
3503        }
3504    }
3505    Ok(())
3506}
3507
3508async fn leave_room_for_session(session: &Session) -> Result<()> {
3509    let mut contacts_to_update = HashSet::default();
3510
3511    let room_id;
3512    let canceled_calls_to_user_ids;
3513    let live_kit_room;
3514    let delete_live_kit_room;
3515    let room;
3516    let channel_members;
3517    let channel_id;
3518
3519    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3520        contacts_to_update.insert(session.user_id);
3521
3522        for project in left_room.left_projects.values() {
3523            project_left(project, session);
3524        }
3525
3526        room_id = RoomId::from_proto(left_room.room.id);
3527        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3528        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3529        delete_live_kit_room = left_room.deleted;
3530        room = mem::take(&mut left_room.room);
3531        channel_members = mem::take(&mut left_room.channel_members);
3532        channel_id = left_room.channel_id;
3533
3534        room_updated(&room, &session.peer);
3535    } else {
3536        return Ok(());
3537    }
3538
3539    if let Some(channel_id) = channel_id {
3540        channel_updated(
3541            channel_id,
3542            &room,
3543            &channel_members,
3544            &session.peer,
3545            &*session.connection_pool().await,
3546        );
3547    }
3548
3549    {
3550        let pool = session.connection_pool().await;
3551        for canceled_user_id in canceled_calls_to_user_ids {
3552            for connection_id in pool.user_connection_ids(canceled_user_id) {
3553                session
3554                    .peer
3555                    .send(
3556                        connection_id,
3557                        proto::CallCanceled {
3558                            room_id: room_id.to_proto(),
3559                        },
3560                    )
3561                    .trace_err();
3562            }
3563            contacts_to_update.insert(canceled_user_id);
3564        }
3565    }
3566
3567    for contact_user_id in contacts_to_update {
3568        update_user_contacts(contact_user_id, &session).await?;
3569    }
3570
3571    if let Some(live_kit) = session.live_kit_client.as_ref() {
3572        live_kit
3573            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3574            .await
3575            .trace_err();
3576
3577        if delete_live_kit_room {
3578            live_kit.delete_room(live_kit_room).await.trace_err();
3579        }
3580    }
3581
3582    Ok(())
3583}
3584
3585async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3586    let left_channel_buffers = session
3587        .db()
3588        .await
3589        .leave_channel_buffers(session.connection_id)
3590        .await?;
3591
3592    for left_buffer in left_channel_buffers {
3593        channel_buffer_updated(
3594            session.connection_id,
3595            left_buffer.connections,
3596            &proto::UpdateChannelBufferCollaborators {
3597                channel_id: left_buffer.channel_id.to_proto(),
3598                collaborators: left_buffer.collaborators,
3599            },
3600            &session.peer,
3601        );
3602    }
3603
3604    Ok(())
3605}
3606
3607fn project_left(project: &db::LeftProject, session: &Session) {
3608    for connection_id in &project.connection_ids {
3609        if project.host_user_id == session.user_id {
3610            session
3611                .peer
3612                .send(
3613                    *connection_id,
3614                    proto::UnshareProject {
3615                        project_id: project.id.to_proto(),
3616                    },
3617                )
3618                .trace_err();
3619        } else {
3620            session
3621                .peer
3622                .send(
3623                    *connection_id,
3624                    proto::RemoveProjectCollaborator {
3625                        project_id: project.id.to_proto(),
3626                        peer_id: Some(session.connection_id.into()),
3627                    },
3628                )
3629                .trace_err();
3630        }
3631    }
3632}
3633
3634pub trait ResultExt {
3635    type Ok;
3636
3637    fn trace_err(self) -> Option<Self::Ok>;
3638}
3639
3640impl<T, E> ResultExt for Result<T, E>
3641where
3642    E: std::fmt::Debug,
3643{
3644    type Ok = T;
3645
3646    fn trace_err(self) -> Option<T> {
3647        match self {
3648            Ok(value) => Some(value),
3649            Err(error) => {
3650                tracing::error!("{:?}", error);
3651                None
3652            }
3653        }
3654    }
3655}