Detailed changes
@@ -14,7 +14,8 @@ use anyhow::anyhow;
use axum::{
body::Body,
extract::{Path, Query},
- http::{self, Request, StatusCode},
+ headers::Header,
+ http::{self, HeaderName, Request, StatusCode},
middleware::{self, Next},
response::IntoResponse,
routing::{get, post},
@@ -22,11 +23,44 @@ use axum::{
};
use axum_extra::response::ErasedJson;
use serde::{Deserialize, Serialize};
-use std::sync::Arc;
+use std::sync::{Arc, OnceLock};
use tower::ServiceBuilder;
pub use extensions::fetch_extensions_from_blob_store_periodically;
+pub struct CloudflareIpCountryHeader(String);
+
+impl Header for CloudflareIpCountryHeader {
+ fn name() -> &'static HeaderName {
+ static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
+ CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
+ }
+
+ fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
+ where
+ Self: Sized,
+ I: Iterator<Item = &'i axum::http::HeaderValue>,
+ {
+ let country_code = values
+ .next()
+ .ok_or_else(axum::headers::Error::invalid)?
+ .to_str()
+ .map_err(|_| axum::headers::Error::invalid())?;
+
+ Ok(Self(country_code.to_string()))
+ }
+
+ fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
+ unimplemented!()
+ }
+}
+
+impl std::fmt::Display for CloudflareIpCountryHeader {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
Router::new()
.route("/user", get(get_authenticated_user))
@@ -1,4 +1,5 @@
use super::ips_file::IpsFile;
+use crate::api::CloudflareIpCountryHeader;
use crate::{api::slack, AppState, Error, Result};
use anyhow::{anyhow, Context};
use aws_sdk_s3::primitives::ByteStream;
@@ -59,33 +60,6 @@ impl Header for ZedChecksumHeader {
}
}
-pub struct CloudflareIpCountryHeader(String);
-
-impl Header for CloudflareIpCountryHeader {
- fn name() -> &'static HeaderName {
- static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
- CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
- }
-
- fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
- where
- Self: Sized,
- I: Iterator<Item = &'i axum::http::HeaderValue>,
- {
- let country_code = values
- .next()
- .ok_or_else(axum::headers::Error::invalid)?
- .to_str()
- .map_err(|_| axum::headers::Error::invalid())?;
-
- Ok(Self(country_code.to_string()))
- }
-
- fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
- unimplemented!()
- }
-}
-
pub async fn post_crash(
Extension(app): Extension<Arc<AppState>>,
headers: HeaderMap,
@@ -413,7 +387,7 @@ pub async fn post_events(
let Some(last_event) = request_body.events.last() else {
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
};
- let country_code = country_code_header.map(|h| h.0 .0);
+ let country_code = country_code_header.map(|h| h.to_string());
let first_event_at = chrono::Utc::now()
- chrono::Duration::milliseconds(last_event.milliseconds_since_first_event);
@@ -1,5 +1,6 @@
mod connection_pool;
+use crate::api::CloudflareIpCountryHeader;
use crate::{
auth,
db::{
@@ -152,6 +153,9 @@ struct Session {
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
rate_limiter: Arc<RateLimiter>,
+ /// The GeoIP country code for the user.
+ #[allow(unused)]
+ geoip_country_code: Option<String>,
_executor: Executor,
}
@@ -984,6 +988,7 @@ impl Server {
address: String,
principal: Principal,
zed_version: ZedVersion,
+ geoip_country_code: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
) -> impl Future<Output = ()> {
@@ -1009,7 +1014,10 @@ impl Server {
let executor = executor.clone();
move |duration| executor.sleep(duration)
});
- tracing::Span::current().record("connection_id", format!("{}", connection_id));
+ tracing::Span::current()
+ .record("connection_id", format!("{}", connection_id))
+ .record("geoip_country_code", geoip_country_code.clone());
+
tracing::info!("connection opened");
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
@@ -1039,6 +1047,7 @@ impl Server {
live_kit_client: this.app_state.live_kit_client.clone(),
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
+ geoip_country_code,
_executor: executor.clone(),
supermaven_client,
};
@@ -1395,6 +1404,7 @@ pub async fn handle_websocket_request(
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
+ country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
ws: WebSocketUpgrade,
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
@@ -1435,6 +1445,7 @@ pub async fn handle_websocket_request(
socket_address,
principal,
version,
+ country_code_header.map(|header| header.to_string()),
None,
Executor::Production,
)
@@ -244,6 +244,7 @@ impl TestServer {
client_name,
Principal::User(user),
ZedVersion(SemanticVersion::new(1, 0, 0)),
+ None,
Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()),
))
@@ -377,6 +378,7 @@ impl TestServer {
"dev-server".to_string(),
Principal::DevServer(dev_server),
ZedVersion(SemanticVersion::new(1, 0, 0)),
+ None,
Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()),
))