1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<String>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(config, claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 config: &Config,
21 claims: &LlmTokenClaims,
22 provider: LanguageModelProvider,
23 model: &str,
24) -> Result<()> {
25 if claims.is_staff {
26 return Ok(());
27 }
28
29 match provider {
30 LanguageModelProvider::Anthropic => {
31 if model == "claude-3-5-sonnet" {
32 return Ok(());
33 }
34
35 if claims.has_llm_closed_beta_feature_flag
36 && Some(model) == config.llm_closed_beta_model_name.as_deref()
37 {
38 return Ok(());
39 }
40 }
41 _ => {}
42 }
43
44 Err(Error::http(
45 StatusCode::FORBIDDEN,
46 format!("access to model {model:?} is not included in your plan"),
47 ))
48}
49
50fn authorize_access_for_country(
51 config: &Config,
52 country_code: Option<String>,
53 provider: LanguageModelProvider,
54) -> Result<()> {
55 // In development we won't have the `CF-IPCountry` header, so we can't check
56 // the country code.
57 //
58 // This shouldn't be necessary, as anyone running in development will need to provide
59 // their own API credentials in order to use an LLM provider.
60 if config.is_development() {
61 return Ok(());
62 }
63
64 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
65 let country_code = match country_code.as_deref() {
66 // `XX` - Used for clients without country code data.
67 None | Some("XX") => Err(Error::http(
68 StatusCode::BAD_REQUEST,
69 "no country code".to_string(),
70 ))?,
71 // `T1` - Used for clients using the Tor network.
72 Some("T1") => Err(Error::http(
73 StatusCode::FORBIDDEN,
74 format!("access to {provider:?} models is not available over Tor"),
75 ))?,
76 Some(country_code) => country_code,
77 };
78
79 let is_country_supported_by_provider = match provider {
80 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
81 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
82 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
83 LanguageModelProvider::Zed => true,
84 };
85 if !is_country_supported_by_provider {
86 Err(Error::http(
87 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
88 format!(
89 "access to {provider:?} models is not available in your region ({country_code})"
90 ),
91 ))?
92 }
93
94 Ok(())
95}
96
97#[cfg(test)]
98mod tests {
99 use axum::response::IntoResponse;
100 use pretty_assertions::assert_eq;
101 use rpc::proto::Plan;
102
103 use super::*;
104
105 #[gpui::test]
106 async fn test_authorize_access_to_language_model_with_supported_country(
107 _cx: &mut gpui::TestAppContext,
108 ) {
109 let config = Config::test();
110
111 let claims = LlmTokenClaims {
112 user_id: 99,
113 plan: Plan::ZedPro,
114 is_staff: true,
115 ..Default::default()
116 };
117
118 let cases = vec![
119 (LanguageModelProvider::Anthropic, "US"), // United States
120 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
121 (LanguageModelProvider::OpenAi, "US"), // United States
122 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
123 (LanguageModelProvider::Google, "US"), // United States
124 (LanguageModelProvider::Google, "GB"), // United Kingdom
125 ];
126
127 for (provider, country_code) in cases {
128 authorize_access_to_language_model(
129 &config,
130 &claims,
131 Some(country_code.into()),
132 provider,
133 "the-model",
134 )
135 .unwrap_or_else(|_| {
136 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
137 })
138 }
139 }
140
141 #[gpui::test]
142 async fn test_authorize_access_to_language_model_with_unsupported_country(
143 _cx: &mut gpui::TestAppContext,
144 ) {
145 let config = Config::test();
146
147 let claims = LlmTokenClaims {
148 user_id: 99,
149 plan: Plan::ZedPro,
150 ..Default::default()
151 };
152
153 let cases = vec![
154 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
155 (LanguageModelProvider::Anthropic, "BY"), // Belarus
156 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
157 (LanguageModelProvider::Anthropic, "CN"), // China
158 (LanguageModelProvider::Anthropic, "CU"), // Cuba
159 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
160 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
161 (LanguageModelProvider::Anthropic, "IR"), // Iran
162 (LanguageModelProvider::Anthropic, "KP"), // North Korea
163 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
164 (LanguageModelProvider::Anthropic, "LY"), // Libya
165 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
166 (LanguageModelProvider::Anthropic, "RU"), // Russia
167 (LanguageModelProvider::Anthropic, "SO"), // Somalia
168 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
169 (LanguageModelProvider::Anthropic, "SD"), // Sudan
170 (LanguageModelProvider::Anthropic, "SY"), // Syria
171 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
172 (LanguageModelProvider::Anthropic, "YE"), // Yemen
173 (LanguageModelProvider::OpenAi, "KP"), // North Korea
174 (LanguageModelProvider::Google, "KP"), // North Korea
175 ];
176
177 for (provider, country_code) in cases {
178 let error_response = authorize_access_to_language_model(
179 &config,
180 &claims,
181 Some(country_code.into()),
182 provider,
183 "the-model",
184 )
185 .expect_err(&format!(
186 "expected authorization to return an error for {provider:?}: {country_code}"
187 ))
188 .into_response();
189
190 assert_eq!(
191 error_response.status(),
192 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
193 );
194 let response_body = hyper::body::to_bytes(error_response.into_body())
195 .await
196 .unwrap()
197 .to_vec();
198 assert_eq!(
199 String::from_utf8(response_body).unwrap(),
200 format!("access to {provider:?} models is not available in your region ({country_code})")
201 );
202 }
203 }
204
205 #[gpui::test]
206 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
207 let config = Config::test();
208
209 let claims = LlmTokenClaims {
210 user_id: 99,
211 plan: Plan::ZedPro,
212 ..Default::default()
213 };
214
215 let cases = vec![
216 (LanguageModelProvider::Anthropic, "T1"), // Tor
217 (LanguageModelProvider::OpenAi, "T1"), // Tor
218 (LanguageModelProvider::Google, "T1"), // Tor
219 (LanguageModelProvider::Zed, "T1"), // Tor
220 ];
221
222 for (provider, country_code) in cases {
223 let error_response = authorize_access_to_language_model(
224 &config,
225 &claims,
226 Some(country_code.into()),
227 provider,
228 "the-model",
229 )
230 .expect_err(&format!(
231 "expected authorization to return an error for {provider:?}: {country_code}"
232 ))
233 .into_response();
234
235 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
236 let response_body = hyper::body::to_bytes(error_response.into_body())
237 .await
238 .unwrap()
239 .to_vec();
240 assert_eq!(
241 String::from_utf8(response_body).unwrap(),
242 format!("access to {provider:?} models is not available over Tor")
243 );
244 }
245 }
246
247 #[gpui::test]
248 async fn test_authorize_access_to_language_model_based_on_plan() {
249 let config = Config::test();
250
251 let test_cases = vec![
252 // Pro plan should have access to claude-3.5-sonnet
253 (
254 Plan::ZedPro,
255 LanguageModelProvider::Anthropic,
256 "claude-3-5-sonnet",
257 true,
258 ),
259 // Free plan should have access to claude-3.5-sonnet
260 (
261 Plan::Free,
262 LanguageModelProvider::Anthropic,
263 "claude-3-5-sonnet",
264 true,
265 ),
266 // Pro plan should NOT have access to other Anthropic models
267 (
268 Plan::ZedPro,
269 LanguageModelProvider::Anthropic,
270 "claude-3-opus",
271 false,
272 ),
273 ];
274
275 for (plan, provider, model, expected_access) in test_cases {
276 let claims = LlmTokenClaims {
277 plan,
278 ..Default::default()
279 };
280
281 let result = authorize_access_to_language_model(
282 &config,
283 &claims,
284 Some("US".into()),
285 provider,
286 model,
287 );
288
289 if expected_access {
290 assert!(
291 result.is_ok(),
292 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
293 plan,
294 provider,
295 model
296 );
297 } else {
298 let error = result.expect_err(&format!(
299 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
300 plan, provider, model
301 ));
302 let response = error.into_response();
303 assert_eq!(response.status(), StatusCode::FORBIDDEN);
304 }
305 }
306 }
307
308 #[gpui::test]
309 async fn test_authorize_access_to_language_model_for_staff() {
310 let config = Config::test();
311
312 let claims = LlmTokenClaims {
313 is_staff: true,
314 ..Default::default()
315 };
316
317 // Staff should have access to all models
318 let test_cases = vec![
319 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
320 (LanguageModelProvider::Anthropic, "claude-2"),
321 (LanguageModelProvider::Anthropic, "claude-123-agi"),
322 (LanguageModelProvider::OpenAi, "gpt-4"),
323 (LanguageModelProvider::Google, "gemini-pro"),
324 ];
325
326 for (provider, model) in test_cases {
327 let result = authorize_access_to_language_model(
328 &config,
329 &claims,
330 Some("US".into()),
331 provider,
332 model,
333 );
334
335 assert!(
336 result.is_ok(),
337 "Expected staff to have access to provider {:?}, model {}",
338 provider,
339 model
340 );
341 }
342 }
343}