1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<&str>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(config, claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 config: &Config,
21 claims: &LlmTokenClaims,
22 provider: LanguageModelProvider,
23 model: &str,
24) -> Result<()> {
25 if claims.is_staff {
26 return Ok(());
27 }
28
29 match provider {
30 LanguageModelProvider::Anthropic => {
31 if model == "claude-3-5-sonnet" {
32 return Ok(());
33 }
34
35 if claims.has_llm_closed_beta_feature_flag
36 && Some(model) == config.llm_closed_beta_model_name.as_deref()
37 {
38 return Ok(());
39 }
40 }
41 _ => {}
42 }
43
44 Err(Error::http(
45 StatusCode::FORBIDDEN,
46 format!("access to model {model:?} is not included in your plan"),
47 ))
48}
49
50fn authorize_access_for_country(
51 config: &Config,
52 country_code: Option<&str>,
53 provider: LanguageModelProvider,
54) -> Result<()> {
55 // In development we won't have the `CF-IPCountry` header, so we can't check
56 // the country code.
57 //
58 // This shouldn't be necessary, as anyone running in development will need to provide
59 // their own API credentials in order to use an LLM provider.
60 if config.is_development() {
61 return Ok(());
62 }
63
64 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
65 let country_code = match country_code {
66 // `XX` - Used for clients without country code data.
67 None | Some("XX") => Err(Error::http(
68 StatusCode::BAD_REQUEST,
69 "no country code".to_string(),
70 ))?,
71 // `T1` - Used for clients using the Tor network.
72 Some("T1") => Err(Error::http(
73 StatusCode::FORBIDDEN,
74 format!("access to {provider:?} models is not available over Tor"),
75 ))?,
76 Some(country_code) => country_code,
77 };
78
79 let is_country_supported_by_provider = match provider {
80 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
81 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
82 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
83 LanguageModelProvider::Zed => true,
84 };
85 if !is_country_supported_by_provider {
86 Err(Error::http(
87 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
88 format!(
89 "access to {provider:?} models is not available in your region ({country_code})"
90 ),
91 ))?
92 }
93
94 Ok(())
95}
96
97#[cfg(test)]
98mod tests {
99 use axum::response::IntoResponse;
100 use pretty_assertions::assert_eq;
101 use rpc::proto::Plan;
102
103 use super::*;
104
105 #[gpui::test]
106 async fn test_authorize_access_to_language_model_with_supported_country(
107 _cx: &mut gpui::TestAppContext,
108 ) {
109 let config = Config::test();
110
111 let claims = LlmTokenClaims {
112 user_id: 99,
113 plan: Plan::ZedPro,
114 is_staff: true,
115 ..Default::default()
116 };
117
118 let cases = vec![
119 (LanguageModelProvider::Anthropic, "US"), // United States
120 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
121 (LanguageModelProvider::OpenAi, "US"), // United States
122 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
123 (LanguageModelProvider::Google, "US"), // United States
124 (LanguageModelProvider::Google, "GB"), // United Kingdom
125 ];
126
127 for (provider, country_code) in cases {
128 authorize_access_to_language_model(
129 &config,
130 &claims,
131 Some(country_code),
132 provider,
133 "the-model",
134 )
135 .unwrap_or_else(|_| {
136 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
137 })
138 }
139 }
140
141 #[gpui::test]
142 async fn test_authorize_access_to_language_model_with_unsupported_country(
143 _cx: &mut gpui::TestAppContext,
144 ) {
145 let config = Config::test();
146
147 let claims = LlmTokenClaims {
148 user_id: 99,
149 plan: Plan::ZedPro,
150 ..Default::default()
151 };
152
153 let cases = vec![
154 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
155 (LanguageModelProvider::Anthropic, "BY"), // Belarus
156 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
157 (LanguageModelProvider::Anthropic, "CN"), // China
158 (LanguageModelProvider::Anthropic, "CU"), // Cuba
159 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
160 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
161 (LanguageModelProvider::Anthropic, "IR"), // Iran
162 (LanguageModelProvider::Anthropic, "KP"), // North Korea
163 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
164 (LanguageModelProvider::Anthropic, "LY"), // Libya
165 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
166 (LanguageModelProvider::Anthropic, "RU"), // Russia
167 (LanguageModelProvider::Anthropic, "SO"), // Somalia
168 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
169 (LanguageModelProvider::Anthropic, "SD"), // Sudan
170 (LanguageModelProvider::Anthropic, "SY"), // Syria
171 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
172 (LanguageModelProvider::Anthropic, "YE"), // Yemen
173 (LanguageModelProvider::OpenAi, "KP"), // North Korea
174 (LanguageModelProvider::Google, "KP"), // North Korea
175 ];
176
177 for (provider, country_code) in cases {
178 let error_response = authorize_access_to_language_model(
179 &config,
180 &claims,
181 Some(country_code),
182 provider,
183 "the-model",
184 )
185 .expect_err(&format!(
186 "expected authorization to return an error for {provider:?}: {country_code}"
187 ))
188 .into_response();
189
190 assert_eq!(
191 error_response.status(),
192 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
193 );
194 let response_body = hyper::body::to_bytes(error_response.into_body())
195 .await
196 .unwrap()
197 .to_vec();
198 assert_eq!(
199 String::from_utf8(response_body).unwrap(),
200 format!("access to {provider:?} models is not available in your region ({country_code})")
201 );
202 }
203 }
204
205 #[gpui::test]
206 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
207 let config = Config::test();
208
209 let claims = LlmTokenClaims {
210 user_id: 99,
211 plan: Plan::ZedPro,
212 ..Default::default()
213 };
214
215 let cases = vec![
216 (LanguageModelProvider::Anthropic, "T1"), // Tor
217 (LanguageModelProvider::OpenAi, "T1"), // Tor
218 (LanguageModelProvider::Google, "T1"), // Tor
219 (LanguageModelProvider::Zed, "T1"), // Tor
220 ];
221
222 for (provider, country_code) in cases {
223 let error_response = authorize_access_to_language_model(
224 &config,
225 &claims,
226 Some(country_code),
227 provider,
228 "the-model",
229 )
230 .expect_err(&format!(
231 "expected authorization to return an error for {provider:?}: {country_code}"
232 ))
233 .into_response();
234
235 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
236 let response_body = hyper::body::to_bytes(error_response.into_body())
237 .await
238 .unwrap()
239 .to_vec();
240 assert_eq!(
241 String::from_utf8(response_body).unwrap(),
242 format!("access to {provider:?} models is not available over Tor")
243 );
244 }
245 }
246
247 #[gpui::test]
248 async fn test_authorize_access_to_language_model_based_on_plan() {
249 let config = Config::test();
250
251 let test_cases = vec![
252 // Pro plan should have access to claude-3.5-sonnet
253 (
254 Plan::ZedPro,
255 LanguageModelProvider::Anthropic,
256 "claude-3-5-sonnet",
257 true,
258 ),
259 // Free plan should have access to claude-3.5-sonnet
260 (
261 Plan::Free,
262 LanguageModelProvider::Anthropic,
263 "claude-3-5-sonnet",
264 true,
265 ),
266 // Pro plan should NOT have access to other Anthropic models
267 (
268 Plan::ZedPro,
269 LanguageModelProvider::Anthropic,
270 "claude-3-opus",
271 false,
272 ),
273 ];
274
275 for (plan, provider, model, expected_access) in test_cases {
276 let claims = LlmTokenClaims {
277 plan,
278 ..Default::default()
279 };
280
281 let result =
282 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
283
284 if expected_access {
285 assert!(
286 result.is_ok(),
287 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
288 plan,
289 provider,
290 model
291 );
292 } else {
293 let error = result.expect_err(&format!(
294 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
295 plan, provider, model
296 ));
297 let response = error.into_response();
298 assert_eq!(response.status(), StatusCode::FORBIDDEN);
299 }
300 }
301 }
302
303 #[gpui::test]
304 async fn test_authorize_access_to_language_model_for_staff() {
305 let config = Config::test();
306
307 let claims = LlmTokenClaims {
308 is_staff: true,
309 ..Default::default()
310 };
311
312 // Staff should have access to all models
313 let test_cases = vec![
314 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
315 (LanguageModelProvider::Anthropic, "claude-2"),
316 (LanguageModelProvider::Anthropic, "claude-123-agi"),
317 (LanguageModelProvider::OpenAi, "gpt-4"),
318 (LanguageModelProvider::Google, "gemini-pro"),
319 ];
320
321 for (provider, model) in test_cases {
322 let result =
323 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
324
325 assert!(
326 result.is_ok(),
327 "Expected staff to have access to provider {:?}, model {}",
328 provider,
329 model
330 );
331 }
332 }
333}