1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::db::Database;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{User, UserId},
12 rpc,
13};
14use ::rpc::proto;
15use anyhow::Context as _;
16use axum::extract;
17use axum::{
18 Extension, Json, Router,
19 body::Body,
20 extract::{Path, Query},
21 headers::Header,
22 http::{self, HeaderName, Request, StatusCode},
23 middleware::{self, Next},
24 response::IntoResponse,
25 routing::{get, post},
26};
27use axum_extra::response::ErasedJson;
28use chrono::{DateTime, Utc};
29use serde::{Deserialize, Serialize};
30use std::sync::{Arc, OnceLock};
31use tower::ServiceBuilder;
32
33pub use extensions::fetch_extensions_from_blob_store_periodically;
34
35pub struct CloudflareIpCountryHeader(String);
36
37impl Header for CloudflareIpCountryHeader {
38 fn name() -> &'static HeaderName {
39 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
40 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
41 }
42
43 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
44 where
45 Self: Sized,
46 I: Iterator<Item = &'i axum::http::HeaderValue>,
47 {
48 let country_code = values
49 .next()
50 .ok_or_else(axum::headers::Error::invalid)?
51 .to_str()
52 .map_err(|_| axum::headers::Error::invalid())?;
53
54 Ok(Self(country_code.to_string()))
55 }
56
57 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
58 unimplemented!()
59 }
60}
61
62impl std::fmt::Display for CloudflareIpCountryHeader {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68pub struct SystemIdHeader(String);
69
70impl Header for SystemIdHeader {
71 fn name() -> &'static HeaderName {
72 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
73 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
74 }
75
76 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
77 where
78 Self: Sized,
79 I: Iterator<Item = &'i axum::http::HeaderValue>,
80 {
81 let system_id = values
82 .next()
83 .ok_or_else(axum::headers::Error::invalid)?
84 .to_str()
85 .map_err(|_| axum::headers::Error::invalid())?;
86
87 Ok(Self(system_id.to_string()))
88 }
89
90 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
91 unimplemented!()
92 }
93}
94
95impl std::fmt::Display for SystemIdHeader {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 write!(f, "{}", self.0)
98 }
99}
100
101pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
102 Router::new()
103 .route("/users/look_up", get(look_up_user))
104 .route("/users/:id/access_tokens", post(create_access_token))
105 .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
106 .route("/users/:id/update_plan", post(update_plan))
107 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
108 .merge(contributors::router())
109 .layer(
110 ServiceBuilder::new()
111 .layer(Extension(rpc_server))
112 .layer(middleware::from_fn(validate_api_token)),
113 )
114}
115
116pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
117 let token = req
118 .headers()
119 .get(http::header::AUTHORIZATION)
120 .and_then(|header| header.to_str().ok())
121 .ok_or_else(|| {
122 Error::http(
123 StatusCode::BAD_REQUEST,
124 "missing authorization header".to_string(),
125 )
126 })?
127 .strip_prefix("token ")
128 .ok_or_else(|| {
129 Error::http(
130 StatusCode::BAD_REQUEST,
131 "invalid authorization header".to_string(),
132 )
133 })?;
134
135 let state = req.extensions().get::<Arc<AppState>>().unwrap();
136
137 if token != state.config.api_token {
138 Err(Error::http(
139 StatusCode::UNAUTHORIZED,
140 "invalid authorization token".to_string(),
141 ))?
142 }
143
144 Ok::<_, Error>(next.run(req).await)
145}
146
147#[derive(Debug, Deserialize)]
148struct LookUpUserParams {
149 identifier: String,
150}
151
152#[derive(Debug, Serialize)]
153struct LookUpUserResponse {
154 user: Option<User>,
155}
156
157async fn look_up_user(
158 Query(params): Query<LookUpUserParams>,
159 Extension(app): Extension<Arc<AppState>>,
160) -> Result<Json<LookUpUserResponse>> {
161 let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
162 let user = if let Some(user) = user {
163 match user {
164 UserOrId::User(user) => Some(user),
165 UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
166 }
167 } else {
168 None
169 };
170
171 Ok(Json(LookUpUserResponse { user }))
172}
173
174enum UserOrId {
175 User(User),
176 Id(UserId),
177}
178
179async fn resolve_identifier_to_user(
180 db: &Arc<Database>,
181 identifier: &str,
182) -> Result<Option<UserOrId>> {
183 if let Some(identifier) = identifier.parse::<i32>().ok() {
184 let user = db.get_user_by_id(UserId(identifier)).await?;
185
186 return Ok(user.map(UserOrId::User));
187 }
188
189 if identifier.starts_with("cus_") {
190 let billing_customer = db
191 .get_billing_customer_by_stripe_customer_id(&identifier)
192 .await?;
193
194 return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
195 }
196
197 if identifier.starts_with("sub_") {
198 let billing_subscription = db
199 .get_billing_subscription_by_stripe_subscription_id(&identifier)
200 .await?;
201
202 if let Some(billing_subscription) = billing_subscription {
203 let billing_customer = db
204 .get_billing_customer_by_id(billing_subscription.billing_customer_id)
205 .await?;
206
207 return Ok(
208 billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
209 );
210 } else {
211 return Ok(None);
212 }
213 }
214
215 if identifier.contains('@') {
216 let user = db.get_user_by_email(identifier).await?;
217
218 return Ok(user.map(UserOrId::User));
219 }
220
221 if let Some(user) = db.get_user_by_github_login(identifier).await? {
222 return Ok(Some(UserOrId::User(user)));
223 }
224
225 Ok(None)
226}
227
228#[derive(Deserialize, Debug)]
229struct CreateUserParams {
230 github_user_id: i32,
231 github_login: String,
232 email_address: String,
233 email_confirmation_code: Option<String>,
234 #[serde(default)]
235 admin: bool,
236 #[serde(default)]
237 invite_count: i32,
238}
239
240async fn get_rpc_server_snapshot(
241 Extension(rpc_server): Extension<Arc<rpc::Server>>,
242) -> Result<ErasedJson> {
243 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
244}
245
246#[derive(Deserialize)]
247struct CreateAccessTokenQueryParams {
248 public_key: String,
249 impersonate: Option<String>,
250}
251
252#[derive(Serialize)]
253struct CreateAccessTokenResponse {
254 user_id: UserId,
255 encrypted_access_token: String,
256}
257
258async fn create_access_token(
259 Path(user_id): Path<UserId>,
260 Query(params): Query<CreateAccessTokenQueryParams>,
261 Extension(app): Extension<Arc<AppState>>,
262) -> Result<Json<CreateAccessTokenResponse>> {
263 let user = app
264 .db
265 .get_user_by_id(user_id)
266 .await?
267 .context("user not found")?;
268
269 let mut impersonated_user_id = None;
270 if let Some(impersonate) = params.impersonate {
271 if user.admin {
272 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
273 impersonated_user_id = Some(impersonated_user.id);
274 } else {
275 return Err(Error::http(
276 StatusCode::UNPROCESSABLE_ENTITY,
277 format!("user {impersonate} does not exist"),
278 ));
279 }
280 } else {
281 return Err(Error::http(
282 StatusCode::UNAUTHORIZED,
283 "you do not have permission to impersonate other users".to_string(),
284 ));
285 }
286 }
287
288 let access_token =
289 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
290 let encrypted_access_token =
291 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
292
293 Ok(Json(CreateAccessTokenResponse {
294 user_id: impersonated_user_id.unwrap_or(user_id),
295 encrypted_access_token,
296 }))
297}
298
299#[derive(Serialize)]
300struct RefreshLlmTokensResponse {}
301
302async fn refresh_llm_tokens(
303 Path(user_id): Path<UserId>,
304 Extension(rpc_server): Extension<Arc<rpc::Server>>,
305) -> Result<Json<RefreshLlmTokensResponse>> {
306 rpc_server.refresh_llm_tokens_for_user(user_id).await;
307
308 Ok(Json(RefreshLlmTokensResponse {}))
309}
310
311#[derive(Debug, Serialize, Deserialize)]
312struct UpdatePlanBody {
313 pub plan: cloud_llm_client::Plan,
314 pub subscription_period: SubscriptionPeriod,
315 pub usage: cloud_llm_client::CurrentUsage,
316 pub trial_started_at: Option<DateTime<Utc>>,
317 pub is_usage_based_billing_enabled: bool,
318 pub is_account_too_young: bool,
319 pub has_overdue_invoices: bool,
320}
321
322#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
323struct SubscriptionPeriod {
324 pub started_at: DateTime<Utc>,
325 pub ended_at: DateTime<Utc>,
326}
327
328#[derive(Serialize)]
329struct UpdatePlanResponse {}
330
331async fn update_plan(
332 Path(user_id): Path<UserId>,
333 Extension(rpc_server): Extension<Arc<rpc::Server>>,
334 extract::Json(body): extract::Json<UpdatePlanBody>,
335) -> Result<Json<UpdatePlanResponse>> {
336 let plan = match body.plan {
337 cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
338 cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
339 cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
340 };
341
342 let update_user_plan = proto::UpdateUserPlan {
343 plan: plan.into(),
344 trial_started_at: body
345 .trial_started_at
346 .map(|trial_started_at| trial_started_at.timestamp() as u64),
347 is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
348 usage: Some(proto::SubscriptionUsage {
349 model_requests_usage_amount: body.usage.model_requests.used,
350 model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
351 edit_predictions_usage_amount: body.usage.edit_predictions.used,
352 edit_predictions_usage_limit: Some(usage_limit_to_proto(
353 body.usage.edit_predictions.limit,
354 )),
355 }),
356 subscription_period: Some(proto::SubscriptionPeriod {
357 started_at: body.subscription_period.started_at.timestamp() as u64,
358 ended_at: body.subscription_period.ended_at.timestamp() as u64,
359 }),
360 account_too_young: Some(body.is_account_too_young),
361 has_overdue_invoices: Some(body.has_overdue_invoices),
362 };
363
364 rpc_server
365 .update_plan_for_user(user_id, update_user_plan)
366 .await?;
367
368 Ok(Json(UpdatePlanResponse {}))
369}
370
371fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
372 proto::UsageLimit {
373 variant: Some(match limit {
374 cloud_llm_client::UsageLimit::Limited(limit) => {
375 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
376 limit: limit as u32,
377 })
378 }
379 cloud_llm_client::UsageLimit::Unlimited => {
380 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
381 }
382 }),
383 }
384}