authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<String>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(config, claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    config: &Config,
 21    claims: &LlmTokenClaims,
 22    provider: LanguageModelProvider,
 23    model: &str,
 24) -> Result<()> {
 25    if claims.is_staff {
 26        return Ok(());
 27    }
 28
 29    match provider {
 30        LanguageModelProvider::Anthropic => {
 31            if model == "claude-3-5-sonnet" {
 32                return Ok(());
 33            }
 34
 35            if claims.has_llm_closed_beta_feature_flag
 36                && Some(model) == config.llm_closed_beta_model_name.as_deref()
 37            {
 38                return Ok(());
 39            }
 40        }
 41        _ => {}
 42    }
 43
 44    Err(Error::http(
 45        StatusCode::FORBIDDEN,
 46        format!("access to model {model:?} is not included in your plan"),
 47    ))
 48}
 49
 50fn authorize_access_for_country(
 51    config: &Config,
 52    country_code: Option<String>,
 53    provider: LanguageModelProvider,
 54) -> Result<()> {
 55    // In development we won't have the `CF-IPCountry` header, so we can't check
 56    // the country code.
 57    //
 58    // This shouldn't be necessary, as anyone running in development will need to provide
 59    // their own API credentials in order to use an LLM provider.
 60    if config.is_development() {
 61        return Ok(());
 62    }
 63
 64    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 65    let country_code = match country_code.as_deref() {
 66        // `XX` - Used for clients without country code data.
 67        None | Some("XX") => Err(Error::http(
 68            StatusCode::BAD_REQUEST,
 69            "no country code".to_string(),
 70        ))?,
 71        // `T1` - Used for clients using the Tor network.
 72        Some("T1") => Err(Error::http(
 73            StatusCode::FORBIDDEN,
 74            format!("access to {provider:?} models is not available over Tor"),
 75        ))?,
 76        Some(country_code) => country_code,
 77    };
 78
 79    let is_country_supported_by_provider = match provider {
 80        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 81        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 82        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 83        LanguageModelProvider::Zed => true,
 84    };
 85    if !is_country_supported_by_provider {
 86        Err(Error::http(
 87            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 88            format!("access to {provider:?} models is not available in your region"),
 89        ))?
 90    }
 91
 92    Ok(())
 93}
 94
 95#[cfg(test)]
 96mod tests {
 97    use axum::response::IntoResponse;
 98    use pretty_assertions::assert_eq;
 99    use rpc::proto::Plan;
100
101    use super::*;
102
103    #[gpui::test]
104    async fn test_authorize_access_to_language_model_with_supported_country(
105        _cx: &mut gpui::TestAppContext,
106    ) {
107        let config = Config::test();
108
109        let claims = LlmTokenClaims {
110            user_id: 99,
111            plan: Plan::ZedPro,
112            is_staff: true,
113            ..Default::default()
114        };
115
116        let cases = vec![
117            (LanguageModelProvider::Anthropic, "US"), // United States
118            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
119            (LanguageModelProvider::OpenAi, "US"),    // United States
120            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
121            (LanguageModelProvider::Google, "US"),    // United States
122            (LanguageModelProvider::Google, "GB"),    // United Kingdom
123        ];
124
125        for (provider, country_code) in cases {
126            authorize_access_to_language_model(
127                &config,
128                &claims,
129                Some(country_code.into()),
130                provider,
131                "the-model",
132            )
133            .unwrap_or_else(|_| {
134                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
135            })
136        }
137    }
138
139    #[gpui::test]
140    async fn test_authorize_access_to_language_model_with_unsupported_country(
141        _cx: &mut gpui::TestAppContext,
142    ) {
143        let config = Config::test();
144
145        let claims = LlmTokenClaims {
146            user_id: 99,
147            plan: Plan::ZedPro,
148            ..Default::default()
149        };
150
151        let cases = vec![
152            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
153            (LanguageModelProvider::Anthropic, "BY"), // Belarus
154            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
155            (LanguageModelProvider::Anthropic, "CN"), // China
156            (LanguageModelProvider::Anthropic, "CU"), // Cuba
157            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
158            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
159            (LanguageModelProvider::Anthropic, "IR"), // Iran
160            (LanguageModelProvider::Anthropic, "KP"), // North Korea
161            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
162            (LanguageModelProvider::Anthropic, "LY"), // Libya
163            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
164            (LanguageModelProvider::Anthropic, "RU"), // Russia
165            (LanguageModelProvider::Anthropic, "SO"), // Somalia
166            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
167            (LanguageModelProvider::Anthropic, "SD"), // Sudan
168            (LanguageModelProvider::Anthropic, "SY"), // Syria
169            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
170            (LanguageModelProvider::Anthropic, "YE"), // Yemen
171            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
172            (LanguageModelProvider::Google, "KP"),    // North Korea
173        ];
174
175        for (provider, country_code) in cases {
176            let error_response = authorize_access_to_language_model(
177                &config,
178                &claims,
179                Some(country_code.into()),
180                provider,
181                "the-model",
182            )
183            .expect_err(&format!(
184                "expected authorization to return an error for {provider:?}: {country_code}"
185            ))
186            .into_response();
187
188            assert_eq!(
189                error_response.status(),
190                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
191            );
192            let response_body = hyper::body::to_bytes(error_response.into_body())
193                .await
194                .unwrap()
195                .to_vec();
196            assert_eq!(
197                String::from_utf8(response_body).unwrap(),
198                format!("access to {provider:?} models is not available in your region")
199            );
200        }
201    }
202
203    #[gpui::test]
204    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
205        let config = Config::test();
206
207        let claims = LlmTokenClaims {
208            user_id: 99,
209            plan: Plan::ZedPro,
210            ..Default::default()
211        };
212
213        let cases = vec![
214            (LanguageModelProvider::Anthropic, "T1"), // Tor
215            (LanguageModelProvider::OpenAi, "T1"),    // Tor
216            (LanguageModelProvider::Google, "T1"),    // Tor
217            (LanguageModelProvider::Zed, "T1"),       // Tor
218        ];
219
220        for (provider, country_code) in cases {
221            let error_response = authorize_access_to_language_model(
222                &config,
223                &claims,
224                Some(country_code.into()),
225                provider,
226                "the-model",
227            )
228            .expect_err(&format!(
229                "expected authorization to return an error for {provider:?}: {country_code}"
230            ))
231            .into_response();
232
233            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
234            let response_body = hyper::body::to_bytes(error_response.into_body())
235                .await
236                .unwrap()
237                .to_vec();
238            assert_eq!(
239                String::from_utf8(response_body).unwrap(),
240                format!("access to {provider:?} models is not available over Tor")
241            );
242        }
243    }
244
245    #[gpui::test]
246    async fn test_authorize_access_to_language_model_based_on_plan() {
247        let config = Config::test();
248
249        let test_cases = vec![
250            // Pro plan should have access to claude-3.5-sonnet
251            (
252                Plan::ZedPro,
253                LanguageModelProvider::Anthropic,
254                "claude-3-5-sonnet",
255                true,
256            ),
257            // Free plan should have access to claude-3.5-sonnet
258            (
259                Plan::Free,
260                LanguageModelProvider::Anthropic,
261                "claude-3-5-sonnet",
262                true,
263            ),
264            // Pro plan should NOT have access to other Anthropic models
265            (
266                Plan::ZedPro,
267                LanguageModelProvider::Anthropic,
268                "claude-3-opus",
269                false,
270            ),
271        ];
272
273        for (plan, provider, model, expected_access) in test_cases {
274            let claims = LlmTokenClaims {
275                plan,
276                ..Default::default()
277            };
278
279            let result = authorize_access_to_language_model(
280                &config,
281                &claims,
282                Some("US".into()),
283                provider,
284                model,
285            );
286
287            if expected_access {
288                assert!(
289                    result.is_ok(),
290                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
291                    plan,
292                    provider,
293                    model
294                );
295            } else {
296                let error = result.expect_err(&format!(
297                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
298                    plan, provider, model
299                ));
300                let response = error.into_response();
301                assert_eq!(response.status(), StatusCode::FORBIDDEN);
302            }
303        }
304    }
305
306    #[gpui::test]
307    async fn test_authorize_access_to_language_model_for_staff() {
308        let config = Config::test();
309
310        let claims = LlmTokenClaims {
311            is_staff: true,
312            ..Default::default()
313        };
314
315        // Staff should have access to all models
316        let test_cases = vec![
317            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
318            (LanguageModelProvider::Anthropic, "claude-2"),
319            (LanguageModelProvider::Anthropic, "claude-123-agi"),
320            (LanguageModelProvider::OpenAi, "gpt-4"),
321            (LanguageModelProvider::Google, "gemini-pro"),
322        ];
323
324        for (provider, model) in test_cases {
325            let result = authorize_access_to_language_model(
326                &config,
327                &claims,
328                Some("US".into()),
329                provider,
330                model,
331            );
332
333            assert!(
334                result.is_ok(),
335                "Expected staff to have access to provider {:?}, model {}",
336                provider,
337                model
338            );
339        }
340    }
341}