Detailed changes
@@ -117,12 +117,10 @@ jobs:
export ZED_KUBE_NAMESPACE=production
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10
export ZED_API_LOAD_BALANCER_SIZE_UNIT=2
- export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2
elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
export ZED_KUBE_NAMESPACE=staging
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1
export ZED_API_LOAD_BALANCER_SIZE_UNIT=1
- export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1
else
echo "cowardly refusing to deploy from an unknown branch"
exit 1
@@ -147,9 +145,3 @@ jobs:
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
-
- export ZED_SERVICE_NAME=llm
- export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT
- envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
- kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
- echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
@@ -2942,7 +2942,6 @@ dependencies = [
name = "collab"
version = "0.44.0"
dependencies = [
- "anthropic",
"anyhow",
"assistant",
"assistant_context_editor",
@@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
test-support = ["sqlite"]
[dependencies]
-anthropic.workspace = true
anyhow.workspace = true
async-stripe.workspace = true
async-tungstenite.workspace = true
@@ -253,7 +253,6 @@ impl Config {
pub enum ServiceMode {
Api,
Collab,
- Llm,
All,
}
@@ -265,10 +264,6 @@ impl ServiceMode {
pub fn is_api(&self) -> bool {
matches!(self, Self::Api | Self::All)
}
-
- pub fn is_llm(&self) -> bool {
- matches!(self, Self::Llm | Self::All)
- }
}
pub struct AppState {
@@ -1,448 +1,10 @@
-mod authorization;
pub mod db;
mod token;
-use crate::api::CloudflareIpCountryHeader;
-use crate::api::events::SnowflakeRow;
-use crate::build_kinesis_client;
-use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE;
-use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor};
-use anyhow::{Context as _, anyhow};
-use authorization::authorize_access_to_language_model;
-use axum::routing::get;
-use axum::{
- Extension, Json, Router, TypedHeader,
- body::Body,
- http::{self, HeaderName, HeaderValue, Request, StatusCode},
- middleware::{self, Next},
- response::{IntoResponse, Response},
- routing::post,
-};
-use chrono::{DateTime, Duration, Utc};
-use collections::HashMap;
-use db::TokenUsage;
-use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure};
-use futures::{Stream, StreamExt as _};
-use reqwest_client::ReqwestClient;
-use rpc::{
- EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan,
-};
-use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
-use serde_json::json;
-use std::{
- pin::Pin,
- sync::Arc,
- task::{Context, Poll},
-};
-use strum::IntoEnumIterator;
-use tokio::sync::RwLock;
-use util::ResultExt;
+use crate::Cents;
pub use token::*;
-const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
-
-pub struct LlmState {
- pub config: Config,
- pub executor: Executor,
- pub db: Arc<LlmDatabase>,
- pub http_client: ReqwestClient,
- pub kinesis_client: Option<aws_sdk_kinesis::Client>,
- active_user_count_by_model:
- RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
-}
-
-impl LlmState {
- pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
- let database_url = config
- .llm_database_url
- .as_ref()
- .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
- let max_connections = config
- .llm_database_max_connections
- .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
-
- let mut db_options = db::ConnectOptions::new(database_url);
- db_options.max_connections(max_connections);
- let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
- db.initialize().await?;
-
- let db = Arc::new(db);
-
- let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
- let http_client =
- ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
-
- let this = Self {
- executor,
- db,
- http_client,
- kinesis_client: if config.kinesis_access_key.is_some() {
- build_kinesis_client(&config).await.log_err()
- } else {
- None
- },
- active_user_count_by_model: RwLock::new(HashMap::default()),
- config,
- };
-
- Ok(Arc::new(this))
- }
-
- pub async fn get_active_user_count(
- &self,
- provider: LanguageModelProvider,
- model: &str,
- ) -> Result<ActiveUserCount> {
- let now = Utc::now();
-
- {
- let active_user_count_by_model = self.active_user_count_by_model.read().await;
- if let Some((last_updated, count)) =
- active_user_count_by_model.get(&(provider, model.to_string()))
- {
- if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
- return Ok(*count);
- }
- }
- }
-
- let mut cache = self.active_user_count_by_model.write().await;
- let new_count = self.db.get_active_user_count(provider, model, now).await?;
- cache.insert((provider, model.to_string()), (now, new_count));
- Ok(new_count)
- }
-}
-
-pub fn routes() -> Router<(), Body> {
- Router::new()
- .route("/models", get(list_models))
- .route("/completion", post(perform_completion))
- .layer(middleware::from_fn(validate_api_token))
-}
-
-async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
- let token = req
- .headers()
- .get(http::header::AUTHORIZATION)
- .and_then(|header| header.to_str().ok())
- .ok_or_else(|| {
- Error::http(
- StatusCode::BAD_REQUEST,
- "missing authorization header".to_string(),
- )
- })?
- .strip_prefix("Bearer ")
- .ok_or_else(|| {
- Error::http(
- StatusCode::BAD_REQUEST,
- "invalid authorization header".to_string(),
- )
- })?;
-
- let state = req.extensions().get::<Arc<LlmState>>().unwrap();
- match LlmTokenClaims::validate(token, &state.config) {
- Ok(claims) => {
- if state.db.is_access_token_revoked(&claims.jti).await? {
- return Err(Error::http(
- StatusCode::UNAUTHORIZED,
- "unauthorized".to_string(),
- ));
- }
-
- tracing::Span::current()
- .record("user_id", claims.user_id)
- .record("login", claims.github_user_login.clone())
- .record("authn.jti", &claims.jti)
- .record("is_staff", claims.is_staff);
-
- req.extensions_mut().insert(claims);
- Ok::<_, Error>(next.run(req).await.into_response())
- }
- Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
- StatusCode::UNAUTHORIZED,
- "unauthorized".to_string(),
- [(
- HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
- HeaderValue::from_static("true"),
- )]
- .into_iter()
- .collect(),
- )),
- Err(_err) => Err(Error::http(
- StatusCode::UNAUTHORIZED,
- "unauthorized".to_string(),
- )),
- }
-}
-
-async fn list_models(
- Extension(state): Extension<Arc<LlmState>>,
- Extension(claims): Extension<LlmTokenClaims>,
- country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
-) -> Result<Json<ListModelsResponse>> {
- let country_code = country_code_header.map(|header| header.to_string());
-
- let mut accessible_models = Vec::new();
-
- for (provider, model) in state.db.all_models() {
- let authorize_result = authorize_access_to_language_model(
- &state.config,
- &claims,
- country_code.as_deref(),
- provider,
- &model.name,
- );
-
- if authorize_result.is_ok() {
- accessible_models.push(rpc::LanguageModel {
- provider,
- name: model.name,
- });
- }
- }
-
- Ok(Json(ListModelsResponse {
- models: accessible_models,
- }))
-}
-
-async fn perform_completion(
- Extension(state): Extension<Arc<LlmState>>,
- Extension(claims): Extension<LlmTokenClaims>,
- country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
- Json(params): Json<PerformCompletionParams>,
-) -> Result<impl IntoResponse> {
- let model = normalize_model_name(
- state.db.model_names_for_provider(params.provider),
- params.model,
- );
-
- let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check;
- if !bypass_account_age_check {
- if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
- Err(anyhow!("account too young"))?
- }
- }
-
- authorize_access_to_language_model(
- &state.config,
- &claims,
- country_code_header
- .map(|header| header.to_string())
- .as_deref(),
- params.provider,
- &model,
- )?;
-
- check_usage_limit(&state, params.provider, &model, &claims).await?;
-
- let stream = match params.provider {
- LanguageModelProvider::Anthropic => {
- let api_key = if claims.is_staff {
- state
- .config
- .anthropic_staff_api_key
- .as_ref()
- .context("no Anthropic AI staff API key configured on the server")?
- } else {
- state
- .config
- .anthropic_api_key
- .as_ref()
- .context("no Anthropic AI API key configured on the server")?
- };
-
- let mut request: anthropic::Request =
- serde_json::from_str(params.provider_request.get())?;
-
- // Override the model on the request with the latest version of the model that is
- // known to the server.
- //
- // Right now, we use the version that's defined in `model.id()`, but we will likely
- // want to change this code once a new version of an Anthropic model is released,
- // so that users can use the new version, without having to update Zed.
- request.model = match model.as_str() {
- "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
- "claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(),
- "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
- "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
- "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
- _ => request.model,
- };
-
- let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
- &state.http_client,
- anthropic::ANTHROPIC_API_URL,
- api_key,
- request,
- )
- .await
- .map_err(|err| match err {
- anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
- Some(anthropic::ApiErrorCode::RateLimitError) => {
- tracing::info!(
- target: "upstream rate limit exceeded",
- user_id = claims.user_id,
- login = claims.github_user_login,
- authn.jti = claims.jti,
- is_staff = claims.is_staff,
- provider = params.provider.to_string(),
- model = model
- );
-
- Error::http(
- StatusCode::TOO_MANY_REQUESTS,
- "Upstream Anthropic rate limit exceeded.".to_string(),
- )
- }
- Some(anthropic::ApiErrorCode::InvalidRequestError) => {
- Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
- }
- Some(anthropic::ApiErrorCode::OverloadedError) => {
- Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
- }
- Some(_) => {
- Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
- }
- None => Error::Internal(anyhow!(err)),
- },
- anthropic::AnthropicError::Other(err) => Error::Internal(err),
- })?;
-
- if let Some(rate_limit_info) = rate_limit_info {
- tracing::info!(
- target: "upstream rate limit",
- is_staff = claims.is_staff,
- provider = params.provider.to_string(),
- model = model,
- tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining),
- input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining),
- output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining),
- requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining),
- requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset),
- tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset),
- input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset),
- output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset),
- );
- }
-
- chunks
- .map(move |event| {
- let chunk = event?;
- let (
- input_tokens,
- output_tokens,
- cache_creation_input_tokens,
- cache_read_input_tokens,
- ) = match &chunk {
- anthropic::Event::MessageStart {
- message: anthropic::Response { usage, .. },
- }
- | anthropic::Event::MessageDelta { usage, .. } => (
- usage.input_tokens.unwrap_or(0) as usize,
- usage.output_tokens.unwrap_or(0) as usize,
- usage.cache_creation_input_tokens.unwrap_or(0) as usize,
- usage.cache_read_input_tokens.unwrap_or(0) as usize,
- ),
- _ => (0, 0, 0, 0),
- };
-
- anyhow::Ok(CompletionChunk {
- bytes: serde_json::to_vec(&chunk).unwrap(),
- input_tokens,
- output_tokens,
- cache_creation_input_tokens,
- cache_read_input_tokens,
- })
- })
- .boxed()
- }
- LanguageModelProvider::OpenAi => {
- let api_key = state
- .config
- .openai_api_key
- .as_ref()
- .context("no OpenAI API key configured on the server")?;
- let chunks = open_ai::stream_completion(
- &state.http_client,
- open_ai::OPEN_AI_API_URL,
- api_key,
- serde_json::from_str(params.provider_request.get())?,
- )
- .await?;
-
- chunks
- .map(|event| {
- event.map(|chunk| {
- let input_tokens =
- chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
- let output_tokens =
- chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
- CompletionChunk {
- bytes: serde_json::to_vec(&chunk).unwrap(),
- input_tokens,
- output_tokens,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- }
- })
- })
- .boxed()
- }
- LanguageModelProvider::Google => {
- let api_key = state
- .config
- .google_ai_api_key
- .as_ref()
- .context("no Google AI API key configured on the server")?;
- let chunks = google_ai::stream_generate_content(
- &state.http_client,
- google_ai::API_URL,
- api_key,
- serde_json::from_str(params.provider_request.get())?,
- )
- .await?;
-
- chunks
- .map(|event| {
- event.map(|chunk| {
- // TODO - implement token counting for Google AI
- CompletionChunk {
- bytes: serde_json::to_vec(&chunk).unwrap(),
- input_tokens: 0,
- output_tokens: 0,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- }
- })
- })
- .boxed()
- }
- };
-
- Ok(Response::new(Body::wrap_stream(TokenCountingStream {
- state,
- claims,
- provider: params.provider,
- model,
- tokens: TokenUsage::default(),
- inner_stream: stream,
- })))
-}
-
-fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
- if let Some(known_model_name) = known_models
- .iter()
- .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
- .max_by_key(|known_model_name| known_model_name.len())
- {
- known_model_name.to_string()
- } else {
- name
- }
-}
-
/// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
@@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
///
/// Used to prevent surprise bills.
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
-
-async fn check_usage_limit(
- state: &Arc<LlmState>,
- provider: LanguageModelProvider,
- model_name: &str,
- claims: &LlmTokenClaims,
-) -> Result<()> {
- if claims.is_staff {
- return Ok(());
- }
-
- let user_id = UserId::from_proto(claims.user_id);
- let model = state.db.model(provider, model_name)?;
- let free_tier = claims.free_tier_monthly_spending_limit();
-
- let spending_this_month = state
- .db
- .get_user_spending_for_month(user_id, Utc::now())
- .await?;
- if spending_this_month >= free_tier {
- if !claims.has_llm_subscription {
- return Err(Error::http(
- StatusCode::PAYMENT_REQUIRED,
- "Maximum spending limit reached for this month.".to_string(),
- ));
- }
-
- let monthly_spend = spending_this_month.saturating_sub(free_tier);
- if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) {
- return Err(Error::Http(
- StatusCode::FORBIDDEN,
- "Maximum spending limit reached for this month.".to_string(),
- [(
- HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
- HeaderValue::from_static("true"),
- )]
- .into_iter()
- .collect(),
- ));
- }
- }
-
- let active_users = state.get_active_user_count(provider, model_name).await?;
-
- let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
- let users_in_recent_days = active_users.users_in_recent_days.max(1);
-
- let per_user_max_requests_per_minute =
- model.max_requests_per_minute as usize / users_in_recent_minutes;
- let per_user_max_tokens_per_minute =
- model.max_tokens_per_minute as usize / users_in_recent_minutes;
- let per_user_max_input_tokens_per_minute =
- model.max_input_tokens_per_minute as usize / users_in_recent_minutes;
- let per_user_max_output_tokens_per_minute =
- model.max_output_tokens_per_minute as usize / users_in_recent_minutes;
- let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
-
- let usage = state
- .db
- .get_usage(user_id, provider, model_name, Utc::now())
- .await?;
-
- let checks = match (provider, model_name) {
- (LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![
- (
- usage.requests_this_minute,
- per_user_max_requests_per_minute,
- UsageMeasure::RequestsPerMinute,
- ),
- (
- usage.input_tokens_this_minute,
- per_user_max_tokens_per_minute,
- UsageMeasure::InputTokensPerMinute,
- ),
- (
- usage.output_tokens_this_minute,
- per_user_max_tokens_per_minute,
- UsageMeasure::OutputTokensPerMinute,
- ),
- (
- usage.tokens_this_day,
- per_user_max_tokens_per_day,
- UsageMeasure::TokensPerDay,
- ),
- ],
- _ => vec![
- (
- usage.requests_this_minute,
- per_user_max_requests_per_minute,
- UsageMeasure::RequestsPerMinute,
- ),
- (
- usage.tokens_this_minute,
- per_user_max_tokens_per_minute,
- UsageMeasure::TokensPerMinute,
- ),
- (
- usage.tokens_this_day,
- per_user_max_tokens_per_day,
- UsageMeasure::TokensPerDay,
- ),
- ],
- };
-
- for (used, limit, usage_measure) in checks {
- if used > limit {
- let resource = match usage_measure {
- UsageMeasure::RequestsPerMinute => "requests_per_minute",
- UsageMeasure::TokensPerMinute => "tokens_per_minute",
- UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute",
- UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute",
- UsageMeasure::TokensPerDay => "tokens_per_day",
- };
-
- tracing::info!(
- target: "user rate limit",
- user_id = claims.user_id,
- login = claims.github_user_login,
- authn.jti = claims.jti,
- is_staff = claims.is_staff,
- provider = provider.to_string(),
- model = model.name,
- usage_measure = resource,
- requests_this_minute = usage.requests_this_minute,
- tokens_this_minute = usage.tokens_this_minute,
- input_tokens_this_minute = usage.input_tokens_this_minute,
- output_tokens_this_minute = usage.output_tokens_this_minute,
- tokens_this_day = usage.tokens_this_day,
- users_in_recent_minutes = users_in_recent_minutes,
- users_in_recent_days = users_in_recent_days,
- max_requests_per_minute = per_user_max_requests_per_minute,
- max_tokens_per_minute = per_user_max_tokens_per_minute,
- max_input_tokens_per_minute = per_user_max_input_tokens_per_minute,
- max_output_tokens_per_minute = per_user_max_output_tokens_per_minute,
- max_tokens_per_day = per_user_max_tokens_per_day,
- );
-
- SnowflakeRow::new(
- "Language Model Rate Limited",
- Some(claims.metrics_id),
- claims.is_staff,
- claims.system_id.clone(),
- json!({
- "usage": usage,
- "users_in_recent_minutes": users_in_recent_minutes,
- "users_in_recent_days": users_in_recent_days,
- "max_requests_per_minute": per_user_max_requests_per_minute,
- "max_tokens_per_minute": per_user_max_tokens_per_minute,
- "max_input_tokens_per_minute": per_user_max_input_tokens_per_minute,
- "max_output_tokens_per_minute": per_user_max_output_tokens_per_minute,
- "max_tokens_per_day": per_user_max_tokens_per_day,
- "plan": match claims.plan {
- Plan::Free => "free".to_string(),
- Plan::ZedPro => "zed_pro".to_string(),
- },
- "model": model.name.clone(),
- "provider": provider.to_string(),
- "usage_measure": resource.to_string(),
- }),
- )
- .write(&state.kinesis_client, &state.config.kinesis_stream)
- .await
- .log_err();
-
- return Err(Error::http(
- StatusCode::TOO_MANY_REQUESTS,
- format!("Rate limit exceeded. Maximum {} reached.", resource),
- ));
- }
- }
-
- Ok(())
-}
-
-struct CompletionChunk {
- bytes: Vec<u8>,
- input_tokens: usize,
- output_tokens: usize,
- cache_creation_input_tokens: usize,
- cache_read_input_tokens: usize,
-}
-
-struct TokenCountingStream<S> {
- state: Arc<LlmState>,
- claims: LlmTokenClaims,
- provider: LanguageModelProvider,
- model: String,
- tokens: TokenUsage,
- inner_stream: S,
-}
-
-impl<S> Stream for TokenCountingStream<S>
-where
- S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
-{
- type Item = Result<Vec<u8>, anyhow::Error>;
-
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- match Pin::new(&mut self.inner_stream).poll_next(cx) {
- Poll::Ready(Some(Ok(mut chunk))) => {
- chunk.bytes.push(b'\n');
- self.tokens.input += chunk.input_tokens;
- self.tokens.output += chunk.output_tokens;
- self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
- self.tokens.input_cache_read += chunk.cache_read_input_tokens;
- Poll::Ready(Some(Ok(chunk.bytes)))
- }
- Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
- Poll::Ready(None) => Poll::Ready(None),
- Poll::Pending => Poll::Pending,
- }
- }
-}
-
-impl<S> Drop for TokenCountingStream<S> {
- fn drop(&mut self) {
- let state = self.state.clone();
- let claims = self.claims.clone();
- let provider = self.provider;
- let model = std::mem::take(&mut self.model);
- let tokens = self.tokens;
- self.state.executor.spawn_detached(async move {
- let usage = state
- .db
- .record_usage(
- UserId::from_proto(claims.user_id),
- claims.is_staff,
- provider,
- &model,
- tokens,
- claims.has_llm_subscription,
- Cents(claims.max_monthly_spend_in_cents),
- claims.free_tier_monthly_spending_limit(),
- Utc::now(),
- )
- .await
- .log_err();
-
- if let Some(usage) = usage {
- tracing::info!(
- target: "user usage",
- user_id = claims.user_id,
- login = claims.github_user_login,
- authn.jti = claims.jti,
- is_staff = claims.is_staff,
- provider = provider.to_string(),
- model = model,
- requests_this_minute = usage.requests_this_minute,
- tokens_this_minute = usage.tokens_this_minute,
- input_tokens_this_minute = usage.input_tokens_this_minute,
- output_tokens_this_minute = usage.output_tokens_this_minute,
- );
-
- let properties = json!({
- "has_llm_subscription": claims.has_llm_subscription,
- "max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents,
- "plan": match claims.plan {
- Plan::Free => "free".to_string(),
- Plan::ZedPro => "zed_pro".to_string(),
- },
- "model": model,
- "provider": provider,
- "usage": usage,
- "tokens": tokens
- });
- SnowflakeRow::new(
- "Language Model Used",
- Some(claims.metrics_id),
- claims.is_staff,
- claims.system_id.clone(),
- properties,
- )
- .write(&state.kinesis_client, &state.config.kinesis_stream)
- .await
- .log_err();
- }
- })
- }
-}
-
-pub fn log_usage_periodically(state: Arc<LlmState>) {
- state.executor.clone().spawn_detached(async move {
- loop {
- state
- .executor
- .sleep(std::time::Duration::from_secs(30))
- .await;
-
- for provider in LanguageModelProvider::iter() {
- for model in state.db.model_names_for_provider(provider) {
- if let Some(active_user_count) = state
- .get_active_user_count(provider, &model)
- .await
- .log_err()
- {
- tracing::info!(
- target: "active user counts",
- provider = provider.to_string(),
- model = model,
- users_in_recent_minutes = active_user_count.users_in_recent_minutes,
- users_in_recent_days = active_user_count.users_in_recent_days,
- );
- }
- }
- }
-
- if let Some(usages) = state
- .db
- .get_application_wide_usages_by_model(Utc::now())
- .await
- .log_err()
- {
- for usage in usages {
- tracing::info!(
- target: "computed usage",
- provider = usage.provider.to_string(),
- model = usage.model,
- requests_this_minute = usage.requests_this_minute,
- tokens_this_minute = usage.tokens_this_minute,
- input_tokens_this_minute = usage.input_tokens_this_minute,
- output_tokens_this_minute = usage.output_tokens_this_minute,
- );
- }
- }
- }
- })
-}
@@ -1,330 +0,0 @@
-use reqwest::StatusCode;
-use rpc::LanguageModelProvider;
-
-use crate::llm::LlmTokenClaims;
-use crate::{Config, Error, Result};
-
-pub fn authorize_access_to_language_model(
- config: &Config,
- claims: &LlmTokenClaims,
- country_code: Option<&str>,
- provider: LanguageModelProvider,
- model: &str,
-) -> Result<()> {
- authorize_access_for_country(config, country_code, provider)?;
- authorize_access_to_model(config, claims, provider, model)?;
- Ok(())
-}
-
-fn authorize_access_to_model(
- config: &Config,
- claims: &LlmTokenClaims,
- provider: LanguageModelProvider,
- model: &str,
-) -> Result<()> {
- if claims.is_staff {
- return Ok(());
- }
-
- if provider == LanguageModelProvider::Anthropic {
- if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
- return Ok(());
- }
-
- if claims.has_llm_closed_beta_feature_flag
- && Some(model) == config.llm_closed_beta_model_name.as_deref()
- {
- return Ok(());
- }
- }
-
- Err(Error::http(
- StatusCode::FORBIDDEN,
- format!("access to model {model:?} is not included in your plan"),
- ))
-}
-
-fn authorize_access_for_country(
- config: &Config,
- country_code: Option<&str>,
- provider: LanguageModelProvider,
-) -> Result<()> {
- // In development we won't have the `CF-IPCountry` header, so we can't check
- // the country code.
- //
- // This shouldn't be necessary, as anyone running in development will need to provide
- // their own API credentials in order to use an LLM provider.
- if config.is_development() {
- return Ok(());
- }
-
- // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
- let country_code = match country_code {
- // `XX` - Used for clients without country code data.
- None | Some("XX") => Err(Error::http(
- StatusCode::BAD_REQUEST,
- "no country code".to_string(),
- ))?,
- // `T1` - Used for clients using the Tor network.
- Some("T1") => Err(Error::http(
- StatusCode::FORBIDDEN,
- format!("access to {provider:?} models is not available over Tor"),
- ))?,
- Some(country_code) => country_code,
- };
-
- let is_country_supported_by_provider = match provider {
- LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
- LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
- LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
- };
- if !is_country_supported_by_provider {
- Err(Error::http(
- StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
- format!(
- "access to {provider:?} models is not available in your region ({country_code})"
- ),
- ))?
- }
-
- Ok(())
-}
-
-#[cfg(test)]
-mod tests {
- use axum::response::IntoResponse;
- use pretty_assertions::assert_eq;
- use rpc::proto::Plan;
-
- use super::*;
-
- #[gpui::test]
- async fn test_authorize_access_to_language_model_with_supported_country(
- _cx: &mut gpui::TestAppContext,
- ) {
- let config = Config::test();
-
- let claims = LlmTokenClaims {
- user_id: 99,
- plan: Plan::ZedPro,
- is_staff: true,
- ..Default::default()
- };
-
- let cases = vec![
- (LanguageModelProvider::Anthropic, "US"), // United States
- (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
- (LanguageModelProvider::OpenAi, "US"), // United States
- (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
- (LanguageModelProvider::Google, "US"), // United States
- (LanguageModelProvider::Google, "GB"), // United Kingdom
- ];
-
- for (provider, country_code) in cases {
- authorize_access_to_language_model(
- &config,
- &claims,
- Some(country_code),
- provider,
- "the-model",
- )
- .unwrap_or_else(|_| {
- panic!("expected authorization to return Ok for {provider:?}: {country_code}")
- })
- }
- }
-
- #[gpui::test]
- async fn test_authorize_access_to_language_model_with_unsupported_country(
- _cx: &mut gpui::TestAppContext,
- ) {
- let config = Config::test();
-
- let claims = LlmTokenClaims {
- user_id: 99,
- plan: Plan::ZedPro,
- ..Default::default()
- };
-
- let cases = vec![
- (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
- (LanguageModelProvider::Anthropic, "BY"), // Belarus
- (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
- (LanguageModelProvider::Anthropic, "CN"), // China
- (LanguageModelProvider::Anthropic, "CU"), // Cuba
- (LanguageModelProvider::Anthropic, "ER"), // Eritrea
- (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
- (LanguageModelProvider::Anthropic, "IR"), // Iran
- (LanguageModelProvider::Anthropic, "KP"), // North Korea
- (LanguageModelProvider::Anthropic, "XK"), // Kosovo
- (LanguageModelProvider::Anthropic, "LY"), // Libya
- (LanguageModelProvider::Anthropic, "MM"), // Myanmar
- (LanguageModelProvider::Anthropic, "RU"), // Russia
- (LanguageModelProvider::Anthropic, "SO"), // Somalia
- (LanguageModelProvider::Anthropic, "SS"), // South Sudan
- (LanguageModelProvider::Anthropic, "SD"), // Sudan
- (LanguageModelProvider::Anthropic, "SY"), // Syria
- (LanguageModelProvider::Anthropic, "VE"), // Venezuela
- (LanguageModelProvider::Anthropic, "YE"), // Yemen
- (LanguageModelProvider::OpenAi, "KP"), // North Korea
- (LanguageModelProvider::Google, "KP"), // North Korea
- ];
-
- for (provider, country_code) in cases {
- let error_response = authorize_access_to_language_model(
- &config,
- &claims,
- Some(country_code),
- provider,
- "the-model",
- )
- .expect_err(&format!(
- "expected authorization to return an error for {provider:?}: {country_code}"
- ))
- .into_response();
-
- assert_eq!(
- error_response.status(),
- StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
- );
- let response_body = hyper::body::to_bytes(error_response.into_body())
- .await
- .unwrap()
- .to_vec();
- assert_eq!(
- String::from_utf8(response_body).unwrap(),
- format!(
- "access to {provider:?} models is not available in your region ({country_code})"
- )
- );
- }
- }
-
- #[gpui::test]
- async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
- let config = Config::test();
-
- let claims = LlmTokenClaims {
- user_id: 99,
- plan: Plan::ZedPro,
- ..Default::default()
- };
-
- let cases = vec![
- (LanguageModelProvider::Anthropic, "T1"), // Tor
- (LanguageModelProvider::OpenAi, "T1"), // Tor
- (LanguageModelProvider::Google, "T1"), // Tor
- ];
-
- for (provider, country_code) in cases {
- let error_response = authorize_access_to_language_model(
- &config,
- &claims,
- Some(country_code),
- provider,
- "the-model",
- )
- .expect_err(&format!(
- "expected authorization to return an error for {provider:?}: {country_code}"
- ))
- .into_response();
-
- assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
- let response_body = hyper::body::to_bytes(error_response.into_body())
- .await
- .unwrap()
- .to_vec();
- assert_eq!(
- String::from_utf8(response_body).unwrap(),
- format!("access to {provider:?} models is not available over Tor")
- );
- }
- }
-
- #[gpui::test]
- async fn test_authorize_access_to_language_model_based_on_plan() {
- let config = Config::test();
-
- let test_cases = vec![
- // Pro plan should have access to claude-3.5-sonnet
- (
- Plan::ZedPro,
- LanguageModelProvider::Anthropic,
- "claude-3-5-sonnet",
- true,
- ),
- // Free plan should have access to claude-3.5-sonnet
- (
- Plan::Free,
- LanguageModelProvider::Anthropic,
- "claude-3-5-sonnet",
- true,
- ),
- // Pro plan should NOT have access to other Anthropic models
- (
- Plan::ZedPro,
- LanguageModelProvider::Anthropic,
- "claude-3-opus",
- false,
- ),
- ];
-
- for (plan, provider, model, expected_access) in test_cases {
- let claims = LlmTokenClaims {
- plan,
- ..Default::default()
- };
-
- let result =
- authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
-
- if expected_access {
- assert!(
- result.is_ok(),
- "Expected access to be granted for plan {:?}, provider {:?}, model {}",
- plan,
- provider,
- model
- );
- } else {
- let error = result.expect_err(&format!(
- "Expected access to be denied for plan {:?}, provider {:?}, model {}",
- plan, provider, model
- ));
- let response = error.into_response();
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
- }
- }
- }
-
- #[gpui::test]
- async fn test_authorize_access_to_language_model_for_staff() {
- let config = Config::test();
-
- let claims = LlmTokenClaims {
- is_staff: true,
- ..Default::default()
- };
-
- // Staff should have access to all models
- let test_cases = vec![
- (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
- (LanguageModelProvider::Anthropic, "claude-2"),
- (LanguageModelProvider::Anthropic, "claude-123-agi"),
- (LanguageModelProvider::OpenAi, "gpt-4"),
- (LanguageModelProvider::Google, "gemini-pro"),
- ];
-
- for (provider, model) in test_cases {
- let result =
- authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
-
- assert!(
- result.is_ok(),
- "Expected staff to have access to provider {:?}, model {}",
- provider,
- model
- );
- }
- }
-}
@@ -20,7 +20,6 @@ use std::future::Future;
use std::sync::Arc;
use anyhow::anyhow;
-pub use queries::usages::{ActiveUserCount, TokenUsage};
pub use sea_orm::ConnectOptions;
use sea_orm::prelude::*;
use sea_orm::{
@@ -2,5 +2,4 @@ use super::*;
pub mod billing_events;
pub mod providers;
-pub mod revoked_access_tokens;
pub mod usages;
@@ -1,15 +0,0 @@
-use super::*;
-
-impl LlmDatabase {
- /// Returns whether the access token with the given `jti` has been revoked.
- pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
- self.transaction(|tx| async move {
- Ok(revoked_access_token::Entity::find()
- .filter(revoked_access_token::Column::Jti.eq(jti))
- .one(&*tx)
- .await?
- .is_some())
- })
- .await
- }
-}
@@ -1,56 +1,12 @@
use crate::db::UserId;
use crate::llm::Cents;
-use chrono::{Datelike, Duration};
+use chrono::Datelike;
use futures::StreamExt as _;
-use rpc::LanguageModelProvider;
-use sea_orm::QuerySelect;
-use std::{iter, str::FromStr};
+use std::str::FromStr;
use strum::IntoEnumIterator as _;
use super::*;
-#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
-pub struct TokenUsage {
- pub input: usize,
- pub input_cache_creation: usize,
- pub input_cache_read: usize,
- pub output: usize,
-}
-
-impl TokenUsage {
- pub fn total(&self) -> usize {
- self.input + self.input_cache_creation + self.input_cache_read + self.output
- }
-}
-
-#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
-pub struct Usage {
- pub requests_this_minute: usize,
- pub tokens_this_minute: usize,
- pub input_tokens_this_minute: usize,
- pub output_tokens_this_minute: usize,
- pub tokens_this_day: usize,
- pub tokens_this_month: TokenUsage,
- pub spending_this_month: Cents,
- pub lifetime_spending: Cents,
-}
-
-#[derive(Debug, PartialEq, Clone)]
-pub struct ApplicationWideUsage {
- pub provider: LanguageModelProvider,
- pub model: String,
- pub requests_this_minute: usize,
- pub tokens_this_minute: usize,
- pub input_tokens_this_minute: usize,
- pub output_tokens_this_minute: usize,
-}
-
-#[derive(Clone, Copy, Debug, Default)]
-pub struct ActiveUserCount {
- pub users_in_recent_minutes: usize,
- pub users_in_recent_days: usize,
-}
-
impl LlmDatabase {
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
let all_measures = self
@@ -90,100 +46,6 @@ impl LlmDatabase {
Ok(())
}
- pub async fn get_application_wide_usages_by_model(
- &self,
- now: DateTimeUtc,
- ) -> Result<Vec<ApplicationWideUsage>> {
- self.transaction(|tx| async move {
- let past_minute = now - Duration::minutes(1);
- let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
- let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
- let input_tokens_per_minute =
- self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
- let output_tokens_per_minute =
- self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
-
- let mut results = Vec::new();
- for ((provider, model_name), model) in self.models.iter() {
- let mut usages = usage::Entity::find()
- .filter(
- usage::Column::Timestamp
- .gte(past_minute.naive_utc())
- .and(usage::Column::IsStaff.eq(false))
- .and(usage::Column::ModelId.eq(model.id))
- .and(
- usage::Column::MeasureId
- .eq(requests_per_minute)
- .or(usage::Column::MeasureId.eq(tokens_per_minute)),
- ),
- )
- .stream(&*tx)
- .await?;
-
- let mut requests_this_minute = 0;
- let mut tokens_this_minute = 0;
- let mut input_tokens_this_minute = 0;
- let mut output_tokens_this_minute = 0;
- while let Some(usage) = usages.next().await {
- let usage = usage?;
- if usage.measure_id == requests_per_minute {
- requests_this_minute += Self::get_live_buckets(
- &usage,
- now.naive_utc(),
- UsageMeasure::RequestsPerMinute,
- )
- .0
- .iter()
- .copied()
- .sum::<i64>() as usize;
- } else if usage.measure_id == tokens_per_minute {
- tokens_this_minute += Self::get_live_buckets(
- &usage,
- now.naive_utc(),
- UsageMeasure::TokensPerMinute,
- )
- .0
- .iter()
- .copied()
- .sum::<i64>() as usize;
- } else if usage.measure_id == input_tokens_per_minute {
- input_tokens_this_minute += Self::get_live_buckets(
- &usage,
- now.naive_utc(),
- UsageMeasure::InputTokensPerMinute,
- )
- .0
- .iter()
- .copied()
- .sum::<i64>() as usize;
- } else if usage.measure_id == output_tokens_per_minute {
- output_tokens_this_minute += Self::get_live_buckets(
- &usage,
- now.naive_utc(),
- UsageMeasure::OutputTokensPerMinute,
- )
- .0
- .iter()
- .copied()
- .sum::<i64>() as usize;
- }
- }
-
- results.push(ApplicationWideUsage {
- provider: *provider,
- model: model_name.clone(),
- requests_this_minute,
- tokens_this_minute,
- input_tokens_this_minute,
- output_tokens_this_minute,
- })
- }
-
- Ok(results)
- })
- .await
- }
-
pub async fn get_user_spending_for_month(
&self,
user_id: UserId,
@@ -223,499 +85,6 @@ impl LlmDatabase {
})
.await
}
-
- pub async fn get_usage(
- &self,
- user_id: UserId,
- provider: LanguageModelProvider,
- model_name: &str,
- now: DateTimeUtc,
- ) -> Result<Usage> {
- self.transaction(|tx| async move {
- let model = self
- .models
- .get(&(provider, model_name.to_string()))
- .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
-
- let usages = usage::Entity::find()
- .filter(
- usage::Column::UserId
- .eq(user_id)
- .and(usage::Column::ModelId.eq(model.id)),
- )
- .all(&*tx)
- .await?;
-
- let month = now.date_naive().month() as i32;
- let year = now.date_naive().year();
- let monthly_usage = monthly_usage::Entity::find()
- .filter(
- monthly_usage::Column::UserId
- .eq(user_id)
- .and(monthly_usage::Column::ModelId.eq(model.id))
- .and(monthly_usage::Column::Month.eq(month))
- .and(monthly_usage::Column::Year.eq(year)),
- )
- .one(&*tx)
- .await?;
- let lifetime_usage = lifetime_usage::Entity::find()
- .filter(
- lifetime_usage::Column::UserId
- .eq(user_id)
- .and(lifetime_usage::Column::ModelId.eq(model.id)),
- )
- .one(&*tx)
- .await?;
-
- let requests_this_minute =
- self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
- let tokens_this_minute =
- self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
- let input_tokens_this_minute =
- self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
- let output_tokens_this_minute =
- self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
- let tokens_this_day =
- self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
- let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
- calculate_spending(
- model,
- monthly_usage.input_tokens as usize,
- monthly_usage.cache_creation_input_tokens as usize,
- monthly_usage.cache_read_input_tokens as usize,
- monthly_usage.output_tokens as usize,
- )
- } else {
- Cents::ZERO
- };
- let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
- calculate_spending(
- model,
- lifetime_usage.input_tokens as usize,
- lifetime_usage.cache_creation_input_tokens as usize,
- lifetime_usage.cache_read_input_tokens as usize,
- lifetime_usage.output_tokens as usize,
- )
- } else {
- Cents::ZERO
- };
-
- Ok(Usage {
- requests_this_minute,
- tokens_this_minute,
- input_tokens_this_minute,
- output_tokens_this_minute,
- tokens_this_day,
- tokens_this_month: TokenUsage {
- input: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.input_tokens as usize),
- input_cache_creation: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
- input_cache_read: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.cache_read_input_tokens as usize),
- output: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.output_tokens as usize),
- },
- spending_this_month,
- lifetime_spending,
- })
- })
- .await
- }
-
- pub async fn record_usage(
- &self,
- user_id: UserId,
- is_staff: bool,
- provider: LanguageModelProvider,
- model_name: &str,
- tokens: TokenUsage,
- has_llm_subscription: bool,
- max_monthly_spend: Cents,
- free_tier_monthly_spending_limit: Cents,
- now: DateTimeUtc,
- ) -> Result<Usage> {
- self.transaction(|tx| async move {
- let model = self.model(provider, model_name)?;
-
- let usages = usage::Entity::find()
- .filter(
- usage::Column::UserId
- .eq(user_id)
- .and(usage::Column::ModelId.eq(model.id)),
- )
- .all(&*tx)
- .await?;
-
- let requests_this_minute = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::RequestsPerMinute,
- now,
- 1,
- &tx,
- )
- .await?;
- let tokens_this_minute = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::TokensPerMinute,
- now,
- tokens.total(),
- &tx,
- )
- .await?;
- let input_tokens_this_minute = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::InputTokensPerMinute,
- now,
- // Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
- tokens.input + tokens.input_cache_creation,
- &tx,
- )
- .await?;
- let output_tokens_this_minute = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::OutputTokensPerMinute,
- now,
- tokens.output,
- &tx,
- )
- .await?;
- let tokens_this_day = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::TokensPerDay,
- now,
- tokens.total(),
- &tx,
- )
- .await?;
-
- let month = now.date_naive().month() as i32;
- let year = now.date_naive().year();
-
- // Update monthly usage
- let monthly_usage = monthly_usage::Entity::find()
- .filter(
- monthly_usage::Column::UserId
- .eq(user_id)
- .and(monthly_usage::Column::ModelId.eq(model.id))
- .and(monthly_usage::Column::Month.eq(month))
- .and(monthly_usage::Column::Year.eq(year)),
- )
- .one(&*tx)
- .await?;
-
- let monthly_usage = match monthly_usage {
- Some(usage) => {
- monthly_usage::Entity::update(monthly_usage::ActiveModel {
- id: ActiveValue::unchanged(usage.id),
- input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
- cache_creation_input_tokens: ActiveValue::set(
- usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
- ),
- cache_read_input_tokens: ActiveValue::set(
- usage.cache_read_input_tokens + tokens.input_cache_read as i64,
- ),
- output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
- ..Default::default()
- })
- .exec(&*tx)
- .await?
- }
- None => {
- monthly_usage::ActiveModel {
- user_id: ActiveValue::set(user_id),
- model_id: ActiveValue::set(model.id),
- month: ActiveValue::set(month),
- year: ActiveValue::set(year),
- input_tokens: ActiveValue::set(tokens.input as i64),
- cache_creation_input_tokens: ActiveValue::set(
- tokens.input_cache_creation as i64,
- ),
- cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
- output_tokens: ActiveValue::set(tokens.output as i64),
- ..Default::default()
- }
- .insert(&*tx)
- .await?
- }
- };
-
- let spending_this_month = calculate_spending(
- model,
- monthly_usage.input_tokens as usize,
- monthly_usage.cache_creation_input_tokens as usize,
- monthly_usage.cache_read_input_tokens as usize,
- monthly_usage.output_tokens as usize,
- );
-
- if !is_staff
- && spending_this_month > free_tier_monthly_spending_limit
- && has_llm_subscription
- && (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend
- {
- billing_event::ActiveModel {
- id: ActiveValue::not_set(),
- idempotency_key: ActiveValue::not_set(),
- user_id: ActiveValue::set(user_id),
- model_id: ActiveValue::set(model.id),
- input_tokens: ActiveValue::set(tokens.input as i64),
- input_cache_creation_tokens: ActiveValue::set(
- tokens.input_cache_creation as i64,
- ),
- input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
- output_tokens: ActiveValue::set(tokens.output as i64),
- }
- .insert(&*tx)
- .await?;
- }
-
- // Update lifetime usage
- let lifetime_usage = lifetime_usage::Entity::find()
- .filter(
- lifetime_usage::Column::UserId
- .eq(user_id)
- .and(lifetime_usage::Column::ModelId.eq(model.id)),
- )
- .one(&*tx)
- .await?;
-
- let lifetime_usage = match lifetime_usage {
- Some(usage) => {
- lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
- id: ActiveValue::unchanged(usage.id),
- input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
- cache_creation_input_tokens: ActiveValue::set(
- usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
- ),
- cache_read_input_tokens: ActiveValue::set(
- usage.cache_read_input_tokens + tokens.input_cache_read as i64,
- ),
- output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
- ..Default::default()
- })
- .exec(&*tx)
- .await?
- }
- None => {
- lifetime_usage::ActiveModel {
- user_id: ActiveValue::set(user_id),
- model_id: ActiveValue::set(model.id),
- input_tokens: ActiveValue::set(tokens.input as i64),
- cache_creation_input_tokens: ActiveValue::set(
- tokens.input_cache_creation as i64,
- ),
- cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
- output_tokens: ActiveValue::set(tokens.output as i64),
- ..Default::default()
- }
- .insert(&*tx)
- .await?
- }
- };
-
- let lifetime_spending = calculate_spending(
- model,
- lifetime_usage.input_tokens as usize,
- lifetime_usage.cache_creation_input_tokens as usize,
- lifetime_usage.cache_read_input_tokens as usize,
- lifetime_usage.output_tokens as usize,
- );
-
- Ok(Usage {
- requests_this_minute,
- tokens_this_minute,
- input_tokens_this_minute,
- output_tokens_this_minute,
- tokens_this_day,
- tokens_this_month: TokenUsage {
- input: monthly_usage.input_tokens as usize,
- input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
- input_cache_read: monthly_usage.cache_read_input_tokens as usize,
- output: monthly_usage.output_tokens as usize,
- },
- spending_this_month,
- lifetime_spending,
- })
- })
- .await
- }
-
- /// Returns the active user count for the specified model.
- pub async fn get_active_user_count(
- &self,
- provider: LanguageModelProvider,
- model_name: &str,
- now: DateTimeUtc,
- ) -> Result<ActiveUserCount> {
- self.transaction(|tx| async move {
- let minute_since = now - Duration::minutes(5);
- let day_since = now - Duration::days(5);
-
- let model = self
- .models
- .get(&(provider, model_name.to_string()))
- .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
-
- let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
-
- let users_in_recent_minutes = usage::Entity::find()
- .filter(
- usage::Column::ModelId
- .eq(model.id)
- .and(usage::Column::MeasureId.eq(tokens_per_minute))
- .and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
- .and(usage::Column::IsStaff.eq(false)),
- )
- .select_only()
- .column(usage::Column::UserId)
- .group_by(usage::Column::UserId)
- .count(&*tx)
- .await? as usize;
-
- let users_in_recent_days = usage::Entity::find()
- .filter(
- usage::Column::ModelId
- .eq(model.id)
- .and(usage::Column::MeasureId.eq(tokens_per_minute))
- .and(usage::Column::Timestamp.gte(day_since.naive_utc()))
- .and(usage::Column::IsStaff.eq(false)),
- )
- .select_only()
- .column(usage::Column::UserId)
- .group_by(usage::Column::UserId)
- .count(&*tx)
- .await? as usize;
-
- Ok(ActiveUserCount {
- users_in_recent_minutes,
- users_in_recent_days,
- })
- })
- .await
- }
-
- async fn update_usage_for_measure(
- &self,
- user_id: UserId,
- is_staff: bool,
- model_id: ModelId,
- usages: &[usage::Model],
- usage_measure: UsageMeasure,
- now: DateTimeUtc,
- usage_to_add: usize,
- tx: &DatabaseTransaction,
- ) -> Result<usize> {
- let now = now.naive_utc();
- let measure_id = *self
- .usage_measure_ids
- .get(&usage_measure)
- .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
-
- let mut id = None;
- let mut timestamp = now;
- let mut buckets = vec![0_i64];
-
- if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
- id = Some(old_usage.id);
- let (live_buckets, buckets_since) =
- Self::get_live_buckets(old_usage, now, usage_measure);
- if !live_buckets.is_empty() {
- buckets.clear();
- buckets.extend_from_slice(live_buckets);
- buckets.extend(iter::repeat(0).take(buckets_since));
- timestamp =
- old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
- }
- }
-
- *buckets.last_mut().unwrap() += usage_to_add as i64;
- let total_usage = buckets.iter().sum::<i64>() as usize;
-
- let mut model = usage::ActiveModel {
- user_id: ActiveValue::set(user_id),
- is_staff: ActiveValue::set(is_staff),
- model_id: ActiveValue::set(model_id),
- measure_id: ActiveValue::set(measure_id),
- timestamp: ActiveValue::set(timestamp),
- buckets: ActiveValue::set(buckets),
- ..Default::default()
- };
-
- if let Some(id) = id {
- model.id = ActiveValue::unchanged(id);
- model.update(tx).await?;
- } else {
- usage::Entity::insert(model)
- .exec_without_returning(tx)
- .await?;
- }
-
- Ok(total_usage)
- }
-
- fn get_usage_for_measure(
- &self,
- usages: &[usage::Model],
- now: DateTimeUtc,
- usage_measure: UsageMeasure,
- ) -> Result<usize> {
- let now = now.naive_utc();
- let measure_id = *self
- .usage_measure_ids
- .get(&usage_measure)
- .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
- let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
- return Ok(0);
- };
-
- let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
- Ok(live_buckets.iter().sum::<i64>() as _)
- }
-
- fn get_live_buckets(
- usage: &usage::Model,
- now: chrono::NaiveDateTime,
- measure: UsageMeasure,
- ) -> (&[i64], usize) {
- let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
- let buckets_since_usage =
- seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
- let buckets_since_usage = buckets_since_usage.ceil() as usize;
- let mut live_buckets = &[] as &[i64];
- if buckets_since_usage < measure.bucket_count() {
- let expired_bucket_count =
- (usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
- live_buckets = &usage.buckets[expired_bucket_count..];
- while live_buckets.first() == Some(&0) {
- live_buckets = &live_buckets[1..];
- }
- }
- (live_buckets, buckets_since_usage)
- }
}
fn calculate_spending(
@@ -741,32 +110,3 @@ fn calculate_spending(
+ output_token_cost;
Cents::new(spending as u32)
}
-
-const MINUTE_BUCKET_COUNT: usize = 12;
-const DAY_BUCKET_COUNT: usize = 48;
-
-impl UsageMeasure {
- fn bucket_count(&self) -> usize {
- match self {
- UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
- UsageMeasure::TokensPerMinute
- | UsageMeasure::InputTokensPerMinute
- | UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
- UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
- }
- }
-
- fn total_duration(&self) -> Duration {
- match self {
- UsageMeasure::RequestsPerMinute => Duration::minutes(1),
- UsageMeasure::TokensPerMinute
- | UsageMeasure::InputTokensPerMinute
- | UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
- UsageMeasure::TokensPerDay => Duration::hours(24),
- }
- }
-
- fn bucket_duration(&self) -> Duration {
- self.total_duration() / self.bucket_count() as i32
- }
-}
@@ -1,8 +1,6 @@
pub mod billing_event;
-pub mod lifetime_usage;
pub mod model;
pub mod monthly_usage;
pub mod provider;
-pub mod revoked_access_token;
pub mod usage;
pub mod usage_measure;
@@ -1,20 +0,0 @@
-use crate::{db::UserId, llm::db::ModelId};
-use sea_orm::entity::prelude::*;
-
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "lifetime_usages")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: i32,
- pub user_id: UserId,
- pub model_id: ModelId,
- pub input_tokens: i64,
- pub cache_creation_input_tokens: i64,
- pub cache_read_input_tokens: i64,
- pub output_tokens: i64,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,19 +0,0 @@
-use chrono::NaiveDateTime;
-use sea_orm::entity::prelude::*;
-
-use crate::llm::db::RevokedAccessTokenId;
-
-/// A revoked access token.
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "revoked_access_tokens")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: RevokedAccessTokenId,
- pub jti: String,
- pub revoked_at: NaiveDateTime,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,6 +1,4 @@
-mod billing_tests;
mod provider_tests;
-mod usage_tests;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
@@ -1,152 +0,0 @@
-use crate::{
- Cents,
- db::UserId,
- llm::{
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- db::{LlmDatabase, TokenUsage, queries::providers::ModelParams},
- },
- test_llm_db,
-};
-use chrono::{DateTime, Utc};
-use pretty_assertions::assert_eq;
-use rpc::LanguageModelProvider;
-
-test_llm_db!(
- test_billing_limit_exceeded,
- test_billing_limit_exceeded_postgres
-);
-
-async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
- let provider = LanguageModelProvider::Anthropic;
- let model = "fake-claude-limerick";
- const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
- const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
-
- // Initialize the database and insert the model
- db.initialize().await.unwrap();
- db.insert_models(&[ModelParams {
- provider,
- name: model.to_string(),
- max_requests_per_minute: 5,
- max_tokens_per_minute: 10_000,
- max_tokens_per_day: 50_000,
- price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
- price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
- }])
- .await
- .unwrap();
-
- // Set a fixed datetime for consistent testing
- let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
- .unwrap()
- .with_timezone(&Utc);
-
- let user_id = UserId::from_proto(123);
-
- let max_monthly_spend = Cents::from_dollars(11);
-
- // Record usage that brings us close to the limit but doesn't exceed it
- // Let's say we use $10.50 worth of tokens
- let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
- let usage = TokenUsage {
- input: tokens_to_use,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- };
-
- // Verify that before we record any usage, there are 0 billing events
- let billing_events = db.get_billing_events().await.unwrap();
- assert_eq!(billing_events.len(), 0);
-
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- usage,
- true,
- max_monthly_spend,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- // Verify the recorded usage and spending
- let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- // Verify that we exceeded the free tier usage
- assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
- assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
-
- // Verify that there is one `billing_event` record
- let billing_events = db.get_billing_events().await.unwrap();
- assert_eq!(billing_events.len(), 1);
-
- let (billing_event, _model) = &billing_events[0];
- assert_eq!(billing_event.user_id, user_id);
- assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
- assert_eq!(billing_event.input_cache_creation_tokens, 0);
- assert_eq!(billing_event.input_cache_read_tokens, 0);
- assert_eq!(billing_event.output_tokens, 0);
-
- // Record usage that puts us at $20.50
- let usage_2 = TokenUsage {
- input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- };
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- usage_2,
- true,
- max_monthly_spend,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- // Verify the updated usage and spending
- let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
-
- // Verify that there are now two billing events
- let billing_events = db.get_billing_events().await.unwrap();
- assert_eq!(billing_events.len(), 2);
-
- let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
- let usage_exceeding = TokenUsage {
- input: tokens_to_exceed,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- };
-
- // This should still create a billing event as it's the first request that exceeds the limit
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- usage_exceeding,
- true,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- max_monthly_spend,
- now,
- )
- .await
- .unwrap();
- // Verify the updated usage and spending
- let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
-
- // Verify that we never exceed the user max spending for the user
- // and avoid charging them.
- let billing_events = db.get_billing_events().await.unwrap();
- assert_eq!(billing_events.len(), 2);
-}
@@ -1,306 +0,0 @@
-use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
-use crate::{
- Cents,
- db::UserId,
- llm::db::{
- LlmDatabase, TokenUsage,
- queries::{providers::ModelParams, usages::Usage},
- },
- test_llm_db,
-};
-use chrono::{DateTime, Duration, Utc};
-use pretty_assertions::assert_eq;
-use rpc::LanguageModelProvider;
-
-test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
-
-async fn test_tracking_usage(db: &mut LlmDatabase) {
- let provider = LanguageModelProvider::Anthropic;
- let model = "claude-3-5-sonnet";
-
- db.initialize().await.unwrap();
- db.insert_models(&[ModelParams {
- provider,
- name: model.to_string(),
- max_requests_per_minute: 5,
- max_tokens_per_minute: 10_000,
- max_tokens_per_day: 50_000,
- price_per_million_input_tokens: 50,
- price_per_million_output_tokens: 50,
- }])
- .await
- .unwrap();
-
- // We're using a fixed datetime to prevent flakiness based on the clock.
- let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
- .unwrap()
- .with_timezone(&Utc);
- let user_id = UserId::from_proto(123);
-
- let now = t0;
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- TokenUsage {
- input: 1000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- false,
- Cents::ZERO,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- let now = t0 + Duration::seconds(10);
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- TokenUsage {
- input: 2000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- false,
- Cents::ZERO,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 2,
- tokens_this_minute: 3000,
- input_tokens_this_minute: 3000,
- output_tokens_this_minute: 0,
- tokens_this_day: 3000,
- tokens_this_month: TokenUsage {
- input: 3000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- spending_this_month: Cents::ZERO,
- lifetime_spending: Cents::ZERO,
- }
- );
-
- let now = t0 + Duration::seconds(60);
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 1,
- tokens_this_minute: 2000,
- input_tokens_this_minute: 2000,
- output_tokens_this_minute: 0,
- tokens_this_day: 3000,
- tokens_this_month: TokenUsage {
- input: 3000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- spending_this_month: Cents::ZERO,
- lifetime_spending: Cents::ZERO,
- }
- );
-
- let now = t0 + Duration::seconds(60);
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- TokenUsage {
- input: 3000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- false,
- Cents::ZERO,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 2,
- tokens_this_minute: 5000,
- input_tokens_this_minute: 5000,
- output_tokens_this_minute: 0,
- tokens_this_day: 6000,
- tokens_this_month: TokenUsage {
- input: 6000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- spending_this_month: Cents::ZERO,
- lifetime_spending: Cents::ZERO,
- }
- );
-
- let t1 = t0 + Duration::hours(24);
- let now = t1;
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 0,
- tokens_this_minute: 0,
- input_tokens_this_minute: 0,
- output_tokens_this_minute: 0,
- tokens_this_day: 5000,
- tokens_this_month: TokenUsage {
- input: 6000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- spending_this_month: Cents::ZERO,
- lifetime_spending: Cents::ZERO,
- }
- );
-
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- TokenUsage {
- input: 4000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- false,
- Cents::ZERO,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 1,
- tokens_this_minute: 4000,
- input_tokens_this_minute: 4000,
- output_tokens_this_minute: 0,
- tokens_this_day: 9000,
- tokens_this_month: TokenUsage {
- input: 10000,
- input_cache_creation: 0,
- input_cache_read: 0,
- output: 0,
- },
- spending_this_month: Cents::ZERO,
- lifetime_spending: Cents::ZERO,
- }
- );
-
- // We're using a fixed datetime to prevent flakiness based on the clock.
- let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
- .unwrap()
- .with_timezone(&Utc);
-
- // Test cache creation input tokens
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- TokenUsage {
- input: 1000,
- input_cache_creation: 500,
- input_cache_read: 0,
- output: 0,
- },
- false,
- Cents::ZERO,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 1,
- tokens_this_minute: 1500,
- input_tokens_this_minute: 1500,
- output_tokens_this_minute: 0,
- tokens_this_day: 1500,
- tokens_this_month: TokenUsage {
- input: 1000,
- input_cache_creation: 500,
- input_cache_read: 0,
- output: 0,
- },
- spending_this_month: Cents::ZERO,
- lifetime_spending: Cents::ZERO,
- }
- );
-
- // Test cache read input tokens
- db.record_usage(
- user_id,
- false,
- provider,
- model,
- TokenUsage {
- input: 1000,
- input_cache_creation: 0,
- input_cache_read: 300,
- output: 0,
- },
- false,
- Cents::ZERO,
- FREE_TIER_MONTHLY_SPENDING_LIMIT,
- now,
- )
- .await
- .unwrap();
-
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 2,
- tokens_this_minute: 2800,
- input_tokens_this_minute: 2500,
- output_tokens_this_minute: 0,
- tokens_this_day: 2800,
- tokens_this_month: TokenUsage {
- input: 2000,
- input_cache_creation: 500,
- input_cache_read: 300,
- output: 0,
- },
- spending_this_month: Cents::ZERO,
- lifetime_spending: Cents::ZERO,
- }
- );
-}
@@ -9,14 +9,14 @@ use axum::{
use collab::api::CloudflareIpCountryHeader;
use collab::api::billing::sync_llm_usage_with_stripe_periodically;
-use collab::llm::{db::LlmDatabase, log_usage_periodically};
+use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
use collab::{
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
env, executor::Executor, rpc::ResultExt,
};
-use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState};
+use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database;
use std::{
env::args,
@@ -74,11 +74,10 @@ async fn main() -> Result<()> {
let mode = match args.next().as_deref() {
Some("collab") => ServiceMode::Collab,
Some("api") => ServiceMode::Api,
- Some("llm") => ServiceMode::Llm,
Some("all") => ServiceMode::All,
_ => {
return Err(anyhow!(
- "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
+ "usage: collab <version | migrate | seed | serve <api|collab|all>>"
))?;
}
};
@@ -97,20 +96,9 @@ async fn main() -> Result<()> {
let mut on_shutdown = None;
- if mode.is_llm() {
- setup_llm_database(&config).await?;
-
- let state = LlmState::new(config.clone(), Executor::Production).await?;
-
- log_usage_periodically(state.clone());
-
- app = app
- .merge(collab::llm::routes())
- .layer(Extension(state.clone()));
- }
-
if mode.is_collab() || mode.is_api() {
setup_app_database(&config).await?;
+ setup_llm_database(&config).await?;
let state = AppState::new(config, Executor::Production).await?;
@@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
}
-async fn handle_liveness_probe(
- app_state: Option<Extension<Arc<AppState>>>,
- llm_state: Option<Extension<Arc<LlmState>>>,
-) -> Result<String> {
+async fn handle_liveness_probe(app_state: Option<Extension<Arc<AppState>>>) -> Result<String> {
if let Some(state) = app_state {
state.db.get_all_users(0, 1).await?;
}
- if let Some(llm_state) = llm_state {
- llm_state.db.list_providers().await?;
- }
-
Ok("ok".to_string())
}