authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<&str>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(config, claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    config: &Config,
 21    claims: &LlmTokenClaims,
 22    provider: LanguageModelProvider,
 23    model: &str,
 24) -> Result<()> {
 25    if claims.is_staff {
 26        return Ok(());
 27    }
 28
 29    if provider == LanguageModelProvider::Anthropic {
 30        if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
 31            return Ok(());
 32        }
 33
 34        if claims.has_llm_closed_beta_feature_flag
 35            && Some(model) == config.llm_closed_beta_model_name.as_deref()
 36        {
 37            return Ok(());
 38        }
 39    }
 40
 41    Err(Error::http(
 42        StatusCode::FORBIDDEN,
 43        format!("access to model {model:?} is not included in your plan"),
 44    ))
 45}
 46
 47fn authorize_access_for_country(
 48    config: &Config,
 49    country_code: Option<&str>,
 50    provider: LanguageModelProvider,
 51) -> Result<()> {
 52    // In development we won't have the `CF-IPCountry` header, so we can't check
 53    // the country code.
 54    //
 55    // This shouldn't be necessary, as anyone running in development will need to provide
 56    // their own API credentials in order to use an LLM provider.
 57    if config.is_development() {
 58        return Ok(());
 59    }
 60
 61    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 62    let country_code = match country_code {
 63        // `XX` - Used for clients without country code data.
 64        None | Some("XX") => Err(Error::http(
 65            StatusCode::BAD_REQUEST,
 66            "no country code".to_string(),
 67        ))?,
 68        // `T1` - Used for clients using the Tor network.
 69        Some("T1") => Err(Error::http(
 70            StatusCode::FORBIDDEN,
 71            format!("access to {provider:?} models is not available over Tor"),
 72        ))?,
 73        Some(country_code) => country_code,
 74    };
 75
 76    let is_country_supported_by_provider = match provider {
 77        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 78        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 79        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 80    };
 81    if !is_country_supported_by_provider {
 82        Err(Error::http(
 83            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 84            format!(
 85                "access to {provider:?} models is not available in your region ({country_code})"
 86            ),
 87        ))?
 88    }
 89
 90    Ok(())
 91}
 92
 93#[cfg(test)]
 94mod tests {
 95    use axum::response::IntoResponse;
 96    use pretty_assertions::assert_eq;
 97    use rpc::proto::Plan;
 98
 99    use super::*;
100
101    #[gpui::test]
102    async fn test_authorize_access_to_language_model_with_supported_country(
103        _cx: &mut gpui::TestAppContext,
104    ) {
105        let config = Config::test();
106
107        let claims = LlmTokenClaims {
108            user_id: 99,
109            plan: Plan::ZedPro,
110            is_staff: true,
111            ..Default::default()
112        };
113
114        let cases = vec![
115            (LanguageModelProvider::Anthropic, "US"), // United States
116            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
117            (LanguageModelProvider::OpenAi, "US"),    // United States
118            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
119            (LanguageModelProvider::Google, "US"),    // United States
120            (LanguageModelProvider::Google, "GB"),    // United Kingdom
121        ];
122
123        for (provider, country_code) in cases {
124            authorize_access_to_language_model(
125                &config,
126                &claims,
127                Some(country_code),
128                provider,
129                "the-model",
130            )
131            .unwrap_or_else(|_| {
132                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
133            })
134        }
135    }
136
137    #[gpui::test]
138    async fn test_authorize_access_to_language_model_with_unsupported_country(
139        _cx: &mut gpui::TestAppContext,
140    ) {
141        let config = Config::test();
142
143        let claims = LlmTokenClaims {
144            user_id: 99,
145            plan: Plan::ZedPro,
146            ..Default::default()
147        };
148
149        let cases = vec![
150            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
151            (LanguageModelProvider::Anthropic, "BY"), // Belarus
152            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
153            (LanguageModelProvider::Anthropic, "CN"), // China
154            (LanguageModelProvider::Anthropic, "CU"), // Cuba
155            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
156            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
157            (LanguageModelProvider::Anthropic, "IR"), // Iran
158            (LanguageModelProvider::Anthropic, "KP"), // North Korea
159            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
160            (LanguageModelProvider::Anthropic, "LY"), // Libya
161            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
162            (LanguageModelProvider::Anthropic, "RU"), // Russia
163            (LanguageModelProvider::Anthropic, "SO"), // Somalia
164            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
165            (LanguageModelProvider::Anthropic, "SD"), // Sudan
166            (LanguageModelProvider::Anthropic, "SY"), // Syria
167            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
168            (LanguageModelProvider::Anthropic, "YE"), // Yemen
169            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
170            (LanguageModelProvider::Google, "KP"),    // North Korea
171        ];
172
173        for (provider, country_code) in cases {
174            let error_response = authorize_access_to_language_model(
175                &config,
176                &claims,
177                Some(country_code),
178                provider,
179                "the-model",
180            )
181            .expect_err(&format!(
182                "expected authorization to return an error for {provider:?}: {country_code}"
183            ))
184            .into_response();
185
186            assert_eq!(
187                error_response.status(),
188                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
189            );
190            let response_body = hyper::body::to_bytes(error_response.into_body())
191                .await
192                .unwrap()
193                .to_vec();
194            assert_eq!(
195                String::from_utf8(response_body).unwrap(),
196                format!("access to {provider:?} models is not available in your region ({country_code})")
197            );
198        }
199    }
200
201    #[gpui::test]
202    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
203        let config = Config::test();
204
205        let claims = LlmTokenClaims {
206            user_id: 99,
207            plan: Plan::ZedPro,
208            ..Default::default()
209        };
210
211        let cases = vec![
212            (LanguageModelProvider::Anthropic, "T1"), // Tor
213            (LanguageModelProvider::OpenAi, "T1"),    // Tor
214            (LanguageModelProvider::Google, "T1"),    // Tor
215        ];
216
217        for (provider, country_code) in cases {
218            let error_response = authorize_access_to_language_model(
219                &config,
220                &claims,
221                Some(country_code),
222                provider,
223                "the-model",
224            )
225            .expect_err(&format!(
226                "expected authorization to return an error for {provider:?}: {country_code}"
227            ))
228            .into_response();
229
230            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
231            let response_body = hyper::body::to_bytes(error_response.into_body())
232                .await
233                .unwrap()
234                .to_vec();
235            assert_eq!(
236                String::from_utf8(response_body).unwrap(),
237                format!("access to {provider:?} models is not available over Tor")
238            );
239        }
240    }
241
242    #[gpui::test]
243    async fn test_authorize_access_to_language_model_based_on_plan() {
244        let config = Config::test();
245
246        let test_cases = vec![
247            // Pro plan should have access to claude-3.5-sonnet
248            (
249                Plan::ZedPro,
250                LanguageModelProvider::Anthropic,
251                "claude-3-5-sonnet",
252                true,
253            ),
254            // Free plan should have access to claude-3.5-sonnet
255            (
256                Plan::Free,
257                LanguageModelProvider::Anthropic,
258                "claude-3-5-sonnet",
259                true,
260            ),
261            // Pro plan should NOT have access to other Anthropic models
262            (
263                Plan::ZedPro,
264                LanguageModelProvider::Anthropic,
265                "claude-3-opus",
266                false,
267            ),
268        ];
269
270        for (plan, provider, model, expected_access) in test_cases {
271            let claims = LlmTokenClaims {
272                plan,
273                ..Default::default()
274            };
275
276            let result =
277                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
278
279            if expected_access {
280                assert!(
281                    result.is_ok(),
282                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
283                    plan,
284                    provider,
285                    model
286                );
287            } else {
288                let error = result.expect_err(&format!(
289                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
290                    plan, provider, model
291                ));
292                let response = error.into_response();
293                assert_eq!(response.status(), StatusCode::FORBIDDEN);
294            }
295        }
296    }
297
298    #[gpui::test]
299    async fn test_authorize_access_to_language_model_for_staff() {
300        let config = Config::test();
301
302        let claims = LlmTokenClaims {
303            is_staff: true,
304            ..Default::default()
305        };
306
307        // Staff should have access to all models
308        let test_cases = vec![
309            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
310            (LanguageModelProvider::Anthropic, "claude-2"),
311            (LanguageModelProvider::Anthropic, "claude-123-agi"),
312            (LanguageModelProvider::OpenAi, "gpt-4"),
313            (LanguageModelProvider::Google, "gemini-pro"),
314        ];
315
316        for (provider, model) in test_cases {
317            let result =
318                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
319
320            assert!(
321                result.is_ok(),
322                "Expected staff to have access to provider {:?}, model {}",
323                provider,
324                model
325            );
326        }
327    }
328}