assistant: Limit model access for Zed AI users to Claude-3.5-sonnet (#15904)

Bennet Bo Fenner and Thorsten created

This prevents users from accessing other models, such as OpenAI's GPT-4
or Google's Gemini-Pro.
Staff members can still access all models.

Co-authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>

Change summary

crates/collab/src/llm/authorization.rs | 125 +++++++++++++++++++++++++++
crates/collab/src/llm/token.rs         |   9 +
crates/collab/src/rpc.rs               |   1 
3 files changed, 130 insertions(+), 5 deletions(-)

Detailed changes

crates/collab/src/llm/authorization.rs 🔗

@@ -6,21 +6,40 @@ use crate::{Config, Error, Result};
 
 pub fn authorize_access_to_language_model(
     config: &Config,
-    _claims: &LlmTokenClaims,
+    claims: &LlmTokenClaims,
     country_code: Option<String>,
     provider: LanguageModelProvider,
     model: &str,
 ) -> Result<()> {
-    authorize_access_for_country(config, country_code, provider, model)?;
-
+    authorize_access_for_country(config, country_code, provider)?;
+    authorize_access_to_model(claims, provider, model)?;
     Ok(())
 }
 
+fn authorize_access_to_model(
+    claims: &LlmTokenClaims,
+    provider: LanguageModelProvider,
+    model: &str,
+) -> Result<()> {
+    if claims.is_staff {
+        return Ok(());
+    }
+
+    match (provider, model) {
+        (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
+            Ok(())
+        }
+        _ => Err(Error::http(
+            StatusCode::FORBIDDEN,
+            format!("access to model {model:?} is not included in your plan"),
+        ))?,
+    }
+}
+
 fn authorize_access_for_country(
     config: &Config,
     country_code: Option<String>,
     provider: LanguageModelProvider,
-    _model: &str,
 ) -> Result<()> {
     // In development we won't have the `CF-IPCountry` header, so we can't check
     // the country code.
@@ -79,6 +98,7 @@ mod tests {
         let claims = LlmTokenClaims {
             user_id: 99,
             plan: Plan::ZedPro,
+            is_staff: true,
             ..Default::default()
         };
 
@@ -210,4 +230,101 @@ mod tests {
             );
         }
     }
+
+    #[gpui::test]
+    async fn test_authorize_access_to_language_model_based_on_plan() {
+        let config = Config::test();
+
+        let test_cases = vec![
+            // Pro plan should have access to claude-3.5-sonnet
+            (
+                Plan::ZedPro,
+                LanguageModelProvider::Anthropic,
+                "claude-3.5-sonnet",
+                true,
+            ),
+            // Free plan should have access to claude-3.5-sonnet
+            (
+                Plan::Free,
+                LanguageModelProvider::Anthropic,
+                "claude-3.5-sonnet",
+                true,
+            ),
+            // Pro plan should NOT have access to other Anthropic models
+            (
+                Plan::ZedPro,
+                LanguageModelProvider::Anthropic,
+                "claude-3-opus",
+                false,
+            ),
+        ];
+
+        for (plan, provider, model, expected_access) in test_cases {
+            let claims = LlmTokenClaims {
+                plan,
+                ..Default::default()
+            };
+
+            let result = authorize_access_to_language_model(
+                &config,
+                &claims,
+                Some("US".into()),
+                provider,
+                model,
+            );
+
+            if expected_access {
+                assert!(
+                    result.is_ok(),
+                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
+                    plan,
+                    provider,
+                    model
+                );
+            } else {
+                let error = result.expect_err(&format!(
+                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
+                    plan, provider, model
+                ));
+                let response = error.into_response();
+                assert_eq!(response.status(), StatusCode::FORBIDDEN);
+            }
+        }
+    }
+
+    #[gpui::test]
+    async fn test_authorize_access_to_language_model_for_staff() {
+        let config = Config::test();
+
+        let claims = LlmTokenClaims {
+            is_staff: true,
+            ..Default::default()
+        };
+
+        // Staff should have access to all models
+        let test_cases = vec![
+            (LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
+            (LanguageModelProvider::Anthropic, "claude-2"),
+            (LanguageModelProvider::Anthropic, "claude-123-agi"),
+            (LanguageModelProvider::OpenAi, "gpt-4"),
+            (LanguageModelProvider::Google, "gemini-pro"),
+        ];
+
+        for (provider, model) in test_cases {
+            let result = authorize_access_to_language_model(
+                &config,
+                &claims,
+                Some("US".into()),
+                provider,
+                model,
+            );
+
+            assert!(
+                result.is_ok(),
+                "Expected staff to have access to provider {:?}, model {}",
+                provider,
+                model
+            );
+        }
+    }
 }

crates/collab/src/llm/token.rs 🔗

@@ -13,13 +13,19 @@ pub struct LlmTokenClaims {
     pub exp: u64,
     pub jti: String,
     pub user_id: u64,
+    pub is_staff: bool,
     pub plan: rpc::proto::Plan,
 }
 
 const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
 
 impl LlmTokenClaims {
-    pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
+    pub fn create(
+        user_id: UserId,
+        is_staff: bool,
+        plan: rpc::proto::Plan,
+        config: &Config,
+    ) -> Result<String> {
         let secret = config
             .llm_api_secret
             .as_ref()
@@ -31,6 +37,7 @@ impl LlmTokenClaims {
             exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
             jti: uuid::Uuid::new_v4().to_string(),
             user_id: user_id.to_proto(),
+            is_staff,
             plan,
         };
 

crates/collab/src/rpc.rs 🔗

@@ -5164,6 +5164,7 @@ async fn get_llm_api_token(
 
     let token = LlmTokenClaims::create(
         session.user_id(),
+        session.is_staff(),
         session.current_plan().await?,
         &session.app_state.config,
     )?;