authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<&str>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(config, claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    config: &Config,
 21    claims: &LlmTokenClaims,
 22    provider: LanguageModelProvider,
 23    model: &str,
 24) -> Result<()> {
 25    if claims.is_staff {
 26        return Ok(());
 27    }
 28
 29    if provider == LanguageModelProvider::Anthropic {
 30        if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
 31            return Ok(());
 32        }
 33
 34        if claims.has_llm_closed_beta_feature_flag
 35            && Some(model) == config.llm_closed_beta_model_name.as_deref()
 36        {
 37            return Ok(());
 38        }
 39    }
 40
 41    Err(Error::http(
 42        StatusCode::FORBIDDEN,
 43        format!("access to model {model:?} is not included in your plan"),
 44    ))
 45}
 46
 47fn authorize_access_for_country(
 48    config: &Config,
 49    country_code: Option<&str>,
 50    provider: LanguageModelProvider,
 51) -> Result<()> {
 52    // In development we won't have the `CF-IPCountry` header, so we can't check
 53    // the country code.
 54    //
 55    // This shouldn't be necessary, as anyone running in development will need to provide
 56    // their own API credentials in order to use an LLM provider.
 57    if config.is_development() {
 58        return Ok(());
 59    }
 60
 61    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 62    let country_code = match country_code {
 63        // `XX` - Used for clients without country code data.
 64        None | Some("XX") => Err(Error::http(
 65            StatusCode::BAD_REQUEST,
 66            "no country code".to_string(),
 67        ))?,
 68        // `T1` - Used for clients using the Tor network.
 69        Some("T1") => Err(Error::http(
 70            StatusCode::FORBIDDEN,
 71            format!("access to {provider:?} models is not available over Tor"),
 72        ))?,
 73        Some(country_code) => country_code,
 74    };
 75
 76    let is_country_supported_by_provider = match provider {
 77        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 78        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 79        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 80    };
 81    if !is_country_supported_by_provider {
 82        Err(Error::http(
 83            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 84            format!(
 85                "access to {provider:?} models is not available in your region ({country_code})"
 86            ),
 87        ))?
 88    }
 89
 90    Ok(())
 91}
 92
 93#[cfg(test)]
 94mod tests {
 95    use axum::response::IntoResponse;
 96    use pretty_assertions::assert_eq;
 97    use rpc::proto::Plan;
 98
 99    use super::*;
100
101    #[gpui::test]
102    async fn test_authorize_access_to_language_model_with_supported_country(
103        _cx: &mut gpui::TestAppContext,
104    ) {
105        let config = Config::test();
106
107        let claims = LlmTokenClaims {
108            user_id: 99,
109            plan: Plan::ZedPro,
110            is_staff: true,
111            ..Default::default()
112        };
113
114        let cases = vec![
115            (LanguageModelProvider::Anthropic, "US"), // United States
116            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
117            (LanguageModelProvider::OpenAi, "US"),    // United States
118            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
119            (LanguageModelProvider::Google, "US"),    // United States
120            (LanguageModelProvider::Google, "GB"),    // United Kingdom
121        ];
122
123        for (provider, country_code) in cases {
124            authorize_access_to_language_model(
125                &config,
126                &claims,
127                Some(country_code),
128                provider,
129                "the-model",
130            )
131            .unwrap_or_else(|_| {
132                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
133            })
134        }
135    }
136
137    #[gpui::test]
138    async fn test_authorize_access_to_language_model_with_unsupported_country(
139        _cx: &mut gpui::TestAppContext,
140    ) {
141        let config = Config::test();
142
143        let claims = LlmTokenClaims {
144            user_id: 99,
145            plan: Plan::ZedPro,
146            ..Default::default()
147        };
148
149        let cases = vec![
150            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
151            (LanguageModelProvider::Anthropic, "BY"), // Belarus
152            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
153            (LanguageModelProvider::Anthropic, "CN"), // China
154            (LanguageModelProvider::Anthropic, "CU"), // Cuba
155            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
156            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
157            (LanguageModelProvider::Anthropic, "IR"), // Iran
158            (LanguageModelProvider::Anthropic, "KP"), // North Korea
159            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
160            (LanguageModelProvider::Anthropic, "LY"), // Libya
161            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
162            (LanguageModelProvider::Anthropic, "RU"), // Russia
163            (LanguageModelProvider::Anthropic, "SO"), // Somalia
164            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
165            (LanguageModelProvider::Anthropic, "SD"), // Sudan
166            (LanguageModelProvider::Anthropic, "SY"), // Syria
167            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
168            (LanguageModelProvider::Anthropic, "YE"), // Yemen
169            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
170            (LanguageModelProvider::Google, "KP"),    // North Korea
171        ];
172
173        for (provider, country_code) in cases {
174            let error_response = authorize_access_to_language_model(
175                &config,
176                &claims,
177                Some(country_code),
178                provider,
179                "the-model",
180            )
181            .expect_err(&format!(
182                "expected authorization to return an error for {provider:?}: {country_code}"
183            ))
184            .into_response();
185
186            assert_eq!(
187                error_response.status(),
188                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
189            );
190            let response_body = hyper::body::to_bytes(error_response.into_body())
191                .await
192                .unwrap()
193                .to_vec();
194            assert_eq!(
195                String::from_utf8(response_body).unwrap(),
196                format!(
197                    "access to {provider:?} models is not available in your region ({country_code})"
198                )
199            );
200        }
201    }
202
203    #[gpui::test]
204    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
205        let config = Config::test();
206
207        let claims = LlmTokenClaims {
208            user_id: 99,
209            plan: Plan::ZedPro,
210            ..Default::default()
211        };
212
213        let cases = vec![
214            (LanguageModelProvider::Anthropic, "T1"), // Tor
215            (LanguageModelProvider::OpenAi, "T1"),    // Tor
216            (LanguageModelProvider::Google, "T1"),    // Tor
217        ];
218
219        for (provider, country_code) in cases {
220            let error_response = authorize_access_to_language_model(
221                &config,
222                &claims,
223                Some(country_code),
224                provider,
225                "the-model",
226            )
227            .expect_err(&format!(
228                "expected authorization to return an error for {provider:?}: {country_code}"
229            ))
230            .into_response();
231
232            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
233            let response_body = hyper::body::to_bytes(error_response.into_body())
234                .await
235                .unwrap()
236                .to_vec();
237            assert_eq!(
238                String::from_utf8(response_body).unwrap(),
239                format!("access to {provider:?} models is not available over Tor")
240            );
241        }
242    }
243
244    #[gpui::test]
245    async fn test_authorize_access_to_language_model_based_on_plan() {
246        let config = Config::test();
247
248        let test_cases = vec![
249            // Pro plan should have access to claude-3.5-sonnet
250            (
251                Plan::ZedPro,
252                LanguageModelProvider::Anthropic,
253                "claude-3-5-sonnet",
254                true,
255            ),
256            // Free plan should have access to claude-3.5-sonnet
257            (
258                Plan::Free,
259                LanguageModelProvider::Anthropic,
260                "claude-3-5-sonnet",
261                true,
262            ),
263            // Pro plan should NOT have access to other Anthropic models
264            (
265                Plan::ZedPro,
266                LanguageModelProvider::Anthropic,
267                "claude-3-opus",
268                false,
269            ),
270        ];
271
272        for (plan, provider, model, expected_access) in test_cases {
273            let claims = LlmTokenClaims {
274                plan,
275                ..Default::default()
276            };
277
278            let result =
279                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
280
281            if expected_access {
282                assert!(
283                    result.is_ok(),
284                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
285                    plan,
286                    provider,
287                    model
288                );
289            } else {
290                let error = result.expect_err(&format!(
291                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
292                    plan, provider, model
293                ));
294                let response = error.into_response();
295                assert_eq!(response.status(), StatusCode::FORBIDDEN);
296            }
297        }
298    }
299
300    #[gpui::test]
301    async fn test_authorize_access_to_language_model_for_staff() {
302        let config = Config::test();
303
304        let claims = LlmTokenClaims {
305            is_staff: true,
306            ..Default::default()
307        };
308
309        // Staff should have access to all models
310        let test_cases = vec![
311            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
312            (LanguageModelProvider::Anthropic, "claude-2"),
313            (LanguageModelProvider::Anthropic, "claude-123-agi"),
314            (LanguageModelProvider::OpenAi, "gpt-4"),
315            (LanguageModelProvider::Google, "gemini-pro"),
316        ];
317
318        for (provider, model) in test_cases {
319            let result =
320                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
321
322            assert!(
323                result.is_ok(),
324                "Expected staff to have access to provider {:?}, model {}",
325                provider,
326                model
327            );
328        }
329    }
330}