authorization.rs

  1use reqwest::StatusCode;
  2use rpc::LanguageModelProvider;
  3
  4use crate::llm::LlmTokenClaims;
  5use crate::{Config, Error, Result};
  6
  7pub fn authorize_access_to_language_model(
  8    config: &Config,
  9    claims: &LlmTokenClaims,
 10    country_code: Option<String>,
 11    provider: LanguageModelProvider,
 12    model: &str,
 13) -> Result<()> {
 14    authorize_access_for_country(config, country_code, provider)?;
 15    authorize_access_to_model(config, claims, provider, model)?;
 16    Ok(())
 17}
 18
 19fn authorize_access_to_model(
 20    config: &Config,
 21    claims: &LlmTokenClaims,
 22    provider: LanguageModelProvider,
 23    model: &str,
 24) -> Result<()> {
 25    if claims.is_staff {
 26        return Ok(());
 27    }
 28
 29    match provider {
 30        LanguageModelProvider::Anthropic => {
 31            if model == "claude-3-5-sonnet" {
 32                return Ok(());
 33            }
 34
 35            if claims.has_llm_closed_beta_feature_flag
 36                && Some(model) == config.llm_closed_beta_model_name.as_deref()
 37            {
 38                return Ok(());
 39            }
 40        }
 41        _ => {}
 42    }
 43
 44    Err(Error::http(
 45        StatusCode::FORBIDDEN,
 46        format!("access to model {model:?} is not included in your plan"),
 47    ))
 48}
 49
 50fn authorize_access_for_country(
 51    config: &Config,
 52    country_code: Option<String>,
 53    provider: LanguageModelProvider,
 54) -> Result<()> {
 55    // In development we won't have the `CF-IPCountry` header, so we can't check
 56    // the country code.
 57    //
 58    // This shouldn't be necessary, as anyone running in development will need to provide
 59    // their own API credentials in order to use an LLM provider.
 60    if config.is_development() {
 61        return Ok(());
 62    }
 63
 64    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
 65    let country_code = match country_code.as_deref() {
 66        // `XX` - Used for clients without country code data.
 67        None | Some("XX") => Err(Error::http(
 68            StatusCode::BAD_REQUEST,
 69            "no country code".to_string(),
 70        ))?,
 71        // `T1` - Used for clients using the Tor network.
 72        Some("T1") => Err(Error::http(
 73            StatusCode::FORBIDDEN,
 74            format!("access to {provider:?} models is not available over Tor"),
 75        ))?,
 76        Some(country_code) => country_code,
 77    };
 78
 79    let is_country_supported_by_provider = match provider {
 80        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
 81        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
 82        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
 83        LanguageModelProvider::Zed => true,
 84    };
 85    if !is_country_supported_by_provider {
 86        Err(Error::http(
 87            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
 88            format!(
 89                "access to {provider:?} models is not available in your region ({country_code})"
 90            ),
 91        ))?
 92    }
 93
 94    Ok(())
 95}
 96
 97#[cfg(test)]
 98mod tests {
 99    use axum::response::IntoResponse;
100    use pretty_assertions::assert_eq;
101    use rpc::proto::Plan;
102
103    use super::*;
104
105    #[gpui::test]
106    async fn test_authorize_access_to_language_model_with_supported_country(
107        _cx: &mut gpui::TestAppContext,
108    ) {
109        let config = Config::test();
110
111        let claims = LlmTokenClaims {
112            user_id: 99,
113            plan: Plan::ZedPro,
114            is_staff: true,
115            ..Default::default()
116        };
117
118        let cases = vec![
119            (LanguageModelProvider::Anthropic, "US"), // United States
120            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
121            (LanguageModelProvider::OpenAi, "US"),    // United States
122            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
123            (LanguageModelProvider::Google, "US"),    // United States
124            (LanguageModelProvider::Google, "GB"),    // United Kingdom
125        ];
126
127        for (provider, country_code) in cases {
128            authorize_access_to_language_model(
129                &config,
130                &claims,
131                Some(country_code.into()),
132                provider,
133                "the-model",
134            )
135            .unwrap_or_else(|_| {
136                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
137            })
138        }
139    }
140
141    #[gpui::test]
142    async fn test_authorize_access_to_language_model_with_unsupported_country(
143        _cx: &mut gpui::TestAppContext,
144    ) {
145        let config = Config::test();
146
147        let claims = LlmTokenClaims {
148            user_id: 99,
149            plan: Plan::ZedPro,
150            ..Default::default()
151        };
152
153        let cases = vec![
154            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
155            (LanguageModelProvider::Anthropic, "BY"), // Belarus
156            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
157            (LanguageModelProvider::Anthropic, "CN"), // China
158            (LanguageModelProvider::Anthropic, "CU"), // Cuba
159            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
160            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
161            (LanguageModelProvider::Anthropic, "IR"), // Iran
162            (LanguageModelProvider::Anthropic, "KP"), // North Korea
163            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
164            (LanguageModelProvider::Anthropic, "LY"), // Libya
165            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
166            (LanguageModelProvider::Anthropic, "RU"), // Russia
167            (LanguageModelProvider::Anthropic, "SO"), // Somalia
168            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
169            (LanguageModelProvider::Anthropic, "SD"), // Sudan
170            (LanguageModelProvider::Anthropic, "SY"), // Syria
171            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
172            (LanguageModelProvider::Anthropic, "YE"), // Yemen
173            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
174            (LanguageModelProvider::Google, "KP"),    // North Korea
175        ];
176
177        for (provider, country_code) in cases {
178            let error_response = authorize_access_to_language_model(
179                &config,
180                &claims,
181                Some(country_code.into()),
182                provider,
183                "the-model",
184            )
185            .expect_err(&format!(
186                "expected authorization to return an error for {provider:?}: {country_code}"
187            ))
188            .into_response();
189
190            assert_eq!(
191                error_response.status(),
192                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
193            );
194            let response_body = hyper::body::to_bytes(error_response.into_body())
195                .await
196                .unwrap()
197                .to_vec();
198            assert_eq!(
199                String::from_utf8(response_body).unwrap(),
200                format!("access to {provider:?} models is not available in your region ({country_code})")
201            );
202        }
203    }
204
205    #[gpui::test]
206    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
207        let config = Config::test();
208
209        let claims = LlmTokenClaims {
210            user_id: 99,
211            plan: Plan::ZedPro,
212            ..Default::default()
213        };
214
215        let cases = vec![
216            (LanguageModelProvider::Anthropic, "T1"), // Tor
217            (LanguageModelProvider::OpenAi, "T1"),    // Tor
218            (LanguageModelProvider::Google, "T1"),    // Tor
219            (LanguageModelProvider::Zed, "T1"),       // Tor
220        ];
221
222        for (provider, country_code) in cases {
223            let error_response = authorize_access_to_language_model(
224                &config,
225                &claims,
226                Some(country_code.into()),
227                provider,
228                "the-model",
229            )
230            .expect_err(&format!(
231                "expected authorization to return an error for {provider:?}: {country_code}"
232            ))
233            .into_response();
234
235            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
236            let response_body = hyper::body::to_bytes(error_response.into_body())
237                .await
238                .unwrap()
239                .to_vec();
240            assert_eq!(
241                String::from_utf8(response_body).unwrap(),
242                format!("access to {provider:?} models is not available over Tor")
243            );
244        }
245    }
246
247    #[gpui::test]
248    async fn test_authorize_access_to_language_model_based_on_plan() {
249        let config = Config::test();
250
251        let test_cases = vec![
252            // Pro plan should have access to claude-3.5-sonnet
253            (
254                Plan::ZedPro,
255                LanguageModelProvider::Anthropic,
256                "claude-3-5-sonnet",
257                true,
258            ),
259            // Free plan should have access to claude-3.5-sonnet
260            (
261                Plan::Free,
262                LanguageModelProvider::Anthropic,
263                "claude-3-5-sonnet",
264                true,
265            ),
266            // Pro plan should NOT have access to other Anthropic models
267            (
268                Plan::ZedPro,
269                LanguageModelProvider::Anthropic,
270                "claude-3-opus",
271                false,
272            ),
273        ];
274
275        for (plan, provider, model, expected_access) in test_cases {
276            let claims = LlmTokenClaims {
277                plan,
278                ..Default::default()
279            };
280
281            let result = authorize_access_to_language_model(
282                &config,
283                &claims,
284                Some("US".into()),
285                provider,
286                model,
287            );
288
289            if expected_access {
290                assert!(
291                    result.is_ok(),
292                    "Expected access to be granted for plan {:?}, provider {:?}, model {}",
293                    plan,
294                    provider,
295                    model
296                );
297            } else {
298                let error = result.expect_err(&format!(
299                    "Expected access to be denied for plan {:?}, provider {:?}, model {}",
300                    plan, provider, model
301                ));
302                let response = error.into_response();
303                assert_eq!(response.status(), StatusCode::FORBIDDEN);
304            }
305        }
306    }
307
308    #[gpui::test]
309    async fn test_authorize_access_to_language_model_for_staff() {
310        let config = Config::test();
311
312        let claims = LlmTokenClaims {
313            is_staff: true,
314            ..Default::default()
315        };
316
317        // Staff should have access to all models
318        let test_cases = vec![
319            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
320            (LanguageModelProvider::Anthropic, "claude-2"),
321            (LanguageModelProvider::Anthropic, "claude-123-agi"),
322            (LanguageModelProvider::OpenAi, "gpt-4"),
323            (LanguageModelProvider::Google, "gemini-pro"),
324        ];
325
326        for (provider, model) in test_cases {
327            let result = authorize_access_to_language_model(
328                &config,
329                &claims,
330                Some("US".into()),
331                provider,
332                model,
333            );
334
335            assert!(
336                result.is_ok(),
337                "Expected staff to have access to provider {:?}, model {}",
338                provider,
339                model
340            );
341        }
342    }
343}