1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<&str>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(config, claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 config: &Config,
21 claims: &LlmTokenClaims,
22 provider: LanguageModelProvider,
23 model: &str,
24) -> Result<()> {
25 if claims.is_staff {
26 return Ok(());
27 }
28
29 if provider == LanguageModelProvider::Anthropic {
30 if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
31 return Ok(());
32 }
33
34 if claims.has_llm_closed_beta_feature_flag
35 && Some(model) == config.llm_closed_beta_model_name.as_deref()
36 {
37 return Ok(());
38 }
39 }
40
41 Err(Error::http(
42 StatusCode::FORBIDDEN,
43 format!("access to model {model:?} is not included in your plan"),
44 ))
45}
46
47fn authorize_access_for_country(
48 config: &Config,
49 country_code: Option<&str>,
50 provider: LanguageModelProvider,
51) -> Result<()> {
52 // In development we won't have the `CF-IPCountry` header, so we can't check
53 // the country code.
54 //
55 // This shouldn't be necessary, as anyone running in development will need to provide
56 // their own API credentials in order to use an LLM provider.
57 if config.is_development() {
58 return Ok(());
59 }
60
61 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
62 let country_code = match country_code {
63 // `XX` - Used for clients without country code data.
64 None | Some("XX") => Err(Error::http(
65 StatusCode::BAD_REQUEST,
66 "no country code".to_string(),
67 ))?,
68 // `T1` - Used for clients using the Tor network.
69 Some("T1") => Err(Error::http(
70 StatusCode::FORBIDDEN,
71 format!("access to {provider:?} models is not available over Tor"),
72 ))?,
73 Some(country_code) => country_code,
74 };
75
76 let is_country_supported_by_provider = match provider {
77 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
78 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
79 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
80 };
81 if !is_country_supported_by_provider {
82 Err(Error::http(
83 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
84 format!(
85 "access to {provider:?} models is not available in your region ({country_code})"
86 ),
87 ))?
88 }
89
90 Ok(())
91}
92
93#[cfg(test)]
94mod tests {
95 use axum::response::IntoResponse;
96 use pretty_assertions::assert_eq;
97 use rpc::proto::Plan;
98
99 use super::*;
100
101 #[gpui::test]
102 async fn test_authorize_access_to_language_model_with_supported_country(
103 _cx: &mut gpui::TestAppContext,
104 ) {
105 let config = Config::test();
106
107 let claims = LlmTokenClaims {
108 user_id: 99,
109 plan: Plan::ZedPro,
110 is_staff: true,
111 ..Default::default()
112 };
113
114 let cases = vec![
115 (LanguageModelProvider::Anthropic, "US"), // United States
116 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
117 (LanguageModelProvider::OpenAi, "US"), // United States
118 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
119 (LanguageModelProvider::Google, "US"), // United States
120 (LanguageModelProvider::Google, "GB"), // United Kingdom
121 ];
122
123 for (provider, country_code) in cases {
124 authorize_access_to_language_model(
125 &config,
126 &claims,
127 Some(country_code),
128 provider,
129 "the-model",
130 )
131 .unwrap_or_else(|_| {
132 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
133 })
134 }
135 }
136
137 #[gpui::test]
138 async fn test_authorize_access_to_language_model_with_unsupported_country(
139 _cx: &mut gpui::TestAppContext,
140 ) {
141 let config = Config::test();
142
143 let claims = LlmTokenClaims {
144 user_id: 99,
145 plan: Plan::ZedPro,
146 ..Default::default()
147 };
148
149 let cases = vec![
150 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
151 (LanguageModelProvider::Anthropic, "BY"), // Belarus
152 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
153 (LanguageModelProvider::Anthropic, "CN"), // China
154 (LanguageModelProvider::Anthropic, "CU"), // Cuba
155 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
156 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
157 (LanguageModelProvider::Anthropic, "IR"), // Iran
158 (LanguageModelProvider::Anthropic, "KP"), // North Korea
159 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
160 (LanguageModelProvider::Anthropic, "LY"), // Libya
161 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
162 (LanguageModelProvider::Anthropic, "RU"), // Russia
163 (LanguageModelProvider::Anthropic, "SO"), // Somalia
164 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
165 (LanguageModelProvider::Anthropic, "SD"), // Sudan
166 (LanguageModelProvider::Anthropic, "SY"), // Syria
167 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
168 (LanguageModelProvider::Anthropic, "YE"), // Yemen
169 (LanguageModelProvider::OpenAi, "KP"), // North Korea
170 (LanguageModelProvider::Google, "KP"), // North Korea
171 ];
172
173 for (provider, country_code) in cases {
174 let error_response = authorize_access_to_language_model(
175 &config,
176 &claims,
177 Some(country_code),
178 provider,
179 "the-model",
180 )
181 .expect_err(&format!(
182 "expected authorization to return an error for {provider:?}: {country_code}"
183 ))
184 .into_response();
185
186 assert_eq!(
187 error_response.status(),
188 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
189 );
190 let response_body = hyper::body::to_bytes(error_response.into_body())
191 .await
192 .unwrap()
193 .to_vec();
194 assert_eq!(
195 String::from_utf8(response_body).unwrap(),
196 format!("access to {provider:?} models is not available in your region ({country_code})")
197 );
198 }
199 }
200
201 #[gpui::test]
202 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
203 let config = Config::test();
204
205 let claims = LlmTokenClaims {
206 user_id: 99,
207 plan: Plan::ZedPro,
208 ..Default::default()
209 };
210
211 let cases = vec![
212 (LanguageModelProvider::Anthropic, "T1"), // Tor
213 (LanguageModelProvider::OpenAi, "T1"), // Tor
214 (LanguageModelProvider::Google, "T1"), // Tor
215 ];
216
217 for (provider, country_code) in cases {
218 let error_response = authorize_access_to_language_model(
219 &config,
220 &claims,
221 Some(country_code),
222 provider,
223 "the-model",
224 )
225 .expect_err(&format!(
226 "expected authorization to return an error for {provider:?}: {country_code}"
227 ))
228 .into_response();
229
230 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
231 let response_body = hyper::body::to_bytes(error_response.into_body())
232 .await
233 .unwrap()
234 .to_vec();
235 assert_eq!(
236 String::from_utf8(response_body).unwrap(),
237 format!("access to {provider:?} models is not available over Tor")
238 );
239 }
240 }
241
242 #[gpui::test]
243 async fn test_authorize_access_to_language_model_based_on_plan() {
244 let config = Config::test();
245
246 let test_cases = vec![
247 // Pro plan should have access to claude-3.5-sonnet
248 (
249 Plan::ZedPro,
250 LanguageModelProvider::Anthropic,
251 "claude-3-5-sonnet",
252 true,
253 ),
254 // Free plan should have access to claude-3.5-sonnet
255 (
256 Plan::Free,
257 LanguageModelProvider::Anthropic,
258 "claude-3-5-sonnet",
259 true,
260 ),
261 // Pro plan should NOT have access to other Anthropic models
262 (
263 Plan::ZedPro,
264 LanguageModelProvider::Anthropic,
265 "claude-3-opus",
266 false,
267 ),
268 ];
269
270 for (plan, provider, model, expected_access) in test_cases {
271 let claims = LlmTokenClaims {
272 plan,
273 ..Default::default()
274 };
275
276 let result =
277 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
278
279 if expected_access {
280 assert!(
281 result.is_ok(),
282 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
283 plan,
284 provider,
285 model
286 );
287 } else {
288 let error = result.expect_err(&format!(
289 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
290 plan, provider, model
291 ));
292 let response = error.into_response();
293 assert_eq!(response.status(), StatusCode::FORBIDDEN);
294 }
295 }
296 }
297
298 #[gpui::test]
299 async fn test_authorize_access_to_language_model_for_staff() {
300 let config = Config::test();
301
302 let claims = LlmTokenClaims {
303 is_staff: true,
304 ..Default::default()
305 };
306
307 // Staff should have access to all models
308 let test_cases = vec![
309 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
310 (LanguageModelProvider::Anthropic, "claude-2"),
311 (LanguageModelProvider::Anthropic, "claude-123-agi"),
312 (LanguageModelProvider::OpenAi, "gpt-4"),
313 (LanguageModelProvider::Google, "gemini-pro"),
314 ];
315
316 for (provider, model) in test_cases {
317 let result =
318 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
319
320 assert!(
321 result.is_ok(),
322 "Expected staff to have access to provider {:?}, model {}",
323 provider,
324 model
325 );
326 }
327 }
328}