1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<String>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(config, claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 config: &Config,
21 claims: &LlmTokenClaims,
22 provider: LanguageModelProvider,
23 model: &str,
24) -> Result<()> {
25 if claims.is_staff {
26 return Ok(());
27 }
28
29 match provider {
30 LanguageModelProvider::Anthropic => {
31 if model == "claude-3-5-sonnet" {
32 return Ok(());
33 }
34
35 if claims.has_llm_closed_beta_feature_flag
36 && Some(model) == config.llm_closed_beta_model_name.as_deref()
37 {
38 return Ok(());
39 }
40 }
41 _ => {}
42 }
43
44 Err(Error::http(
45 StatusCode::FORBIDDEN,
46 format!("access to model {model:?} is not included in your plan"),
47 ))
48}
49
50fn authorize_access_for_country(
51 config: &Config,
52 country_code: Option<String>,
53 provider: LanguageModelProvider,
54) -> Result<()> {
55 // In development we won't have the `CF-IPCountry` header, so we can't check
56 // the country code.
57 //
58 // This shouldn't be necessary, as anyone running in development will need to provide
59 // their own API credentials in order to use an LLM provider.
60 if config.is_development() {
61 return Ok(());
62 }
63
64 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
65 let country_code = match country_code.as_deref() {
66 // `XX` - Used for clients without country code data.
67 None | Some("XX") => Err(Error::http(
68 StatusCode::BAD_REQUEST,
69 "no country code".to_string(),
70 ))?,
71 // `T1` - Used for clients using the Tor network.
72 Some("T1") => Err(Error::http(
73 StatusCode::FORBIDDEN,
74 format!("access to {provider:?} models is not available over Tor"),
75 ))?,
76 Some(country_code) => country_code,
77 };
78
79 let is_country_supported_by_provider = match provider {
80 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
81 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
82 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
83 LanguageModelProvider::Zed => true,
84 };
85 if !is_country_supported_by_provider {
86 Err(Error::http(
87 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
88 format!("access to {provider:?} models is not available in your region"),
89 ))?
90 }
91
92 Ok(())
93}
94
95#[cfg(test)]
96mod tests {
97 use axum::response::IntoResponse;
98 use pretty_assertions::assert_eq;
99 use rpc::proto::Plan;
100
101 use super::*;
102
103 #[gpui::test]
104 async fn test_authorize_access_to_language_model_with_supported_country(
105 _cx: &mut gpui::TestAppContext,
106 ) {
107 let config = Config::test();
108
109 let claims = LlmTokenClaims {
110 user_id: 99,
111 plan: Plan::ZedPro,
112 is_staff: true,
113 ..Default::default()
114 };
115
116 let cases = vec![
117 (LanguageModelProvider::Anthropic, "US"), // United States
118 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
119 (LanguageModelProvider::OpenAi, "US"), // United States
120 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
121 (LanguageModelProvider::Google, "US"), // United States
122 (LanguageModelProvider::Google, "GB"), // United Kingdom
123 ];
124
125 for (provider, country_code) in cases {
126 authorize_access_to_language_model(
127 &config,
128 &claims,
129 Some(country_code.into()),
130 provider,
131 "the-model",
132 )
133 .unwrap_or_else(|_| {
134 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
135 })
136 }
137 }
138
139 #[gpui::test]
140 async fn test_authorize_access_to_language_model_with_unsupported_country(
141 _cx: &mut gpui::TestAppContext,
142 ) {
143 let config = Config::test();
144
145 let claims = LlmTokenClaims {
146 user_id: 99,
147 plan: Plan::ZedPro,
148 ..Default::default()
149 };
150
151 let cases = vec![
152 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
153 (LanguageModelProvider::Anthropic, "BY"), // Belarus
154 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
155 (LanguageModelProvider::Anthropic, "CN"), // China
156 (LanguageModelProvider::Anthropic, "CU"), // Cuba
157 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
158 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
159 (LanguageModelProvider::Anthropic, "IR"), // Iran
160 (LanguageModelProvider::Anthropic, "KP"), // North Korea
161 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
162 (LanguageModelProvider::Anthropic, "LY"), // Libya
163 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
164 (LanguageModelProvider::Anthropic, "RU"), // Russia
165 (LanguageModelProvider::Anthropic, "SO"), // Somalia
166 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
167 (LanguageModelProvider::Anthropic, "SD"), // Sudan
168 (LanguageModelProvider::Anthropic, "SY"), // Syria
169 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
170 (LanguageModelProvider::Anthropic, "YE"), // Yemen
171 (LanguageModelProvider::OpenAi, "KP"), // North Korea
172 (LanguageModelProvider::Google, "KP"), // North Korea
173 ];
174
175 for (provider, country_code) in cases {
176 let error_response = authorize_access_to_language_model(
177 &config,
178 &claims,
179 Some(country_code.into()),
180 provider,
181 "the-model",
182 )
183 .expect_err(&format!(
184 "expected authorization to return an error for {provider:?}: {country_code}"
185 ))
186 .into_response();
187
188 assert_eq!(
189 error_response.status(),
190 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
191 );
192 let response_body = hyper::body::to_bytes(error_response.into_body())
193 .await
194 .unwrap()
195 .to_vec();
196 assert_eq!(
197 String::from_utf8(response_body).unwrap(),
198 format!("access to {provider:?} models is not available in your region")
199 );
200 }
201 }
202
203 #[gpui::test]
204 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
205 let config = Config::test();
206
207 let claims = LlmTokenClaims {
208 user_id: 99,
209 plan: Plan::ZedPro,
210 ..Default::default()
211 };
212
213 let cases = vec![
214 (LanguageModelProvider::Anthropic, "T1"), // Tor
215 (LanguageModelProvider::OpenAi, "T1"), // Tor
216 (LanguageModelProvider::Google, "T1"), // Tor
217 (LanguageModelProvider::Zed, "T1"), // Tor
218 ];
219
220 for (provider, country_code) in cases {
221 let error_response = authorize_access_to_language_model(
222 &config,
223 &claims,
224 Some(country_code.into()),
225 provider,
226 "the-model",
227 )
228 .expect_err(&format!(
229 "expected authorization to return an error for {provider:?}: {country_code}"
230 ))
231 .into_response();
232
233 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
234 let response_body = hyper::body::to_bytes(error_response.into_body())
235 .await
236 .unwrap()
237 .to_vec();
238 assert_eq!(
239 String::from_utf8(response_body).unwrap(),
240 format!("access to {provider:?} models is not available over Tor")
241 );
242 }
243 }
244
245 #[gpui::test]
246 async fn test_authorize_access_to_language_model_based_on_plan() {
247 let config = Config::test();
248
249 let test_cases = vec![
250 // Pro plan should have access to claude-3.5-sonnet
251 (
252 Plan::ZedPro,
253 LanguageModelProvider::Anthropic,
254 "claude-3-5-sonnet",
255 true,
256 ),
257 // Free plan should have access to claude-3.5-sonnet
258 (
259 Plan::Free,
260 LanguageModelProvider::Anthropic,
261 "claude-3-5-sonnet",
262 true,
263 ),
264 // Pro plan should NOT have access to other Anthropic models
265 (
266 Plan::ZedPro,
267 LanguageModelProvider::Anthropic,
268 "claude-3-opus",
269 false,
270 ),
271 ];
272
273 for (plan, provider, model, expected_access) in test_cases {
274 let claims = LlmTokenClaims {
275 plan,
276 ..Default::default()
277 };
278
279 let result = authorize_access_to_language_model(
280 &config,
281 &claims,
282 Some("US".into()),
283 provider,
284 model,
285 );
286
287 if expected_access {
288 assert!(
289 result.is_ok(),
290 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
291 plan,
292 provider,
293 model
294 );
295 } else {
296 let error = result.expect_err(&format!(
297 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
298 plan, provider, model
299 ));
300 let response = error.into_response();
301 assert_eq!(response.status(), StatusCode::FORBIDDEN);
302 }
303 }
304 }
305
306 #[gpui::test]
307 async fn test_authorize_access_to_language_model_for_staff() {
308 let config = Config::test();
309
310 let claims = LlmTokenClaims {
311 is_staff: true,
312 ..Default::default()
313 };
314
315 // Staff should have access to all models
316 let test_cases = vec![
317 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
318 (LanguageModelProvider::Anthropic, "claude-2"),
319 (LanguageModelProvider::Anthropic, "claude-123-agi"),
320 (LanguageModelProvider::OpenAi, "gpt-4"),
321 (LanguageModelProvider::Google, "gemini-pro"),
322 ];
323
324 for (provider, model) in test_cases {
325 let result = authorize_access_to_language_model(
326 &config,
327 &claims,
328 Some("US".into()),
329 provider,
330 model,
331 );
332
333 assert!(
334 result.is_ok(),
335 "Expected staff to have access to provider {:?}, model {}",
336 provider,
337 model
338 );
339 }
340 }
341}