@@ -6,21 +6,40 @@ use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
- _claims: &LlmTokenClaims,
+ claims: &LlmTokenClaims,
country_code: Option<String>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
- authorize_access_for_country(config, country_code, provider, model)?;
-
+ authorize_access_for_country(config, country_code, provider)?;
+ authorize_access_to_model(claims, provider, model)?;
Ok(())
}
+fn authorize_access_to_model(
+ claims: &LlmTokenClaims,
+ provider: LanguageModelProvider,
+ model: &str,
+) -> Result<()> {
+ if claims.is_staff {
+ return Ok(());
+ }
+
+ match (provider, model) {
+ (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
+ Ok(())
+ }
+ _ => Err(Error::http(
+ StatusCode::FORBIDDEN,
+ format!("access to model {model:?} is not included in your plan"),
+ ))?,
+ }
+}
+
fn authorize_access_for_country(
config: &Config,
country_code: Option<String>,
provider: LanguageModelProvider,
- _model: &str,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
// the country code.
@@ -79,6 +98,7 @@ mod tests {
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
+ is_staff: true,
..Default::default()
};
@@ -210,4 +230,101 @@ mod tests {
);
}
}
+
+ #[gpui::test]
+ async fn test_authorize_access_to_language_model_based_on_plan() {
+ let config = Config::test();
+
+ let test_cases = vec![
+ // Pro plan should have access to claude-3.5-sonnet
+ (
+ Plan::ZedPro,
+ LanguageModelProvider::Anthropic,
+ "claude-3.5-sonnet",
+ true,
+ ),
+ // Free plan should have access to claude-3.5-sonnet
+ (
+ Plan::Free,
+ LanguageModelProvider::Anthropic,
+ "claude-3.5-sonnet",
+ true,
+ ),
+ // Pro plan should NOT have access to other Anthropic models
+ (
+ Plan::ZedPro,
+ LanguageModelProvider::Anthropic,
+ "claude-3-opus",
+ false,
+ ),
+ ];
+
+ for (plan, provider, model, expected_access) in test_cases {
+ let claims = LlmTokenClaims {
+ plan,
+ ..Default::default()
+ };
+
+ let result = authorize_access_to_language_model(
+ &config,
+ &claims,
+ Some("US".into()),
+ provider,
+ model,
+ );
+
+ if expected_access {
+ assert!(
+ result.is_ok(),
+ "Expected access to be granted for plan {:?}, provider {:?}, model {}",
+ plan,
+ provider,
+ model
+ );
+ } else {
+ let error = result.expect_err(&format!(
+ "Expected access to be denied for plan {:?}, provider {:?}, model {}",
+ plan, provider, model
+ ));
+ let response = error.into_response();
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
+ }
+ }
+ }
+
+ #[gpui::test]
+ async fn test_authorize_access_to_language_model_for_staff() {
+ let config = Config::test();
+
+ let claims = LlmTokenClaims {
+ is_staff: true,
+ ..Default::default()
+ };
+
+ // Staff should have access to all models
+ let test_cases = vec![
+ (LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
+ (LanguageModelProvider::Anthropic, "claude-2"),
+ (LanguageModelProvider::Anthropic, "claude-123-agi"),
+ (LanguageModelProvider::OpenAi, "gpt-4"),
+ (LanguageModelProvider::Google, "gemini-pro"),
+ ];
+
+ for (provider, model) in test_cases {
+ let result = authorize_access_to_language_model(
+ &config,
+ &claims,
+ Some("US".into()),
+ provider,
+ model,
+ );
+
+ assert!(
+ result.is_ok(),
+ "Expected staff to have access to provider {:?}, model {}",
+ provider,
+ model
+ );
+ }
+ }
}
@@ -13,13 +13,19 @@ pub struct LlmTokenClaims {
pub exp: u64,
pub jti: String,
pub user_id: u64,
+ pub is_staff: bool,
pub plan: rpc::proto::Plan,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
impl LlmTokenClaims {
- pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
+ pub fn create(
+ user_id: UserId,
+ is_staff: bool,
+ plan: rpc::proto::Plan,
+ config: &Config,
+ ) -> Result<String> {
let secret = config
.llm_api_secret
.as_ref()
@@ -31,6 +37,7 @@ impl LlmTokenClaims {
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user_id.to_proto(),
+ is_staff,
plan,
};