1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 claims: &LlmTokenClaims,
10 country_code: Option<&str>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider)?;
15 authorize_access_to_model(config, claims, provider, model)?;
16 Ok(())
17}
18
19fn authorize_access_to_model(
20 config: &Config,
21 claims: &LlmTokenClaims,
22 provider: LanguageModelProvider,
23 model: &str,
24) -> Result<()> {
25 if claims.is_staff {
26 return Ok(());
27 }
28
29 if provider == LanguageModelProvider::Anthropic {
30 if model == "claude-3-5-sonnet" {
31 return Ok(());
32 }
33
34 if claims.has_llm_closed_beta_feature_flag
35 && Some(model) == config.llm_closed_beta_model_name.as_deref()
36 {
37 return Ok(());
38 }
39 }
40
41 Err(Error::http(
42 StatusCode::FORBIDDEN,
43 format!("access to model {model:?} is not included in your plan"),
44 ))
45}
46
47fn authorize_access_for_country(
48 config: &Config,
49 country_code: Option<&str>,
50 provider: LanguageModelProvider,
51) -> Result<()> {
52 // In development we won't have the `CF-IPCountry` header, so we can't check
53 // the country code.
54 //
55 // This shouldn't be necessary, as anyone running in development will need to provide
56 // their own API credentials in order to use an LLM provider.
57 if config.is_development() {
58 return Ok(());
59 }
60
61 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
62 let country_code = match country_code {
63 // `XX` - Used for clients without country code data.
64 None | Some("XX") => Err(Error::http(
65 StatusCode::BAD_REQUEST,
66 "no country code".to_string(),
67 ))?,
68 // `T1` - Used for clients using the Tor network.
69 Some("T1") => Err(Error::http(
70 StatusCode::FORBIDDEN,
71 format!("access to {provider:?} models is not available over Tor"),
72 ))?,
73 Some(country_code) => country_code,
74 };
75
76 let is_country_supported_by_provider = match provider {
77 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
78 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
79 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
80 LanguageModelProvider::Zed => true,
81 };
82 if !is_country_supported_by_provider {
83 Err(Error::http(
84 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
85 format!(
86 "access to {provider:?} models is not available in your region ({country_code})"
87 ),
88 ))?
89 }
90
91 Ok(())
92}
93
94#[cfg(test)]
95mod tests {
96 use axum::response::IntoResponse;
97 use pretty_assertions::assert_eq;
98 use rpc::proto::Plan;
99
100 use super::*;
101
102 #[gpui::test]
103 async fn test_authorize_access_to_language_model_with_supported_country(
104 _cx: &mut gpui::TestAppContext,
105 ) {
106 let config = Config::test();
107
108 let claims = LlmTokenClaims {
109 user_id: 99,
110 plan: Plan::ZedPro,
111 is_staff: true,
112 ..Default::default()
113 };
114
115 let cases = vec![
116 (LanguageModelProvider::Anthropic, "US"), // United States
117 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
118 (LanguageModelProvider::OpenAi, "US"), // United States
119 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
120 (LanguageModelProvider::Google, "US"), // United States
121 (LanguageModelProvider::Google, "GB"), // United Kingdom
122 ];
123
124 for (provider, country_code) in cases {
125 authorize_access_to_language_model(
126 &config,
127 &claims,
128 Some(country_code),
129 provider,
130 "the-model",
131 )
132 .unwrap_or_else(|_| {
133 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
134 })
135 }
136 }
137
138 #[gpui::test]
139 async fn test_authorize_access_to_language_model_with_unsupported_country(
140 _cx: &mut gpui::TestAppContext,
141 ) {
142 let config = Config::test();
143
144 let claims = LlmTokenClaims {
145 user_id: 99,
146 plan: Plan::ZedPro,
147 ..Default::default()
148 };
149
150 let cases = vec![
151 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
152 (LanguageModelProvider::Anthropic, "BY"), // Belarus
153 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
154 (LanguageModelProvider::Anthropic, "CN"), // China
155 (LanguageModelProvider::Anthropic, "CU"), // Cuba
156 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
157 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
158 (LanguageModelProvider::Anthropic, "IR"), // Iran
159 (LanguageModelProvider::Anthropic, "KP"), // North Korea
160 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
161 (LanguageModelProvider::Anthropic, "LY"), // Libya
162 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
163 (LanguageModelProvider::Anthropic, "RU"), // Russia
164 (LanguageModelProvider::Anthropic, "SO"), // Somalia
165 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
166 (LanguageModelProvider::Anthropic, "SD"), // Sudan
167 (LanguageModelProvider::Anthropic, "SY"), // Syria
168 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
169 (LanguageModelProvider::Anthropic, "YE"), // Yemen
170 (LanguageModelProvider::OpenAi, "KP"), // North Korea
171 (LanguageModelProvider::Google, "KP"), // North Korea
172 ];
173
174 for (provider, country_code) in cases {
175 let error_response = authorize_access_to_language_model(
176 &config,
177 &claims,
178 Some(country_code),
179 provider,
180 "the-model",
181 )
182 .expect_err(&format!(
183 "expected authorization to return an error for {provider:?}: {country_code}"
184 ))
185 .into_response();
186
187 assert_eq!(
188 error_response.status(),
189 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
190 );
191 let response_body = hyper::body::to_bytes(error_response.into_body())
192 .await
193 .unwrap()
194 .to_vec();
195 assert_eq!(
196 String::from_utf8(response_body).unwrap(),
197 format!("access to {provider:?} models is not available in your region ({country_code})")
198 );
199 }
200 }
201
202 #[gpui::test]
203 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
204 let config = Config::test();
205
206 let claims = LlmTokenClaims {
207 user_id: 99,
208 plan: Plan::ZedPro,
209 ..Default::default()
210 };
211
212 let cases = vec![
213 (LanguageModelProvider::Anthropic, "T1"), // Tor
214 (LanguageModelProvider::OpenAi, "T1"), // Tor
215 (LanguageModelProvider::Google, "T1"), // Tor
216 (LanguageModelProvider::Zed, "T1"), // Tor
217 ];
218
219 for (provider, country_code) in cases {
220 let error_response = authorize_access_to_language_model(
221 &config,
222 &claims,
223 Some(country_code),
224 provider,
225 "the-model",
226 )
227 .expect_err(&format!(
228 "expected authorization to return an error for {provider:?}: {country_code}"
229 ))
230 .into_response();
231
232 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
233 let response_body = hyper::body::to_bytes(error_response.into_body())
234 .await
235 .unwrap()
236 .to_vec();
237 assert_eq!(
238 String::from_utf8(response_body).unwrap(),
239 format!("access to {provider:?} models is not available over Tor")
240 );
241 }
242 }
243
244 #[gpui::test]
245 async fn test_authorize_access_to_language_model_based_on_plan() {
246 let config = Config::test();
247
248 let test_cases = vec![
249 // Pro plan should have access to claude-3.5-sonnet
250 (
251 Plan::ZedPro,
252 LanguageModelProvider::Anthropic,
253 "claude-3-5-sonnet",
254 true,
255 ),
256 // Free plan should have access to claude-3.5-sonnet
257 (
258 Plan::Free,
259 LanguageModelProvider::Anthropic,
260 "claude-3-5-sonnet",
261 true,
262 ),
263 // Pro plan should NOT have access to other Anthropic models
264 (
265 Plan::ZedPro,
266 LanguageModelProvider::Anthropic,
267 "claude-3-opus",
268 false,
269 ),
270 ];
271
272 for (plan, provider, model, expected_access) in test_cases {
273 let claims = LlmTokenClaims {
274 plan,
275 ..Default::default()
276 };
277
278 let result =
279 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
280
281 if expected_access {
282 assert!(
283 result.is_ok(),
284 "Expected access to be granted for plan {:?}, provider {:?}, model {}",
285 plan,
286 provider,
287 model
288 );
289 } else {
290 let error = result.expect_err(&format!(
291 "Expected access to be denied for plan {:?}, provider {:?}, model {}",
292 plan, provider, model
293 ));
294 let response = error.into_response();
295 assert_eq!(response.status(), StatusCode::FORBIDDEN);
296 }
297 }
298 }
299
300 #[gpui::test]
301 async fn test_authorize_access_to_language_model_for_staff() {
302 let config = Config::test();
303
304 let claims = LlmTokenClaims {
305 is_staff: true,
306 ..Default::default()
307 };
308
309 // Staff should have access to all models
310 let test_cases = vec![
311 (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
312 (LanguageModelProvider::Anthropic, "claude-2"),
313 (LanguageModelProvider::Anthropic, "claude-123-agi"),
314 (LanguageModelProvider::OpenAi, "gpt-4"),
315 (LanguageModelProvider::Google, "gemini-pro"),
316 ];
317
318 for (provider, model) in test_cases {
319 let result =
320 authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
321
322 assert!(
323 result.is_ok(),
324 "Expected staff to have access to provider {:?}, model {}",
325 provider,
326 model
327 );
328 }
329 }
330}