1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::db::Database;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{User, UserId},
12 rpc,
13};
14use ::rpc::proto;
15use anyhow::Context as _;
16use axum::extract;
17use axum::{
18 Extension, Json, Router,
19 body::Body,
20 extract::{Path, Query},
21 headers::Header,
22 http::{self, HeaderName, Request, StatusCode},
23 middleware::{self, Next},
24 response::IntoResponse,
25 routing::{get, post},
26};
27use axum_extra::response::ErasedJson;
28use chrono::{DateTime, Utc};
29use serde::{Deserialize, Serialize};
30use std::sync::{Arc, OnceLock};
31use tower::ServiceBuilder;
32
33pub use extensions::fetch_extensions_from_blob_store_periodically;
34
35pub struct CloudflareIpCountryHeader(String);
36
37impl Header for CloudflareIpCountryHeader {
38 fn name() -> &'static HeaderName {
39 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
40 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
41 }
42
43 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
44 where
45 Self: Sized,
46 I: Iterator<Item = &'i axum::http::HeaderValue>,
47 {
48 let country_code = values
49 .next()
50 .ok_or_else(axum::headers::Error::invalid)?
51 .to_str()
52 .map_err(|_| axum::headers::Error::invalid())?;
53
54 Ok(Self(country_code.to_string()))
55 }
56
57 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
58 unimplemented!()
59 }
60}
61
62impl std::fmt::Display for CloudflareIpCountryHeader {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68pub struct SystemIdHeader(String);
69
70impl Header for SystemIdHeader {
71 fn name() -> &'static HeaderName {
72 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
73 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
74 }
75
76 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
77 where
78 Self: Sized,
79 I: Iterator<Item = &'i axum::http::HeaderValue>,
80 {
81 let system_id = values
82 .next()
83 .ok_or_else(axum::headers::Error::invalid)?
84 .to_str()
85 .map_err(|_| axum::headers::Error::invalid())?;
86
87 Ok(Self(system_id.to_string()))
88 }
89
90 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
91 unimplemented!()
92 }
93}
94
95impl std::fmt::Display for SystemIdHeader {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 write!(f, "{}", self.0)
98 }
99}
100
101pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
102 Router::new()
103 .route("/user", get(update_or_create_authenticated_user))
104 .route("/users/look_up", get(look_up_user))
105 .route("/users/:id/access_tokens", post(create_access_token))
106 .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
107 .route("/users/:id/update_plan", post(update_plan))
108 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
109 .merge(billing::router())
110 .merge(contributors::router())
111 .layer(
112 ServiceBuilder::new()
113 .layer(Extension(rpc_server))
114 .layer(middleware::from_fn(validate_api_token)),
115 )
116}
117
118pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
119 let token = req
120 .headers()
121 .get(http::header::AUTHORIZATION)
122 .and_then(|header| header.to_str().ok())
123 .ok_or_else(|| {
124 Error::http(
125 StatusCode::BAD_REQUEST,
126 "missing authorization header".to_string(),
127 )
128 })?
129 .strip_prefix("token ")
130 .ok_or_else(|| {
131 Error::http(
132 StatusCode::BAD_REQUEST,
133 "invalid authorization header".to_string(),
134 )
135 })?;
136
137 let state = req.extensions().get::<Arc<AppState>>().unwrap();
138
139 if token != state.config.api_token {
140 Err(Error::http(
141 StatusCode::UNAUTHORIZED,
142 "invalid authorization token".to_string(),
143 ))?
144 }
145
146 Ok::<_, Error>(next.run(req).await)
147}
148
149#[derive(Debug, Deserialize)]
150struct AuthenticatedUserParams {
151 github_user_id: i32,
152 github_login: String,
153 github_email: Option<String>,
154 github_name: Option<String>,
155 github_user_created_at: chrono::DateTime<chrono::Utc>,
156}
157
158#[derive(Debug, Serialize)]
159struct AuthenticatedUserResponse {
160 user: User,
161 metrics_id: String,
162 feature_flags: Vec<String>,
163}
164
165async fn update_or_create_authenticated_user(
166 Query(params): Query<AuthenticatedUserParams>,
167 Extension(app): Extension<Arc<AppState>>,
168) -> Result<Json<AuthenticatedUserResponse>> {
169 let initial_channel_id = app.config.auto_join_channel_id;
170
171 let user = app
172 .db
173 .update_or_create_user_by_github_account(
174 ¶ms.github_login,
175 params.github_user_id,
176 params.github_email.as_deref(),
177 params.github_name.as_deref(),
178 params.github_user_created_at,
179 initial_channel_id,
180 )
181 .await?;
182 let metrics_id = app.db.get_user_metrics_id(user.id).await?;
183 let feature_flags = app.db.get_user_flags(user.id).await?;
184 Ok(Json(AuthenticatedUserResponse {
185 user,
186 metrics_id,
187 feature_flags,
188 }))
189}
190
191#[derive(Debug, Deserialize)]
192struct LookUpUserParams {
193 identifier: String,
194}
195
196#[derive(Debug, Serialize)]
197struct LookUpUserResponse {
198 user: Option<User>,
199}
200
201async fn look_up_user(
202 Query(params): Query<LookUpUserParams>,
203 Extension(app): Extension<Arc<AppState>>,
204) -> Result<Json<LookUpUserResponse>> {
205 let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
206 let user = if let Some(user) = user {
207 match user {
208 UserOrId::User(user) => Some(user),
209 UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
210 }
211 } else {
212 None
213 };
214
215 Ok(Json(LookUpUserResponse { user }))
216}
217
218enum UserOrId {
219 User(User),
220 Id(UserId),
221}
222
223async fn resolve_identifier_to_user(
224 db: &Arc<Database>,
225 identifier: &str,
226) -> Result<Option<UserOrId>> {
227 if let Some(identifier) = identifier.parse::<i32>().ok() {
228 let user = db.get_user_by_id(UserId(identifier)).await?;
229
230 return Ok(user.map(UserOrId::User));
231 }
232
233 if identifier.starts_with("cus_") {
234 let billing_customer = db
235 .get_billing_customer_by_stripe_customer_id(&identifier)
236 .await?;
237
238 return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
239 }
240
241 if identifier.starts_with("sub_") {
242 let billing_subscription = db
243 .get_billing_subscription_by_stripe_subscription_id(&identifier)
244 .await?;
245
246 if let Some(billing_subscription) = billing_subscription {
247 let billing_customer = db
248 .get_billing_customer_by_id(billing_subscription.billing_customer_id)
249 .await?;
250
251 return Ok(
252 billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
253 );
254 } else {
255 return Ok(None);
256 }
257 }
258
259 if identifier.contains('@') {
260 let user = db.get_user_by_email(identifier).await?;
261
262 return Ok(user.map(UserOrId::User));
263 }
264
265 if let Some(user) = db.get_user_by_github_login(identifier).await? {
266 return Ok(Some(UserOrId::User(user)));
267 }
268
269 Ok(None)
270}
271
272#[derive(Deserialize, Debug)]
273struct CreateUserParams {
274 github_user_id: i32,
275 github_login: String,
276 email_address: String,
277 email_confirmation_code: Option<String>,
278 #[serde(default)]
279 admin: bool,
280 #[serde(default)]
281 invite_count: i32,
282}
283
284async fn get_rpc_server_snapshot(
285 Extension(rpc_server): Extension<Arc<rpc::Server>>,
286) -> Result<ErasedJson> {
287 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
288}
289
290#[derive(Deserialize)]
291struct CreateAccessTokenQueryParams {
292 public_key: String,
293 impersonate: Option<String>,
294}
295
296#[derive(Serialize)]
297struct CreateAccessTokenResponse {
298 user_id: UserId,
299 encrypted_access_token: String,
300}
301
302async fn create_access_token(
303 Path(user_id): Path<UserId>,
304 Query(params): Query<CreateAccessTokenQueryParams>,
305 Extension(app): Extension<Arc<AppState>>,
306) -> Result<Json<CreateAccessTokenResponse>> {
307 let user = app
308 .db
309 .get_user_by_id(user_id)
310 .await?
311 .context("user not found")?;
312
313 let mut impersonated_user_id = None;
314 if let Some(impersonate) = params.impersonate {
315 if user.admin {
316 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
317 impersonated_user_id = Some(impersonated_user.id);
318 } else {
319 return Err(Error::http(
320 StatusCode::UNPROCESSABLE_ENTITY,
321 format!("user {impersonate} does not exist"),
322 ));
323 }
324 } else {
325 return Err(Error::http(
326 StatusCode::UNAUTHORIZED,
327 "you do not have permission to impersonate other users".to_string(),
328 ));
329 }
330 }
331
332 let access_token =
333 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
334 let encrypted_access_token =
335 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
336
337 Ok(Json(CreateAccessTokenResponse {
338 user_id: impersonated_user_id.unwrap_or(user_id),
339 encrypted_access_token,
340 }))
341}
342
343#[derive(Serialize)]
344struct RefreshLlmTokensResponse {}
345
346async fn refresh_llm_tokens(
347 Path(user_id): Path<UserId>,
348 Extension(rpc_server): Extension<Arc<rpc::Server>>,
349) -> Result<Json<RefreshLlmTokensResponse>> {
350 rpc_server.refresh_llm_tokens_for_user(user_id).await;
351
352 Ok(Json(RefreshLlmTokensResponse {}))
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356struct UpdatePlanBody {
357 pub plan: zed_llm_client::Plan,
358 pub subscription_period: SubscriptionPeriod,
359 pub usage: zed_llm_client::CurrentUsage,
360 pub trial_started_at: Option<DateTime<Utc>>,
361 pub is_usage_based_billing_enabled: bool,
362 pub is_account_too_young: bool,
363 pub has_overdue_invoices: bool,
364}
365
366#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
367struct SubscriptionPeriod {
368 pub started_at: DateTime<Utc>,
369 pub ended_at: DateTime<Utc>,
370}
371
372#[derive(Serialize)]
373struct UpdatePlanResponse {}
374
375async fn update_plan(
376 Path(user_id): Path<UserId>,
377 Extension(rpc_server): Extension<Arc<rpc::Server>>,
378 extract::Json(body): extract::Json<UpdatePlanBody>,
379) -> Result<Json<UpdatePlanResponse>> {
380 let plan = match body.plan {
381 zed_llm_client::Plan::ZedFree => proto::Plan::Free,
382 zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
383 zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
384 };
385
386 let update_user_plan = proto::UpdateUserPlan {
387 plan: plan.into(),
388 trial_started_at: body
389 .trial_started_at
390 .map(|trial_started_at| trial_started_at.timestamp() as u64),
391 is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
392 usage: Some(proto::SubscriptionUsage {
393 model_requests_usage_amount: body.usage.model_requests.used,
394 model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
395 edit_predictions_usage_amount: body.usage.edit_predictions.used,
396 edit_predictions_usage_limit: Some(usage_limit_to_proto(
397 body.usage.edit_predictions.limit,
398 )),
399 }),
400 subscription_period: Some(proto::SubscriptionPeriod {
401 started_at: body.subscription_period.started_at.timestamp() as u64,
402 ended_at: body.subscription_period.ended_at.timestamp() as u64,
403 }),
404 account_too_young: Some(body.is_account_too_young),
405 has_overdue_invoices: Some(body.has_overdue_invoices),
406 };
407
408 rpc_server
409 .update_plan_for_user(user_id, update_user_plan)
410 .await?;
411
412 Ok(Json(UpdatePlanResponse {}))
413}
414
415fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
416 proto::UsageLimit {
417 variant: Some(match limit {
418 zed_llm_client::UsageLimit::Limited(limit) => {
419 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
420 limit: limit as u32,
421 })
422 }
423 zed_llm_client::UsageLimit::Unlimited => {
424 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
425 }
426 }),
427 }
428}