authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<String>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    claims: &LlmTokenClaims,
 21    provider: LanguageModelProvider,
 22    model: &str,
 23) -> Result<()> {
 24    if claims.is_staff {
 25        return Ok(());
 26    }
 27
 28    match (provider, model) {
 29        (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
 30            Ok(())
 31        }
 32        _ => Err(Error::http(
 33            StatusCode::FORBIDDEN,
 34            format!("access to model {model:?} is not included in your plan"),
 35        ))?,
 36    }
 37}
 38
 39fn authorize_access_for_country(
 40    config: &Config,
 41    country_code: Option<String>,
 42    provider: LanguageModelProvider,
 43) -> Result<()> {
 44    // In development we won't have the `CF-IPCountry` header, so we can't check
 45    // the country code.
 46    //
 47    // This shouldn't be necessary, as anyone running in development will need to provide
 48    // their own API credentials in order to use an LLM provider.
 49    if config.is_development() {
 50        return Ok(());
 51    }
 52
 53    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 54    let country_code = match country_code.as_deref() {
 55        // `XX` - Used for clients without country code data.
 56        None | Some("XX") => Err(Error::http(
 57            StatusCode::BAD_REQUEST,
 58            "no country code".to_string(),
 59        ))?,
 60        // `T1` - Used for clients using the Tor network.
 61        Some("T1") => Err(Error::http(
 62            StatusCode::FORBIDDEN,
 63            format!("access to {provider:?} models is not available over Tor"),
 64        ))?,
 65        Some(country_code) => country_code,
 66    };
 67
 68    let is_country_supported_by_provider = match provider {
 69        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 70        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 71        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 72        LanguageModelProvider::Zed => true,
 73    };
 74    if !is_country_supported_by_provider {
 75        Err(Error::http(
 76            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 77            format!("access to {provider:?} models is not available in your region"),
 78        ))?
 79    }
 80
 81    Ok(())
 82}
 83
 84#[cfg(test)]
 85mod tests {
 86    use axum::response::IntoResponse;
 87    use pretty_assertions::assert_eq;
 88    use rpc::proto::Plan;
 89
 90    use super::*;
 91
 92    #[gpui::test]
 93    async fn test_authorize_access_to_language_model_with_supported_country(
 94        _cx: &mut gpui::TestAppContext,
 95    ) {
 96        let config = Config::test();
 97
 98        let claims = LlmTokenClaims {
 99            user_id: 99,
100            plan: Plan::ZedPro,
101            is_staff: true,
102            ..Default::default()
103        };
104
105        let cases = vec![
106            (LanguageModelProvider::Anthropic, "US"), // United States
107            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
108            (LanguageModelProvider::OpenAi, "US"),    // United States
109            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
110            (LanguageModelProvider::Google, "US"),    // United States
111            (LanguageModelProvider::Google, "GB"),    // United Kingdom
112        ];
113
114        for (provider, country_code) in cases {
115            authorize_access_to_language_model(
116                &config,
117                &claims,
118                Some(country_code.into()),
119                provider,
120                "the-model",
121            )
122            .unwrap_or_else(|_| {
123                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
124            })
125        }
126    }
127
128    #[gpui::test]
129    async fn test_authorize_access_to_language_model_with_unsupported_country(
130        _cx: &mut gpui::TestAppContext,
131    ) {
132        let config = Config::test();
133
134        let claims = LlmTokenClaims {
135            user_id: 99,
136            plan: Plan::ZedPro,
137            ..Default::default()
138        };
139
140        let cases = vec![
141            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
142            (LanguageModelProvider::Anthropic, "BY"), // Belarus
143            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
144            (LanguageModelProvider::Anthropic, "CN"), // China
145            (LanguageModelProvider::Anthropic, "CU"), // Cuba
146            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
147            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
148            (LanguageModelProvider::Anthropic, "IR"), // Iran
149            (LanguageModelProvider::Anthropic, "KP"), // North Korea
150            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
151            (LanguageModelProvider::Anthropic, "LY"), // Libya
152            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
153            (LanguageModelProvider::Anthropic, "RU"), // Russia
154            (LanguageModelProvider::Anthropic, "SO"), // Somalia
155            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
156            (LanguageModelProvider::Anthropic, "SD"), // Sudan
157            (LanguageModelProvider::Anthropic, "SY"), // Syria
158            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
159            (LanguageModelProvider::Anthropic, "YE"), // Yemen
160            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
161            (LanguageModelProvider::Google, "KP"),    // North Korea
162        ];
163
164        for (provider, country_code) in cases {
165            let error_response = authorize_access_to_language_model(
166                &config,
167                &claims,
168                Some(country_code.into()),
169                provider,
170                "the-model",
171            )
172            .expect_err(&format!(
173                "expected authorization to return an error for {provider:?}: {country_code}"
174            ))
175            .into_response();
176
177            assert_eq!(
178                error_response.status(),
179                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
180            );
181            let response_body = hyper::body::to_bytes(error_response.into_body())
182                .await
183                .unwrap()
184                .to_vec();
185            assert_eq!(
186                String::from_utf8(response_body).unwrap(),
187                format!("access to {provider:?} models is not available in your region")
188            );
189        }
190    }
191
192    #[gpui::test]
193    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
194        let config = Config::test();
195
196        let claims = LlmTokenClaims {
197            user_id: 99,
198            plan: Plan::ZedPro,
199            ..Default::default()
200        };
201
202        let cases = vec![
203            (LanguageModelProvider::Anthropic, "T1"), // Tor
204            (LanguageModelProvider::OpenAi, "T1"),    // Tor
205            (LanguageModelProvider::Google, "T1"),    // Tor
206            (LanguageModelProvider::Zed, "T1"),       // Tor
207        ];
208
209        for (provider, country_code) in cases {
210            let error_response = authorize_access_to_language_model(
211                &config,
212                &claims,
213                Some(country_code.into()),
214                provider,
215                "the-model",
216            )
217            .expect_err(&format!(
218                "expected authorization to return an error for {provider:?}: {country_code}"
219            ))
220            .into_response();
221
222            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
223            let response_body = hyper::body::to_bytes(error_response.into_body())
224                .await
225                .unwrap()
226                .to_vec();
227            assert_eq!(
228                String::from_utf8(response_body).unwrap(),
229                format!("access to {provider:?} models is not available over Tor")
230            );
231        }
232    }
233
234    #[gpui::test]
235    async fn test_authorize_access_to_language_model_based_on_plan() {
236        let config = Config::test();
237
238        let test_cases = vec![
239            // Pro plan should have access to claude-3.5-sonnet
240            (
241                Plan::ZedPro,
242                LanguageModelProvider::Anthropic,
243                "claude-3-5-sonnet",
244                true,
245            ),
246            // Free plan should have access to claude-3.5-sonnet
247            (
248                Plan::Free,
249                LanguageModelProvider::Anthropic,
250                "claude-3-5-sonnet",
251                true,
252            ),
253            // Pro plan should NOT have access to other Anthropic models
254            (
255                Plan::ZedPro,
256                LanguageModelProvider::Anthropic,
257                "claude-3-opus",
258                false,
259            ),
260        ];
261
262        for (plan, provider, model, expected_access) in test_cases {
263            let claims = LlmTokenClaims {
264                plan,
265                ..Default::default()
266            };
267
268            let result = authorize_access_to_language_model(
269                &config,
270                &claims,
271                Some("US".into()),
272                provider,
273                model,
274            );
275
276            if expected_access {
277                assert!(
278                    result.is_ok(),
279                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
280                    plan,
281                    provider,
282                    model
283                );
284            } else {
285                let error = result.expect_err(&format!(
286                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
287                    plan, provider, model
288                ));
289                let response = error.into_response();
290                assert_eq!(response.status(), StatusCode::FORBIDDEN);
291            }
292        }
293    }
294
295    #[gpui::test]
296    async fn test_authorize_access_to_language_model_for_staff() {
297        let config = Config::test();
298
299        let claims = LlmTokenClaims {
300            is_staff: true,
301            ..Default::default()
302        };
303
304        // Staff should have access to all models
305        let test_cases = vec![
306            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
307            (LanguageModelProvider::Anthropic, "claude-2"),
308            (LanguageModelProvider::Anthropic, "claude-123-agi"),
309            (LanguageModelProvider::OpenAi, "gpt-4"),
310            (LanguageModelProvider::Google, "gemini-pro"),
311        ];
312
313        for (provider, model) in test_cases {
314            let result = authorize_access_to_language_model(
315                &config,
316                &claims,
317                Some("US".into()),
318                provider,
319                model,
320            );
321
322            assert!(
323                result.is_ok(),
324                "Expected staff to have access to provider {:?}, model {}",
325                provider,
326                model
327            );
328        }
329    }
330}