1use reqwest::StatusCode;
2use rpc::LanguageModelProvider;
3
4use crate::llm::LlmTokenClaims;
5use crate::{Config, Error, Result};
6
7pub fn authorize_access_to_language_model(
8 config: &Config,
9 _claims: &LlmTokenClaims,
10 country_code: Option<String>,
11 provider: LanguageModelProvider,
12 model: &str,
13) -> Result<()> {
14 authorize_access_for_country(config, country_code, provider, model)?;
15
16 Ok(())
17}
18
19fn authorize_access_for_country(
20 config: &Config,
21 country_code: Option<String>,
22 provider: LanguageModelProvider,
23 _model: &str,
24) -> Result<()> {
25 // In development we won't have the `CF-IPCountry` header, so we can't check
26 // the country code.
27 //
28 // This shouldn't be necessary, as anyone running in development will need to provide
29 // their own API credentials in order to use an LLM provider.
30 if config.is_development() {
31 return Ok(());
32 }
33
34 // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
35 let country_code = match country_code.as_deref() {
36 // `XX` - Used for clients without country code data.
37 None | Some("XX") => Err(Error::http(
38 StatusCode::BAD_REQUEST,
39 "no country code".to_string(),
40 ))?,
41 // `T1` - Used for clients using the Tor network.
42 Some("T1") => Err(Error::http(
43 StatusCode::FORBIDDEN,
44 format!("access to {provider:?} models is not available over Tor"),
45 ))?,
46 Some(country_code) => country_code,
47 };
48
49 let is_country_supported_by_provider = match provider {
50 LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
51 LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
52 LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
53 LanguageModelProvider::Zed => true,
54 };
55 if !is_country_supported_by_provider {
56 Err(Error::http(
57 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
58 format!("access to {provider:?} models is not available in your region"),
59 ))?
60 }
61
62 Ok(())
63}
64
65#[cfg(test)]
66mod tests {
67 use axum::response::IntoResponse;
68 use pretty_assertions::assert_eq;
69 use rpc::proto::Plan;
70
71 use super::*;
72
73 #[gpui::test]
74 async fn test_authorize_access_to_language_model_with_supported_country(
75 _cx: &mut gpui::TestAppContext,
76 ) {
77 let config = Config::test();
78
79 let claims = LlmTokenClaims {
80 user_id: 99,
81 plan: Plan::ZedPro,
82 ..Default::default()
83 };
84
85 let cases = vec![
86 (LanguageModelProvider::Anthropic, "US"), // United States
87 (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
88 (LanguageModelProvider::OpenAi, "US"), // United States
89 (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
90 (LanguageModelProvider::Google, "US"), // United States
91 (LanguageModelProvider::Google, "GB"), // United Kingdom
92 ];
93
94 for (provider, country_code) in cases {
95 authorize_access_to_language_model(
96 &config,
97 &claims,
98 Some(country_code.into()),
99 provider,
100 "the-model",
101 )
102 .unwrap_or_else(|_| {
103 panic!("expected authorization to return Ok for {provider:?}: {country_code}")
104 })
105 }
106 }
107
108 #[gpui::test]
109 async fn test_authorize_access_to_language_model_with_unsupported_country(
110 _cx: &mut gpui::TestAppContext,
111 ) {
112 let config = Config::test();
113
114 let claims = LlmTokenClaims {
115 user_id: 99,
116 plan: Plan::ZedPro,
117 ..Default::default()
118 };
119
120 let cases = vec![
121 (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
122 (LanguageModelProvider::Anthropic, "BY"), // Belarus
123 (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
124 (LanguageModelProvider::Anthropic, "CN"), // China
125 (LanguageModelProvider::Anthropic, "CU"), // Cuba
126 (LanguageModelProvider::Anthropic, "ER"), // Eritrea
127 (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
128 (LanguageModelProvider::Anthropic, "IR"), // Iran
129 (LanguageModelProvider::Anthropic, "KP"), // North Korea
130 (LanguageModelProvider::Anthropic, "XK"), // Kosovo
131 (LanguageModelProvider::Anthropic, "LY"), // Libya
132 (LanguageModelProvider::Anthropic, "MM"), // Myanmar
133 (LanguageModelProvider::Anthropic, "RU"), // Russia
134 (LanguageModelProvider::Anthropic, "SO"), // Somalia
135 (LanguageModelProvider::Anthropic, "SS"), // South Sudan
136 (LanguageModelProvider::Anthropic, "SD"), // Sudan
137 (LanguageModelProvider::Anthropic, "SY"), // Syria
138 (LanguageModelProvider::Anthropic, "VE"), // Venezuela
139 (LanguageModelProvider::Anthropic, "YE"), // Yemen
140 (LanguageModelProvider::OpenAi, "KP"), // North Korea
141 (LanguageModelProvider::Google, "KP"), // North Korea
142 ];
143
144 for (provider, country_code) in cases {
145 let error_response = authorize_access_to_language_model(
146 &config,
147 &claims,
148 Some(country_code.into()),
149 provider,
150 "the-model",
151 )
152 .expect_err(&format!(
153 "expected authorization to return an error for {provider:?}: {country_code}"
154 ))
155 .into_response();
156
157 assert_eq!(
158 error_response.status(),
159 StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
160 );
161 let response_body = hyper::body::to_bytes(error_response.into_body())
162 .await
163 .unwrap()
164 .to_vec();
165 assert_eq!(
166 String::from_utf8(response_body).unwrap(),
167 format!("access to {provider:?} models is not available in your region")
168 );
169 }
170 }
171
172 #[gpui::test]
173 async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
174 let config = Config::test();
175
176 let claims = LlmTokenClaims {
177 user_id: 99,
178 plan: Plan::ZedPro,
179 ..Default::default()
180 };
181
182 let cases = vec![
183 (LanguageModelProvider::Anthropic, "T1"), // Tor
184 (LanguageModelProvider::OpenAi, "T1"), // Tor
185 (LanguageModelProvider::Google, "T1"), // Tor
186 (LanguageModelProvider::Zed, "T1"), // Tor
187 ];
188
189 for (provider, country_code) in cases {
190 let error_response = authorize_access_to_language_model(
191 &config,
192 &claims,
193 Some(country_code.into()),
194 provider,
195 "the-model",
196 )
197 .expect_err(&format!(
198 "expected authorization to return an error for {provider:?}: {country_code}"
199 ))
200 .into_response();
201
202 assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
203 let response_body = hyper::body::to_bytes(error_response.into_body())
204 .await
205 .unwrap()
206 .to_vec();
207 assert_eq!(
208 String::from_utf8(response_body).unwrap(),
209 format!("access to {provider:?} models is not available over Tor")
210 );
211 }
212 }
213}