authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<String>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    claims: &LlmTokenClaims,
 21    provider: LanguageModelProvider,
 22    model: &str,
 23) -> Result<()> {
 24    if claims.is_staff {
 25        return Ok(());
 26    }
 27
 28    match (provider, model) {
 29        (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
 30        _ => Err(Error::http(
 31            StatusCode::FORBIDDEN,
 32            format!("access to model {model:?} is not included in your plan"),
 33        ))?,
 34    }
 35}
 36
 37fn authorize_access_for_country(
 38    config: &Config,
 39    country_code: Option<String>,
 40    provider: LanguageModelProvider,
 41) -> Result<()> {
 42    // In development we won't have the `CF-IPCountry` header, so we can't check
 43    // the country code.
 44    //
 45    // This shouldn't be necessary, as anyone running in development will need to provide
 46    // their own API credentials in order to use an LLM provider.
 47    if config.is_development() {
 48        return Ok(());
 49    }
 50
 51    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 52    let country_code = match country_code.as_deref() {
 53        // `XX` - Used for clients without country code data.
 54        None | Some("XX") => Err(Error::http(
 55            StatusCode::BAD_REQUEST,
 56            "no country code".to_string(),
 57        ))?,
 58        // `T1` - Used for clients using the Tor network.
 59        Some("T1") => Err(Error::http(
 60            StatusCode::FORBIDDEN,
 61            format!("access to {provider:?} models is not available over Tor"),
 62        ))?,
 63        Some(country_code) => country_code,
 64    };
 65
 66    let is_country_supported_by_provider = match provider {
 67        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 68        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 69        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 70        LanguageModelProvider::Zed => true,
 71    };
 72    if !is_country_supported_by_provider {
 73        Err(Error::http(
 74            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 75            format!("access to {provider:?} models is not available in your region"),
 76        ))?
 77    }
 78
 79    Ok(())
 80}
 81
 82#[cfg(test)]
 83mod tests {
 84    use axum::response::IntoResponse;
 85    use pretty_assertions::assert_eq;
 86    use rpc::proto::Plan;
 87
 88    use super::*;
 89
 90    #[gpui::test]
 91    async fn test_authorize_access_to_language_model_with_supported_country(
 92        _cx: &mut gpui::TestAppContext,
 93    ) {
 94        let config = Config::test();
 95
 96        let claims = LlmTokenClaims {
 97            user_id: 99,
 98            plan: Plan::ZedPro,
 99            is_staff: true,
100            ..Default::default()
101        };
102
103        let cases = vec![
104            (LanguageModelProvider::Anthropic, "US"), // United States
105            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
106            (LanguageModelProvider::OpenAi, "US"),    // United States
107            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
108            (LanguageModelProvider::Google, "US"),    // United States
109            (LanguageModelProvider::Google, "GB"),    // United Kingdom
110        ];
111
112        for (provider, country_code) in cases {
113            authorize_access_to_language_model(
114                &config,
115                &claims,
116                Some(country_code.into()),
117                provider,
118                "the-model",
119            )
120            .unwrap_or_else(|_| {
121                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
122            })
123        }
124    }
125
126    #[gpui::test]
127    async fn test_authorize_access_to_language_model_with_unsupported_country(
128        _cx: &mut gpui::TestAppContext,
129    ) {
130        let config = Config::test();
131
132        let claims = LlmTokenClaims {
133            user_id: 99,
134            plan: Plan::ZedPro,
135            ..Default::default()
136        };
137
138        let cases = vec![
139            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
140            (LanguageModelProvider::Anthropic, "BY"), // Belarus
141            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
142            (LanguageModelProvider::Anthropic, "CN"), // China
143            (LanguageModelProvider::Anthropic, "CU"), // Cuba
144            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
145            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
146            (LanguageModelProvider::Anthropic, "IR"), // Iran
147            (LanguageModelProvider::Anthropic, "KP"), // North Korea
148            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
149            (LanguageModelProvider::Anthropic, "LY"), // Libya
150            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
151            (LanguageModelProvider::Anthropic, "RU"), // Russia
152            (LanguageModelProvider::Anthropic, "SO"), // Somalia
153            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
154            (LanguageModelProvider::Anthropic, "SD"), // Sudan
155            (LanguageModelProvider::Anthropic, "SY"), // Syria
156            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
157            (LanguageModelProvider::Anthropic, "YE"), // Yemen
158            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
159            (LanguageModelProvider::Google, "KP"),    // North Korea
160        ];
161
162        for (provider, country_code) in cases {
163            let error_response = authorize_access_to_language_model(
164                &config,
165                &claims,
166                Some(country_code.into()),
167                provider,
168                "the-model",
169            )
170            .expect_err(&format!(
171                "expected authorization to return an error for {provider:?}: {country_code}"
172            ))
173            .into_response();
174
175            assert_eq!(
176                error_response.status(),
177                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
178            );
179            let response_body = hyper::body::to_bytes(error_response.into_body())
180                .await
181                .unwrap()
182                .to_vec();
183            assert_eq!(
184                String::from_utf8(response_body).unwrap(),
185                format!("access to {provider:?} models is not available in your region")
186            );
187        }
188    }
189
190    #[gpui::test]
191    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
192        let config = Config::test();
193
194        let claims = LlmTokenClaims {
195            user_id: 99,
196            plan: Plan::ZedPro,
197            ..Default::default()
198        };
199
200        let cases = vec![
201            (LanguageModelProvider::Anthropic, "T1"), // Tor
202            (LanguageModelProvider::OpenAi, "T1"),    // Tor
203            (LanguageModelProvider::Google, "T1"),    // Tor
204            (LanguageModelProvider::Zed, "T1"),       // Tor
205        ];
206
207        for (provider, country_code) in cases {
208            let error_response = authorize_access_to_language_model(
209                &config,
210                &claims,
211                Some(country_code.into()),
212                provider,
213                "the-model",
214            )
215            .expect_err(&format!(
216                "expected authorization to return an error for {provider:?}: {country_code}"
217            ))
218            .into_response();
219
220            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
221            let response_body = hyper::body::to_bytes(error_response.into_body())
222                .await
223                .unwrap()
224                .to_vec();
225            assert_eq!(
226                String::from_utf8(response_body).unwrap(),
227                format!("access to {provider:?} models is not available over Tor")
228            );
229        }
230    }
231
232    #[gpui::test]
233    async fn test_authorize_access_to_language_model_based_on_plan() {
234        let config = Config::test();
235
236        let test_cases = vec![
237            // Pro plan should have access to claude-3.5-sonnet
238            (
239                Plan::ZedPro,
240                LanguageModelProvider::Anthropic,
241                "claude-3-5-sonnet",
242                true,
243            ),
244            // Free plan should have access to claude-3.5-sonnet
245            (
246                Plan::Free,
247                LanguageModelProvider::Anthropic,
248                "claude-3-5-sonnet",
249                true,
250            ),
251            // Pro plan should NOT have access to other Anthropic models
252            (
253                Plan::ZedPro,
254                LanguageModelProvider::Anthropic,
255                "claude-3-opus",
256                false,
257            ),
258        ];
259
260        for (plan, provider, model, expected_access) in test_cases {
261            let claims = LlmTokenClaims {
262                plan,
263                ..Default::default()
264            };
265
266            let result = authorize_access_to_language_model(
267                &config,
268                &claims,
269                Some("US".into()),
270                provider,
271                model,
272            );
273
274            if expected_access {
275                assert!(
276                    result.is_ok(),
277                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
278                    plan,
279                    provider,
280                    model
281                );
282            } else {
283                let error = result.expect_err(&format!(
284                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
285                    plan, provider, model
286                ));
287                let response = error.into_response();
288                assert_eq!(response.status(), StatusCode::FORBIDDEN);
289            }
290        }
291    }
292
293    #[gpui::test]
294    async fn test_authorize_access_to_language_model_for_staff() {
295        let config = Config::test();
296
297        let claims = LlmTokenClaims {
298            is_staff: true,
299            ..Default::default()
300        };
301
302        // Staff should have access to all models
303        let test_cases = vec![
304            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
305            (LanguageModelProvider::Anthropic, "claude-2"),
306            (LanguageModelProvider::Anthropic, "claude-123-agi"),
307            (LanguageModelProvider::OpenAi, "gpt-4"),
308            (LanguageModelProvider::Google, "gemini-pro"),
309        ];
310
311        for (provider, model) in test_cases {
312            let result = authorize_access_to_language_model(
313                &config,
314                &claims,
315                Some("US".into()),
316                provider,
317                model,
318            );
319
320            assert!(
321                result.is_ok(),
322                "Expected staff to have access to provider {:?}, model {}",
323                provider,
324                model
325            );
326        }
327    }
328}