1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::db::Database;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{User, UserId},
12 rpc,
13};
14use anyhow::Context as _;
15use axum::{
16 Extension, Json, Router,
17 body::Body,
18 extract::{Path, Query},
19 headers::Header,
20 http::{self, HeaderName, Request, StatusCode},
21 middleware::{self, Next},
22 response::IntoResponse,
23 routing::{get, post},
24};
25use axum_extra::response::ErasedJson;
26use serde::{Deserialize, Serialize};
27use std::sync::{Arc, OnceLock};
28use tower::ServiceBuilder;
29
30pub use extensions::fetch_extensions_from_blob_store_periodically;
31
32pub struct CloudflareIpCountryHeader(String);
33
34impl Header for CloudflareIpCountryHeader {
35 fn name() -> &'static HeaderName {
36 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
37 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
38 }
39
40 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
41 where
42 Self: Sized,
43 I: Iterator<Item = &'i axum::http::HeaderValue>,
44 {
45 let country_code = values
46 .next()
47 .ok_or_else(axum::headers::Error::invalid)?
48 .to_str()
49 .map_err(|_| axum::headers::Error::invalid())?;
50
51 Ok(Self(country_code.to_string()))
52 }
53
54 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
55 unimplemented!()
56 }
57}
58
59impl std::fmt::Display for CloudflareIpCountryHeader {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65pub struct SystemIdHeader(String);
66
67impl Header for SystemIdHeader {
68 fn name() -> &'static HeaderName {
69 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
70 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
71 }
72
73 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
74 where
75 Self: Sized,
76 I: Iterator<Item = &'i axum::http::HeaderValue>,
77 {
78 let system_id = values
79 .next()
80 .ok_or_else(axum::headers::Error::invalid)?
81 .to_str()
82 .map_err(|_| axum::headers::Error::invalid())?;
83
84 Ok(Self(system_id.to_string()))
85 }
86
87 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
88 unimplemented!()
89 }
90}
91
92impl std::fmt::Display for SystemIdHeader {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
99 Router::new()
100 .route("/user", get(update_or_create_authenticated_user))
101 .route("/users/look_up", get(look_up_user))
102 .route("/users/:id/access_tokens", post(create_access_token))
103 .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
104 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
105 .merge(billing::router())
106 .merge(contributors::router())
107 .layer(
108 ServiceBuilder::new()
109 .layer(Extension(rpc_server))
110 .layer(middleware::from_fn(validate_api_token)),
111 )
112}
113
114pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
115 let token = req
116 .headers()
117 .get(http::header::AUTHORIZATION)
118 .and_then(|header| header.to_str().ok())
119 .ok_or_else(|| {
120 Error::http(
121 StatusCode::BAD_REQUEST,
122 "missing authorization header".to_string(),
123 )
124 })?
125 .strip_prefix("token ")
126 .ok_or_else(|| {
127 Error::http(
128 StatusCode::BAD_REQUEST,
129 "invalid authorization header".to_string(),
130 )
131 })?;
132
133 let state = req.extensions().get::<Arc<AppState>>().unwrap();
134
135 if token != state.config.api_token {
136 Err(Error::http(
137 StatusCode::UNAUTHORIZED,
138 "invalid authorization token".to_string(),
139 ))?
140 }
141
142 Ok::<_, Error>(next.run(req).await)
143}
144
145#[derive(Debug, Deserialize)]
146struct AuthenticatedUserParams {
147 github_user_id: i32,
148 github_login: String,
149 github_email: Option<String>,
150 github_name: Option<String>,
151 github_user_created_at: chrono::DateTime<chrono::Utc>,
152}
153
154#[derive(Debug, Serialize)]
155struct AuthenticatedUserResponse {
156 user: User,
157 metrics_id: String,
158 feature_flags: Vec<String>,
159}
160
161async fn update_or_create_authenticated_user(
162 Query(params): Query<AuthenticatedUserParams>,
163 Extension(app): Extension<Arc<AppState>>,
164) -> Result<Json<AuthenticatedUserResponse>> {
165 let initial_channel_id = app.config.auto_join_channel_id;
166
167 let user = app
168 .db
169 .update_or_create_user_by_github_account(
170 ¶ms.github_login,
171 params.github_user_id,
172 params.github_email.as_deref(),
173 params.github_name.as_deref(),
174 params.github_user_created_at,
175 initial_channel_id,
176 )
177 .await?;
178 let metrics_id = app.db.get_user_metrics_id(user.id).await?;
179 let feature_flags = app.db.get_user_flags(user.id).await?;
180 Ok(Json(AuthenticatedUserResponse {
181 user,
182 metrics_id,
183 feature_flags,
184 }))
185}
186
187#[derive(Debug, Deserialize)]
188struct LookUpUserParams {
189 identifier: String,
190}
191
192#[derive(Debug, Serialize)]
193struct LookUpUserResponse {
194 user: Option<User>,
195}
196
197async fn look_up_user(
198 Query(params): Query<LookUpUserParams>,
199 Extension(app): Extension<Arc<AppState>>,
200) -> Result<Json<LookUpUserResponse>> {
201 let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
202 let user = if let Some(user) = user {
203 match user {
204 UserOrId::User(user) => Some(user),
205 UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
206 }
207 } else {
208 None
209 };
210
211 Ok(Json(LookUpUserResponse { user }))
212}
213
214enum UserOrId {
215 User(User),
216 Id(UserId),
217}
218
219async fn resolve_identifier_to_user(
220 db: &Arc<Database>,
221 identifier: &str,
222) -> Result<Option<UserOrId>> {
223 if let Some(identifier) = identifier.parse::<i32>().ok() {
224 let user = db.get_user_by_id(UserId(identifier)).await?;
225
226 return Ok(user.map(UserOrId::User));
227 }
228
229 if identifier.starts_with("cus_") {
230 let billing_customer = db
231 .get_billing_customer_by_stripe_customer_id(&identifier)
232 .await?;
233
234 return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
235 }
236
237 if identifier.starts_with("sub_") {
238 let billing_subscription = db
239 .get_billing_subscription_by_stripe_subscription_id(&identifier)
240 .await?;
241
242 if let Some(billing_subscription) = billing_subscription {
243 let billing_customer = db
244 .get_billing_customer_by_id(billing_subscription.billing_customer_id)
245 .await?;
246
247 return Ok(
248 billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
249 );
250 } else {
251 return Ok(None);
252 }
253 }
254
255 if identifier.contains('@') {
256 let user = db.get_user_by_email(identifier).await?;
257
258 return Ok(user.map(UserOrId::User));
259 }
260
261 if let Some(user) = db.get_user_by_github_login(identifier).await? {
262 return Ok(Some(UserOrId::User(user)));
263 }
264
265 Ok(None)
266}
267
268#[derive(Deserialize, Debug)]
269struct CreateUserParams {
270 github_user_id: i32,
271 github_login: String,
272 email_address: String,
273 email_confirmation_code: Option<String>,
274 #[serde(default)]
275 admin: bool,
276 #[serde(default)]
277 invite_count: i32,
278}
279
280async fn get_rpc_server_snapshot(
281 Extension(rpc_server): Extension<Arc<rpc::Server>>,
282) -> Result<ErasedJson> {
283 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
284}
285
286#[derive(Deserialize)]
287struct CreateAccessTokenQueryParams {
288 public_key: String,
289 impersonate: Option<String>,
290}
291
292#[derive(Serialize)]
293struct CreateAccessTokenResponse {
294 user_id: UserId,
295 encrypted_access_token: String,
296}
297
298async fn create_access_token(
299 Path(user_id): Path<UserId>,
300 Query(params): Query<CreateAccessTokenQueryParams>,
301 Extension(app): Extension<Arc<AppState>>,
302) -> Result<Json<CreateAccessTokenResponse>> {
303 let user = app
304 .db
305 .get_user_by_id(user_id)
306 .await?
307 .context("user not found")?;
308
309 let mut impersonated_user_id = None;
310 if let Some(impersonate) = params.impersonate {
311 if user.admin {
312 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
313 impersonated_user_id = Some(impersonated_user.id);
314 } else {
315 return Err(Error::http(
316 StatusCode::UNPROCESSABLE_ENTITY,
317 format!("user {impersonate} does not exist"),
318 ));
319 }
320 } else {
321 return Err(Error::http(
322 StatusCode::UNAUTHORIZED,
323 "you do not have permission to impersonate other users".to_string(),
324 ));
325 }
326 }
327
328 let access_token =
329 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
330 let encrypted_access_token =
331 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
332
333 Ok(Json(CreateAccessTokenResponse {
334 user_id: impersonated_user_id.unwrap_or(user_id),
335 encrypted_access_token,
336 }))
337}
338
339#[derive(Serialize)]
340struct RefreshLlmTokensResponse {}
341
342async fn refresh_llm_tokens(
343 Path(user_id): Path<UserId>,
344 Extension(rpc_server): Extension<Arc<rpc::Server>>,
345) -> Result<Json<RefreshLlmTokensResponse>> {
346 rpc_server.refresh_llm_tokens_for_user(user_id).await;
347
348 Ok(Json(RefreshLlmTokensResponse {}))
349}