collab: Attach GeoIP country code to RPC sessions (#15814)

Marshall Bowers created

This PR updates collab to attach the user's GeoIP country code to their
RPC session.

We source the country code from the
[`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry)
header.

Release Notes:

- N/A

Change summary

crates/collab/src/api.rs               | 38 ++++++++++++++++++++++++++-
crates/collab/src/api/events.rs        | 30 +--------------------
crates/collab/src/rpc.rs               | 13 ++++++++
crates/collab/src/tests/test_server.rs |  2 +
4 files changed, 52 insertions(+), 31 deletions(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -14,7 +14,8 @@ use anyhow::anyhow;
 use axum::{
     body::Body,
     extract::{Path, Query},
-    http::{self, Request, StatusCode},
+    headers::Header,
+    http::{self, HeaderName, Request, StatusCode},
     middleware::{self, Next},
     response::IntoResponse,
     routing::{get, post},
@@ -22,11 +23,44 @@ use axum::{
 };
 use axum_extra::response::ErasedJson;
 use serde::{Deserialize, Serialize};
-use std::sync::Arc;
+use std::sync::{Arc, OnceLock};
 use tower::ServiceBuilder;
 
 pub use extensions::fetch_extensions_from_blob_store_periodically;
 
+pub struct CloudflareIpCountryHeader(String);
+
+impl Header for CloudflareIpCountryHeader {
+    fn name() -> &'static HeaderName {
+        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
+        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
+    }
+
+    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
+    where
+        Self: Sized,
+        I: Iterator<Item = &'i axum::http::HeaderValue>,
+    {
+        let country_code = values
+            .next()
+            .ok_or_else(axum::headers::Error::invalid)?
+            .to_str()
+            .map_err(|_| axum::headers::Error::invalid())?;
+
+        Ok(Self(country_code.to_string()))
+    }
+
+    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
+        unimplemented!()
+    }
+}
+
+impl std::fmt::Display for CloudflareIpCountryHeader {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
 pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
     Router::new()
         .route("/user", get(get_authenticated_user))

crates/collab/src/api/events.rs 🔗

@@ -1,4 +1,5 @@
 use super::ips_file::IpsFile;
+use crate::api::CloudflareIpCountryHeader;
 use crate::{api::slack, AppState, Error, Result};
 use anyhow::{anyhow, Context};
 use aws_sdk_s3::primitives::ByteStream;
@@ -59,33 +60,6 @@ impl Header for ZedChecksumHeader {
     }
 }
 
-pub struct CloudflareIpCountryHeader(String);
-
-impl Header for CloudflareIpCountryHeader {
-    fn name() -> &'static HeaderName {
-        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
-        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
-    }
-
-    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
-    where
-        Self: Sized,
-        I: Iterator<Item = &'i axum::http::HeaderValue>,
-    {
-        let country_code = values
-            .next()
-            .ok_or_else(axum::headers::Error::invalid)?
-            .to_str()
-            .map_err(|_| axum::headers::Error::invalid())?;
-
-        Ok(Self(country_code.to_string()))
-    }
-
-    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
-        unimplemented!()
-    }
-}
-
 pub async fn post_crash(
     Extension(app): Extension<Arc<AppState>>,
     headers: HeaderMap,
@@ -413,7 +387,7 @@ pub async fn post_events(
     let Some(last_event) = request_body.events.last() else {
         return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
     };
-    let country_code = country_code_header.map(|h| h.0 .0);
+    let country_code = country_code_header.map(|h| h.to_string());
 
     let first_event_at = chrono::Utc::now()
         - chrono::Duration::milliseconds(last_event.milliseconds_since_first_event);

crates/collab/src/rpc.rs 🔗

@@ -1,5 +1,6 @@
 mod connection_pool;
 
+use crate::api::CloudflareIpCountryHeader;
 use crate::{
     auth,
     db::{
@@ -152,6 +153,9 @@ struct Session {
     supermaven_client: Option<Arc<SupermavenAdminApi>>,
     http_client: Arc<IsahcHttpClient>,
     rate_limiter: Arc<RateLimiter>,
+    /// The GeoIP country code for the user.
+    #[allow(unused)]
+    geoip_country_code: Option<String>,
     _executor: Executor,
 }
 
@@ -984,6 +988,7 @@ impl Server {
         address: String,
         principal: Principal,
         zed_version: ZedVersion,
+        geoip_country_code: Option<String>,
         send_connection_id: Option<oneshot::Sender<ConnectionId>>,
         executor: Executor,
     ) -> impl Future<Output = ()> {
@@ -1009,7 +1014,10 @@ impl Server {
                     let executor = executor.clone();
                     move |duration| executor.sleep(duration)
                 });
-            tracing::Span::current().record("connection_id", format!("{}", connection_id));
+            tracing::Span::current()
+                .record("connection_id", format!("{}", connection_id))
+                .record("geoip_country_code", geoip_country_code.clone());
+
             tracing::info!("connection opened");
 
             let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
@@ -1039,6 +1047,7 @@ impl Server {
                 live_kit_client: this.app_state.live_kit_client.clone(),
                 http_client,
                 rate_limiter: this.app_state.rate_limiter.clone(),
+                geoip_country_code,
                 _executor: executor.clone(),
                 supermaven_client,
             };
@@ -1395,6 +1404,7 @@ pub async fn handle_websocket_request(
     ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
     Extension(server): Extension<Arc<Server>>,
     Extension(principal): Extension<Principal>,
+    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
     ws: WebSocketUpgrade,
 ) -> axum::response::Response {
     if protocol_version != rpc::PROTOCOL_VERSION {
@@ -1435,6 +1445,7 @@ pub async fn handle_websocket_request(
                     socket_address,
                     principal,
                     version,
+                    country_code_header.map(|header| header.to_string()),
                     None,
                     Executor::Production,
                 )

crates/collab/src/tests/test_server.rs 🔗

@@ -244,6 +244,7 @@ impl TestServer {
                                 client_name,
                                 Principal::User(user),
                                 ZedVersion(SemanticVersion::new(1, 0, 0)),
+                                None,
                                 Some(connection_id_tx),
                                 Executor::Deterministic(cx.background_executor().clone()),
                             ))
@@ -377,6 +378,7 @@ impl TestServer {
                                 "dev-server".to_string(),
                                 Principal::DevServer(dev_server),
                                 ZedVersion(SemanticVersion::new(1, 0, 0)),
+                                None,
                                 Some(connection_id_tx),
                                 Executor::Deterministic(cx.background_executor().clone()),
                             ))