Cargo.lock 🔗
@@ -2464,6 +2464,7 @@ dependencies = [
"headless",
"hex",
"http_client",
+ "hyper",
"indoc",
"jsonwebtoken",
"language",
Marshall Bowers created
This PR updates the LLM service to authorize access to language model
providers based on the requester's country.
We detect the country using Cloudflare's
[`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry)
header.
The country code is then checked against the list of supported countries
for the given LLM provider. Countries that are not supported will
receive an `HTTP 451: Unavailable For Legal Reasons` response.
Release Notes:
- N/A
Cargo.lock | 1
Cargo.toml | 1
crates/anthropic/src/anthropic.rs | 4
crates/anthropic/src/supported_countries.rs | 194 +++++++++++++++++++
crates/collab/Cargo.toml | 1
crates/collab/src/lib.rs | 40 +++
crates/collab/src/llm.rs | 15 +
crates/collab/src/llm/authorization.rs | 213 +++++++++++++++++++++
crates/google_ai/src/google_ai.rs | 4
crates/google_ai/src/supported_countries.rs | 232 +++++++++++++++++++++++
crates/open_ai/src/open_ai.rs | 4
crates/open_ai/src/supported_countries.rs | 207 ++++++++++++++++++++
typos.toml | 6
13 files changed, 921 insertions(+), 1 deletion(-)
@@ -2464,6 +2464,7 @@ dependencies = [
"headless",
"hex",
"http_client",
+ "hyper",
"indoc",
"jsonwebtoken",
"language",
@@ -340,6 +340,7 @@ git2 = { version = "0.19", default-features = false }
globset = "0.4"
heed = { version = "0.20.1", features = ["read-txn-no-tls"] }
hex = "0.4.3"
+hyper = "0.14"
html5ever = "0.27.0"
ignore = "0.4.22"
image = "0.25.1"
@@ -1,3 +1,5 @@
+mod supported_countries;
+
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@@ -6,6 +8,8 @@ use serde::{Deserialize, Serialize};
use std::time::Duration;
use strum::EnumIter;
+pub use supported_countries::*;
+
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
@@ -0,0 +1,194 @@
+use std::collections::HashSet;
+use std::sync::LazyLock;
+
+/// Returns whether the given country code is supported by Anthropic.
+///
+/// https://www.anthropic.com/supported-countries
+pub fn is_supported_country(country_code: &str) -> bool {
+ SUPPORTED_COUNTRIES.contains(&country_code)
+}
+
+/// The list of country codes supported by Anthropic.
+///
+/// https://www.anthropic.com/supported-countries
+static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
+ vec![
+ "AL", // Albania
+ "DZ", // Algeria
+ "AD", // Andorra
+ "AO", // Angola
+ "AG", // Antigua and Barbuda
+ "AR", // Argentina
+ "AM", // Armenia
+ "AU", // Australia
+ "AT", // Austria
+ "AZ", // Azerbaijan
+ "BS", // Bahamas
+ "BH", // Bahrain
+ "BD", // Bangladesh
+ "BB", // Barbados
+ "BE", // Belgium
+ "BZ", // Belize
+ "BJ", // Benin
+ "BT", // Bhutan
+ "BO", // Bolivia
+ "BA", // Bosnia and Herzegovina
+ "BW", // Botswana
+ "BR", // Brazil
+ "BN", // Brunei
+ "BG", // Bulgaria
+ "BF", // Burkina Faso
+ "BI", // Burundi
+ "CV", // Cabo Verde
+ "KH", // Cambodia
+ "CM", // Cameroon
+ "CA", // Canada
+ "TD", // Chad
+ "CL", // Chile
+ "CO", // Colombia
+ "KM", // Comoros
+ "CG", // Congo (Brazzaville)
+ "CR", // Costa Rica
+ "CI", // Côte d'Ivoire
+ "HR", // Croatia
+ "CY", // Cyprus
+ "CZ", // Czechia (Czech Republic)
+ "DK", // Denmark
+ "DJ", // Djibouti
+ "DM", // Dominica
+ "DO", // Dominican Republic
+ "EC", // Ecuador
+ "EG", // Egypt
+ "SV", // El Salvador
+ "GQ", // Equatorial Guinea
+ "EE", // Estonia
+ "SZ", // Eswatini
+ "FJ", // Fiji
+ "FI", // Finland
+ "FR", // France
+ "GA", // Gabon
+ "GM", // Gambia
+ "GE", // Georgia
+ "DE", // Germany
+ "GH", // Ghana
+ "GR", // Greece
+ "GD", // Grenada
+ "GT", // Guatemala
+ "GN", // Guinea
+ "GW", // Guinea-Bissau
+ "GY", // Guyana
+ "HT", // Haiti
+ "HN", // Honduras
+ "HU", // Hungary
+ "IS", // Iceland
+ "IN", // India
+ "ID", // Indonesia
+ "IQ", // Iraq
+ "IE", // Ireland
+ "IL", // Israel
+ "IT", // Italy
+ "JM", // Jamaica
+ "JP", // Japan
+ "JO", // Jordan
+ "KZ", // Kazakhstan
+ "KE", // Kenya
+ "KI", // Kiribati
+ "KW", // Kuwait
+ "KG", // Kyrgyzstan
+ "LA", // Laos
+ "LV", // Latvia
+ "LB", // Lebanon
+ "LS", // Lesotho
+ "LR", // Liberia
+ "LI", // Liechtenstein
+ "LT", // Lithuania
+ "LU", // Luxembourg
+ "MG", // Madagascar
+ "MW", // Malawi
+ "MY", // Malaysia
+ "MV", // Maldives
+ "MT", // Malta
+ "MH", // Marshall Islands
+ "MR", // Mauritania
+ "MU", // Mauritius
+ "MX", // Mexico
+ "FM", // Micronesia
+ "MD", // Moldova
+ "MC", // Monaco
+ "MN", // Mongolia
+ "ME", // Montenegro
+ "MA", // Morocco
+ "MZ", // Mozambique
+ "NA", // Namibia
+ "NR", // Nauru
+ "NP", // Nepal
+ "NL", // Netherlands
+ "NZ", // New Zealand
+ "NE", // Niger
+ "NG", // Nigeria
+ "MK", // North Macedonia
+ "NO", // Norway
+ "OM", // Oman
+ "PK", // Pakistan
+ "PW", // Palau
+ "PS", // Palestine
+ "PA", // Panama
+ "PG", // Papua New Guinea
+ "PY", // Paraguay
+ "PE", // Peru
+ "PH", // Philippines
+ "PL", // Poland
+ "PT", // Portugal
+ "QA", // Qatar
+ "RO", // Romania
+ "RW", // Rwanda
+ "KN", // Saint Kitts and Nevis
+ "LC", // Saint Lucia
+ "VC", // Saint Vincent and the Grenadines
+ "WS", // Samoa
+ "SM", // San Marino
+ "ST", // São Tomé and Príncipe
+ "SA", // Saudi Arabia
+ "SN", // Senegal
+ "RS", // Serbia
+ "SC", // Seychelles
+ "SL", // Sierra Leone
+ "SG", // Singapore
+ "SK", // Slovakia
+ "SI", // Slovenia
+ "SB", // Solomon Islands
+ "ZA", // South Africa
+ "KR", // South Korea
+ "ES", // Spain
+ "LK", // Sri Lanka
+ "SR", // Suriname
+ "SE", // Sweden
+ "CH", // Switzerland
+ "TW", // Taiwan
+ "TJ", // Tajikistan
+ "TZ", // Tanzania
+ "TH", // Thailand
+ "TL", // Timor-Leste
+ "TG", // Togo
+ "TO", // Tonga
+ "TT", // Trinidad and Tobago
+ "TN", // Tunisia
+ "TR", // Türkiye (Turkey)
+ "TM", // Turkmenistan
+ "TV", // Tuvalu
+ "UG", // Uganda
+ "UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
+ "AE", // United Arab Emirates
+ "GB", // United Kingdom
+ "US", // United States of America
+ "UY", // Uruguay
+ "UZ", // Uzbekistan
+ "VU", // Vanuatu
+ "VA", // Vatican City
+ "VN", // Vietnam
+ "ZM", // Zambia
+ "ZW", // Zimbabwe
+ ]
+ .into_iter()
+ .collect()
+});
@@ -90,6 +90,7 @@ fs = { workspace = true, features = ["test-support"] }
git = { workspace = true, features = ["test-support"] }
git_hosting_providers.workspace = true
gpui = { workspace = true, features = ["test-support"] }
+hyper.workspace = true
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
@@ -185,6 +185,46 @@ impl Config {
_ => "https://zed.dev",
}
}
+
+ #[cfg(test)]
+ pub fn test() -> Self {
+ Self {
+ http_port: 0,
+ database_url: "".into(),
+ database_max_connections: 0,
+ api_token: "".into(),
+ invite_link_prefix: "".into(),
+ live_kit_server: None,
+ live_kit_key: None,
+ live_kit_secret: None,
+ llm_api_secret: None,
+ rust_log: None,
+ log_json: None,
+ zed_environment: "test".into(),
+ blob_store_url: None,
+ blob_store_region: None,
+ blob_store_access_key: None,
+ blob_store_secret_key: None,
+ blob_store_bucket: None,
+ openai_api_key: None,
+ google_ai_api_key: None,
+ anthropic_api_key: None,
+ clickhouse_url: None,
+ clickhouse_user: None,
+ clickhouse_password: None,
+ clickhouse_database: None,
+ zed_client_checksum_seed: None,
+ slack_panics_webhook: None,
+ auto_join_channel_id: None,
+ migrations_path: None,
+ seed_path: None,
+ stripe_api_key: None,
+ stripe_price_id: None,
+ supermaven_admin_api_key: None,
+ qwen2_7b_api_key: None,
+ qwen2_7b_api_url: None,
+ }
+ }
}
/// The service mode that collab should run in.
@@ -1,7 +1,11 @@
+mod authorization;
mod token;
+use crate::api::CloudflareIpCountryHeader;
+use crate::llm::authorization::authorize_access_to_language_model;
use crate::{executor::Executor, Config, Error, Result};
use anyhow::Context as _;
+use axum::TypedHeader;
use axum::{
body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode},
@@ -91,9 +95,18 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
async fn perform_completion(
Extension(state): Extension<Arc<LlmState>>,
- Extension(_claims): Extension<LlmTokenClaims>,
+ Extension(claims): Extension<LlmTokenClaims>,
+ country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
+ authorize_access_to_language_model(
+ &state.config,
+ &claims,
+ country_code_header.map(|header| header.to_string()),
+ params.provider,
+ ¶ms.model,
+ )?;
+
match params.provider {
LanguageModelProvider::Anthropic => {
let api_key = state
@@ -0,0 +1,213 @@
+use reqwest::StatusCode;
+use rpc::LanguageModelProvider;
+
+use crate::llm::LlmTokenClaims;
+use crate::{Config, Error, Result};
+
+pub fn authorize_access_to_language_model(
+ config: &Config,
+ _claims: &LlmTokenClaims,
+ country_code: Option<String>,
+ provider: LanguageModelProvider,
+ model: &str,
+) -> Result<()> {
+ authorize_access_for_country(config, country_code, provider, model)?;
+
+ Ok(())
+}
+
+fn authorize_access_for_country(
+ config: &Config,
+ country_code: Option<String>,
+ provider: LanguageModelProvider,
+ _model: &str,
+) -> Result<()> {
+ // In development we won't have the `CF-IPCountry` header, so we can't check
+ // the country code.
+ //
+ // This shouldn't be necessary, as anyone running in development will need to provide
+ // their own API credentials in order to use an LLM provider.
+ if config.is_development() {
+ return Ok(());
+ }
+
+ // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
+ let country_code = match country_code.as_deref() {
+ // `XX` - Used for clients without country code data.
+ None | Some("XX") => Err(Error::http(
+ StatusCode::BAD_REQUEST,
+ "no country code".to_string(),
+ ))?,
+ // `T1` - Used for clients using the Tor network.
+ Some("T1") => Err(Error::http(
+ StatusCode::FORBIDDEN,
+ format!("access to {provider:?} models is not available over Tor"),
+ ))?,
+ Some(country_code) => country_code,
+ };
+
+ let is_country_supported_by_provider = match provider {
+ LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
+ LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
+ LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
+ LanguageModelProvider::Zed => true,
+ };
+ if !is_country_supported_by_provider {
+ Err(Error::http(
+ StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
+ format!("access to {provider:?} models is not available in your region"),
+ ))?
+ }
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use axum::response::IntoResponse;
+ use pretty_assertions::assert_eq;
+ use rpc::proto::Plan;
+
+ use super::*;
+
+ #[gpui::test]
+ async fn test_authorize_access_to_language_model_with_supported_country(
+ _cx: &mut gpui::TestAppContext,
+ ) {
+ let config = Config::test();
+
+ let claims = LlmTokenClaims {
+ user_id: 99,
+ plan: Plan::ZedPro,
+ ..Default::default()
+ };
+
+ let cases = vec![
+ (LanguageModelProvider::Anthropic, "US"), // United States
+ (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
+ (LanguageModelProvider::OpenAi, "US"), // United States
+ (LanguageModelProvider::OpenAi, "GB"), // United Kingdom
+ (LanguageModelProvider::Google, "US"), // United States
+ (LanguageModelProvider::Google, "GB"), // United Kingdom
+ ];
+
+ for (provider, country_code) in cases {
+ authorize_access_to_language_model(
+ &config,
+ &claims,
+ Some(country_code.into()),
+ provider,
+ "the-model",
+ )
+ .unwrap_or_else(|_| {
+ panic!("expected authorization to return Ok for {provider:?}: {country_code}")
+ })
+ }
+ }
+
+ #[gpui::test]
+ async fn test_authorize_access_to_language_model_with_unsupported_country(
+ _cx: &mut gpui::TestAppContext,
+ ) {
+ let config = Config::test();
+
+ let claims = LlmTokenClaims {
+ user_id: 99,
+ plan: Plan::ZedPro,
+ ..Default::default()
+ };
+
+ let cases = vec![
+ (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
+ (LanguageModelProvider::Anthropic, "BY"), // Belarus
+ (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
+ (LanguageModelProvider::Anthropic, "CN"), // China
+ (LanguageModelProvider::Anthropic, "CU"), // Cuba
+ (LanguageModelProvider::Anthropic, "ER"), // Eritrea
+ (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
+ (LanguageModelProvider::Anthropic, "IR"), // Iran
+ (LanguageModelProvider::Anthropic, "KP"), // North Korea
+ (LanguageModelProvider::Anthropic, "XK"), // Kosovo
+ (LanguageModelProvider::Anthropic, "LY"), // Libya
+ (LanguageModelProvider::Anthropic, "MM"), // Myanmar
+ (LanguageModelProvider::Anthropic, "RU"), // Russia
+ (LanguageModelProvider::Anthropic, "SO"), // Somalia
+ (LanguageModelProvider::Anthropic, "SS"), // South Sudan
+ (LanguageModelProvider::Anthropic, "SD"), // Sudan
+ (LanguageModelProvider::Anthropic, "SY"), // Syria
+ (LanguageModelProvider::Anthropic, "VE"), // Venezuela
+ (LanguageModelProvider::Anthropic, "YE"), // Yemen
+ (LanguageModelProvider::OpenAi, "KP"), // North Korea
+ (LanguageModelProvider::Google, "KP"), // North Korea
+ ];
+
+ for (provider, country_code) in cases {
+ let error_response = authorize_access_to_language_model(
+ &config,
+ &claims,
+ Some(country_code.into()),
+ provider,
+ "the-model",
+ )
+ .expect_err(&format!(
+ "expected authorization to return an error for {provider:?}: {country_code}"
+ ))
+ .into_response();
+
+ assert_eq!(
+ error_response.status(),
+ StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
+ );
+ let response_body = hyper::body::to_bytes(error_response.into_body())
+ .await
+ .unwrap()
+ .to_vec();
+ assert_eq!(
+ String::from_utf8(response_body).unwrap(),
+ format!("access to {provider:?} models is not available in your region")
+ );
+ }
+ }
+
+ #[gpui::test]
+ async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
+ let config = Config::test();
+
+ let claims = LlmTokenClaims {
+ user_id: 99,
+ plan: Plan::ZedPro,
+ ..Default::default()
+ };
+
+ let cases = vec![
+ (LanguageModelProvider::Anthropic, "T1"), // Tor
+ (LanguageModelProvider::OpenAi, "T1"), // Tor
+ (LanguageModelProvider::Google, "T1"), // Tor
+ (LanguageModelProvider::Zed, "T1"), // Tor
+ ];
+
+ for (provider, country_code) in cases {
+ let error_response = authorize_access_to_language_model(
+ &config,
+ &claims,
+ Some(country_code.into()),
+ provider,
+ "the-model",
+ )
+ .expect_err(&format!(
+ "expected authorization to return an error for {provider:?}: {country_code}"
+ ))
+ .into_response();
+
+ assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
+ let response_body = hyper::body::to_bytes(error_response.into_body())
+ .await
+ .unwrap()
+ .to_vec();
+ assert_eq!(
+ String::from_utf8(response_body).unwrap(),
+ format!("access to {provider:?} models is not available over Tor")
+ );
+ }
+ }
+}
@@ -1,8 +1,12 @@
+mod supported_countries;
+
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::HttpClient;
use serde::{Deserialize, Serialize};
+pub use supported_countries::*;
+
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content(
@@ -0,0 +1,232 @@
+use std::collections::HashSet;
+use std::sync::LazyLock;
+
+/// Returns whether the given country code is supported by Google Gemini.
+///
+/// https://ai.google.dev/gemini-api/docs/available-regions
+pub fn is_supported_country(country_code: &str) -> bool {
+ SUPPORTED_COUNTRIES.contains(&country_code)
+}
+
+/// The list of country codes supported by Google Gemini.
+///
+/// https://ai.google.dev/gemini-api/docs/available-regions
+static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
+ vec![
+ "DZ", // Algeria
+ "AS", // American Samoa
+ "AO", // Angola
+ "AI", // Anguilla
+ "AQ", // Antarctica
+ "AG", // Antigua and Barbuda
+ "AR", // Argentina
+ "AM", // Armenia
+ "AW", // Aruba
+ "AU", // Australia
+ "AT", // Austria
+ "AZ", // Azerbaijan
+ "BS", // The Bahamas
+ "BH", // Bahrain
+ "BD", // Bangladesh
+ "BB", // Barbados
+ "BE", // Belgium
+ "BZ", // Belize
+ "BJ", // Benin
+ "BM", // Bermuda
+ "BT", // Bhutan
+ "BO", // Bolivia
+ "BW", // Botswana
+ "BR", // Brazil
+ "IO", // British Indian Ocean Territory
+ "VG", // British Virgin Islands
+ "BN", // Brunei
+ "BG", // Bulgaria
+ "BF", // Burkina Faso
+ "BI", // Burundi
+ "CV", // Cabo Verde
+ "KH", // Cambodia
+ "CM", // Cameroon
+ "CA", // Canada
+ "BQ", // Caribbean Netherlands
+ "KY", // Cayman Islands
+ "CF", // Central African Republic
+ "TD", // Chad
+ "CL", // Chile
+ "CX", // Christmas Island
+ "CC", // Cocos (Keeling) Islands
+ "CO", // Colombia
+ "KM", // Comoros
+ "CK", // Cook Islands
+ "CI", // Côte d'Ivoire
+ "CR", // Costa Rica
+ "HR", // Croatia
+ "CW", // Curaçao
+ "CZ", // Czech Republic
+ "CD", // Democratic Republic of the Congo
+ "DK", // Denmark
+ "DJ", // Djibouti
+ "DM", // Dominica
+ "DO", // Dominican Republic
+ "EC", // Ecuador
+ "EG", // Egypt
+ "SV", // El Salvador
+ "GQ", // Equatorial Guinea
+ "ER", // Eritrea
+ "EE", // Estonia
+ "SZ", // Eswatini
+ "ET", // Ethiopia
+ "FK", // Falkland Islands (Islas Malvinas)
+ "FJ", // Fiji
+ "FI", // Finland
+ "FR", // France
+ "GA", // Gabon
+ "GM", // The Gambia
+ "GE", // Georgia
+ "DE", // Germany
+ "GH", // Ghana
+ "GI", // Gibraltar
+ "GR", // Greece
+ "GD", // Grenada
+ "GU", // Guam
+ "GT", // Guatemala
+ "GG", // Guernsey
+ "GN", // Guinea
+ "GW", // Guinea-Bissau
+ "GY", // Guyana
+ "HT", // Haiti
+ "HM", // Heard Island and McDonald Islands
+ "HN", // Honduras
+ "HU", // Hungary
+ "IS", // Iceland
+ "IN", // India
+ "ID", // Indonesia
+ "IQ", // Iraq
+ "IE", // Ireland
+ "IM", // Isle of Man
+ "IL", // Israel
+ "IT", // Italy
+ "JM", // Jamaica
+ "JP", // Japan
+ "JE", // Jersey
+ "JO", // Jordan
+ "KZ", // Kazakhstan
+ "KE", // Kenya
+ "KI", // Kiribati
+ "KG", // Kyrgyzstan
+ "KW", // Kuwait
+ "LA", // Laos
+ "LV", // Latvia
+ "LB", // Lebanon
+ "LS", // Lesotho
+ "LR", // Liberia
+ "LY", // Libya
+ "LI", // Liechtenstein
+ "LT", // Lithuania
+ "LU", // Luxembourg
+ "MG", // Madagascar
+ "MW", // Malawi
+ "MY", // Malaysia
+ "MV", // Maldives
+ "ML", // Mali
+ "MT", // Malta
+ "MH", // Marshall Islands
+ "MR", // Mauritania
+ "MU", // Mauritius
+ "MX", // Mexico
+ "FM", // Micronesia
+ "MN", // Mongolia
+ "MS", // Montserrat
+ "MA", // Morocco
+ "MZ", // Mozambique
+ "NA", // Namibia
+ "NR", // Nauru
+ "NP", // Nepal
+ "NL", // Netherlands
+ "NC", // New Caledonia
+ "NZ", // New Zealand
+ "NI", // Nicaragua
+ "NE", // Niger
+ "NG", // Nigeria
+ "NU", // Niue
+ "NF", // Norfolk Island
+ "MP", // Northern Mariana Islands
+ "NO", // Norway
+ "OM", // Oman
+ "PK", // Pakistan
+ "PW", // Palau
+ "PS", // Palestine
+ "PA", // Panama
+ "PG", // Papua New Guinea
+ "PY", // Paraguay
+ "PE", // Peru
+ "PH", // Philippines
+ "PN", // Pitcairn Islands
+ "PL", // Poland
+ "PT", // Portugal
+ "PR", // Puerto Rico
+ "QA", // Qatar
+ "CY", // Republic of Cyprus
+ "CG", // Republic of the Congo
+ "RO", // Romania
+ "RW", // Rwanda
+ "BL", // Saint Barthélemy
+ "KN", // Saint Kitts and Nevis
+ "LC", // Saint Lucia
+ "PM", // Saint Pierre and Miquelon
+ "VC", // Saint Vincent and the Grenadines
+ "SH", // Saint Helena, Ascension and Tristan da Cunha
+ "WS", // Samoa
+ "ST", // São Tomé and Príncipe
+ "SA", // Saudi Arabia
+ "SN", // Senegal
+ "SC", // Seychelles
+ "SL", // Sierra Leone
+ "SG", // Singapore
+ "SK", // Slovakia
+ "SI", // Slovenia
+ "SB", // Solomon Islands
+ "SO", // Somalia
+ "ZA", // South Africa
+ "GS", // South Georgia and the South Sandwich Islands
+ "KR", // South Korea
+ "SS", // South Sudan
+ "ES", // Spain
+ "LK", // Sri Lanka
+ "SD", // Sudan
+ "SR", // Suriname
+ "SE", // Sweden
+ "CH", // Switzerland
+ "TW", // Taiwan
+ "TJ", // Tajikistan
+ "TZ", // Tanzania
+ "TH", // Thailand
+ "TL", // Timor-Leste
+ "TG", // Togo
+ "TK", // Tokelau
+ "TO", // Tonga
+ "TT", // Trinidad and Tobago
+ "TN", // Tunisia
+ "TR", // Türkiye
+ "TM", // Turkmenistan
+ "TC", // Turks and Caicos Islands
+ "TV", // Tuvalu
+ "UG", // Uganda
+ "GB", // United Kingdom
+ "AE", // United Arab Emirates
+ "US", // United States
+ "UM", // United States Minor Outlying Islands
+ "VI", // U.S. Virgin Islands
+ "UY", // Uruguay
+ "UZ", // Uzbekistan
+ "VU", // Vanuatu
+ "VE", // Venezuela
+ "VN", // Vietnam
+ "WF", // Wallis and Futuna
+ "EH", // Western Sahara
+ "YE", // Yemen
+ "ZM", // Zambia
+ "ZW", // Zimbabwe
+ ]
+ .into_iter()
+ .collect()
+});
@@ -1,3 +1,5 @@
+mod supported_countries;
+
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@@ -7,6 +9,8 @@ use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use strum::EnumIter;
+pub use supported_countries::*;
+
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
@@ -0,0 +1,207 @@
+use std::collections::HashSet;
+use std::sync::LazyLock;
+
+/// Returns whether the given country code is supported by OpenAI.
+///
+/// https://platform.openai.com/docs/supported-countries
+pub fn is_supported_country(country_code: &str) -> bool {
+ SUPPORTED_COUNTRIES.contains(&country_code)
+}
+
+/// The list of country codes supported by OpenAI.
+///
+/// https://platform.openai.com/docs/supported-countries
+static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
+ vec![
+ "AL", // Albania
+ "DZ", // Algeria
+ "AF", // Afghanistan
+ "AD", // Andorra
+ "AO", // Angola
+ "AG", // Antigua and Barbuda
+ "AR", // Argentina
+ "AM", // Armenia
+ "AU", // Australia
+ "AT", // Austria
+ "AZ", // Azerbaijan
+ "BS", // Bahamas
+ "BH", // Bahrain
+ "BD", // Bangladesh
+ "BB", // Barbados
+ "BE", // Belgium
+ "BZ", // Belize
+ "BJ", // Benin
+ "BT", // Bhutan
+ "BO", // Bolivia
+ "BA", // Bosnia and Herzegovina
+ "BW", // Botswana
+ "BR", // Brazil
+ "BN", // Brunei
+ "BG", // Bulgaria
+ "BF", // Burkina Faso
+ "BI", // Burundi
+ "CV", // Cabo Verde
+ "KH", // Cambodia
+ "CM", // Cameroon
+ "CA", // Canada
+ "CF", // Central African Republic
+ "TD", // Chad
+ "CL", // Chile
+ "CO", // Colombia
+ "KM", // Comoros
+ "CG", // Congo (Brazzaville)
+ "CD", // Congo (DRC)
+ "CR", // Costa Rica
+ "CI", // Côte d'Ivoire
+ "HR", // Croatia
+ "CY", // Cyprus
+ "CZ", // Czechia (Czech Republic)
+ "DK", // Denmark
+ "DJ", // Djibouti
+ "DM", // Dominica
+ "DO", // Dominican Republic
+ "EC", // Ecuador
+ "EG", // Egypt
+ "SV", // El Salvador
+ "GQ", // Equatorial Guinea
+ "ER", // Eritrea
+ "EE", // Estonia
+ "SZ", // Eswatini (Swaziland)
+ "ET", // Ethiopia
+ "FJ", // Fiji
+ "FI", // Finland
+ "FR", // France
+ "GA", // Gabon
+ "GM", // Gambia
+ "GE", // Georgia
+ "DE", // Germany
+ "GH", // Ghana
+ "GR", // Greece
+ "GD", // Grenada
+ "GT", // Guatemala
+ "GN", // Guinea
+ "GW", // Guinea-Bissau
+ "GY", // Guyana
+ "HT", // Haiti
+ "VA", // Holy See (Vatican City)
+ "HN", // Honduras
+ "HU", // Hungary
+ "IS", // Iceland
+ "IN", // India
+ "ID", // Indonesia
+ "IQ", // Iraq
+ "IE", // Ireland
+ "IL", // Israel
+ "IT", // Italy
+ "JM", // Jamaica
+ "JP", // Japan
+ "JO", // Jordan
+ "KZ", // Kazakhstan
+ "KE", // Kenya
+ "KI", // Kiribati
+ "KW", // Kuwait
+ "KG", // Kyrgyzstan
+ "LA", // Laos
+ "LV", // Latvia
+ "LB", // Lebanon
+ "LS", // Lesotho
+ "LR", // Liberia
+ "LY", // Libya
+ "LI", // Liechtenstein
+ "LT", // Lithuania
+ "LU", // Luxembourg
+ "MG", // Madagascar
+ "MW", // Malawi
+ "MY", // Malaysia
+ "MV", // Maldives
+ "ML", // Mali
+ "MT", // Malta
+ "MH", // Marshall Islands
+ "MR", // Mauritania
+ "MU", // Mauritius
+ "MX", // Mexico
+ "FM", // Micronesia
+ "MD", // Moldova
+ "MC", // Monaco
+ "MN", // Mongolia
+ "ME", // Montenegro
+ "MA", // Morocco
+ "MZ", // Mozambique
+ "MM", // Myanmar
+ "NA", // Namibia
+ "NR", // Nauru
+ "NP", // Nepal
+ "NL", // Netherlands
+ "NZ", // New Zealand
+ "NI", // Nicaragua
+ "NE", // Niger
+ "NG", // Nigeria
+ "MK", // North Macedonia
+ "NO", // Norway
+ "OM", // Oman
+ "PK", // Pakistan
+ "PW", // Palau
+ "PS", // Palestine
+ "PA", // Panama
+ "PG", // Papua New Guinea
+ "PY", // Paraguay
+ "PE", // Peru
+ "PH", // Philippines
+ "PL", // Poland
+ "PT", // Portugal
+ "QA", // Qatar
+ "RO", // Romania
+ "RW", // Rwanda
+ "KN", // Saint Kitts and Nevis
+ "LC", // Saint Lucia
+ "VC", // Saint Vincent and the Grenadines
+ "WS", // Samoa
+ "SM", // San Marino
+ "ST", // Sao Tome and Principe
+ "SA", // Saudi Arabia
+ "SN", // Senegal
+ "RS", // Serbia
+ "SC", // Seychelles
+ "SL", // Sierra Leone
+ "SG", // Singapore
+ "SK", // Slovakia
+ "SI", // Slovenia
+ "SB", // Solomon Islands
+ "SO", // Somalia
+ "ZA", // South Africa
+ "KR", // South Korea
+ "SS", // South Sudan
+ "ES", // Spain
+ "LK", // Sri Lanka
+ "SR", // Suriname
+ "SE", // Sweden
+ "CH", // Switzerland
+ "SD", // Sudan
+ "TW", // Taiwan
+ "TJ", // Tajikistan
+ "TZ", // Tanzania
+ "TH", // Thailand
+ "TL", // Timor-Leste (East Timor)
+ "TG", // Togo
+ "TO", // Tonga
+ "TT", // Trinidad and Tobago
+ "TN", // Tunisia
+ "TR", // Turkey
+ "TM", // Turkmenistan
+ "TV", // Tuvalu
+ "UG", // Uganda
+ "UA", // Ukraine (with certain exceptions)
+ "AE", // United Arab Emirates
+ "GB", // United Kingdom
+ "US", // United States of America
+ "UY", // Uruguay
+ "UZ", // Uzbekistan
+ "VU", // Vanuatu
+ "VN", // Vietnam
+ "YE", // Yemen
+ "ZM", // Zambia
+ "ZW", // Zimbabwe
+ ]
+ .into_iter()
+ .collect()
+});
@@ -6,6 +6,12 @@ extend-exclude = [
# File suffixes aren't typos
"assets/icons/file_icons/file_types.json",
"crates/extensions_ui/src/extension_suggest.rs",
+
+ # Some countries codes are flagged as typos.
+ "crates/anthropic/src/supported_countries.rs",
+ "crates/google_ai/src/supported_countries.rs",
+ "crates/open_ai/src/supported_countries.rs",
+
# Stripe IDs are flagged as typos.
"crates/collab/src/db/tests/processed_stripe_event_tests.rs",
# Not our typos