collab: Use a separate Anthropic API key for Zed staff (#16128)

Marshall Bowers and Max created

This PR makes it so Zed staff can use a separate Anthropic API key for
the LLM service.

We also added an `is_staff` column to the `usages` table so that we can
exclude staff usage from the "active users" metrics that influence the
rate limits.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/collab/k8s/collab.template.yml                                  |  5 
crates/collab/migrations_llm/20240812184444_add_is_staff_to_usages.sql |  1 
crates/collab/src/lib.rs                                               |  2 
crates/collab/src/llm.rs                                               | 19 
crates/collab/src/llm/db/queries/usages.rs                             | 21 
crates/collab/src/llm/db/tables/usage.rs                               |  1 
crates/collab/src/llm/db/tests/usage_tests.rs                          |  8 
crates/collab/src/tests/test_server.rs                                 |  1 
8 files changed, 47 insertions(+), 11 deletions(-)

Detailed changes

crates/collab/k8s/collab.template.yml 🔗

@@ -134,6 +134,11 @@ spec:
                 secretKeyRef:
                   name: anthropic
                   key: api_key
+            - name: ANTHROPIC_STAFF_API_KEY
+              valueFrom:
+                secretKeyRef:
+                  name: anthropic
+                  key: staff_api_key
             - name: GOOGLE_AI_API_KEY
               valueFrom:
                 secretKeyRef:

crates/collab/src/lib.rs 🔗

@@ -166,6 +166,7 @@ pub struct Config {
     pub openai_api_key: Option<Arc<str>>,
     pub google_ai_api_key: Option<Arc<str>>,
     pub anthropic_api_key: Option<Arc<str>>,
+    pub anthropic_staff_api_key: Option<Arc<str>>,
     pub qwen2_7b_api_key: Option<Arc<str>>,
     pub qwen2_7b_api_url: Option<Arc<str>>,
     pub zed_client_checksum_seed: Option<String>,
@@ -216,6 +217,7 @@ impl Config {
             openai_api_key: None,
             google_ai_api_key: None,
             anthropic_api_key: None,
+            anthropic_staff_api_key: None,
             clickhouse_url: None,
             clickhouse_user: None,
             clickhouse_password: None,

crates/collab/src/llm.rs 🔗

@@ -171,11 +171,19 @@ async fn perform_completion(
 
     let stream = match params.provider {
         LanguageModelProvider::Anthropic => {
-            let api_key = state
-                .config
-                .anthropic_api_key
-                .as_ref()
-                .context("no Anthropic AI API key configured on the server")?;
+            let api_key = if claims.is_staff {
+                state
+                    .config
+                    .anthropic_staff_api_key
+                    .as_ref()
+                    .context("no Anthropic AI staff API key configured on the server")?
+            } else {
+                state
+                    .config
+                    .anthropic_api_key
+                    .as_ref()
+                    .context("no Anthropic AI API key configured on the server")?
+            };
 
             let mut request: anthropic::Request =
                 serde_json::from_str(&params.provider_request.get())?;
@@ -473,6 +481,7 @@ impl<S> Drop for TokenCountingStream<S> {
                 .db
                 .record_usage(
                     claims.user_id as i32,
+                    claims.is_staff,
                     provider,
                     &model,
                     input_token_count,

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -108,9 +108,11 @@ impl LlmDatabase {
         .await
     }
 
+    #[allow(clippy::too_many_arguments)]
     pub async fn record_usage(
         &self,
         user_id: i32,
+        is_staff: bool,
         provider: LanguageModelProvider,
         model_name: &str,
         input_token_count: usize,
@@ -132,6 +134,7 @@ impl LlmDatabase {
             let requests_this_minute = self
                 .update_usage_for_measure(
                     user_id,
+                    is_staff,
                     model.id,
                     &usages,
                     UsageMeasure::RequestsPerMinute,
@@ -143,6 +146,7 @@ impl LlmDatabase {
             let tokens_this_minute = self
                 .update_usage_for_measure(
                     user_id,
+                    is_staff,
                     model.id,
                     &usages,
                     UsageMeasure::TokensPerMinute,
@@ -154,6 +158,7 @@ impl LlmDatabase {
             let tokens_this_day = self
                 .update_usage_for_measure(
                     user_id,
+                    is_staff,
                     model.id,
                     &usages,
                     UsageMeasure::TokensPerDay,
@@ -165,6 +170,7 @@ impl LlmDatabase {
             let input_tokens_this_month = self
                 .update_usage_for_measure(
                     user_id,
+                    is_staff,
                     model.id,
                     &usages,
                     UsageMeasure::InputTokensPerMonth,
@@ -176,6 +182,7 @@ impl LlmDatabase {
             let output_tokens_this_month = self
                 .update_usage_for_measure(
                     user_id,
+                    is_staff,
                     model.id,
                     &usages,
                     UsageMeasure::OutputTokensPerMonth,
@@ -205,7 +212,11 @@ impl LlmDatabase {
             let day_since = now - Duration::days(5);
 
             let users_in_recent_minutes = usage::Entity::find()
-                .filter(usage::Column::Timestamp.gte(minute_since.naive_utc()))
+                .filter(
+                    usage::Column::Timestamp
+                        .gte(minute_since.naive_utc())
+                        .and(usage::Column::IsStaff.eq(false)),
+                )
                 .select_only()
                 .column(usage::Column::UserId)
                 .group_by(usage::Column::UserId)
@@ -213,7 +224,11 @@ impl LlmDatabase {
                 .await? as usize;
 
             let users_in_recent_days = usage::Entity::find()
-                .filter(usage::Column::Timestamp.gte(day_since.naive_utc()))
+                .filter(
+                    usage::Column::Timestamp
+                        .gte(day_since.naive_utc())
+                        .and(usage::Column::IsStaff.eq(false)),
+                )
                 .select_only()
                 .column(usage::Column::UserId)
                 .group_by(usage::Column::UserId)
@@ -232,6 +247,7 @@ impl LlmDatabase {
     async fn update_usage_for_measure(
         &self,
         user_id: i32,
+        is_staff: bool,
         model_id: ModelId,
         usages: &[usage::Model],
         usage_measure: UsageMeasure,
@@ -267,6 +283,7 @@ impl LlmDatabase {
 
         let mut model = usage::ActiveModel {
             user_id: ActiveValue::set(user_id),
+            is_staff: ActiveValue::set(is_staff),
             model_id: ActiveValue::set(model_id),
             measure_id: ActiveValue::set(measure_id),
             timestamp: ActiveValue::set(timestamp),

crates/collab/src/llm/db/tables/usage.rs 🔗

@@ -15,6 +15,7 @@ pub struct Model {
     pub measure_id: UsageMeasureId,
     pub timestamp: DateTime,
     pub buckets: Vec<i64>,
+    pub is_staff: bool,
 }
 
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -29,12 +29,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
     let user_id = 123;
 
     let now = t0;
-    db.record_usage(user_id, provider, model, 1000, 0, now)
+    db.record_usage(user_id, false, provider, model, 1000, 0, now)
         .await
         .unwrap();
 
     let now = t0 + Duration::seconds(10);
-    db.record_usage(user_id, provider, model, 2000, 0, now)
+    db.record_usage(user_id, false, provider, model, 2000, 0, now)
         .await
         .unwrap();
 
@@ -66,7 +66,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
     );
 
     let now = t0 + Duration::seconds(60);
-    db.record_usage(user_id, provider, model, 3000, 0, now)
+    db.record_usage(user_id, false, provider, model, 3000, 0, now)
         .await
         .unwrap();
 
@@ -98,7 +98,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
         }
     );
 
-    db.record_usage(user_id, provider, model, 4000, 0, now)
+    db.record_usage(user_id, false, provider, model, 4000, 0, now)
         .await
         .unwrap();
 

crates/collab/src/tests/test_server.rs 🔗

@@ -666,6 +666,7 @@ impl TestServer {
                 openai_api_key: None,
                 google_ai_api_key: None,
                 anthropic_api_key: None,
+                anthropic_staff_api_key: None,
                 clickhouse_url: None,
                 clickhouse_user: None,
                 clickhouse_password: None,