collab: Remove LLM service (#28728)

Marshall Bowers created

This PR removes the LLM service from collab, as it has been moved to
Cloudflare.

Release Notes:

- N/A

Change summary

.github/workflows/deploy_collab.yml                       |   8 
Cargo.lock                                                |   1 
crates/collab/Cargo.toml                                  |   1 
crates/collab/src/lib.rs                                  |   5 
crates/collab/src/llm.rs                                  | 767 --------
crates/collab/src/llm/authorization.rs                    | 330 ---
crates/collab/src/llm/db.rs                               |   1 
crates/collab/src/llm/db/queries.rs                       |   1 
crates/collab/src/llm/db/queries/revoked_access_tokens.rs |  15 
crates/collab/src/llm/db/queries/usages.rs                | 664 -------
crates/collab/src/llm/db/tables.rs                        |   2 
crates/collab/src/llm/db/tables/lifetime_usage.rs         |  20 
crates/collab/src/llm/db/tables/revoked_access_token.rs   |  19 
crates/collab/src/llm/db/tests.rs                         |   2 
crates/collab/src/llm/db/tests/billing_tests.rs           | 152 -
crates/collab/src/llm/db/tests/usage_tests.rs             | 306 ---
crates/collab/src/main.rs                                 |  29 
17 files changed, 8 insertions(+), 2,315 deletions(-)

Detailed changes

.github/workflows/deploy_collab.yml 🔗

@@ -117,12 +117,10 @@ jobs:
             export ZED_KUBE_NAMESPACE=production
             export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10
             export ZED_API_LOAD_BALANCER_SIZE_UNIT=2
-            export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2
           elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
             export ZED_KUBE_NAMESPACE=staging
             export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1
             export ZED_API_LOAD_BALANCER_SIZE_UNIT=1
-            export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1
           else
             echo "cowardly refusing to deploy from an unknown branch"
             exit 1
@@ -147,9 +145,3 @@ jobs:
           envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
           kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
           echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
-
-          export ZED_SERVICE_NAME=llm
-          export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT
-          envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
-          kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
-          echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"

Cargo.lock 🔗

@@ -2942,7 +2942,6 @@ dependencies = [
 name = "collab"
 version = "0.44.0"
 dependencies = [
- "anthropic",
  "anyhow",
  "assistant",
  "assistant_context_editor",

crates/collab/Cargo.toml 🔗

@@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
 test-support = ["sqlite"]
 
 [dependencies]
-anthropic.workspace = true
 anyhow.workspace = true
 async-stripe.workspace = true
 async-tungstenite.workspace = true

crates/collab/src/lib.rs 🔗

@@ -253,7 +253,6 @@ impl Config {
 pub enum ServiceMode {
     Api,
     Collab,
-    Llm,
     All,
 }
 
@@ -265,10 +264,6 @@ impl ServiceMode {
     pub fn is_api(&self) -> bool {
         matches!(self, Self::Api | Self::All)
     }
-
-    pub fn is_llm(&self) -> bool {
-        matches!(self, Self::Llm | Self::All)
-    }
 }
 
 pub struct AppState {

crates/collab/src/llm.rs 🔗

@@ -1,448 +1,10 @@
-mod authorization;
 pub mod db;
 mod token;
 
-use crate::api::CloudflareIpCountryHeader;
-use crate::api::events::SnowflakeRow;
-use crate::build_kinesis_client;
-use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE;
-use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor};
-use anyhow::{Context as _, anyhow};
-use authorization::authorize_access_to_language_model;
-use axum::routing::get;
-use axum::{
-    Extension, Json, Router, TypedHeader,
-    body::Body,
-    http::{self, HeaderName, HeaderValue, Request, StatusCode},
-    middleware::{self, Next},
-    response::{IntoResponse, Response},
-    routing::post,
-};
-use chrono::{DateTime, Duration, Utc};
-use collections::HashMap;
-use db::TokenUsage;
-use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure};
-use futures::{Stream, StreamExt as _};
-use reqwest_client::ReqwestClient;
-use rpc::{
-    EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan,
-};
-use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
-use serde_json::json;
-use std::{
-    pin::Pin,
-    sync::Arc,
-    task::{Context, Poll},
-};
-use strum::IntoEnumIterator;
-use tokio::sync::RwLock;
-use util::ResultExt;
+use crate::Cents;
 
 pub use token::*;
 
-const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
-
-pub struct LlmState {
-    pub config: Config,
-    pub executor: Executor,
-    pub db: Arc<LlmDatabase>,
-    pub http_client: ReqwestClient,
-    pub kinesis_client: Option<aws_sdk_kinesis::Client>,
-    active_user_count_by_model:
-        RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
-}
-
-impl LlmState {
-    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
-        let database_url = config
-            .llm_database_url
-            .as_ref()
-            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
-        let max_connections = config
-            .llm_database_max_connections
-            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
-
-        let mut db_options = db::ConnectOptions::new(database_url);
-        db_options.max_connections(max_connections);
-        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
-        db.initialize().await?;
-
-        let db = Arc::new(db);
-
-        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
-        let http_client =
-            ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
-
-        let this = Self {
-            executor,
-            db,
-            http_client,
-            kinesis_client: if config.kinesis_access_key.is_some() {
-                build_kinesis_client(&config).await.log_err()
-            } else {
-                None
-            },
-            active_user_count_by_model: RwLock::new(HashMap::default()),
-            config,
-        };
-
-        Ok(Arc::new(this))
-    }
-
-    pub async fn get_active_user_count(
-        &self,
-        provider: LanguageModelProvider,
-        model: &str,
-    ) -> Result<ActiveUserCount> {
-        let now = Utc::now();
-
-        {
-            let active_user_count_by_model = self.active_user_count_by_model.read().await;
-            if let Some((last_updated, count)) =
-                active_user_count_by_model.get(&(provider, model.to_string()))
-            {
-                if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
-                    return Ok(*count);
-                }
-            }
-        }
-
-        let mut cache = self.active_user_count_by_model.write().await;
-        let new_count = self.db.get_active_user_count(provider, model, now).await?;
-        cache.insert((provider, model.to_string()), (now, new_count));
-        Ok(new_count)
-    }
-}
-
-pub fn routes() -> Router<(), Body> {
-    Router::new()
-        .route("/models", get(list_models))
-        .route("/completion", post(perform_completion))
-        .layer(middleware::from_fn(validate_api_token))
-}
-
-async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
-    let token = req
-        .headers()
-        .get(http::header::AUTHORIZATION)
-        .and_then(|header| header.to_str().ok())
-        .ok_or_else(|| {
-            Error::http(
-                StatusCode::BAD_REQUEST,
-                "missing authorization header".to_string(),
-            )
-        })?
-        .strip_prefix("Bearer ")
-        .ok_or_else(|| {
-            Error::http(
-                StatusCode::BAD_REQUEST,
-                "invalid authorization header".to_string(),
-            )
-        })?;
-
-    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
-    match LlmTokenClaims::validate(token, &state.config) {
-        Ok(claims) => {
-            if state.db.is_access_token_revoked(&claims.jti).await? {
-                return Err(Error::http(
-                    StatusCode::UNAUTHORIZED,
-                    "unauthorized".to_string(),
-                ));
-            }
-
-            tracing::Span::current()
-                .record("user_id", claims.user_id)
-                .record("login", claims.github_user_login.clone())
-                .record("authn.jti", &claims.jti)
-                .record("is_staff", claims.is_staff);
-
-            req.extensions_mut().insert(claims);
-            Ok::<_, Error>(next.run(req).await.into_response())
-        }
-        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
-            StatusCode::UNAUTHORIZED,
-            "unauthorized".to_string(),
-            [(
-                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
-                HeaderValue::from_static("true"),
-            )]
-            .into_iter()
-            .collect(),
-        )),
-        Err(_err) => Err(Error::http(
-            StatusCode::UNAUTHORIZED,
-            "unauthorized".to_string(),
-        )),
-    }
-}
-
-async fn list_models(
-    Extension(state): Extension<Arc<LlmState>>,
-    Extension(claims): Extension<LlmTokenClaims>,
-    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
-) -> Result<Json<ListModelsResponse>> {
-    let country_code = country_code_header.map(|header| header.to_string());
-
-    let mut accessible_models = Vec::new();
-
-    for (provider, model) in state.db.all_models() {
-        let authorize_result = authorize_access_to_language_model(
-            &state.config,
-            &claims,
-            country_code.as_deref(),
-            provider,
-            &model.name,
-        );
-
-        if authorize_result.is_ok() {
-            accessible_models.push(rpc::LanguageModel {
-                provider,
-                name: model.name,
-            });
-        }
-    }
-
-    Ok(Json(ListModelsResponse {
-        models: accessible_models,
-    }))
-}
-
-async fn perform_completion(
-    Extension(state): Extension<Arc<LlmState>>,
-    Extension(claims): Extension<LlmTokenClaims>,
-    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
-    Json(params): Json<PerformCompletionParams>,
-) -> Result<impl IntoResponse> {
-    let model = normalize_model_name(
-        state.db.model_names_for_provider(params.provider),
-        params.model,
-    );
-
-    let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check;
-    if !bypass_account_age_check {
-        if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
-            Err(anyhow!("account too young"))?
-        }
-    }
-
-    authorize_access_to_language_model(
-        &state.config,
-        &claims,
-        country_code_header
-            .map(|header| header.to_string())
-            .as_deref(),
-        params.provider,
-        &model,
-    )?;
-
-    check_usage_limit(&state, params.provider, &model, &claims).await?;
-
-    let stream = match params.provider {
-        LanguageModelProvider::Anthropic => {
-            let api_key = if claims.is_staff {
-                state
-                    .config
-                    .anthropic_staff_api_key
-                    .as_ref()
-                    .context("no Anthropic AI staff API key configured on the server")?
-            } else {
-                state
-                    .config
-                    .anthropic_api_key
-                    .as_ref()
-                    .context("no Anthropic AI API key configured on the server")?
-            };
-
-            let mut request: anthropic::Request =
-                serde_json::from_str(params.provider_request.get())?;
-
-            // Override the model on the request with the latest version of the model that is
-            // known to the server.
-            //
-            // Right now, we use the version that's defined in `model.id()`, but we will likely
-            // want to change this code once a new version of an Anthropic model is released,
-            // so that users can use the new version, without having to update Zed.
-            request.model = match model.as_str() {
-                "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
-                "claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(),
-                "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
-                "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
-                "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
-                _ => request.model,
-            };
-
-            let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
-                &state.http_client,
-                anthropic::ANTHROPIC_API_URL,
-                api_key,
-                request,
-            )
-            .await
-            .map_err(|err| match err {
-                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
-                    Some(anthropic::ApiErrorCode::RateLimitError) => {
-                        tracing::info!(
-                            target: "upstream rate limit exceeded",
-                            user_id = claims.user_id,
-                            login = claims.github_user_login,
-                            authn.jti = claims.jti,
-                            is_staff = claims.is_staff,
-                            provider = params.provider.to_string(),
-                            model = model
-                        );
-
-                        Error::http(
-                            StatusCode::TOO_MANY_REQUESTS,
-                            "Upstream Anthropic rate limit exceeded.".to_string(),
-                        )
-                    }
-                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
-                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
-                    }
-                    Some(anthropic::ApiErrorCode::OverloadedError) => {
-                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
-                    }
-                    Some(_) => {
-                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
-                    }
-                    None => Error::Internal(anyhow!(err)),
-                },
-                anthropic::AnthropicError::Other(err) => Error::Internal(err),
-            })?;
-
-            if let Some(rate_limit_info) = rate_limit_info {
-                tracing::info!(
-                    target: "upstream rate limit",
-                    is_staff = claims.is_staff,
-                    provider = params.provider.to_string(),
-                    model = model,
-                    tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining),
-                    input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining),
-                    output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining),
-                    requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining),
-                    requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset),
-                    tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset),
-                    input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset),
-                    output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset),
-                );
-            }
-
-            chunks
-                .map(move |event| {
-                    let chunk = event?;
-                    let (
-                        input_tokens,
-                        output_tokens,
-                        cache_creation_input_tokens,
-                        cache_read_input_tokens,
-                    ) = match &chunk {
-                        anthropic::Event::MessageStart {
-                            message: anthropic::Response { usage, .. },
-                        }
-                        | anthropic::Event::MessageDelta { usage, .. } => (
-                            usage.input_tokens.unwrap_or(0) as usize,
-                            usage.output_tokens.unwrap_or(0) as usize,
-                            usage.cache_creation_input_tokens.unwrap_or(0) as usize,
-                            usage.cache_read_input_tokens.unwrap_or(0) as usize,
-                        ),
-                        _ => (0, 0, 0, 0),
-                    };
-
-                    anyhow::Ok(CompletionChunk {
-                        bytes: serde_json::to_vec(&chunk).unwrap(),
-                        input_tokens,
-                        output_tokens,
-                        cache_creation_input_tokens,
-                        cache_read_input_tokens,
-                    })
-                })
-                .boxed()
-        }
-        LanguageModelProvider::OpenAi => {
-            let api_key = state
-                .config
-                .openai_api_key
-                .as_ref()
-                .context("no OpenAI API key configured on the server")?;
-            let chunks = open_ai::stream_completion(
-                &state.http_client,
-                open_ai::OPEN_AI_API_URL,
-                api_key,
-                serde_json::from_str(params.provider_request.get())?,
-            )
-            .await?;
-
-            chunks
-                .map(|event| {
-                    event.map(|chunk| {
-                        let input_tokens =
-                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
-                        let output_tokens =
-                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
-                        CompletionChunk {
-                            bytes: serde_json::to_vec(&chunk).unwrap(),
-                            input_tokens,
-                            output_tokens,
-                            cache_creation_input_tokens: 0,
-                            cache_read_input_tokens: 0,
-                        }
-                    })
-                })
-                .boxed()
-        }
-        LanguageModelProvider::Google => {
-            let api_key = state
-                .config
-                .google_ai_api_key
-                .as_ref()
-                .context("no Google AI API key configured on the server")?;
-            let chunks = google_ai::stream_generate_content(
-                &state.http_client,
-                google_ai::API_URL,
-                api_key,
-                serde_json::from_str(params.provider_request.get())?,
-            )
-            .await?;
-
-            chunks
-                .map(|event| {
-                    event.map(|chunk| {
-                        // TODO - implement token counting for Google AI
-                        CompletionChunk {
-                            bytes: serde_json::to_vec(&chunk).unwrap(),
-                            input_tokens: 0,
-                            output_tokens: 0,
-                            cache_creation_input_tokens: 0,
-                            cache_read_input_tokens: 0,
-                        }
-                    })
-                })
-                .boxed()
-        }
-    };
-
-    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
-        state,
-        claims,
-        provider: params.provider,
-        model,
-        tokens: TokenUsage::default(),
-        inner_stream: stream,
-    })))
-}
-
-fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
-    if let Some(known_model_name) = known_models
-        .iter()
-        .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
-        .max_by_key(|known_model_name| known_model_name.len())
-    {
-        known_model_name.to_string()
-    } else {
-        name
-    }
-}
-
 /// The maximum monthly spending an individual user can reach on the free tier
 /// before they have to pay.
 pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
@@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
 ///
 /// Used to prevent surprise bills.
 pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
-
-async fn check_usage_limit(
-    state: &Arc<LlmState>,
-    provider: LanguageModelProvider,
-    model_name: &str,
-    claims: &LlmTokenClaims,
-) -> Result<()> {
-    if claims.is_staff {
-        return Ok(());
-    }
-
-    let user_id = UserId::from_proto(claims.user_id);
-    let model = state.db.model(provider, model_name)?;
-    let free_tier = claims.free_tier_monthly_spending_limit();
-
-    let spending_this_month = state
-        .db
-        .get_user_spending_for_month(user_id, Utc::now())
-        .await?;
-    if spending_this_month >= free_tier {
-        if !claims.has_llm_subscription {
-            return Err(Error::http(
-                StatusCode::PAYMENT_REQUIRED,
-                "Maximum spending limit reached for this month.".to_string(),
-            ));
-        }
-
-        let monthly_spend = spending_this_month.saturating_sub(free_tier);
-        if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) {
-            return Err(Error::Http(
-                StatusCode::FORBIDDEN,
-                "Maximum spending limit reached for this month.".to_string(),
-                [(
-                    HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
-                    HeaderValue::from_static("true"),
-                )]
-                .into_iter()
-                .collect(),
-            ));
-        }
-    }
-
-    let active_users = state.get_active_user_count(provider, model_name).await?;
-
-    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
-    let users_in_recent_days = active_users.users_in_recent_days.max(1);
-
-    let per_user_max_requests_per_minute =
-        model.max_requests_per_minute as usize / users_in_recent_minutes;
-    let per_user_max_tokens_per_minute =
-        model.max_tokens_per_minute as usize / users_in_recent_minutes;
-    let per_user_max_input_tokens_per_minute =
-        model.max_input_tokens_per_minute as usize / users_in_recent_minutes;
-    let per_user_max_output_tokens_per_minute =
-        model.max_output_tokens_per_minute as usize / users_in_recent_minutes;
-    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
-
-    let usage = state
-        .db
-        .get_usage(user_id, provider, model_name, Utc::now())
-        .await?;
-
-    let checks = match (provider, model_name) {
-        (LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![
-            (
-                usage.requests_this_minute,
-                per_user_max_requests_per_minute,
-                UsageMeasure::RequestsPerMinute,
-            ),
-            (
-                usage.input_tokens_this_minute,
-                per_user_max_tokens_per_minute,
-                UsageMeasure::InputTokensPerMinute,
-            ),
-            (
-                usage.output_tokens_this_minute,
-                per_user_max_tokens_per_minute,
-                UsageMeasure::OutputTokensPerMinute,
-            ),
-            (
-                usage.tokens_this_day,
-                per_user_max_tokens_per_day,
-                UsageMeasure::TokensPerDay,
-            ),
-        ],
-        _ => vec![
-            (
-                usage.requests_this_minute,
-                per_user_max_requests_per_minute,
-                UsageMeasure::RequestsPerMinute,
-            ),
-            (
-                usage.tokens_this_minute,
-                per_user_max_tokens_per_minute,
-                UsageMeasure::TokensPerMinute,
-            ),
-            (
-                usage.tokens_this_day,
-                per_user_max_tokens_per_day,
-                UsageMeasure::TokensPerDay,
-            ),
-        ],
-    };
-
-    for (used, limit, usage_measure) in checks {
-        if used > limit {
-            let resource = match usage_measure {
-                UsageMeasure::RequestsPerMinute => "requests_per_minute",
-                UsageMeasure::TokensPerMinute => "tokens_per_minute",
-                UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute",
-                UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute",
-                UsageMeasure::TokensPerDay => "tokens_per_day",
-            };
-
-            tracing::info!(
-                target: "user rate limit",
-                user_id = claims.user_id,
-                login = claims.github_user_login,
-                authn.jti = claims.jti,
-                is_staff = claims.is_staff,
-                provider = provider.to_string(),
-                model = model.name,
-                usage_measure = resource,
-                requests_this_minute = usage.requests_this_minute,
-                tokens_this_minute = usage.tokens_this_minute,
-                input_tokens_this_minute = usage.input_tokens_this_minute,
-                output_tokens_this_minute = usage.output_tokens_this_minute,
-                tokens_this_day = usage.tokens_this_day,
-                users_in_recent_minutes = users_in_recent_minutes,
-                users_in_recent_days = users_in_recent_days,
-                max_requests_per_minute = per_user_max_requests_per_minute,
-                max_tokens_per_minute = per_user_max_tokens_per_minute,
-                max_input_tokens_per_minute = per_user_max_input_tokens_per_minute,
-                max_output_tokens_per_minute = per_user_max_output_tokens_per_minute,
-                max_tokens_per_day = per_user_max_tokens_per_day,
-            );
-
-            SnowflakeRow::new(
-                "Language Model Rate Limited",
-                Some(claims.metrics_id),
-                claims.is_staff,
-                claims.system_id.clone(),
-                json!({
-                    "usage": usage,
-                    "users_in_recent_minutes": users_in_recent_minutes,
-                    "users_in_recent_days": users_in_recent_days,
-                    "max_requests_per_minute": per_user_max_requests_per_minute,
-                    "max_tokens_per_minute": per_user_max_tokens_per_minute,
-                    "max_input_tokens_per_minute": per_user_max_input_tokens_per_minute,
-                    "max_output_tokens_per_minute": per_user_max_output_tokens_per_minute,
-                    "max_tokens_per_day": per_user_max_tokens_per_day,
-                    "plan": match claims.plan {
-                        Plan::Free => "free".to_string(),
-                        Plan::ZedPro => "zed_pro".to_string(),
-                    },
-                    "model": model.name.clone(),
-                    "provider": provider.to_string(),
-                    "usage_measure": resource.to_string(),
-                }),
-            )
-            .write(&state.kinesis_client, &state.config.kinesis_stream)
-            .await
-            .log_err();
-
-            return Err(Error::http(
-                StatusCode::TOO_MANY_REQUESTS,
-                format!("Rate limit exceeded. Maximum {} reached.", resource),
-            ));
-        }
-    }
-
-    Ok(())
-}
-
-struct CompletionChunk {
-    bytes: Vec<u8>,
-    input_tokens: usize,
-    output_tokens: usize,
-    cache_creation_input_tokens: usize,
-    cache_read_input_tokens: usize,
-}
-
-struct TokenCountingStream<S> {
-    state: Arc<LlmState>,
-    claims: LlmTokenClaims,
-    provider: LanguageModelProvider,
-    model: String,
-    tokens: TokenUsage,
-    inner_stream: S,
-}
-
-impl<S> Stream for TokenCountingStream<S>
-where
-    S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
-{
-    type Item = Result<Vec<u8>, anyhow::Error>;
-
-    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        match Pin::new(&mut self.inner_stream).poll_next(cx) {
-            Poll::Ready(Some(Ok(mut chunk))) => {
-                chunk.bytes.push(b'\n');
-                self.tokens.input += chunk.input_tokens;
-                self.tokens.output += chunk.output_tokens;
-                self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
-                self.tokens.input_cache_read += chunk.cache_read_input_tokens;
-                Poll::Ready(Some(Ok(chunk.bytes)))
-            }
-            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
-            Poll::Ready(None) => Poll::Ready(None),
-            Poll::Pending => Poll::Pending,
-        }
-    }
-}
-
-impl<S> Drop for TokenCountingStream<S> {
-    fn drop(&mut self) {
-        let state = self.state.clone();
-        let claims = self.claims.clone();
-        let provider = self.provider;
-        let model = std::mem::take(&mut self.model);
-        let tokens = self.tokens;
-        self.state.executor.spawn_detached(async move {
-            let usage = state
-                .db
-                .record_usage(
-                    UserId::from_proto(claims.user_id),
-                    claims.is_staff,
-                    provider,
-                    &model,
-                    tokens,
-                    claims.has_llm_subscription,
-                    Cents(claims.max_monthly_spend_in_cents),
-                    claims.free_tier_monthly_spending_limit(),
-                    Utc::now(),
-                )
-                .await
-                .log_err();
-
-            if let Some(usage) = usage {
-                tracing::info!(
-                    target: "user usage",
-                    user_id = claims.user_id,
-                    login = claims.github_user_login,
-                    authn.jti = claims.jti,
-                    is_staff = claims.is_staff,
-                    provider = provider.to_string(),
-                    model = model,
-                    requests_this_minute = usage.requests_this_minute,
-                    tokens_this_minute = usage.tokens_this_minute,
-                    input_tokens_this_minute = usage.input_tokens_this_minute,
-                    output_tokens_this_minute = usage.output_tokens_this_minute,
-                );
-
-                let properties = json!({
-                    "has_llm_subscription": claims.has_llm_subscription,
-                    "max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents,
-                    "plan": match claims.plan {
-                        Plan::Free => "free".to_string(),
-                        Plan::ZedPro => "zed_pro".to_string(),
-                    },
-                    "model": model,
-                    "provider": provider,
-                    "usage": usage,
-                    "tokens": tokens
-                });
-                SnowflakeRow::new(
-                    "Language Model Used",
-                    Some(claims.metrics_id),
-                    claims.is_staff,
-                    claims.system_id.clone(),
-                    properties,
-                )
-                .write(&state.kinesis_client, &state.config.kinesis_stream)
-                .await
-                .log_err();
-            }
-        })
-    }
-}
-
-pub fn log_usage_periodically(state: Arc<LlmState>) {
-    state.executor.clone().spawn_detached(async move {
-        loop {
-            state
-                .executor
-                .sleep(std::time::Duration::from_secs(30))
-                .await;
-
-            for provider in LanguageModelProvider::iter() {
-                for model in state.db.model_names_for_provider(provider) {
-                    if let Some(active_user_count) = state
-                        .get_active_user_count(provider, &model)
-                        .await
-                        .log_err()
-                    {
-                        tracing::info!(
-                            target: "active user counts",
-                            provider = provider.to_string(),
-                            model = model,
-                            users_in_recent_minutes = active_user_count.users_in_recent_minutes,
-                            users_in_recent_days = active_user_count.users_in_recent_days,
-                        );
-                    }
-                }
-            }
-
-            if let Some(usages) = state
-                .db
-                .get_application_wide_usages_by_model(Utc::now())
-                .await
-                .log_err()
-            {
-                for usage in usages {
-                    tracing::info!(
-                        target: "computed usage",
-                        provider = usage.provider.to_string(),
-                        model = usage.model,
-                        requests_this_minute = usage.requests_this_minute,
-                        tokens_this_minute = usage.tokens_this_minute,
-                        input_tokens_this_minute = usage.input_tokens_this_minute,
-                        output_tokens_this_minute = usage.output_tokens_this_minute,
-                    );
-                }
-            }
-        }
-    })
-}

crates/collab/src/llm/authorization.rs 🔗

@@ -1,330 +0,0 @@
-use reqwest::StatusCode;
-use rpc::LanguageModelProvider;
-
-use crate::llm::LlmTokenClaims;
-use crate::{Config, Error, Result};
-
-pub fn authorize_access_to_language_model(
-    config: &Config,
-    claims: &LlmTokenClaims,
-    country_code: Option<&str>,
-    provider: LanguageModelProvider,
-    model: &str,
-) -> Result<()> {
-    authorize_access_for_country(config, country_code, provider)?;
-    authorize_access_to_model(config, claims, provider, model)?;
-    Ok(())
-}
-
-fn authorize_access_to_model(
-    config: &Config,
-    claims: &LlmTokenClaims,
-    provider: LanguageModelProvider,
-    model: &str,
-) -> Result<()> {
-    if claims.is_staff {
-        return Ok(());
-    }
-
-    if provider == LanguageModelProvider::Anthropic {
-        if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
-            return Ok(());
-        }
-
-        if claims.has_llm_closed_beta_feature_flag
-            && Some(model) == config.llm_closed_beta_model_name.as_deref()
-        {
-            return Ok(());
-        }
-    }
-
-    Err(Error::http(
-        StatusCode::FORBIDDEN,
-        format!("access to model {model:?} is not included in your plan"),
-    ))
-}
-
-fn authorize_access_for_country(
-    config: &Config,
-    country_code: Option<&str>,
-    provider: LanguageModelProvider,
-) -> Result<()> {
-    // In development we won't have the `CF-IPCountry` header, so we can't check
-    // the country code.
-    //
-    // This shouldn't be necessary, as anyone running in development will need to provide
-    // their own API credentials in order to use an LLM provider.
-    if config.is_development() {
-        return Ok(());
-    }
-
-    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
-    let country_code = match country_code {
-        // `XX` - Used for clients without country code data.
-        None | Some("XX") => Err(Error::http(
-            StatusCode::BAD_REQUEST,
-            "no country code".to_string(),
-        ))?,
-        // `T1` - Used for clients using the Tor network.
-        Some("T1") => Err(Error::http(
-            StatusCode::FORBIDDEN,
-            format!("access to {provider:?} models is not available over Tor"),
-        ))?,
-        Some(country_code) => country_code,
-    };
-
-    let is_country_supported_by_provider = match provider {
-        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
-        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
-        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
-    };
-    if !is_country_supported_by_provider {
-        Err(Error::http(
-            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
-            format!(
-                "access to {provider:?} models is not available in your region ({country_code})"
-            ),
-        ))?
-    }
-
-    Ok(())
-}
-
-#[cfg(test)]
-mod tests {
-    use axum::response::IntoResponse;
-    use pretty_assertions::assert_eq;
-    use rpc::proto::Plan;
-
-    use super::*;
-
-    #[gpui::test]
-    async fn test_authorize_access_to_language_model_with_supported_country(
-        _cx: &mut gpui::TestAppContext,
-    ) {
-        let config = Config::test();
-
-        let claims = LlmTokenClaims {
-            user_id: 99,
-            plan: Plan::ZedPro,
-            is_staff: true,
-            ..Default::default()
-        };
-
-        let cases = vec![
-            (LanguageModelProvider::Anthropic, "US"), // United States
-            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
-            (LanguageModelProvider::OpenAi, "US"),    // United States
-            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
-            (LanguageModelProvider::Google, "US"),    // United States
-            (LanguageModelProvider::Google, "GB"),    // United Kingdom
-        ];
-
-        for (provider, country_code) in cases {
-            authorize_access_to_language_model(
-                &config,
-                &claims,
-                Some(country_code),
-                provider,
-                "the-model",
-            )
-            .unwrap_or_else(|_| {
-                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
-            })
-        }
-    }
-
-    #[gpui::test]
-    async fn test_authorize_access_to_language_model_with_unsupported_country(
-        _cx: &mut gpui::TestAppContext,
-    ) {
-        let config = Config::test();
-
-        let claims = LlmTokenClaims {
-            user_id: 99,
-            plan: Plan::ZedPro,
-            ..Default::default()
-        };
-
-        let cases = vec![
-            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
-            (LanguageModelProvider::Anthropic, "BY"), // Belarus
-            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
-            (LanguageModelProvider::Anthropic, "CN"), // China
-            (LanguageModelProvider::Anthropic, "CU"), // Cuba
-            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
-            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
-            (LanguageModelProvider::Anthropic, "IR"), // Iran
-            (LanguageModelProvider::Anthropic, "KP"), // North Korea
-            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
-            (LanguageModelProvider::Anthropic, "LY"), // Libya
-            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
-            (LanguageModelProvider::Anthropic, "RU"), // Russia
-            (LanguageModelProvider::Anthropic, "SO"), // Somalia
-            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
-            (LanguageModelProvider::Anthropic, "SD"), // Sudan
-            (LanguageModelProvider::Anthropic, "SY"), // Syria
-            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
-            (LanguageModelProvider::Anthropic, "YE"), // Yemen
-            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
-            (LanguageModelProvider::Google, "KP"),    // North Korea
-        ];
-
-        for (provider, country_code) in cases {
-            let error_response = authorize_access_to_language_model(
-                &config,
-                &claims,
-                Some(country_code),
-                provider,
-                "the-model",
-            )
-            .expect_err(&format!(
-                "expected authorization to return an error for {provider:?}: {country_code}"
-            ))
-            .into_response();
-
-            assert_eq!(
-                error_response.status(),
-                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
-            );
-            let response_body = hyper::body::to_bytes(error_response.into_body())
-                .await
-                .unwrap()
-                .to_vec();
-            assert_eq!(
-                String::from_utf8(response_body).unwrap(),
-                format!(
-                    "access to {provider:?} models is not available in your region ({country_code})"
-                )
-            );
-        }
-    }
-
-    #[gpui::test]
-    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
-        let config = Config::test();
-
-        let claims = LlmTokenClaims {
-            user_id: 99,
-            plan: Plan::ZedPro,
-            ..Default::default()
-        };
-
-        let cases = vec![
-            (LanguageModelProvider::Anthropic, "T1"), // Tor
-            (LanguageModelProvider::OpenAi, "T1"),    // Tor
-            (LanguageModelProvider::Google, "T1"),    // Tor
-        ];
-
-        for (provider, country_code) in cases {
-            let error_response = authorize_access_to_language_model(
-                &config,
-                &claims,
-                Some(country_code),
-                provider,
-                "the-model",
-            )
-            .expect_err(&format!(
-                "expected authorization to return an error for {provider:?}: {country_code}"
-            ))
-            .into_response();
-
-            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
-            let response_body = hyper::body::to_bytes(error_response.into_body())
-                .await
-                .unwrap()
-                .to_vec();
-            assert_eq!(
-                String::from_utf8(response_body).unwrap(),
-                format!("access to {provider:?} models is not available over Tor")
-            );
-        }
-    }
-
-    #[gpui::test]
-    async fn test_authorize_access_to_language_model_based_on_plan() {
-        let config = Config::test();
-
-        let test_cases = vec![
-            // Pro plan should have access to claude-3.5-sonnet
-            (
-                Plan::ZedPro,
-                LanguageModelProvider::Anthropic,
-                "claude-3-5-sonnet",
-                true,
-            ),
-            // Free plan should have access to claude-3.5-sonnet
-            (
-                Plan::Free,
-                LanguageModelProvider::Anthropic,
-                "claude-3-5-sonnet",
-                true,
-            ),
-            // Pro plan should NOT have access to other Anthropic models
-            (
-                Plan::ZedPro,
-                LanguageModelProvider::Anthropic,
-                "claude-3-opus",
-                false,
-            ),
-        ];
-
-        for (plan, provider, model, expected_access) in test_cases {
-            let claims = LlmTokenClaims {
-                plan,
-                ..Default::default()
-            };
-
-            let result =
-                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
-
-            if expected_access {
-                assert!(
-                    result.is_ok(),
-                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
-                    plan,
-                    provider,
-                    model
-                );
-            } else {
-                let error = result.expect_err(&format!(
-                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
-                    plan, provider, model
-                ));
-                let response = error.into_response();
-                assert_eq!(response.status(), StatusCode::FORBIDDEN);
-            }
-        }
-    }
-
-    #[gpui::test]
-    async fn test_authorize_access_to_language_model_for_staff() {
-        let config = Config::test();
-
-        let claims = LlmTokenClaims {
-            is_staff: true,
-            ..Default::default()
-        };
-
-        // Staff should have access to all models
-        let test_cases = vec![
-            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
-            (LanguageModelProvider::Anthropic, "claude-2"),
-            (LanguageModelProvider::Anthropic, "claude-123-agi"),
-            (LanguageModelProvider::OpenAi, "gpt-4"),
-            (LanguageModelProvider::Google, "gemini-pro"),
-        ];
-
-        for (provider, model) in test_cases {
-            let result =
-                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
-
-            assert!(
-                result.is_ok(),
-                "Expected staff to have access to provider {:?}, model {}",
-                provider,
-                model
-            );
-        }
-    }
-}

crates/collab/src/llm/db.rs 🔗

@@ -20,7 +20,6 @@ use std::future::Future;
 use std::sync::Arc;
 
 use anyhow::anyhow;
-pub use queries::usages::{ActiveUserCount, TokenUsage};
 pub use sea_orm::ConnectOptions;
 use sea_orm::prelude::*;
 use sea_orm::{

crates/collab/src/llm/db/queries/revoked_access_tokens.rs 🔗

@@ -1,15 +0,0 @@
-use super::*;
-
-impl LlmDatabase {
-    /// Returns whether the access token with the given `jti` has been revoked.
-    pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
-        self.transaction(|tx| async move {
-            Ok(revoked_access_token::Entity::find()
-                .filter(revoked_access_token::Column::Jti.eq(jti))
-                .one(&*tx)
-                .await?
-                .is_some())
-        })
-        .await
-    }
-}

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -1,56 +1,12 @@
 use crate::db::UserId;
 use crate::llm::Cents;
-use chrono::{Datelike, Duration};
+use chrono::Datelike;
 use futures::StreamExt as _;
-use rpc::LanguageModelProvider;
-use sea_orm::QuerySelect;
-use std::{iter, str::FromStr};
+use std::str::FromStr;
 use strum::IntoEnumIterator as _;
 
 use super::*;
 
-#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
-pub struct TokenUsage {
-    pub input: usize,
-    pub input_cache_creation: usize,
-    pub input_cache_read: usize,
-    pub output: usize,
-}
-
-impl TokenUsage {
-    pub fn total(&self) -> usize {
-        self.input + self.input_cache_creation + self.input_cache_read + self.output
-    }
-}
-
-#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
-pub struct Usage {
-    pub requests_this_minute: usize,
-    pub tokens_this_minute: usize,
-    pub input_tokens_this_minute: usize,
-    pub output_tokens_this_minute: usize,
-    pub tokens_this_day: usize,
-    pub tokens_this_month: TokenUsage,
-    pub spending_this_month: Cents,
-    pub lifetime_spending: Cents,
-}
-
-#[derive(Debug, PartialEq, Clone)]
-pub struct ApplicationWideUsage {
-    pub provider: LanguageModelProvider,
-    pub model: String,
-    pub requests_this_minute: usize,
-    pub tokens_this_minute: usize,
-    pub input_tokens_this_minute: usize,
-    pub output_tokens_this_minute: usize,
-}
-
-#[derive(Clone, Copy, Debug, Default)]
-pub struct ActiveUserCount {
-    pub users_in_recent_minutes: usize,
-    pub users_in_recent_days: usize,
-}
-
 impl LlmDatabase {
     pub async fn initialize_usage_measures(&mut self) -> Result<()> {
         let all_measures = self
@@ -90,100 +46,6 @@ impl LlmDatabase {
         Ok(())
     }
 
-    pub async fn get_application_wide_usages_by_model(
-        &self,
-        now: DateTimeUtc,
-    ) -> Result<Vec<ApplicationWideUsage>> {
-        self.transaction(|tx| async move {
-            let past_minute = now - Duration::minutes(1);
-            let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
-            let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
-            let input_tokens_per_minute =
-                self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
-            let output_tokens_per_minute =
-                self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
-
-            let mut results = Vec::new();
-            for ((provider, model_name), model) in self.models.iter() {
-                let mut usages = usage::Entity::find()
-                    .filter(
-                        usage::Column::Timestamp
-                            .gte(past_minute.naive_utc())
-                            .and(usage::Column::IsStaff.eq(false))
-                            .and(usage::Column::ModelId.eq(model.id))
-                            .and(
-                                usage::Column::MeasureId
-                                    .eq(requests_per_minute)
-                                    .or(usage::Column::MeasureId.eq(tokens_per_minute)),
-                            ),
-                    )
-                    .stream(&*tx)
-                    .await?;
-
-                let mut requests_this_minute = 0;
-                let mut tokens_this_minute = 0;
-                let mut input_tokens_this_minute = 0;
-                let mut output_tokens_this_minute = 0;
-                while let Some(usage) = usages.next().await {
-                    let usage = usage?;
-                    if usage.measure_id == requests_per_minute {
-                        requests_this_minute += Self::get_live_buckets(
-                            &usage,
-                            now.naive_utc(),
-                            UsageMeasure::RequestsPerMinute,
-                        )
-                        .0
-                        .iter()
-                        .copied()
-                        .sum::<i64>() as usize;
-                    } else if usage.measure_id == tokens_per_minute {
-                        tokens_this_minute += Self::get_live_buckets(
-                            &usage,
-                            now.naive_utc(),
-                            UsageMeasure::TokensPerMinute,
-                        )
-                        .0
-                        .iter()
-                        .copied()
-                        .sum::<i64>() as usize;
-                    } else if usage.measure_id == input_tokens_per_minute {
-                        input_tokens_this_minute += Self::get_live_buckets(
-                            &usage,
-                            now.naive_utc(),
-                            UsageMeasure::InputTokensPerMinute,
-                        )
-                        .0
-                        .iter()
-                        .copied()
-                        .sum::<i64>() as usize;
-                    } else if usage.measure_id == output_tokens_per_minute {
-                        output_tokens_this_minute += Self::get_live_buckets(
-                            &usage,
-                            now.naive_utc(),
-                            UsageMeasure::OutputTokensPerMinute,
-                        )
-                        .0
-                        .iter()
-                        .copied()
-                        .sum::<i64>() as usize;
-                    }
-                }
-
-                results.push(ApplicationWideUsage {
-                    provider: *provider,
-                    model: model_name.clone(),
-                    requests_this_minute,
-                    tokens_this_minute,
-                    input_tokens_this_minute,
-                    output_tokens_this_minute,
-                })
-            }
-
-            Ok(results)
-        })
-        .await
-    }
-
     pub async fn get_user_spending_for_month(
         &self,
         user_id: UserId,
@@ -223,499 +85,6 @@ impl LlmDatabase {
         })
         .await
     }
-
-    pub async fn get_usage(
-        &self,
-        user_id: UserId,
-        provider: LanguageModelProvider,
-        model_name: &str,
-        now: DateTimeUtc,
-    ) -> Result<Usage> {
-        self.transaction(|tx| async move {
-            let model = self
-                .models
-                .get(&(provider, model_name.to_string()))
-                .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
-
-            let usages = usage::Entity::find()
-                .filter(
-                    usage::Column::UserId
-                        .eq(user_id)
-                        .and(usage::Column::ModelId.eq(model.id)),
-                )
-                .all(&*tx)
-                .await?;
-
-            let month = now.date_naive().month() as i32;
-            let year = now.date_naive().year();
-            let monthly_usage = monthly_usage::Entity::find()
-                .filter(
-                    monthly_usage::Column::UserId
-                        .eq(user_id)
-                        .and(monthly_usage::Column::ModelId.eq(model.id))
-                        .and(monthly_usage::Column::Month.eq(month))
-                        .and(monthly_usage::Column::Year.eq(year)),
-                )
-                .one(&*tx)
-                .await?;
-            let lifetime_usage = lifetime_usage::Entity::find()
-                .filter(
-                    lifetime_usage::Column::UserId
-                        .eq(user_id)
-                        .and(lifetime_usage::Column::ModelId.eq(model.id)),
-                )
-                .one(&*tx)
-                .await?;
-
-            let requests_this_minute =
-                self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
-            let tokens_this_minute =
-                self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
-            let input_tokens_this_minute =
-                self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
-            let output_tokens_this_minute =
-                self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
-            let tokens_this_day =
-                self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
-            let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
-                calculate_spending(
-                    model,
-                    monthly_usage.input_tokens as usize,
-                    monthly_usage.cache_creation_input_tokens as usize,
-                    monthly_usage.cache_read_input_tokens as usize,
-                    monthly_usage.output_tokens as usize,
-                )
-            } else {
-                Cents::ZERO
-            };
-            let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
-                calculate_spending(
-                    model,
-                    lifetime_usage.input_tokens as usize,
-                    lifetime_usage.cache_creation_input_tokens as usize,
-                    lifetime_usage.cache_read_input_tokens as usize,
-                    lifetime_usage.output_tokens as usize,
-                )
-            } else {
-                Cents::ZERO
-            };
-
-            Ok(Usage {
-                requests_this_minute,
-                tokens_this_minute,
-                input_tokens_this_minute,
-                output_tokens_this_minute,
-                tokens_this_day,
-                tokens_this_month: TokenUsage {
-                    input: monthly_usage
-                        .as_ref()
-                        .map_or(0, |usage| usage.input_tokens as usize),
-                    input_cache_creation: monthly_usage
-                        .as_ref()
-                        .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
-                    input_cache_read: monthly_usage
-                        .as_ref()
-                        .map_or(0, |usage| usage.cache_read_input_tokens as usize),
-                    output: monthly_usage
-                        .as_ref()
-                        .map_or(0, |usage| usage.output_tokens as usize),
-                },
-                spending_this_month,
-                lifetime_spending,
-            })
-        })
-        .await
-    }
-
-    pub async fn record_usage(
-        &self,
-        user_id: UserId,
-        is_staff: bool,
-        provider: LanguageModelProvider,
-        model_name: &str,
-        tokens: TokenUsage,
-        has_llm_subscription: bool,
-        max_monthly_spend: Cents,
-        free_tier_monthly_spending_limit: Cents,
-        now: DateTimeUtc,
-    ) -> Result<Usage> {
-        self.transaction(|tx| async move {
-            let model = self.model(provider, model_name)?;
-
-            let usages = usage::Entity::find()
-                .filter(
-                    usage::Column::UserId
-                        .eq(user_id)
-                        .and(usage::Column::ModelId.eq(model.id)),
-                )
-                .all(&*tx)
-                .await?;
-
-            let requests_this_minute = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::RequestsPerMinute,
-                    now,
-                    1,
-                    &tx,
-                )
-                .await?;
-            let tokens_this_minute = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::TokensPerMinute,
-                    now,
-                    tokens.total(),
-                    &tx,
-                )
-                .await?;
-            let input_tokens_this_minute = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::InputTokensPerMinute,
-                    now,
-                    // Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
-                    tokens.input + tokens.input_cache_creation,
-                    &tx,
-                )
-                .await?;
-            let output_tokens_this_minute = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::OutputTokensPerMinute,
-                    now,
-                    tokens.output,
-                    &tx,
-                )
-                .await?;
-            let tokens_this_day = self
-                .update_usage_for_measure(
-                    user_id,
-                    is_staff,
-                    model.id,
-                    &usages,
-                    UsageMeasure::TokensPerDay,
-                    now,
-                    tokens.total(),
-                    &tx,
-                )
-                .await?;
-
-            let month = now.date_naive().month() as i32;
-            let year = now.date_naive().year();
-
-            // Update monthly usage
-            let monthly_usage = monthly_usage::Entity::find()
-                .filter(
-                    monthly_usage::Column::UserId
-                        .eq(user_id)
-                        .and(monthly_usage::Column::ModelId.eq(model.id))
-                        .and(monthly_usage::Column::Month.eq(month))
-                        .and(monthly_usage::Column::Year.eq(year)),
-                )
-                .one(&*tx)
-                .await?;
-
-            let monthly_usage = match monthly_usage {
-                Some(usage) => {
-                    monthly_usage::Entity::update(monthly_usage::ActiveModel {
-                        id: ActiveValue::unchanged(usage.id),
-                        input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
-                        cache_creation_input_tokens: ActiveValue::set(
-                            usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
-                        ),
-                        cache_read_input_tokens: ActiveValue::set(
-                            usage.cache_read_input_tokens + tokens.input_cache_read as i64,
-                        ),
-                        output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
-                        ..Default::default()
-                    })
-                    .exec(&*tx)
-                    .await?
-                }
-                None => {
-                    monthly_usage::ActiveModel {
-                        user_id: ActiveValue::set(user_id),
-                        model_id: ActiveValue::set(model.id),
-                        month: ActiveValue::set(month),
-                        year: ActiveValue::set(year),
-                        input_tokens: ActiveValue::set(tokens.input as i64),
-                        cache_creation_input_tokens: ActiveValue::set(
-                            tokens.input_cache_creation as i64,
-                        ),
-                        cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
-                        output_tokens: ActiveValue::set(tokens.output as i64),
-                        ..Default::default()
-                    }
-                    .insert(&*tx)
-                    .await?
-                }
-            };
-
-            let spending_this_month = calculate_spending(
-                model,
-                monthly_usage.input_tokens as usize,
-                monthly_usage.cache_creation_input_tokens as usize,
-                monthly_usage.cache_read_input_tokens as usize,
-                monthly_usage.output_tokens as usize,
-            );
-
-            if !is_staff
-                && spending_this_month > free_tier_monthly_spending_limit
-                && has_llm_subscription
-                && (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend
-            {
-                billing_event::ActiveModel {
-                    id: ActiveValue::not_set(),
-                    idempotency_key: ActiveValue::not_set(),
-                    user_id: ActiveValue::set(user_id),
-                    model_id: ActiveValue::set(model.id),
-                    input_tokens: ActiveValue::set(tokens.input as i64),
-                    input_cache_creation_tokens: ActiveValue::set(
-                        tokens.input_cache_creation as i64,
-                    ),
-                    input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
-                    output_tokens: ActiveValue::set(tokens.output as i64),
-                }
-                .insert(&*tx)
-                .await?;
-            }
-
-            // Update lifetime usage
-            let lifetime_usage = lifetime_usage::Entity::find()
-                .filter(
-                    lifetime_usage::Column::UserId
-                        .eq(user_id)
-                        .and(lifetime_usage::Column::ModelId.eq(model.id)),
-                )
-                .one(&*tx)
-                .await?;
-
-            let lifetime_usage = match lifetime_usage {
-                Some(usage) => {
-                    lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
-                        id: ActiveValue::unchanged(usage.id),
-                        input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
-                        cache_creation_input_tokens: ActiveValue::set(
-                            usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
-                        ),
-                        cache_read_input_tokens: ActiveValue::set(
-                            usage.cache_read_input_tokens + tokens.input_cache_read as i64,
-                        ),
-                        output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
-                        ..Default::default()
-                    })
-                    .exec(&*tx)
-                    .await?
-                }
-                None => {
-                    lifetime_usage::ActiveModel {
-                        user_id: ActiveValue::set(user_id),
-                        model_id: ActiveValue::set(model.id),
-                        input_tokens: ActiveValue::set(tokens.input as i64),
-                        cache_creation_input_tokens: ActiveValue::set(
-                            tokens.input_cache_creation as i64,
-                        ),
-                        cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
-                        output_tokens: ActiveValue::set(tokens.output as i64),
-                        ..Default::default()
-                    }
-                    .insert(&*tx)
-                    .await?
-                }
-            };
-
-            let lifetime_spending = calculate_spending(
-                model,
-                lifetime_usage.input_tokens as usize,
-                lifetime_usage.cache_creation_input_tokens as usize,
-                lifetime_usage.cache_read_input_tokens as usize,
-                lifetime_usage.output_tokens as usize,
-            );
-
-            Ok(Usage {
-                requests_this_minute,
-                tokens_this_minute,
-                input_tokens_this_minute,
-                output_tokens_this_minute,
-                tokens_this_day,
-                tokens_this_month: TokenUsage {
-                    input: monthly_usage.input_tokens as usize,
-                    input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
-                    input_cache_read: monthly_usage.cache_read_input_tokens as usize,
-                    output: monthly_usage.output_tokens as usize,
-                },
-                spending_this_month,
-                lifetime_spending,
-            })
-        })
-        .await
-    }
-
-    /// Returns the active user count for the specified model.
-    pub async fn get_active_user_count(
-        &self,
-        provider: LanguageModelProvider,
-        model_name: &str,
-        now: DateTimeUtc,
-    ) -> Result<ActiveUserCount> {
-        self.transaction(|tx| async move {
-            let minute_since = now - Duration::minutes(5);
-            let day_since = now - Duration::days(5);
-
-            let model = self
-                .models
-                .get(&(provider, model_name.to_string()))
-                .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
-
-            let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
-
-            let users_in_recent_minutes = usage::Entity::find()
-                .filter(
-                    usage::Column::ModelId
-                        .eq(model.id)
-                        .and(usage::Column::MeasureId.eq(tokens_per_minute))
-                        .and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
-                        .and(usage::Column::IsStaff.eq(false)),
-                )
-                .select_only()
-                .column(usage::Column::UserId)
-                .group_by(usage::Column::UserId)
-                .count(&*tx)
-                .await? as usize;
-
-            let users_in_recent_days = usage::Entity::find()
-                .filter(
-                    usage::Column::ModelId
-                        .eq(model.id)
-                        .and(usage::Column::MeasureId.eq(tokens_per_minute))
-                        .and(usage::Column::Timestamp.gte(day_since.naive_utc()))
-                        .and(usage::Column::IsStaff.eq(false)),
-                )
-                .select_only()
-                .column(usage::Column::UserId)
-                .group_by(usage::Column::UserId)
-                .count(&*tx)
-                .await? as usize;
-
-            Ok(ActiveUserCount {
-                users_in_recent_minutes,
-                users_in_recent_days,
-            })
-        })
-        .await
-    }
-
-    async fn update_usage_for_measure(
-        &self,
-        user_id: UserId,
-        is_staff: bool,
-        model_id: ModelId,
-        usages: &[usage::Model],
-        usage_measure: UsageMeasure,
-        now: DateTimeUtc,
-        usage_to_add: usize,
-        tx: &DatabaseTransaction,
-    ) -> Result<usize> {
-        let now = now.naive_utc();
-        let measure_id = *self
-            .usage_measure_ids
-            .get(&usage_measure)
-            .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
-
-        let mut id = None;
-        let mut timestamp = now;
-        let mut buckets = vec![0_i64];
-
-        if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
-            id = Some(old_usage.id);
-            let (live_buckets, buckets_since) =
-                Self::get_live_buckets(old_usage, now, usage_measure);
-            if !live_buckets.is_empty() {
-                buckets.clear();
-                buckets.extend_from_slice(live_buckets);
-                buckets.extend(iter::repeat(0).take(buckets_since));
-                timestamp =
-                    old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
-            }
-        }
-
-        *buckets.last_mut().unwrap() += usage_to_add as i64;
-        let total_usage = buckets.iter().sum::<i64>() as usize;
-
-        let mut model = usage::ActiveModel {
-            user_id: ActiveValue::set(user_id),
-            is_staff: ActiveValue::set(is_staff),
-            model_id: ActiveValue::set(model_id),
-            measure_id: ActiveValue::set(measure_id),
-            timestamp: ActiveValue::set(timestamp),
-            buckets: ActiveValue::set(buckets),
-            ..Default::default()
-        };
-
-        if let Some(id) = id {
-            model.id = ActiveValue::unchanged(id);
-            model.update(tx).await?;
-        } else {
-            usage::Entity::insert(model)
-                .exec_without_returning(tx)
-                .await?;
-        }
-
-        Ok(total_usage)
-    }
-
-    fn get_usage_for_measure(
-        &self,
-        usages: &[usage::Model],
-        now: DateTimeUtc,
-        usage_measure: UsageMeasure,
-    ) -> Result<usize> {
-        let now = now.naive_utc();
-        let measure_id = *self
-            .usage_measure_ids
-            .get(&usage_measure)
-            .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
-        let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
-            return Ok(0);
-        };
-
-        let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
-        Ok(live_buckets.iter().sum::<i64>() as _)
-    }
-
-    fn get_live_buckets(
-        usage: &usage::Model,
-        now: chrono::NaiveDateTime,
-        measure: UsageMeasure,
-    ) -> (&[i64], usize) {
-        let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
-        let buckets_since_usage =
-            seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
-        let buckets_since_usage = buckets_since_usage.ceil() as usize;
-        let mut live_buckets = &[] as &[i64];
-        if buckets_since_usage < measure.bucket_count() {
-            let expired_bucket_count =
-                (usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
-            live_buckets = &usage.buckets[expired_bucket_count..];
-            while live_buckets.first() == Some(&0) {
-                live_buckets = &live_buckets[1..];
-            }
-        }
-        (live_buckets, buckets_since_usage)
-    }
 }
 
 fn calculate_spending(
@@ -741,32 +110,3 @@ fn calculate_spending(
         + output_token_cost;
     Cents::new(spending as u32)
 }
-
-const MINUTE_BUCKET_COUNT: usize = 12;
-const DAY_BUCKET_COUNT: usize = 48;
-
-impl UsageMeasure {
-    fn bucket_count(&self) -> usize {
-        match self {
-            UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
-            UsageMeasure::TokensPerMinute
-            | UsageMeasure::InputTokensPerMinute
-            | UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
-            UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
-        }
-    }
-
-    fn total_duration(&self) -> Duration {
-        match self {
-            UsageMeasure::RequestsPerMinute => Duration::minutes(1),
-            UsageMeasure::TokensPerMinute
-            | UsageMeasure::InputTokensPerMinute
-            | UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
-            UsageMeasure::TokensPerDay => Duration::hours(24),
-        }
-    }
-
-    fn bucket_duration(&self) -> Duration {
-        self.total_duration() / self.bucket_count() as i32
-    }
-}

crates/collab/src/llm/db/tables.rs 🔗

@@ -1,8 +1,6 @@
 pub mod billing_event;
-pub mod lifetime_usage;
 pub mod model;
 pub mod monthly_usage;
 pub mod provider;
-pub mod revoked_access_token;
 pub mod usage;
 pub mod usage_measure;

crates/collab/src/llm/db/tables/lifetime_usage.rs 🔗

@@ -1,20 +0,0 @@
-use crate::{db::UserId, llm::db::ModelId};
-use sea_orm::entity::prelude::*;
-
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "lifetime_usages")]
-pub struct Model {
-    #[sea_orm(primary_key)]
-    pub id: i32,
-    pub user_id: UserId,
-    pub model_id: ModelId,
-    pub input_tokens: i64,
-    pub cache_creation_input_tokens: i64,
-    pub cache_read_input_tokens: i64,
-    pub output_tokens: i64,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {}
-
-impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tables/revoked_access_token.rs 🔗

@@ -1,19 +0,0 @@
-use chrono::NaiveDateTime;
-use sea_orm::entity::prelude::*;
-
-use crate::llm::db::RevokedAccessTokenId;
-
-/// A revoked access token.
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "revoked_access_tokens")]
-pub struct Model {
-    #[sea_orm(primary_key)]
-    pub id: RevokedAccessTokenId,
-    pub jti: String,
-    pub revoked_at: NaiveDateTime,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {}
-
-impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tests/billing_tests.rs 🔗

@@ -1,152 +0,0 @@
-use crate::{
-    Cents,
-    db::UserId,
-    llm::{
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        db::{LlmDatabase, TokenUsage, queries::providers::ModelParams},
-    },
-    test_llm_db,
-};
-use chrono::{DateTime, Utc};
-use pretty_assertions::assert_eq;
-use rpc::LanguageModelProvider;
-
-test_llm_db!(
-    test_billing_limit_exceeded,
-    test_billing_limit_exceeded_postgres
-);
-
-async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
-    let provider = LanguageModelProvider::Anthropic;
-    let model = "fake-claude-limerick";
-    const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
-    const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
-
-    // Initialize the database and insert the model
-    db.initialize().await.unwrap();
-    db.insert_models(&[ModelParams {
-        provider,
-        name: model.to_string(),
-        max_requests_per_minute: 5,
-        max_tokens_per_minute: 10_000,
-        max_tokens_per_day: 50_000,
-        price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
-        price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
-    }])
-    .await
-    .unwrap();
-
-    // Set a fixed datetime for consistent testing
-    let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
-        .unwrap()
-        .with_timezone(&Utc);
-
-    let user_id = UserId::from_proto(123);
-
-    let max_monthly_spend = Cents::from_dollars(11);
-
-    // Record usage that brings us close to the limit but doesn't exceed it
-    // Let's say we use $10.50 worth of tokens
-    let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
-    let usage = TokenUsage {
-        input: tokens_to_use,
-        input_cache_creation: 0,
-        input_cache_read: 0,
-        output: 0,
-    };
-
-    // Verify that before we record any usage, there are 0 billing events
-    let billing_events = db.get_billing_events().await.unwrap();
-    assert_eq!(billing_events.len(), 0);
-
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        usage,
-        true,
-        max_monthly_spend,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    // Verify the recorded usage and spending
-    let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    // Verify that we exceeded the free tier usage
-    assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
-    assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
-
-    // Verify that there is one `billing_event` record
-    let billing_events = db.get_billing_events().await.unwrap();
-    assert_eq!(billing_events.len(), 1);
-
-    let (billing_event, _model) = &billing_events[0];
-    assert_eq!(billing_event.user_id, user_id);
-    assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
-    assert_eq!(billing_event.input_cache_creation_tokens, 0);
-    assert_eq!(billing_event.input_cache_read_tokens, 0);
-    assert_eq!(billing_event.output_tokens, 0);
-
-    // Record usage that puts us at $20.50
-    let usage_2 = TokenUsage {
-        input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
-        input_cache_creation: 0,
-        input_cache_read: 0,
-        output: 0,
-    };
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        usage_2,
-        true,
-        max_monthly_spend,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    // Verify the updated usage and spending
-    let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
-
-    // Verify that there are now two billing events
-    let billing_events = db.get_billing_events().await.unwrap();
-    assert_eq!(billing_events.len(), 2);
-
-    let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
-    let usage_exceeding = TokenUsage {
-        input: tokens_to_exceed,
-        input_cache_creation: 0,
-        input_cache_read: 0,
-        output: 0,
-    };
-
-    // This should still create a billing event as it's the first request that exceeds the limit
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        usage_exceeding,
-        true,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        max_monthly_spend,
-        now,
-    )
-    .await
-    .unwrap();
-    // Verify the updated usage and spending
-    let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
-
-    // Verify that we never exceed the user max spending for the user
-    // and avoid charging them.
-    let billing_events = db.get_billing_events().await.unwrap();
-    assert_eq!(billing_events.len(), 2);
-}

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -1,306 +0,0 @@
-use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
-use crate::{
-    Cents,
-    db::UserId,
-    llm::db::{
-        LlmDatabase, TokenUsage,
-        queries::{providers::ModelParams, usages::Usage},
-    },
-    test_llm_db,
-};
-use chrono::{DateTime, Duration, Utc};
-use pretty_assertions::assert_eq;
-use rpc::LanguageModelProvider;
-
-test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
-
-async fn test_tracking_usage(db: &mut LlmDatabase) {
-    let provider = LanguageModelProvider::Anthropic;
-    let model = "claude-3-5-sonnet";
-
-    db.initialize().await.unwrap();
-    db.insert_models(&[ModelParams {
-        provider,
-        name: model.to_string(),
-        max_requests_per_minute: 5,
-        max_tokens_per_minute: 10_000,
-        max_tokens_per_day: 50_000,
-        price_per_million_input_tokens: 50,
-        price_per_million_output_tokens: 50,
-    }])
-    .await
-    .unwrap();
-
-    // We're using a fixed datetime to prevent flakiness based on the clock.
-    let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
-        .unwrap()
-        .with_timezone(&Utc);
-    let user_id = UserId::from_proto(123);
-
-    let now = t0;
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        TokenUsage {
-            input: 1000,
-            input_cache_creation: 0,
-            input_cache_read: 0,
-            output: 0,
-        },
-        false,
-        Cents::ZERO,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    let now = t0 + Duration::seconds(10);
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        TokenUsage {
-            input: 2000,
-            input_cache_creation: 0,
-            input_cache_read: 0,
-            output: 0,
-        },
-        false,
-        Cents::ZERO,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 2,
-            tokens_this_minute: 3000,
-            input_tokens_this_minute: 3000,
-            output_tokens_this_minute: 0,
-            tokens_this_day: 3000,
-            tokens_this_month: TokenUsage {
-                input: 3000,
-                input_cache_creation: 0,
-                input_cache_read: 0,
-                output: 0,
-            },
-            spending_this_month: Cents::ZERO,
-            lifetime_spending: Cents::ZERO,
-        }
-    );
-
-    let now = t0 + Duration::seconds(60);
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 1,
-            tokens_this_minute: 2000,
-            input_tokens_this_minute: 2000,
-            output_tokens_this_minute: 0,
-            tokens_this_day: 3000,
-            tokens_this_month: TokenUsage {
-                input: 3000,
-                input_cache_creation: 0,
-                input_cache_read: 0,
-                output: 0,
-            },
-            spending_this_month: Cents::ZERO,
-            lifetime_spending: Cents::ZERO,
-        }
-    );
-
-    let now = t0 + Duration::seconds(60);
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        TokenUsage {
-            input: 3000,
-            input_cache_creation: 0,
-            input_cache_read: 0,
-            output: 0,
-        },
-        false,
-        Cents::ZERO,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 2,
-            tokens_this_minute: 5000,
-            input_tokens_this_minute: 5000,
-            output_tokens_this_minute: 0,
-            tokens_this_day: 6000,
-            tokens_this_month: TokenUsage {
-                input: 6000,
-                input_cache_creation: 0,
-                input_cache_read: 0,
-                output: 0,
-            },
-            spending_this_month: Cents::ZERO,
-            lifetime_spending: Cents::ZERO,
-        }
-    );
-
-    let t1 = t0 + Duration::hours(24);
-    let now = t1;
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 0,
-            tokens_this_minute: 0,
-            input_tokens_this_minute: 0,
-            output_tokens_this_minute: 0,
-            tokens_this_day: 5000,
-            tokens_this_month: TokenUsage {
-                input: 6000,
-                input_cache_creation: 0,
-                input_cache_read: 0,
-                output: 0,
-            },
-            spending_this_month: Cents::ZERO,
-            lifetime_spending: Cents::ZERO,
-        }
-    );
-
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        TokenUsage {
-            input: 4000,
-            input_cache_creation: 0,
-            input_cache_read: 0,
-            output: 0,
-        },
-        false,
-        Cents::ZERO,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 1,
-            tokens_this_minute: 4000,
-            input_tokens_this_minute: 4000,
-            output_tokens_this_minute: 0,
-            tokens_this_day: 9000,
-            tokens_this_month: TokenUsage {
-                input: 10000,
-                input_cache_creation: 0,
-                input_cache_read: 0,
-                output: 0,
-            },
-            spending_this_month: Cents::ZERO,
-            lifetime_spending: Cents::ZERO,
-        }
-    );
-
-    // We're using a fixed datetime to prevent flakiness based on the clock.
-    let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
-        .unwrap()
-        .with_timezone(&Utc);
-
-    // Test cache creation input tokens
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        TokenUsage {
-            input: 1000,
-            input_cache_creation: 500,
-            input_cache_read: 0,
-            output: 0,
-        },
-        false,
-        Cents::ZERO,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 1,
-            tokens_this_minute: 1500,
-            input_tokens_this_minute: 1500,
-            output_tokens_this_minute: 0,
-            tokens_this_day: 1500,
-            tokens_this_month: TokenUsage {
-                input: 1000,
-                input_cache_creation: 500,
-                input_cache_read: 0,
-                output: 0,
-            },
-            spending_this_month: Cents::ZERO,
-            lifetime_spending: Cents::ZERO,
-        }
-    );
-
-    // Test cache read input tokens
-    db.record_usage(
-        user_id,
-        false,
-        provider,
-        model,
-        TokenUsage {
-            input: 1000,
-            input_cache_creation: 0,
-            input_cache_read: 300,
-            output: 0,
-        },
-        false,
-        Cents::ZERO,
-        FREE_TIER_MONTHLY_SPENDING_LIMIT,
-        now,
-    )
-    .await
-    .unwrap();
-
-    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
-    assert_eq!(
-        usage,
-        Usage {
-            requests_this_minute: 2,
-            tokens_this_minute: 2800,
-            input_tokens_this_minute: 2500,
-            output_tokens_this_minute: 0,
-            tokens_this_day: 2800,
-            tokens_this_month: TokenUsage {
-                input: 2000,
-                input_cache_creation: 500,
-                input_cache_read: 300,
-                output: 0,
-            },
-            spending_this_month: Cents::ZERO,
-            lifetime_spending: Cents::ZERO,
-        }
-    );
-}

crates/collab/src/main.rs 🔗

@@ -9,14 +9,14 @@ use axum::{
 
 use collab::api::CloudflareIpCountryHeader;
 use collab::api::billing::sync_llm_usage_with_stripe_periodically;
-use collab::llm::{db::LlmDatabase, log_usage_periodically};
+use collab::llm::db::LlmDatabase;
 use collab::migrations::run_database_migrations;
 use collab::user_backfiller::spawn_user_backfiller;
 use collab::{
     AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
     env, executor::Executor, rpc::ResultExt,
 };
-use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState};
+use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
 use db::Database;
 use std::{
     env::args,
@@ -74,11 +74,10 @@ async fn main() -> Result<()> {
             let mode = match args.next().as_deref() {
                 Some("collab") => ServiceMode::Collab,
                 Some("api") => ServiceMode::Api,
-                Some("llm") => ServiceMode::Llm,
                 Some("all") => ServiceMode::All,
                 _ => {
                     return Err(anyhow!(
-                        "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
+                        "usage: collab <version | migrate | seed | serve <api|collab|all>>"
                     ))?;
                 }
             };
@@ -97,20 +96,9 @@ async fn main() -> Result<()> {
 
             let mut on_shutdown = None;
 
-            if mode.is_llm() {
-                setup_llm_database(&config).await?;
-
-                let state = LlmState::new(config.clone(), Executor::Production).await?;
-
-                log_usage_periodically(state.clone());
-
-                app = app
-                    .merge(collab::llm::routes())
-                    .layer(Extension(state.clone()));
-            }
-
             if mode.is_collab() || mode.is_api() {
                 setup_app_database(&config).await?;
+                setup_llm_database(&config).await?;
 
                 let state = AppState::new(config, Executor::Production).await?;
 
@@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
     format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
 }
 
-async fn handle_liveness_probe(
-    app_state: Option<Extension<Arc<AppState>>>,
-    llm_state: Option<Extension<Arc<LlmState>>>,
-) -> Result<String> {
+async fn handle_liveness_probe(app_state: Option<Extension<Arc<AppState>>>) -> Result<String> {
     if let Some(state) = app_state {
         state.db.get_all_users(0, 1).await?;
     }
 
-    if let Some(llm_state) = llm_state {
-        llm_state.db.list_providers().await?;
-    }
-
     Ok("ok".to_string())
 }