1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<String>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 claims: &LlmTokenClaims,
21 provider: LanguageModelProvider,
22 model: &str,
23) -> Result<()> {
24 if claims.is_staff {
25 return Ok(());
26 }
27
28 match (provider, model) {
29 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
30 _ => Err(Error::http(
31 StatusCode::FORBIDDEN,
32 format!("access to model {model:?} is not included in your plan"),
33 ))?,
34 }
35}
36
37fn authorize_access_for_country(
38 config: &Config,
39 country_code: Option<String>,
40 provider: LanguageModelProvider,
41) -> Result<()> {
42 // In development we won't have the `CF-IPCountry` header, so we can't check
43 // the country code.
44 //
45 // This shouldn't be necessary, as anyone running in development will need to provide
46 // their own API credentials in order to use an LLM provider.
47 if config.is_development() {
48 return Ok(());
49 }
50
51 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
52 let country_code = match country_code.as_deref() {
53 // `XX` - Used for clients without country code data.
54 None | Some("XX") => Err(Error::http(
55 StatusCode::BAD_REQUEST,
56 "no country code".to_string(),
57 ))?,
58 // `T1` - Used for clients using the Tor network.
59 Some("T1") => Err(Error::http(
60 StatusCode::FORBIDDEN,
61 format!("access to {provider:?} models is not available over Tor"),
62 ))?,
63 Some(country_code) => country_code,
64 };
65
66 let is_country_supported_by_provider = match provider {
67 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
68 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
69 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
70 LanguageModelProvider::Zed => true,
71 };
72 if !is_country_supported_by_provider {
73 Err(Error::http(
74 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
75 format!("access to {provider:?} models is not available in your region"),
76 ))?
77 }
78
79 Ok(())
80}
81
82#[cfg(test)]
83mod tests {
84 use axum::response::IntoResponse;
85 use pretty_assertions::assert_eq;
86 use rpc::proto::Plan;
87
88 use super::*;
89
90 #[gpui::test]
91 async fn test_authorize_access_to_language_model_with_supported_country(
92 _cx: &mut gpui::TestAppContext,
93 ) {
94 let config = Config::test();
95
96 let claims = LlmTokenClaims {
97 user_id: 99,
98 plan: Plan::ZedPro,
99 is_staff: true,
100 ..Default::default()
101 };
102
103 let cases = vec![
104 (LanguageModelProvider::Anthropic, "US"), // United States
105 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
106 (LanguageModelProvider::OpenAi, "US"), // United States
107 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
108 (LanguageModelProvider::Google, "US"), // United States
109 (LanguageModelProvider::Google, "GB"), // United Kingdom
110 ];
111
112 for (provider, country_code) in cases {
113 authorize_access_to_language_model(
114 &config,
115 &claims,
116 Some(country_code.into()),
117 provider,
118 "the-model",
119 )
120 .unwrap_or_else(|_| {
121 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
122 })
123 }
124 }
125
126 #[gpui::test]
127 async fn test_authorize_access_to_language_model_with_unsupported_country(
128 _cx: &mut gpui::TestAppContext,
129 ) {
130 let config = Config::test();
131
132 let claims = LlmTokenClaims {
133 user_id: 99,
134 plan: Plan::ZedPro,
135 ..Default::default()
136 };
137
138 let cases = vec![
139 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
140 (LanguageModelProvider::Anthropic, "BY"), // Belarus
141 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
142 (LanguageModelProvider::Anthropic, "CN"), // China
143 (LanguageModelProvider::Anthropic, "CU"), // Cuba
144 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
145 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
146 (LanguageModelProvider::Anthropic, "IR"), // Iran
147 (LanguageModelProvider::Anthropic, "KP"), // North Korea
148 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
149 (LanguageModelProvider::Anthropic, "LY"), // Libya
150 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
151 (LanguageModelProvider::Anthropic, "RU"), // Russia
152 (LanguageModelProvider::Anthropic, "SO"), // Somalia
153 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
154 (LanguageModelProvider::Anthropic, "SD"), // Sudan
155 (LanguageModelProvider::Anthropic, "SY"), // Syria
156 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
157 (LanguageModelProvider::Anthropic, "YE"), // Yemen
158 (LanguageModelProvider::OpenAi, "KP"), // North Korea
159 (LanguageModelProvider::Google, "KP"), // North Korea
160 ];
161
162 for (provider, country_code) in cases {
163 let error_response = authorize_access_to_language_model(
164 &config,
165 &claims,
166 Some(country_code.into()),
167 provider,
168 "the-model",
169 )
170 .expect_err(&format!(
171 "expected authorization to return an error for {provider:?}: {country_code}"
172 ))
173 .into_response();
174
175 assert_eq!(
176 error_response.status(),
177 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
178 );
179 let response_body = hyper::body::to_bytes(error_response.into_body())
180 .await
181 .unwrap()
182 .to_vec();
183 assert_eq!(
184 String::from_utf8(response_body).unwrap(),
185 format!("access to {provider:?} models is not available in your region")
186 );
187 }
188 }
189
190 #[gpui::test]
191 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
192 let config = Config::test();
193
194 let claims = LlmTokenClaims {
195 user_id: 99,
196 plan: Plan::ZedPro,
197 ..Default::default()
198 };
199
200 let cases = vec![
201 (LanguageModelProvider::Anthropic, "T1"), // Tor
202 (LanguageModelProvider::OpenAi, "T1"), // Tor
203 (LanguageModelProvider::Google, "T1"), // Tor
204 (LanguageModelProvider::Zed, "T1"), // Tor
205 ];
206
207 for (provider, country_code) in cases {
208 let error_response = authorize_access_to_language_model(
209 &config,
210 &claims,
211 Some(country_code.into()),
212 provider,
213 "the-model",
214 )
215 .expect_err(&format!(
216 "expected authorization to return an error for {provider:?}: {country_code}"
217 ))
218 .into_response();
219
220 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
221 let response_body = hyper::body::to_bytes(error_response.into_body())
222 .await
223 .unwrap()
224 .to_vec();
225 assert_eq!(
226 String::from_utf8(response_body).unwrap(),
227 format!("access to {provider:?} models is not available over Tor")
228 );
229 }
230 }
231
232 #[gpui::test]
233 async fn test_authorize_access_to_language_model_based_on_plan() {
234 let config = Config::test();
235
236 let test_cases = vec![
237 // Pro plan should have access to claude-3.5-sonnet
238 (
239 Plan::ZedPro,
240 LanguageModelProvider::Anthropic,
241 "claude-3-5-sonnet",
242 true,
243 ),
244 // Free plan should have access to claude-3.5-sonnet
245 (
246 Plan::Free,
247 LanguageModelProvider::Anthropic,
248 "claude-3-5-sonnet",
249 true,
250 ),
251 // Pro plan should NOT have access to other Anthropic models
252 (
253 Plan::ZedPro,
254 LanguageModelProvider::Anthropic,
255 "claude-3-opus",
256 false,
257 ),
258 ];
259
260 for (plan, provider, model, expected_access) in test_cases {
261 let claims = LlmTokenClaims {
262 plan,
263 ..Default::default()
264 };
265
266 let result = authorize_access_to_language_model(
267 &config,
268 &claims,
269 Some("US".into()),
270 provider,
271 model,
272 );
273
274 if expected_access {
275 assert!(
276 result.is_ok(),
277 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
278 plan,
279 provider,
280 model
281 );
282 } else {
283 let error = result.expect_err(&format!(
284 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
285 plan, provider, model
286 ));
287 let response = error.into_response();
288 assert_eq!(response.status(), StatusCode::FORBIDDEN);
289 }
290 }
291 }
292
293 #[gpui::test]
294 async fn test_authorize_access_to_language_model_for_staff() {
295 let config = Config::test();
296
297 let claims = LlmTokenClaims {
298 is_staff: true,
299 ..Default::default()
300 };
301
302 // Staff should have access to all models
303 let test_cases = vec![
304 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
305 (LanguageModelProvider::Anthropic, "claude-2"),
306 (LanguageModelProvider::Anthropic, "claude-123-agi"),
307 (LanguageModelProvider::OpenAi, "gpt-4"),
308 (LanguageModelProvider::Google, "gemini-pro"),
309 ];
310
311 for (provider, model) in test_cases {
312 let result = authorize_access_to_language_model(
313 &config,
314 &claims,
315 Some("US".into()),
316 provider,
317 model,
318 );
319
320 assert!(
321 result.is_ok(),
322 "Expected staff to have access to provider {:?}, model {}",
323 provider,
324 model
325 );
326 }
327 }
328}