authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<&str>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(config, claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    config: &Config,
 21    claims: &LlmTokenClaims,
 22    provider: LanguageModelProvider,
 23    model: &str,
 24) -> Result<()> {
 25    if claims.is_staff {
 26        return Ok(());
 27    }
 28
 29    match provider {
 30        LanguageModelProvider::Anthropic => {
 31            if model == "claude-3-5-sonnet" {
 32                return Ok(());
 33            }
 34
 35            if claims.has_llm_closed_beta_feature_flag
 36                && Some(model) == config.llm_closed_beta_model_name.as_deref()
 37            {
 38                return Ok(());
 39            }
 40        }
 41        _ => {}
 42    }
 43
 44    Err(Error::http(
 45        StatusCode::FORBIDDEN,
 46        format!("access to model {model:?} is not included in your plan"),
 47    ))
 48}
 49
 50fn authorize_access_for_country(
 51    config: &Config,
 52    country_code: Option<&str>,
 53    provider: LanguageModelProvider,
 54) -> Result<()> {
 55    // In development we won't have the `CF-IPCountry` header, so we can't check
 56    // the country code.
 57    //
 58    // This shouldn't be necessary, as anyone running in development will need to provide
 59    // their own API credentials in order to use an LLM provider.
 60    if config.is_development() {
 61        return Ok(());
 62    }
 63
 64    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 65    let country_code = match country_code {
 66        // `XX` - Used for clients without country code data.
 67        None | Some("XX") => Err(Error::http(
 68            StatusCode::BAD_REQUEST,
 69            "no country code".to_string(),
 70        ))?,
 71        // `T1` - Used for clients using the Tor network.
 72        Some("T1") => Err(Error::http(
 73            StatusCode::FORBIDDEN,
 74            format!("access to {provider:?} models is not available over Tor"),
 75        ))?,
 76        Some(country_code) => country_code,
 77    };
 78
 79    let is_country_supported_by_provider = match provider {
 80        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 81        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 82        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 83        LanguageModelProvider::Zed => true,
 84    };
 85    if !is_country_supported_by_provider {
 86        Err(Error::http(
 87            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 88            format!(
 89                "access to {provider:?} models is not available in your region ({country_code})"
 90            ),
 91        ))?
 92    }
 93
 94    Ok(())
 95}
 96
 97#[cfg(test)]
 98mod tests {
 99    use axum::response::IntoResponse;
100    use pretty_assertions::assert_eq;
101    use rpc::proto::Plan;
102
103    use super::*;
104
105    #[gpui::test]
106    async fn test_authorize_access_to_language_model_with_supported_country(
107        _cx: &mut gpui::TestAppContext,
108    ) {
109        let config = Config::test();
110
111        let claims = LlmTokenClaims {
112            user_id: 99,
113            plan: Plan::ZedPro,
114            is_staff: true,
115            ..Default::default()
116        };
117
118        let cases = vec![
119            (LanguageModelProvider::Anthropic, "US"), // United States
120            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
121            (LanguageModelProvider::OpenAi, "US"),    // United States
122            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
123            (LanguageModelProvider::Google, "US"),    // United States
124            (LanguageModelProvider::Google, "GB"),    // United Kingdom
125        ];
126
127        for (provider, country_code) in cases {
128            authorize_access_to_language_model(
129                &config,
130                &claims,
131                Some(country_code),
132                provider,
133                "the-model",
134            )
135            .unwrap_or_else(|_| {
136                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
137            })
138        }
139    }
140
141    #[gpui::test]
142    async fn test_authorize_access_to_language_model_with_unsupported_country(
143        _cx: &mut gpui::TestAppContext,
144    ) {
145        let config = Config::test();
146
147        let claims = LlmTokenClaims {
148            user_id: 99,
149            plan: Plan::ZedPro,
150            ..Default::default()
151        };
152
153        let cases = vec![
154            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
155            (LanguageModelProvider::Anthropic, "BY"), // Belarus
156            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
157            (LanguageModelProvider::Anthropic, "CN"), // China
158            (LanguageModelProvider::Anthropic, "CU"), // Cuba
159            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
160            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
161            (LanguageModelProvider::Anthropic, "IR"), // Iran
162            (LanguageModelProvider::Anthropic, "KP"), // North Korea
163            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
164            (LanguageModelProvider::Anthropic, "LY"), // Libya
165            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
166            (LanguageModelProvider::Anthropic, "RU"), // Russia
167            (LanguageModelProvider::Anthropic, "SO"), // Somalia
168            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
169            (LanguageModelProvider::Anthropic, "SD"), // Sudan
170            (LanguageModelProvider::Anthropic, "SY"), // Syria
171            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
172            (LanguageModelProvider::Anthropic, "YE"), // Yemen
173            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
174            (LanguageModelProvider::Google, "KP"),    // North Korea
175        ];
176
177        for (provider, country_code) in cases {
178            let error_response = authorize_access_to_language_model(
179                &config,
180                &claims,
181                Some(country_code),
182                provider,
183                "the-model",
184            )
185            .expect_err(&format!(
186                "expected authorization to return an error for {provider:?}: {country_code}"
187            ))
188            .into_response();
189
190            assert_eq!(
191                error_response.status(),
192                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
193            );
194            let response_body = hyper::body::to_bytes(error_response.into_body())
195                .await
196                .unwrap()
197                .to_vec();
198            assert_eq!(
199                String::from_utf8(response_body).unwrap(),
200                format!("access to {provider:?} models is not available in your region ({country_code})")
201            );
202        }
203    }
204
205    #[gpui::test]
206    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
207        let config = Config::test();
208
209        let claims = LlmTokenClaims {
210            user_id: 99,
211            plan: Plan::ZedPro,
212            ..Default::default()
213        };
214
215        let cases = vec![
216            (LanguageModelProvider::Anthropic, "T1"), // Tor
217            (LanguageModelProvider::OpenAi, "T1"),    // Tor
218            (LanguageModelProvider::Google, "T1"),    // Tor
219            (LanguageModelProvider::Zed, "T1"),       // Tor
220        ];
221
222        for (provider, country_code) in cases {
223            let error_response = authorize_access_to_language_model(
224                &config,
225                &claims,
226                Some(country_code),
227                provider,
228                "the-model",
229            )
230            .expect_err(&format!(
231                "expected authorization to return an error for {provider:?}: {country_code}"
232            ))
233            .into_response();
234
235            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
236            let response_body = hyper::body::to_bytes(error_response.into_body())
237                .await
238                .unwrap()
239                .to_vec();
240            assert_eq!(
241                String::from_utf8(response_body).unwrap(),
242                format!("access to {provider:?} models is not available over Tor")
243            );
244        }
245    }
246
247    #[gpui::test]
248    async fn test_authorize_access_to_language_model_based_on_plan() {
249        let config = Config::test();
250
251        let test_cases = vec![
252            // Pro plan should have access to claude-3.5-sonnet
253            (
254                Plan::ZedPro,
255                LanguageModelProvider::Anthropic,
256                "claude-3-5-sonnet",
257                true,
258            ),
259            // Free plan should have access to claude-3.5-sonnet
260            (
261                Plan::Free,
262                LanguageModelProvider::Anthropic,
263                "claude-3-5-sonnet",
264                true,
265            ),
266            // Pro plan should NOT have access to other Anthropic models
267            (
268                Plan::ZedPro,
269                LanguageModelProvider::Anthropic,
270                "claude-3-opus",
271                false,
272            ),
273        ];
274
275        for (plan, provider, model, expected_access) in test_cases {
276            let claims = LlmTokenClaims {
277                plan,
278                ..Default::default()
279            };
280
281            let result =
282                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
283
284            if expected_access {
285                assert!(
286                    result.is_ok(),
287                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
288                    plan,
289                    provider,
290                    model
291                );
292            } else {
293                let error = result.expect_err(&format!(
294                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
295                    plan, provider, model
296                ));
297                let response = error.into_response();
298                assert_eq!(response.status(), StatusCode::FORBIDDEN);
299            }
300        }
301    }
302
303    #[gpui::test]
304    async fn test_authorize_access_to_language_model_for_staff() {
305        let config = Config::test();
306
307        let claims = LlmTokenClaims {
308            is_staff: true,
309            ..Default::default()
310        };
311
312        // Staff should have access to all models
313        let test_cases = vec![
314            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
315            (LanguageModelProvider::Anthropic, "claude-2"),
316            (LanguageModelProvider::Anthropic, "claude-123-agi"),
317            (LanguageModelProvider::OpenAi, "gpt-4"),
318            (LanguageModelProvider::Google, "gemini-pro"),
319        ];
320
321        for (provider, model) in test_cases {
322            let result =
323                authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
324
325            assert!(
326                result.is_ok(),
327                "Expected staff to have access to provider {:?}, model {}",
328                provider,
329                model
330            );
331        }
332    }
333}