1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<String>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 claims: &LlmTokenClaims,
21 provider: LanguageModelProvider,
22 model: &str,
23) -> Result<()> {
24 if claims.is_staff {
25 return Ok(());
26 }
27
28 match (provider, model) {
29 (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
30 Ok(())
31 }
32 _ => Err(Error::http(
33 StatusCode::FORBIDDEN,
34 format!("access to model {model:?} is not included in your plan"),
35 ))?,
36 }
37}
38
39fn authorize_access_for_country(
40 config: &Config,
41 country_code: Option<String>,
42 provider: LanguageModelProvider,
43) -> Result<()> {
44 // In development we won't have the `CF-IPCountry` header, so we can't check
45 // the country code.
46 //
47 // This shouldn't be necessary, as anyone running in development will need to provide
48 // their own API credentials in order to use an LLM provider.
49 if config.is_development() {
50 return Ok(());
51 }
52
53 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
54 let country_code = match country_code.as_deref() {
55 // `XX` - Used for clients without country code data.
56 None | Some("XX") => Err(Error::http(
57 StatusCode::BAD_REQUEST,
58 "no country code".to_string(),
59 ))?,
60 // `T1` - Used for clients using the Tor network.
61 Some("T1") => Err(Error::http(
62 StatusCode::FORBIDDEN,
63 format!("access to {provider:?} models is not available over Tor"),
64 ))?,
65 Some(country_code) => country_code,
66 };
67
68 let is_country_supported_by_provider = match provider {
69 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
70 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
71 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
72 LanguageModelProvider::Zed => true,
73 };
74 if !is_country_supported_by_provider {
75 Err(Error::http(
76 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
77 format!("access to {provider:?} models is not available in your region"),
78 ))?
79 }
80
81 Ok(())
82}
83
84#[cfg(test)]
85mod tests {
86 use axum::response::IntoResponse;
87 use pretty_assertions::assert_eq;
88 use rpc::proto::Plan;
89
90 use super::*;
91
92 #[gpui::test]
93 async fn test_authorize_access_to_language_model_with_supported_country(
94 _cx: &mut gpui::TestAppContext,
95 ) {
96 let config = Config::test();
97
98 let claims = LlmTokenClaims {
99 user_id: 99,
100 plan: Plan::ZedPro,
101 is_staff: true,
102 ..Default::default()
103 };
104
105 let cases = vec![
106 (LanguageModelProvider::Anthropic, "US"), // United States
107 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
108 (LanguageModelProvider::OpenAi, "US"), // United States
109 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
110 (LanguageModelProvider::Google, "US"), // United States
111 (LanguageModelProvider::Google, "GB"), // United Kingdom
112 ];
113
114 for (provider, country_code) in cases {
115 authorize_access_to_language_model(
116 &config,
117 &claims,
118 Some(country_code.into()),
119 provider,
120 "the-model",
121 )
122 .unwrap_or_else(|_| {
123 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
124 })
125 }
126 }
127
128 #[gpui::test]
129 async fn test_authorize_access_to_language_model_with_unsupported_country(
130 _cx: &mut gpui::TestAppContext,
131 ) {
132 let config = Config::test();
133
134 let claims = LlmTokenClaims {
135 user_id: 99,
136 plan: Plan::ZedPro,
137 ..Default::default()
138 };
139
140 let cases = vec![
141 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
142 (LanguageModelProvider::Anthropic, "BY"), // Belarus
143 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
144 (LanguageModelProvider::Anthropic, "CN"), // China
145 (LanguageModelProvider::Anthropic, "CU"), // Cuba
146 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
147 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
148 (LanguageModelProvider::Anthropic, "IR"), // Iran
149 (LanguageModelProvider::Anthropic, "KP"), // North Korea
150 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
151 (LanguageModelProvider::Anthropic, "LY"), // Libya
152 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
153 (LanguageModelProvider::Anthropic, "RU"), // Russia
154 (LanguageModelProvider::Anthropic, "SO"), // Somalia
155 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
156 (LanguageModelProvider::Anthropic, "SD"), // Sudan
157 (LanguageModelProvider::Anthropic, "SY"), // Syria
158 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
159 (LanguageModelProvider::Anthropic, "YE"), // Yemen
160 (LanguageModelProvider::OpenAi, "KP"), // North Korea
161 (LanguageModelProvider::Google, "KP"), // North Korea
162 ];
163
164 for (provider, country_code) in cases {
165 let error_response = authorize_access_to_language_model(
166 &config,
167 &claims,
168 Some(country_code.into()),
169 provider,
170 "the-model",
171 )
172 .expect_err(&format!(
173 "expected authorization to return an error for {provider:?}: {country_code}"
174 ))
175 .into_response();
176
177 assert_eq!(
178 error_response.status(),
179 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
180 );
181 let response_body = hyper::body::to_bytes(error_response.into_body())
182 .await
183 .unwrap()
184 .to_vec();
185 assert_eq!(
186 String::from_utf8(response_body).unwrap(),
187 format!("access to {provider:?} models is not available in your region")
188 );
189 }
190 }
191
192 #[gpui::test]
193 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
194 let config = Config::test();
195
196 let claims = LlmTokenClaims {
197 user_id: 99,
198 plan: Plan::ZedPro,
199 ..Default::default()
200 };
201
202 let cases = vec![
203 (LanguageModelProvider::Anthropic, "T1"), // Tor
204 (LanguageModelProvider::OpenAi, "T1"), // Tor
205 (LanguageModelProvider::Google, "T1"), // Tor
206 (LanguageModelProvider::Zed, "T1"), // Tor
207 ];
208
209 for (provider, country_code) in cases {
210 let error_response = authorize_access_to_language_model(
211 &config,
212 &claims,
213 Some(country_code.into()),
214 provider,
215 "the-model",
216 )
217 .expect_err(&format!(
218 "expected authorization to return an error for {provider:?}: {country_code}"
219 ))
220 .into_response();
221
222 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
223 let response_body = hyper::body::to_bytes(error_response.into_body())
224 .await
225 .unwrap()
226 .to_vec();
227 assert_eq!(
228 String::from_utf8(response_body).unwrap(),
229 format!("access to {provider:?} models is not available over Tor")
230 );
231 }
232 }
233
234 #[gpui::test]
235 async fn test_authorize_access_to_language_model_based_on_plan() {
236 let config = Config::test();
237
238 let test_cases = vec![
239 // Pro plan should have access to claude-3.5-sonnet
240 (
241 Plan::ZedPro,
242 LanguageModelProvider::Anthropic,
243 "claude-3.5-sonnet",
244 true,
245 ),
246 // Free plan should have access to claude-3.5-sonnet
247 (
248 Plan::Free,
249 LanguageModelProvider::Anthropic,
250 "claude-3.5-sonnet",
251 true,
252 ),
253 // Pro plan should NOT have access to other Anthropic models
254 (
255 Plan::ZedPro,
256 LanguageModelProvider::Anthropic,
257 "claude-3-opus",
258 false,
259 ),
260 ];
261
262 for (plan, provider, model, expected_access) in test_cases {
263 let claims = LlmTokenClaims {
264 plan,
265 ..Default::default()
266 };
267
268 let result = authorize_access_to_language_model(
269 &config,
270 &claims,
271 Some("US".into()),
272 provider,
273 model,
274 );
275
276 if expected_access {
277 assert!(
278 result.is_ok(),
279 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
280 plan,
281 provider,
282 model
283 );
284 } else {
285 let error = result.expect_err(&format!(
286 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
287 plan, provider, model
288 ));
289 let response = error.into_response();
290 assert_eq!(response.status(), StatusCode::FORBIDDEN);
291 }
292 }
293 }
294
295 #[gpui::test]
296 async fn test_authorize_access_to_language_model_for_staff() {
297 let config = Config::test();
298
299 let claims = LlmTokenClaims {
300 is_staff: true,
301 ..Default::default()
302 };
303
304 // Staff should have access to all models
305 let test_cases = vec![
306 (LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
307 (LanguageModelProvider::Anthropic, "claude-2"),
308 (LanguageModelProvider::Anthropic, "claude-123-agi"),
309 (LanguageModelProvider::OpenAi, "gpt-4"),
310 (LanguageModelProvider::Google, "gemini-pro"),
311 ];
312
313 for (provider, model) in test_cases {
314 let result = authorize_access_to_language_model(
315 &config,
316 &claims,
317 Some("US".into()),
318 provider,
319 model,
320 );
321
322 assert!(
323 result.is_ok(),
324 "Expected staff to have access to provider {:?}, model {}",
325 provider,
326 model
327 );
328 }
329 }
330}