authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<&str>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(config, claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    config: &Config,
 21    claims: &LlmTokenClaims,
 22    provider: LanguageModelProvider,
 23    model: &str,
 24) -> Result<()> {
 25    if claims.is_staff {
 26        return Ok(());
 27    }
 28
 29    if provider == LanguageModelProvider::Anthropic {
 30        if model == "claude-3-5-sonnet" {
 31            return Ok(());
 32        }
 33
 34        if claims.has_llm_closed_beta_feature_flag
 35            && Some(model) == config.llm_closed_beta_model_name.as_deref()
 36        {
 37            return Ok(());
 38        }
 39    }
 40
 41    Err(Error::http(
 42        StatusCode::FORBIDDEN,
 43        format!("access to model {model:?} is not included in your plan"),
 44    ))
 45}
 46
 47fn authorize_access_for_country(
 48    config: &Config,
 49    country_code: Option<&str>,
 50    provider: LanguageModelProvider,
 51) -> Result<()> {
 52    // In development we won't have the `CF-IPCountry` header, so we can't check
 53    // the country code.
 54    //
 55    // This shouldn't be necessary, as anyone running in development will need to provide
 56    // their own API credentials in order to use an LLM provider.
 57    if config.is_development() {
 58        return Ok(());
 59    }
 60
 61    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 62    let country_code = match country_code {
 63        // `XX` - Used for clients without country code data.
 64        None | Some("XX") => Err(Error::http(
 65            StatusCode::BAD_REQUEST,
 66            "no country code".to_string(),
 67        ))?,
 68        // `T1` - Used for clients using the Tor network.
 69        Some("T1") => Err(Error::http(
 70            StatusCode::FORBIDDEN,
 71            format!("access to {provider:?} models is not available over Tor"),
 72        ))?,
 73        Some(country_code) => country_code,
 74    };
 75
 76    let is_country_supported_by_provider = match provider {
 77        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 78        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 79        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 80        LanguageModelProvider::Zed => true,
 81    };
 82    if !is_country_supported_by_provider {
 83        Err(Error::http(
 84            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 85            format!(
 86                "access to {provider:?} models is not available in your region ({country_code})"
 87            ),
 88        ))?
 89    }
 90
 91    Ok(())
 92}
 93
 94#[cfg(test)]
 95mod tests {
 96    use axum::response::IntoResponse;
 97    use pretty_assertions::assert_eq;
 98    use rpc::proto::Plan;
 99
100    use super::*;
101
102    #[gpui::test]
103    async fn test_authorize_access_to_language_model_with_supported_country(
104        _cx: &mut gpui::TestAppContext,
105    ) {
106        let config = Config::test();
107
108        let claims = LlmTokenClaims {
109            user_id: 99,
110            plan: Plan::ZedPro,
111            is_staff: true,
112            ..Default::default()
113        };
114
115        let cases = vec![
116            (LanguageModelProvider::Anthropic, "US"), // United States
117            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
118            (LanguageModelProvider::OpenAi, "US"),    // United States
119            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
120            (LanguageModelProvider::Google, "US"),    // United States
121            (LanguageModelProvider::Google, "GB"),    // United Kingdom
122        ];
123
124        for (provider, country_code) in cases {
125            authorize_access_to_language_model(
126                &config,
127                &claims,
128                Some(country_code),
129                provider,
130                "the-model",
131            )
132            .unwrap_or_else(|_| {
133                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
134            })
135        }
136    }
137
138    #[gpui::test]
139    async fn test_authorize_access_to_language_model_with_unsupported_country(
140        _cx: &mut gpui::TestAppContext,
141    ) {
142        let config = Config::test();
143
144        let claims = LlmTokenClaims {
145            user_id: 99,
146            plan: Plan::ZedPro,
147            ..Default::default()
148        };
149
150        let cases = vec![
151            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
152            (LanguageModelProvider::Anthropic, "BY"), // Belarus
153            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
154            (LanguageModelProvider::Anthropic, "CN"), // China
155            (LanguageModelProvider::Anthropic, "CU"), // Cuba
156            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
157            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
158            (LanguageModelProvider::Anthropic, "IR"), // Iran
159            (LanguageModelProvider::Anthropic, "KP"), // North Korea
160            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
161            (LanguageModelProvider::Anthropic, "LY"), // Libya
162            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
163            (LanguageModelProvider::Anthropic, "RU"), // Russia
164            (LanguageModelProvider::Anthropic, "SO"), // Somalia
165            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
166            (LanguageModelProvider::Anthropic, "SD"), // Sudan
167            (LanguageModelProvider::Anthropic, "SY"), // Syria
168            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
169            (LanguageModelProvider::Anthropic, "YE"), // Yemen
170            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
171            (LanguageModelProvider::Google, "KP"),    // North Korea
172        ];
173
174        for (provider, country_code) in cases {
175            let error_response = authorize_access_to_language_model(
176                &config,
177                &claims,
178                Some(country_code),
179                provider,
180                "the-model",
181            )
182            .expect_err(&format!(
183                "expected authorization to return an error for {provider:?}: {country_code}"
184            ))
185            .into_response();
186
187            assert_eq!(
188                error_response.status(),
189                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
190            );
191            let response_body = hyper::body::to_bytes(error_response.into_body())
192                .await
193                .unwrap()
194                .to_vec();
195            assert_eq!(
196                String::from_utf8(response_body).unwrap(),
197                format!("access to {provider:?} models is not available in your region ({country_code})")
198            );
199        }
200    }
201
202    #[gpui::test]
203    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
204        let config = Config::test();
205
206        let claims = LlmTokenClaims {
207            user_id: 99,
208            plan: Plan::ZedPro,
209            ..Default::default()
210        };
211
212        let cases = vec![
213            (LanguageModelProvider::Anthropic, "T1"), // Tor
214            (LanguageModelProvider::OpenAi, "T1"),    // Tor
215            (LanguageModelProvider::Google, "T1"),    // Tor
216            (LanguageModelProvider::Zed, "T1"),       // Tor
217        ];
218
219        for (provider, country_code) in cases {
220            let error_response = authorize_access_to_language_model(
221                &config,
222                &claims,
223                Some(country_code),
224                provider,
225                "the-model",
226            )
227            .expect_err(&format!(
228                "expected authorization to return an error for {provider:?}: {country_code}"
229            ))
230            .into_response();
231
232            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
233            let response_body = hyper::body::to_bytes(error_response.into_body())
234                .await
235                .unwrap()
236                .to_vec();
237            assert_eq!(
238                String::from_utf8(response_body).unwrap(),
239                format!("access to {provider:?} models is not available over Tor")
240            );
241        }
242    }
243
244    #[gpui::test]
245    async fn test_authorize_access_to_language_model_based_on_plan() {
246        let config = Config::test();
247
248        let test_cases = vec![
249            // Pro plan should have access to claude-3.5-sonnet
250            (
251                Plan::ZedPro,
252                LanguageModelProvider::Anthropic,
253                "claude-3-5-sonnet",
254                true,
255            ),
256            // Free plan should have access to claude-3.5-sonnet
257            (
258                Plan::Free,
259                LanguageModelProvider::Anthropic,
260                "claude-3-5-sonnet",
261                true,
262            ),
263            // Pro plan should NOT have access to other Anthropic models
264            (
265                Plan::ZedPro,
266                LanguageModelProvider::Anthropic,
267                "claude-3-opus",
268                false,
269            ),
270        ];
271
272        for (plan, provider, model, expected_access) in test_cases {
273            let claims = LlmTokenClaims {
274                plan,
275                ..Default::default()
276            };
277
278            let result =
279                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
280
281            if expected_access {
282                assert!(
283                    result.is_ok(),
284                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
285                    plan,
286                    provider,
287                    model
288                );
289            } else {
290                let error = result.expect_err(&format!(
291                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
292                    plan, provider, model
293                ));
294                let response = error.into_response();
295                assert_eq!(response.status(), StatusCode::FORBIDDEN);
296            }
297        }
298    }
299
300    #[gpui::test]
301    async fn test_authorize_access_to_language_model_for_staff() {
302        let config = Config::test();
303
304        let claims = LlmTokenClaims {
305            is_staff: true,
306            ..Default::default()
307        };
308
309        // Staff should have access to all models
310        let test_cases = vec![
311            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
312            (LanguageModelProvider::Anthropic, "claude-2"),
313            (LanguageModelProvider::Anthropic, "claude-123-agi"),
314            (LanguageModelProvider::OpenAi, "gpt-4"),
315            (LanguageModelProvider::Google, "gemini-pro"),
316        ];
317
318        for (provider, model) in test_cases {
319            let result =
320                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
321
322            assert!(
323                result.is_ok(),
324                "Expected staff to have access to provider {:?}, model {}",
325                provider,
326                model
327            );
328        }
329    }
330}