1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<&str>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(config, claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 config: &Config,
21 claims: &LlmTokenClaims,
22 provider: LanguageModelProvider,
23 model: &str,
24) -> Result<()> {
25 if claims.is_staff {
26 return Ok(());
27 }
28
29 if provider == LanguageModelProvider::Anthropic {
30 if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
31 return Ok(());
32 }
33
34 if claims.has_llm_closed_beta_feature_flag
35 && Some(model) == config.llm_closed_beta_model_name.as_deref()
36 {
37 return Ok(());
38 }
39 }
40
41 Err(Error::http(
42 StatusCode::FORBIDDEN,
43 format!("access to model {model:?} is not included in your plan"),
44 ))
45}
46
47fn authorize_access_for_country(
48 config: &Config,
49 country_code: Option<&str>,
50 provider: LanguageModelProvider,
51) -> Result<()> {
52 // In development we won't have the `CF-IPCountry` header, so we can't check
53 // the country code.
54 //
55 // This shouldn't be necessary, as anyone running in development will need to provide
56 // their own API credentials in order to use an LLM provider.
57 if config.is_development() {
58 return Ok(());
59 }
60
61 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
62 let country_code = match country_code {
63 // `XX` - Used for clients without country code data.
64 None | Some("XX") => Err(Error::http(
65 StatusCode::BAD_REQUEST,
66 "no country code".to_string(),
67 ))?,
68 // `T1` - Used for clients using the Tor network.
69 Some("T1") => Err(Error::http(
70 StatusCode::FORBIDDEN,
71 format!("access to {provider:?} models is not available over Tor"),
72 ))?,
73 Some(country_code) => country_code,
74 };
75
76 let is_country_supported_by_provider = match provider {
77 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
78 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
79 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
80 };
81 if !is_country_supported_by_provider {
82 Err(Error::http(
83 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
84 format!(
85 "access to {provider:?} models is not available in your region ({country_code})"
86 ),
87 ))?
88 }
89
90 Ok(())
91}
92
93#[cfg(test)]
94mod tests {
95 use axum::response::IntoResponse;
96 use pretty_assertions::assert_eq;
97 use rpc::proto::Plan;
98
99 use super::*;
100
101 #[gpui::test]
102 async fn test_authorize_access_to_language_model_with_supported_country(
103 _cx: &mut gpui::TestAppContext,
104 ) {
105 let config = Config::test();
106
107 let claims = LlmTokenClaims {
108 user_id: 99,
109 plan: Plan::ZedPro,
110 is_staff: true,
111 ..Default::default()
112 };
113
114 let cases = vec![
115 (LanguageModelProvider::Anthropic, "US"), // United States
116 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
117 (LanguageModelProvider::OpenAi, "US"), // United States
118 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
119 (LanguageModelProvider::Google, "US"), // United States
120 (LanguageModelProvider::Google, "GB"), // United Kingdom
121 ];
122
123 for (provider, country_code) in cases {
124 authorize_access_to_language_model(
125 &config,
126 &claims,
127 Some(country_code),
128 provider,
129 "the-model",
130 )
131 .unwrap_or_else(|_| {
132 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
133 })
134 }
135 }
136
137 #[gpui::test]
138 async fn test_authorize_access_to_language_model_with_unsupported_country(
139 _cx: &mut gpui::TestAppContext,
140 ) {
141 let config = Config::test();
142
143 let claims = LlmTokenClaims {
144 user_id: 99,
145 plan: Plan::ZedPro,
146 ..Default::default()
147 };
148
149 let cases = vec![
150 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
151 (LanguageModelProvider::Anthropic, "BY"), // Belarus
152 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
153 (LanguageModelProvider::Anthropic, "CN"), // China
154 (LanguageModelProvider::Anthropic, "CU"), // Cuba
155 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
156 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
157 (LanguageModelProvider::Anthropic, "IR"), // Iran
158 (LanguageModelProvider::Anthropic, "KP"), // North Korea
159 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
160 (LanguageModelProvider::Anthropic, "LY"), // Libya
161 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
162 (LanguageModelProvider::Anthropic, "RU"), // Russia
163 (LanguageModelProvider::Anthropic, "SO"), // Somalia
164 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
165 (LanguageModelProvider::Anthropic, "SD"), // Sudan
166 (LanguageModelProvider::Anthropic, "SY"), // Syria
167 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
168 (LanguageModelProvider::Anthropic, "YE"), // Yemen
169 (LanguageModelProvider::OpenAi, "KP"), // North Korea
170 (LanguageModelProvider::Google, "KP"), // North Korea
171 ];
172
173 for (provider, country_code) in cases {
174 let error_response = authorize_access_to_language_model(
175 &config,
176 &claims,
177 Some(country_code),
178 provider,
179 "the-model",
180 )
181 .expect_err(&format!(
182 "expected authorization to return an error for {provider:?}: {country_code}"
183 ))
184 .into_response();
185
186 assert_eq!(
187 error_response.status(),
188 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
189 );
190 let response_body = hyper::body::to_bytes(error_response.into_body())
191 .await
192 .unwrap()
193 .to_vec();
194 assert_eq!(
195 String::from_utf8(response_body).unwrap(),
196 format!(
197 "access to {provider:?} models is not available in your region ({country_code})"
198 )
199 );
200 }
201 }
202
203 #[gpui::test]
204 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
205 let config = Config::test();
206
207 let claims = LlmTokenClaims {
208 user_id: 99,
209 plan: Plan::ZedPro,
210 ..Default::default()
211 };
212
213 let cases = vec![
214 (LanguageModelProvider::Anthropic, "T1"), // Tor
215 (LanguageModelProvider::OpenAi, "T1"), // Tor
216 (LanguageModelProvider::Google, "T1"), // Tor
217 ];
218
219 for (provider, country_code) in cases {
220 let error_response = authorize_access_to_language_model(
221 &config,
222 &claims,
223 Some(country_code),
224 provider,
225 "the-model",
226 )
227 .expect_err(&format!(
228 "expected authorization to return an error for {provider:?}: {country_code}"
229 ))
230 .into_response();
231
232 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
233 let response_body = hyper::body::to_bytes(error_response.into_body())
234 .await
235 .unwrap()
236 .to_vec();
237 assert_eq!(
238 String::from_utf8(response_body).unwrap(),
239 format!("access to {provider:?} models is not available over Tor")
240 );
241 }
242 }
243
244 #[gpui::test]
245 async fn test_authorize_access_to_language_model_based_on_plan() {
246 let config = Config::test();
247
248 let test_cases = vec![
249 // Pro plan should have access to claude-3.5-sonnet
250 (
251 Plan::ZedPro,
252 LanguageModelProvider::Anthropic,
253 "claude-3-5-sonnet",
254 true,
255 ),
256 // Free plan should have access to claude-3.5-sonnet
257 (
258 Plan::Free,
259 LanguageModelProvider::Anthropic,
260 "claude-3-5-sonnet",
261 true,
262 ),
263 // Pro plan should NOT have access to other Anthropic models
264 (
265 Plan::ZedPro,
266 LanguageModelProvider::Anthropic,
267 "claude-3-opus",
268 false,
269 ),
270 ];
271
272 for (plan, provider, model, expected_access) in test_cases {
273 let claims = LlmTokenClaims {
274 plan,
275 ..Default::default()
276 };
277
278 let result =
279 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
280
281 if expected_access {
282 assert!(
283 result.is_ok(),
284 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
285 plan,
286 provider,
287 model
288 );
289 } else {
290 let error = result.expect_err(&format!(
291 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
292 plan, provider, model
293 ));
294 let response = error.into_response();
295 assert_eq!(response.status(), StatusCode::FORBIDDEN);
296 }
297 }
298 }
299
300 #[gpui::test]
301 async fn test_authorize_access_to_language_model_for_staff() {
302 let config = Config::test();
303
304 let claims = LlmTokenClaims {
305 is_staff: true,
306 ..Default::default()
307 };
308
309 // Staff should have access to all models
310 let test_cases = vec![
311 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
312 (LanguageModelProvider::Anthropic, "claude-2"),
313 (LanguageModelProvider::Anthropic, "claude-123-agi"),
314 (LanguageModelProvider::OpenAi, "gpt-4"),
315 (LanguageModelProvider::Google, "gemini-pro"),
316 ];
317
318 for (provider, model) in test_cases {
319 let result =
320 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
321
322 assert!(
323 result.is_ok(),
324 "Expected staff to have access to provider {:?}, model {}",
325 provider,
326 model
327 );
328 }
329 }
330}