authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    _claims: &LlmTokenClaims,
 10    country_code: Option<String>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider, model)?;
 15
 16    Ok(())
 17}
 18
 19fn authorize_access_for_country(
 20    config: &Config,
 21    country_code: Option<String>,
 22    provider: LanguageModelProvider,
 23    _model: &str,
 24) -> Result<()> {
 25    // In development we won't have the `CF-IPCountry` header, so we can't check
 26    // the country code.
 27    //
 28    // This shouldn't be necessary, as anyone running in development will need to provide
 29    // their own API credentials in order to use an LLM provider.
 30    if config.is_development() {
 31        return Ok(());
 32    }
 33
 34    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 35    let country_code = match country_code.as_deref() {
 36        // `XX` - Used for clients without country code data.
 37        None | Some("XX") => Err(Error::http(
 38            StatusCode::BAD_REQUEST,
 39            "no country code".to_string(),
 40        ))?,
 41        // `T1` - Used for clients using the Tor network.
 42        Some("T1") => Err(Error::http(
 43            StatusCode::FORBIDDEN,
 44            format!("access to {provider:?} models is not available over Tor"),
 45        ))?,
 46        Some(country_code) => country_code,
 47    };
 48
 49    let is_country_supported_by_provider = match provider {
 50        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 51        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 52        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 53        LanguageModelProvider::Zed => true,
 54    };
 55    if !is_country_supported_by_provider {
 56        Err(Error::http(
 57            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 58            format!("access to {provider:?} models is not available in your region"),
 59        ))?
 60    }
 61
 62    Ok(())
 63}
 64
 65#[cfg(test)]
 66mod tests {
 67    use axum::response::IntoResponse;
 68    use pretty_assertions::assert_eq;
 69    use rpc::proto::Plan;
 70
 71    use super::*;
 72
 73    #[gpui::test]
 74    async fn test_authorize_access_to_language_model_with_supported_country(
 75        _cx: &mut gpui::TestAppContext,
 76    ) {
 77        let config = Config::test();
 78
 79        let claims = LlmTokenClaims {
 80            user_id: 99,
 81            plan: Plan::ZedPro,
 82            ..Default::default()
 83        };
 84
 85        let cases = vec![
 86            (LanguageModelProvider::Anthropic, "US"), // United States
 87            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
 88            (LanguageModelProvider::OpenAi, "US"),    // United States
 89            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
 90            (LanguageModelProvider::Google, "US"),    // United States
 91            (LanguageModelProvider::Google, "GB"),    // United Kingdom
 92        ];
 93
 94        for (provider, country_code) in cases {
 95            authorize_access_to_language_model(
 96                &config,
 97                &claims,
 98                Some(country_code.into()),
 99                provider,
100                "the-model",
101            )
102            .unwrap_or_else(|_| {
103                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
104            })
105        }
106    }
107
108    #[gpui::test]
109    async fn test_authorize_access_to_language_model_with_unsupported_country(
110        _cx: &mut gpui::TestAppContext,
111    ) {
112        let config = Config::test();
113
114        let claims = LlmTokenClaims {
115            user_id: 99,
116            plan: Plan::ZedPro,
117            ..Default::default()
118        };
119
120        let cases = vec![
121            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
122            (LanguageModelProvider::Anthropic, "BY"), // Belarus
123            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
124            (LanguageModelProvider::Anthropic, "CN"), // China
125            (LanguageModelProvider::Anthropic, "CU"), // Cuba
126            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
127            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
128            (LanguageModelProvider::Anthropic, "IR"), // Iran
129            (LanguageModelProvider::Anthropic, "KP"), // North Korea
130            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
131            (LanguageModelProvider::Anthropic, "LY"), // Libya
132            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
133            (LanguageModelProvider::Anthropic, "RU"), // Russia
134            (LanguageModelProvider::Anthropic, "SO"), // Somalia
135            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
136            (LanguageModelProvider::Anthropic, "SD"), // Sudan
137            (LanguageModelProvider::Anthropic, "SY"), // Syria
138            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
139            (LanguageModelProvider::Anthropic, "YE"), // Yemen
140            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
141            (LanguageModelProvider::Google, "KP"),    // North Korea
142        ];
143
144        for (provider, country_code) in cases {
145            let error_response = authorize_access_to_language_model(
146                &config,
147                &claims,
148                Some(country_code.into()),
149                provider,
150                "the-model",
151            )
152            .expect_err(&format!(
153                "expected authorization to return an error for {provider:?}: {country_code}"
154            ))
155            .into_response();
156
157            assert_eq!(
158                error_response.status(),
159                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
160            );
161            let response_body = hyper::body::to_bytes(error_response.into_body())
162                .await
163                .unwrap()
164                .to_vec();
165            assert_eq!(
166                String::from_utf8(response_body).unwrap(),
167                format!("access to {provider:?} models is not available in your region")
168            );
169        }
170    }
171
172    #[gpui::test]
173    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
174        let config = Config::test();
175
176        let claims = LlmTokenClaims {
177            user_id: 99,
178            plan: Plan::ZedPro,
179            ..Default::default()
180        };
181
182        let cases = vec![
183            (LanguageModelProvider::Anthropic, "T1"), // Tor
184            (LanguageModelProvider::OpenAi, "T1"),    // Tor
185            (LanguageModelProvider::Google, "T1"),    // Tor
186            (LanguageModelProvider::Zed, "T1"),       // Tor
187        ];
188
189        for (provider, country_code) in cases {
190            let error_response = authorize_access_to_language_model(
191                &config,
192                &claims,
193                Some(country_code.into()),
194                provider,
195                "the-model",
196            )
197            .expect_err(&format!(
198                "expected authorization to return an error for {provider:?}: {country_code}"
199            ))
200            .into_response();
201
202            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
203            let response_body = hyper::body::to_bytes(error_response.into_body())
204                .await
205                .unwrap()
206                .to_vec();
207            assert_eq!(
208                String::from_utf8(response_body).unwrap(),
209                format!("access to {provider:?} models is not available over Tor")
210            );
211        }
212    }
213}