Add logic for closed beta LLM models (#16482)

Max Brunsfeld and Marshall created

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>

Change summary

crates/collab/k8s/collab.template.yml       |  5 +
crates/collab/src/lib.rs                    |  2 
crates/collab/src/llm/authorization.rs      | 27 ++++-
crates/collab/src/llm/db/queries/usages.rs  |  7 
crates/collab/src/llm/token.rs              |  4 
crates/collab/src/rpc.rs                    |  6 +
crates/collab/src/tests/test_server.rs      |  1 
crates/feature_flags/src/feature_flags.rs   |  5 +
crates/language_model/src/provider/cloud.rs | 94 ++++++++++++++--------
9 files changed, 104 insertions(+), 47 deletions(-)

Detailed changes

crates/collab/k8s/collab.template.yml 🔗

@@ -139,6 +139,11 @@ spec:
                 secretKeyRef:
                   name: anthropic
                   key: staff_api_key
+            - name: LLM_CLOSED_BETA_MODEL_NAME
+              valueFrom:
+                secretKeyRef:
+                  name: llm-closed-beta
+                  key: model_name
             - name: GOOGLE_AI_API_KEY
               valueFrom:
                 secretKeyRef:

crates/collab/src/lib.rs 🔗

@@ -168,6 +168,7 @@ pub struct Config {
     pub google_ai_api_key: Option<Arc<str>>,
     pub anthropic_api_key: Option<Arc<str>>,
     pub anthropic_staff_api_key: Option<Arc<str>>,
+    pub llm_closed_beta_model_name: Option<Arc<str>>,
     pub qwen2_7b_api_key: Option<Arc<str>>,
     pub qwen2_7b_api_url: Option<Arc<str>>,
     pub zed_client_checksum_seed: Option<String>,
@@ -219,6 +220,7 @@ impl Config {
             google_ai_api_key: None,
             anthropic_api_key: None,
             anthropic_staff_api_key: None,
+            llm_closed_beta_model_name: None,
             clickhouse_url: None,
             clickhouse_user: None,
             clickhouse_password: None,

crates/collab/src/llm/authorization.rs 🔗

@@ -12,11 +12,12 @@ pub fn authorize_access_to_language_model(
     model: &str,
 ) -> Result<()> {
     authorize_access_for_country(config, country_code, provider)?;
-    authorize_access_to_model(claims, provider, model)?;
+    authorize_access_to_model(config, claims, provider, model)?;
     Ok(())
 }
 
 fn authorize_access_to_model(
+    config: &Config,
     claims: &LlmTokenClaims,
     provider: LanguageModelProvider,
     model: &str,
@@ -25,13 +26,25 @@ fn authorize_access_to_model(
         return Ok(());
     }
 
-    match (provider, model) {
-        (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
-        _ => Err(Error::http(
-            StatusCode::FORBIDDEN,
-            format!("access to model {model:?} is not included in your plan"),
-        ))?,
+    match provider {
+        LanguageModelProvider::Anthropic => {
+            if model == "claude-3-5-sonnet" {
+                return Ok(());
+            }
+
+            if claims.has_llm_closed_beta_feature_flag
+                && Some(model) == config.llm_closed_beta_model_name.as_deref()
+            {
+                return Ok(());
+            }
+        }
+        _ => {}
     }
+
+    Err(Error::http(
+        StatusCode::FORBIDDEN,
+        format!("access to model {model:?} is not included in your plan"),
+    ))
 }
 
 fn authorize_access_for_country(

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -82,12 +82,13 @@ impl LlmDatabase {
             let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
 
             let mut results = Vec::new();
-            for (provider, model) in self.models.keys().cloned() {
+            for ((provider, model_name), model) in self.models.iter() {
                 let mut usages = usage::Entity::find()
                     .filter(
                         usage::Column::Timestamp
                             .gte(past_minute.naive_utc())
                             .and(usage::Column::IsStaff.eq(false))
+                            .and(usage::Column::ModelId.eq(model.id))
                             .and(
                                 usage::Column::MeasureId
                                     .eq(requests_per_minute)
@@ -125,8 +126,8 @@ impl LlmDatabase {
                 }
 
                 results.push(ApplicationWideUsage {
-                    provider,
-                    model,
+                    provider: *provider,
+                    model: model_name.clone(),
                     requests_this_minute,
                     tokens_this_minute,
                 })

crates/collab/src/llm/token.rs 🔗

@@ -20,6 +20,8 @@ pub struct LlmTokenClaims {
     #[serde(default)]
     pub github_user_login: Option<String>,
     pub is_staff: bool,
+    #[serde(default)]
+    pub has_llm_closed_beta_feature_flag: bool,
     pub plan: rpc::proto::Plan,
 }
 
@@ -30,6 +32,7 @@ impl LlmTokenClaims {
         user_id: UserId,
         github_user_login: String,
         is_staff: bool,
+        has_llm_closed_beta_feature_flag: bool,
         plan: rpc::proto::Plan,
         config: &Config,
     ) -> Result<String> {
@@ -46,6 +49,7 @@ impl LlmTokenClaims {
             user_id: user_id.to_proto(),
             github_user_login: Some(github_user_login),
             is_staff,
+            has_llm_closed_beta_feature_flag,
             plan,
         };
 

crates/collab/src/rpc.rs 🔗

@@ -4918,7 +4918,10 @@ async fn get_llm_api_token(
     let db = session.db().await;
 
     let flags = db.get_user_flags(session.user_id()).await?;
-    if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
+    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
+    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
+
+    if !session.is_staff() && !has_language_models_feature_flag {
         Err(anyhow!("permission denied"))?
     }
 
@@ -4943,6 +4946,7 @@ async fn get_llm_api_token(
         user.id,
         user.github_login.clone(),
         session.is_staff(),
+        has_llm_closed_beta_feature_flag,
         session.current_plan(db).await?,
         &session.app_state.config,
     )?;

crates/collab/src/tests/test_server.rs 🔗

@@ -667,6 +667,7 @@ impl TestServer {
                 google_ai_api_key: None,
                 anthropic_api_key: None,
                 anthropic_staff_api_key: None,
+                llm_closed_beta_model_name: None,
                 clickhouse_url: None,
                 clickhouse_user: None,
                 clickhouse_password: None,

crates/feature_flags/src/feature_flags.rs 🔗

@@ -43,6 +43,11 @@ impl FeatureFlag for LanguageModels {
     const NAME: &'static str = "language-models";
 }
 
+pub struct LlmClosedBeta {}
+impl FeatureFlag for LlmClosedBeta {
+    const NAME: &'static str = "llm-closed-beta";
+}
+
 pub struct ZedPro {}
 impl FeatureFlag for ZedPro {
     const NAME: &'static str = "zed-pro";

crates/language_model/src/provider/cloud.rs 🔗

@@ -8,7 +8,7 @@ use anthropic::AnthropicError;
 use anyhow::{anyhow, Result};
 use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use collections::BTreeMap;
-use feature_flags::{FeatureFlagAppExt, ZedPro};
+use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
 use futures::{
     future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
     TryStreamExt as _,
@@ -26,7 +26,10 @@ use smol::{
     io::{AsyncReadExt, BufReader},
     lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 };
-use std::{future, sync::Arc};
+use std::{
+    future,
+    sync::{Arc, LazyLock},
+};
 use strum::IntoEnumIterator;
 use ui::prelude::*;
 
@@ -37,6 +40,18 @@ use super::anthropic::count_anthropic_tokens;
 pub const PROVIDER_ID: &str = "zed.dev";
 pub const PROVIDER_NAME: &str = "Zed";
 
+const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
+    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
+
+fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
+    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
+        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
+            .map(|json| serde_json::from_str(json).unwrap())
+            .unwrap_or(Vec::new())
+    });
+    ADDITIONAL_MODELS.as_slice()
+}
+
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct ZedDotDevSettings {
     pub available_models: Vec<AvailableModel>,
@@ -200,40 +215,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
             for model in ZedModel::iter() {
                 models.insert(model.id().to_string(), CloudModel::Zed(model));
             }
-
-            // Override with available models from settings
-            for model in &AllLanguageModelSettings::get_global(cx)
-                .zed_dot_dev
-                .available_models
-            {
-                let model = match model.provider {
-                    AvailableProvider::Anthropic => {
-                        CloudModel::Anthropic(anthropic::Model::Custom {
-                            name: model.name.clone(),
-                            display_name: model.display_name.clone(),
-                            max_tokens: model.max_tokens,
-                            tool_override: model.tool_override.clone(),
-                            cache_configuration: model.cache_configuration.as_ref().map(|config| {
-                                anthropic::AnthropicModelCacheConfiguration {
-                                    max_cache_anchors: config.max_cache_anchors,
-                                    should_speculate: config.should_speculate,
-                                    min_total_token: config.min_total_token,
-                                }
-                            }),
-                            max_output_tokens: model.max_output_tokens,
-                        })
-                    }
-                    AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
-                        name: model.name.clone(),
-                        max_tokens: model.max_tokens,
-                    }),
-                    AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
-                        name: model.name.clone(),
-                        max_tokens: model.max_tokens,
-                    }),
-                };
-                models.insert(model.id().to_string(), model.clone());
-            }
         } else {
             models.insert(
                 anthropic::Model::Claude3_5Sonnet.id().to_string(),
@@ -241,6 +222,47 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
             );
         }
 
+        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
+            zed_cloud_provider_additional_models()
+        } else {
+            &[]
+        };
+
+        // Override with available models from settings
+        for model in AllLanguageModelSettings::get_global(cx)
+            .zed_dot_dev
+            .available_models
+            .iter()
+            .chain(llm_closed_beta_models)
+            .cloned()
+        {
+            let model = match model.provider {
+                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
+                    name: model.name.clone(),
+                    display_name: model.display_name.clone(),
+                    max_tokens: model.max_tokens,
+                    tool_override: model.tool_override.clone(),
+                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
+                        anthropic::AnthropicModelCacheConfiguration {
+                            max_cache_anchors: config.max_cache_anchors,
+                            should_speculate: config.should_speculate,
+                            min_total_token: config.min_total_token,
+                        }
+                    }),
+                    max_output_tokens: model.max_output_tokens,
+                }),
+                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
+                    name: model.name.clone(),
+                    max_tokens: model.max_tokens,
+                }),
+                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
+                    name: model.name.clone(),
+                    max_tokens: model.max_tokens,
+                }),
+            };
+            models.insert(model.id().to_string(), model.clone());
+        }
+
         models
             .into_values()
             .map(|model| {