1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::db::Database;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{User, UserId},
12 rpc,
13};
14use ::rpc::proto;
15use anyhow::Context as _;
16use axum::extract;
17use axum::{
18 Extension, Json, Router,
19 body::Body,
20 extract::{Path, Query},
21 headers::Header,
22 http::{self, HeaderName, Request, StatusCode},
23 middleware::{self, Next},
24 response::IntoResponse,
25 routing::{get, post},
26};
27use axum_extra::response::ErasedJson;
28use chrono::{DateTime, Utc};
29use serde::{Deserialize, Serialize};
30use std::sync::{Arc, OnceLock};
31use tower::ServiceBuilder;
32
33pub use extensions::fetch_extensions_from_blob_store_periodically;
34
35pub struct CloudflareIpCountryHeader(String);
36
37impl Header for CloudflareIpCountryHeader {
38 fn name() -> &'static HeaderName {
39 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
40 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
41 }
42
43 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
44 where
45 Self: Sized,
46 I: Iterator<Item = &'i axum::http::HeaderValue>,
47 {
48 let country_code = values
49 .next()
50 .ok_or_else(axum::headers::Error::invalid)?
51 .to_str()
52 .map_err(|_| axum::headers::Error::invalid())?;
53
54 Ok(Self(country_code.to_string()))
55 }
56
57 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
58 unimplemented!()
59 }
60}
61
62impl std::fmt::Display for CloudflareIpCountryHeader {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68pub struct SystemIdHeader(String);
69
70impl Header for SystemIdHeader {
71 fn name() -> &'static HeaderName {
72 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
73 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
74 }
75
76 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
77 where
78 Self: Sized,
79 I: Iterator<Item = &'i axum::http::HeaderValue>,
80 {
81 let system_id = values
82 .next()
83 .ok_or_else(axum::headers::Error::invalid)?
84 .to_str()
85 .map_err(|_| axum::headers::Error::invalid())?;
86
87 Ok(Self(system_id.to_string()))
88 }
89
90 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
91 unimplemented!()
92 }
93}
94
95impl std::fmt::Display for SystemIdHeader {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 write!(f, "{}", self.0)
98 }
99}
100
101pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
102 Router::new()
103 .route("/user", get(legacy_update_or_create_authenticated_user))
104 .route("/users/look_up", get(look_up_user))
105 .route("/users/:id/access_tokens", post(create_access_token))
106 .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
107 .route("/users/:id/update_plan", post(update_plan))
108 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
109 .merge(contributors::router())
110 .layer(
111 ServiceBuilder::new()
112 .layer(Extension(rpc_server))
113 .layer(middleware::from_fn(validate_api_token)),
114 )
115}
116
117pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
118 let token = req
119 .headers()
120 .get(http::header::AUTHORIZATION)
121 .and_then(|header| header.to_str().ok())
122 .ok_or_else(|| {
123 Error::http(
124 StatusCode::BAD_REQUEST,
125 "missing authorization header".to_string(),
126 )
127 })?
128 .strip_prefix("token ")
129 .ok_or_else(|| {
130 Error::http(
131 StatusCode::BAD_REQUEST,
132 "invalid authorization header".to_string(),
133 )
134 })?;
135
136 let state = req.extensions().get::<Arc<AppState>>().unwrap();
137
138 if token != state.config.api_token {
139 Err(Error::http(
140 StatusCode::UNAUTHORIZED,
141 "invalid authorization token".to_string(),
142 ))?
143 }
144
145 Ok::<_, Error>(next.run(req).await)
146}
147
148#[derive(Debug, Deserialize)]
149struct AuthenticatedUserParams {
150 github_user_id: i32,
151 github_login: String,
152 github_email: Option<String>,
153 github_name: Option<String>,
154 github_user_created_at: chrono::DateTime<chrono::Utc>,
155}
156
157#[derive(Debug, Serialize)]
158struct AuthenticatedUserResponse {
159 user: User,
160 metrics_id: String,
161 feature_flags: Vec<String>,
162}
163
164/// This is a legacy endpoint that is no longer used in production.
165///
166/// It currently only exists to be used when developing Collab locally.
167async fn legacy_update_or_create_authenticated_user(
168 Query(params): Query<AuthenticatedUserParams>,
169 Extension(app): Extension<Arc<AppState>>,
170) -> Result<Json<AuthenticatedUserResponse>> {
171 let initial_channel_id = app.config.auto_join_channel_id;
172
173 let user = app
174 .db
175 .update_or_create_user_by_github_account(
176 ¶ms.github_login,
177 params.github_user_id,
178 params.github_email.as_deref(),
179 params.github_name.as_deref(),
180 params.github_user_created_at,
181 initial_channel_id,
182 )
183 .await?;
184 let metrics_id = app.db.get_user_metrics_id(user.id).await?;
185 let feature_flags = app.db.get_user_flags(user.id).await?;
186 Ok(Json(AuthenticatedUserResponse {
187 user,
188 metrics_id,
189 feature_flags,
190 }))
191}
192
193#[derive(Debug, Deserialize)]
194struct LookUpUserParams {
195 identifier: String,
196}
197
198#[derive(Debug, Serialize)]
199struct LookUpUserResponse {
200 user: Option<User>,
201}
202
203async fn look_up_user(
204 Query(params): Query<LookUpUserParams>,
205 Extension(app): Extension<Arc<AppState>>,
206) -> Result<Json<LookUpUserResponse>> {
207 let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
208 let user = if let Some(user) = user {
209 match user {
210 UserOrId::User(user) => Some(user),
211 UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
212 }
213 } else {
214 None
215 };
216
217 Ok(Json(LookUpUserResponse { user }))
218}
219
220enum UserOrId {
221 User(User),
222 Id(UserId),
223}
224
225async fn resolve_identifier_to_user(
226 db: &Arc<Database>,
227 identifier: &str,
228) -> Result<Option<UserOrId>> {
229 if let Some(identifier) = identifier.parse::<i32>().ok() {
230 let user = db.get_user_by_id(UserId(identifier)).await?;
231
232 return Ok(user.map(UserOrId::User));
233 }
234
235 if identifier.starts_with("cus_") {
236 let billing_customer = db
237 .get_billing_customer_by_stripe_customer_id(&identifier)
238 .await?;
239
240 return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
241 }
242
243 if identifier.starts_with("sub_") {
244 let billing_subscription = db
245 .get_billing_subscription_by_stripe_subscription_id(&identifier)
246 .await?;
247
248 if let Some(billing_subscription) = billing_subscription {
249 let billing_customer = db
250 .get_billing_customer_by_id(billing_subscription.billing_customer_id)
251 .await?;
252
253 return Ok(
254 billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
255 );
256 } else {
257 return Ok(None);
258 }
259 }
260
261 if identifier.contains('@') {
262 let user = db.get_user_by_email(identifier).await?;
263
264 return Ok(user.map(UserOrId::User));
265 }
266
267 if let Some(user) = db.get_user_by_github_login(identifier).await? {
268 return Ok(Some(UserOrId::User(user)));
269 }
270
271 Ok(None)
272}
273
274#[derive(Deserialize, Debug)]
275struct CreateUserParams {
276 github_user_id: i32,
277 github_login: String,
278 email_address: String,
279 email_confirmation_code: Option<String>,
280 #[serde(default)]
281 admin: bool,
282 #[serde(default)]
283 invite_count: i32,
284}
285
286async fn get_rpc_server_snapshot(
287 Extension(rpc_server): Extension<Arc<rpc::Server>>,
288) -> Result<ErasedJson> {
289 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
290}
291
292#[derive(Deserialize)]
293struct CreateAccessTokenQueryParams {
294 public_key: String,
295 impersonate: Option<String>,
296}
297
298#[derive(Serialize)]
299struct CreateAccessTokenResponse {
300 user_id: UserId,
301 encrypted_access_token: String,
302}
303
304async fn create_access_token(
305 Path(user_id): Path<UserId>,
306 Query(params): Query<CreateAccessTokenQueryParams>,
307 Extension(app): Extension<Arc<AppState>>,
308) -> Result<Json<CreateAccessTokenResponse>> {
309 let user = app
310 .db
311 .get_user_by_id(user_id)
312 .await?
313 .context("user not found")?;
314
315 let mut impersonated_user_id = None;
316 if let Some(impersonate) = params.impersonate {
317 if user.admin {
318 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
319 impersonated_user_id = Some(impersonated_user.id);
320 } else {
321 return Err(Error::http(
322 StatusCode::UNPROCESSABLE_ENTITY,
323 format!("user {impersonate} does not exist"),
324 ));
325 }
326 } else {
327 return Err(Error::http(
328 StatusCode::UNAUTHORIZED,
329 "you do not have permission to impersonate other users".to_string(),
330 ));
331 }
332 }
333
334 let access_token =
335 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
336 let encrypted_access_token =
337 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
338
339 Ok(Json(CreateAccessTokenResponse {
340 user_id: impersonated_user_id.unwrap_or(user_id),
341 encrypted_access_token,
342 }))
343}
344
345#[derive(Serialize)]
346struct RefreshLlmTokensResponse {}
347
348async fn refresh_llm_tokens(
349 Path(user_id): Path<UserId>,
350 Extension(rpc_server): Extension<Arc<rpc::Server>>,
351) -> Result<Json<RefreshLlmTokensResponse>> {
352 rpc_server.refresh_llm_tokens_for_user(user_id).await;
353
354 Ok(Json(RefreshLlmTokensResponse {}))
355}
356
357#[derive(Debug, Serialize, Deserialize)]
358struct UpdatePlanBody {
359 pub plan: cloud_llm_client::Plan,
360 pub subscription_period: SubscriptionPeriod,
361 pub usage: cloud_llm_client::CurrentUsage,
362 pub trial_started_at: Option<DateTime<Utc>>,
363 pub is_usage_based_billing_enabled: bool,
364 pub is_account_too_young: bool,
365 pub has_overdue_invoices: bool,
366}
367
368#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
369struct SubscriptionPeriod {
370 pub started_at: DateTime<Utc>,
371 pub ended_at: DateTime<Utc>,
372}
373
374#[derive(Serialize)]
375struct UpdatePlanResponse {}
376
377async fn update_plan(
378 Path(user_id): Path<UserId>,
379 Extension(rpc_server): Extension<Arc<rpc::Server>>,
380 extract::Json(body): extract::Json<UpdatePlanBody>,
381) -> Result<Json<UpdatePlanResponse>> {
382 let plan = match body.plan {
383 cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
384 cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
385 cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
386 };
387
388 let update_user_plan = proto::UpdateUserPlan {
389 plan: plan.into(),
390 trial_started_at: body
391 .trial_started_at
392 .map(|trial_started_at| trial_started_at.timestamp() as u64),
393 is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
394 usage: Some(proto::SubscriptionUsage {
395 model_requests_usage_amount: body.usage.model_requests.used,
396 model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
397 edit_predictions_usage_amount: body.usage.edit_predictions.used,
398 edit_predictions_usage_limit: Some(usage_limit_to_proto(
399 body.usage.edit_predictions.limit,
400 )),
401 }),
402 subscription_period: Some(proto::SubscriptionPeriod {
403 started_at: body.subscription_period.started_at.timestamp() as u64,
404 ended_at: body.subscription_period.ended_at.timestamp() as u64,
405 }),
406 account_too_young: Some(body.is_account_too_young),
407 has_overdue_invoices: Some(body.has_overdue_invoices),
408 };
409
410 rpc_server
411 .update_plan_for_user(user_id, update_user_plan)
412 .await?;
413
414 Ok(Json(UpdatePlanResponse {}))
415}
416
417fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
418 proto::UsageLimit {
419 variant: Some(match limit {
420 cloud_llm_client::UsageLimit::Limited(limit) => {
421 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
422 limit: limit as u32,
423 })
424 }
425 cloud_llm_client::UsageLimit::Unlimited => {
426 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
427 }
428 }),
429 }
430}