Authorize access to language model providers based on country (#15859)

Marshall Bowers created

This PR updates the LLM service to authorize access to language model
providers based on the requester's country.

We detect the country using Cloudflare's
[`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry)
header.

The country code is then checked against the list of supported countries
for the given LLM provider. Countries that are not supported will
receive an `HTTP 451: Unavailable For Legal Reasons` response.

Release Notes:

- N/A

Change summary

Cargo.lock                                  |   1 
Cargo.toml                                  |   1 
crates/anthropic/src/anthropic.rs           |   4 
crates/anthropic/src/supported_countries.rs | 194 +++++++++++++++++++
crates/collab/Cargo.toml                    |   1 
crates/collab/src/lib.rs                    |  40 +++
crates/collab/src/llm.rs                    |  15 +
crates/collab/src/llm/authorization.rs      | 213 +++++++++++++++++++++
crates/google_ai/src/google_ai.rs           |   4 
crates/google_ai/src/supported_countries.rs | 232 +++++++++++++++++++++++
crates/open_ai/src/open_ai.rs               |   4 
crates/open_ai/src/supported_countries.rs   | 207 ++++++++++++++++++++
typos.toml                                  |   6 
13 files changed, 921 insertions(+), 1 deletion(-)

Detailed changes

Cargo.lock 🔗

@@ -2464,6 +2464,7 @@ dependencies = [
  "headless",
  "hex",
  "http_client",
+ "hyper",
  "indoc",
  "jsonwebtoken",
  "language",

Cargo.toml 🔗

@@ -340,6 +340,7 @@ git2 = { version = "0.19", default-features = false }
 globset = "0.4"
 heed = { version = "0.20.1", features = ["read-txn-no-tls"] }
 hex = "0.4.3"
+hyper = "0.14"
 html5ever = "0.27.0"
 ignore = "0.4.22"
 image = "0.25.1"

crates/anthropic/src/anthropic.rs 🔗

@@ -1,3 +1,5 @@
+mod supported_countries;
+
 use anyhow::{anyhow, Result};
 use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@@ -6,6 +8,8 @@ use serde::{Deserialize, Serialize};
 use std::time::Duration;
 use strum::EnumIter;
 
+pub use supported_countries::*;
+
 pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]

crates/anthropic/src/supported_countries.rs 🔗

@@ -0,0 +1,194 @@
+use std::collections::HashSet;
+use std::sync::LazyLock;
+
+/// Returns whether the given country code is supported by Anthropic.
+///
+/// https://www.anthropic.com/supported-countries
+pub fn is_supported_country(country_code: &str) -> bool {
+    SUPPORTED_COUNTRIES.contains(&country_code)
+}
+
+/// The list of country codes supported by Anthropic.
+///
+/// https://www.anthropic.com/supported-countries
+static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
+    vec![
+        "AL", // Albania
+        "DZ", // Algeria
+        "AD", // Andorra
+        "AO", // Angola
+        "AG", // Antigua and Barbuda
+        "AR", // Argentina
+        "AM", // Armenia
+        "AU", // Australia
+        "AT", // Austria
+        "AZ", // Azerbaijan
+        "BS", // Bahamas
+        "BH", // Bahrain
+        "BD", // Bangladesh
+        "BB", // Barbados
+        "BE", // Belgium
+        "BZ", // Belize
+        "BJ", // Benin
+        "BT", // Bhutan
+        "BO", // Bolivia
+        "BA", // Bosnia and Herzegovina
+        "BW", // Botswana
+        "BR", // Brazil
+        "BN", // Brunei
+        "BG", // Bulgaria
+        "BF", // Burkina Faso
+        "BI", // Burundi
+        "CV", // Cabo Verde
+        "KH", // Cambodia
+        "CM", // Cameroon
+        "CA", // Canada
+        "TD", // Chad
+        "CL", // Chile
+        "CO", // Colombia
+        "KM", // Comoros
+        "CG", // Congo (Brazzaville)
+        "CR", // Costa Rica
+        "CI", // Côte d'Ivoire
+        "HR", // Croatia
+        "CY", // Cyprus
+        "CZ", // Czechia (Czech Republic)
+        "DK", // Denmark
+        "DJ", // Djibouti
+        "DM", // Dominica
+        "DO", // Dominican Republic
+        "EC", // Ecuador
+        "EG", // Egypt
+        "SV", // El Salvador
+        "GQ", // Equatorial Guinea
+        "EE", // Estonia
+        "SZ", // Eswatini
+        "FJ", // Fiji
+        "FI", // Finland
+        "FR", // France
+        "GA", // Gabon
+        "GM", // Gambia
+        "GE", // Georgia
+        "DE", // Germany
+        "GH", // Ghana
+        "GR", // Greece
+        "GD", // Grenada
+        "GT", // Guatemala
+        "GN", // Guinea
+        "GW", // Guinea-Bissau
+        "GY", // Guyana
+        "HT", // Haiti
+        "HN", // Honduras
+        "HU", // Hungary
+        "IS", // Iceland
+        "IN", // India
+        "ID", // Indonesia
+        "IQ", // Iraq
+        "IE", // Ireland
+        "IL", // Israel
+        "IT", // Italy
+        "JM", // Jamaica
+        "JP", // Japan
+        "JO", // Jordan
+        "KZ", // Kazakhstan
+        "KE", // Kenya
+        "KI", // Kiribati
+        "KW", // Kuwait
+        "KG", // Kyrgyzstan
+        "LA", // Laos
+        "LV", // Latvia
+        "LB", // Lebanon
+        "LS", // Lesotho
+        "LR", // Liberia
+        "LI", // Liechtenstein
+        "LT", // Lithuania
+        "LU", // Luxembourg
+        "MG", // Madagascar
+        "MW", // Malawi
+        "MY", // Malaysia
+        "MV", // Maldives
+        "MT", // Malta
+        "MH", // Marshall Islands
+        "MR", // Mauritania
+        "MU", // Mauritius
+        "MX", // Mexico
+        "FM", // Micronesia
+        "MD", // Moldova
+        "MC", // Monaco
+        "MN", // Mongolia
+        "ME", // Montenegro
+        "MA", // Morocco
+        "MZ", // Mozambique
+        "NA", // Namibia
+        "NR", // Nauru
+        "NP", // Nepal
+        "NL", // Netherlands
+        "NZ", // New Zealand
+        "NE", // Niger
+        "NG", // Nigeria
+        "MK", // North Macedonia
+        "NO", // Norway
+        "OM", // Oman
+        "PK", // Pakistan
+        "PW", // Palau
+        "PS", // Palestine
+        "PA", // Panama
+        "PG", // Papua New Guinea
+        "PY", // Paraguay
+        "PE", // Peru
+        "PH", // Philippines
+        "PL", // Poland
+        "PT", // Portugal
+        "QA", // Qatar
+        "RO", // Romania
+        "RW", // Rwanda
+        "KN", // Saint Kitts and Nevis
+        "LC", // Saint Lucia
+        "VC", // Saint Vincent and the Grenadines
+        "WS", // Samoa
+        "SM", // San Marino
+        "ST", // São Tomé and Príncipe
+        "SA", // Saudi Arabia
+        "SN", // Senegal
+        "RS", // Serbia
+        "SC", // Seychelles
+        "SL", // Sierra Leone
+        "SG", // Singapore
+        "SK", // Slovakia
+        "SI", // Slovenia
+        "SB", // Solomon Islands
+        "ZA", // South Africa
+        "KR", // South Korea
+        "ES", // Spain
+        "LK", // Sri Lanka
+        "SR", // Suriname
+        "SE", // Sweden
+        "CH", // Switzerland
+        "TW", // Taiwan
+        "TJ", // Tajikistan
+        "TZ", // Tanzania
+        "TH", // Thailand
+        "TL", // Timor-Leste
+        "TG", // Togo
+        "TO", // Tonga
+        "TT", // Trinidad and Tobago
+        "TN", // Tunisia
+        "TR", // Türkiye (Turkey)
+        "TM", // Turkmenistan
+        "TV", // Tuvalu
+        "UG", // Uganda
+        "UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
+        "AE", // United Arab Emirates
+        "GB", // United Kingdom
+        "US", // United States of America
+        "UY", // Uruguay
+        "UZ", // Uzbekistan
+        "VU", // Vanuatu
+        "VA", // Vatican City
+        "VN", // Vietnam
+        "ZM", // Zambia
+        "ZW", // Zimbabwe
+    ]
+    .into_iter()
+    .collect()
+});

crates/collab/Cargo.toml 🔗

@@ -90,6 +90,7 @@ fs = { workspace = true, features = ["test-support"] }
 git = { workspace = true, features = ["test-support"] }
 git_hosting_providers.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
+hyper.workspace = true
 indoc.workspace = true
 language = { workspace = true, features = ["test-support"] }
 language_model = { workspace = true, features = ["test-support"] }

crates/collab/src/lib.rs 🔗

@@ -185,6 +185,46 @@ impl Config {
             _ => "https://zed.dev",
         }
     }
+
+    #[cfg(test)]
+    pub fn test() -> Self {
+        Self {
+            http_port: 0,
+            database_url: "".into(),
+            database_max_connections: 0,
+            api_token: "".into(),
+            invite_link_prefix: "".into(),
+            live_kit_server: None,
+            live_kit_key: None,
+            live_kit_secret: None,
+            llm_api_secret: None,
+            rust_log: None,
+            log_json: None,
+            zed_environment: "test".into(),
+            blob_store_url: None,
+            blob_store_region: None,
+            blob_store_access_key: None,
+            blob_store_secret_key: None,
+            blob_store_bucket: None,
+            openai_api_key: None,
+            google_ai_api_key: None,
+            anthropic_api_key: None,
+            clickhouse_url: None,
+            clickhouse_user: None,
+            clickhouse_password: None,
+            clickhouse_database: None,
+            zed_client_checksum_seed: None,
+            slack_panics_webhook: None,
+            auto_join_channel_id: None,
+            migrations_path: None,
+            seed_path: None,
+            stripe_api_key: None,
+            stripe_price_id: None,
+            supermaven_admin_api_key: None,
+            qwen2_7b_api_key: None,
+            qwen2_7b_api_url: None,
+        }
+    }
 }
 
 /// The service mode that collab should run in.

crates/collab/src/llm.rs 🔗

@@ -1,7 +1,11 @@
+mod authorization;
 mod token;
 
+use crate::api::CloudflareIpCountryHeader;
+use crate::llm::authorization::authorize_access_to_language_model;
 use crate::{executor::Executor, Config, Error, Result};
 use anyhow::Context as _;
+use axum::TypedHeader;
 use axum::{
     body::Body,
     http::{self, HeaderName, HeaderValue, Request, StatusCode},
@@ -91,9 +95,18 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
 
 async fn perform_completion(
     Extension(state): Extension<Arc<LlmState>>,
-    Extension(_claims): Extension<LlmTokenClaims>,
+    Extension(claims): Extension<LlmTokenClaims>,
+    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
     Json(params): Json<PerformCompletionParams>,
 ) -> Result<impl IntoResponse> {
+    authorize_access_to_language_model(
+        &state.config,
+        &claims,
+        country_code_header.map(|header| header.to_string()),
+        params.provider,
+        &params.model,
+    )?;
+
     match params.provider {
         LanguageModelProvider::Anthropic => {
             let api_key = state

crates/collab/src/llm/authorization.rs 🔗

@@ -0,0 +1,213 @@
+use reqwest::StatusCode;
+use rpc::LanguageModelProvider;
+
+use crate::llm::LlmTokenClaims;
+use crate::{Config, Error, Result};
+
+pub fn authorize_access_to_language_model(
+    config: &Config,
+    _claims: &LlmTokenClaims,
+    country_code: Option<String>,
+    provider: LanguageModelProvider,
+    model: &str,
+) -> Result<()> {
+    authorize_access_for_country(config, country_code, provider, model)?;
+
+    Ok(())
+}
+
+fn authorize_access_for_country(
+    config: &Config,
+    country_code: Option<String>,
+    provider: LanguageModelProvider,
+    _model: &str,
+) -> Result<()> {
+    // In development we won't have the `CF-IPCountry` header, so we can't check
+    // the country code.
+    //
+    // This shouldn't be necessary, as anyone running in development will need to provide
+    // their own API credentials in order to use an LLM provider.
+    if config.is_development() {
+        return Ok(());
+    }
+
+    // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
+    let country_code = match country_code.as_deref() {
+        // `XX` - Used for clients without country code data.
+        None | Some("XX") => Err(Error::http(
+            StatusCode::BAD_REQUEST,
+            "no country code".to_string(),
+        ))?,
+        // `T1` - Used for clients using the Tor network.
+        Some("T1") => Err(Error::http(
+            StatusCode::FORBIDDEN,
+            format!("access to {provider:?} models is not available over Tor"),
+        ))?,
+        Some(country_code) => country_code,
+    };
+
+    let is_country_supported_by_provider = match provider {
+        LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
+        LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
+        LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
+        LanguageModelProvider::Zed => true,
+    };
+    if !is_country_supported_by_provider {
+        Err(Error::http(
+            StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
+            format!("access to {provider:?} models is not available in your region"),
+        ))?
+    }
+
+    Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+    use axum::response::IntoResponse;
+    use pretty_assertions::assert_eq;
+    use rpc::proto::Plan;
+
+    use super::*;
+
+    #[gpui::test]
+    async fn test_authorize_access_to_language_model_with_supported_country(
+        _cx: &mut gpui::TestAppContext,
+    ) {
+        let config = Config::test();
+
+        let claims = LlmTokenClaims {
+            user_id: 99,
+            plan: Plan::ZedPro,
+            ..Default::default()
+        };
+
+        let cases = vec![
+            (LanguageModelProvider::Anthropic, "US"), // United States
+            (LanguageModelProvider::Anthropic, "GB"), // United Kingdom
+            (LanguageModelProvider::OpenAi, "US"),    // United States
+            (LanguageModelProvider::OpenAi, "GB"),    // United Kingdom
+            (LanguageModelProvider::Google, "US"),    // United States
+            (LanguageModelProvider::Google, "GB"),    // United Kingdom
+        ];
+
+        for (provider, country_code) in cases {
+            authorize_access_to_language_model(
+                &config,
+                &claims,
+                Some(country_code.into()),
+                provider,
+                "the-model",
+            )
+            .unwrap_or_else(|_| {
+                panic!("expected authorization to return Ok for {provider:?}: {country_code}")
+            })
+        }
+    }
+
+    #[gpui::test]
+    async fn test_authorize_access_to_language_model_with_unsupported_country(
+        _cx: &mut gpui::TestAppContext,
+    ) {
+        let config = Config::test();
+
+        let claims = LlmTokenClaims {
+            user_id: 99,
+            plan: Plan::ZedPro,
+            ..Default::default()
+        };
+
+        let cases = vec![
+            (LanguageModelProvider::Anthropic, "AF"), // Afghanistan
+            (LanguageModelProvider::Anthropic, "BY"), // Belarus
+            (LanguageModelProvider::Anthropic, "CF"), // Central African Republic
+            (LanguageModelProvider::Anthropic, "CN"), // China
+            (LanguageModelProvider::Anthropic, "CU"), // Cuba
+            (LanguageModelProvider::Anthropic, "ER"), // Eritrea
+            (LanguageModelProvider::Anthropic, "ET"), // Ethiopia
+            (LanguageModelProvider::Anthropic, "IR"), // Iran
+            (LanguageModelProvider::Anthropic, "KP"), // North Korea
+            (LanguageModelProvider::Anthropic, "XK"), // Kosovo
+            (LanguageModelProvider::Anthropic, "LY"), // Libya
+            (LanguageModelProvider::Anthropic, "MM"), // Myanmar
+            (LanguageModelProvider::Anthropic, "RU"), // Russia
+            (LanguageModelProvider::Anthropic, "SO"), // Somalia
+            (LanguageModelProvider::Anthropic, "SS"), // South Sudan
+            (LanguageModelProvider::Anthropic, "SD"), // Sudan
+            (LanguageModelProvider::Anthropic, "SY"), // Syria
+            (LanguageModelProvider::Anthropic, "VE"), // Venezuela
+            (LanguageModelProvider::Anthropic, "YE"), // Yemen
+            (LanguageModelProvider::OpenAi, "KP"),    // North Korea
+            (LanguageModelProvider::Google, "KP"),    // North Korea
+        ];
+
+        for (provider, country_code) in cases {
+            let error_response = authorize_access_to_language_model(
+                &config,
+                &claims,
+                Some(country_code.into()),
+                provider,
+                "the-model",
+            )
+            .expect_err(&format!(
+                "expected authorization to return an error for {provider:?}: {country_code}"
+            ))
+            .into_response();
+
+            assert_eq!(
+                error_response.status(),
+                StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
+            );
+            let response_body = hyper::body::to_bytes(error_response.into_body())
+                .await
+                .unwrap()
+                .to_vec();
+            assert_eq!(
+                String::from_utf8(response_body).unwrap(),
+                format!("access to {provider:?} models is not available in your region")
+            );
+        }
+    }
+
+    #[gpui::test]
+    async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
+        let config = Config::test();
+
+        let claims = LlmTokenClaims {
+            user_id: 99,
+            plan: Plan::ZedPro,
+            ..Default::default()
+        };
+
+        let cases = vec![
+            (LanguageModelProvider::Anthropic, "T1"), // Tor
+            (LanguageModelProvider::OpenAi, "T1"),    // Tor
+            (LanguageModelProvider::Google, "T1"),    // Tor
+            (LanguageModelProvider::Zed, "T1"),       // Tor
+        ];
+
+        for (provider, country_code) in cases {
+            let error_response = authorize_access_to_language_model(
+                &config,
+                &claims,
+                Some(country_code.into()),
+                provider,
+                "the-model",
+            )
+            .expect_err(&format!(
+                "expected authorization to return an error for {provider:?}: {country_code}"
+            ))
+            .into_response();
+
+            assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
+            let response_body = hyper::body::to_bytes(error_response.into_body())
+                .await
+                .unwrap()
+                .to_vec();
+            assert_eq!(
+                String::from_utf8(response_body).unwrap(),
+                format!("access to {provider:?} models is not available over Tor")
+            );
+        }
+    }
+}

crates/google_ai/src/google_ai.rs 🔗

@@ -1,8 +1,12 @@
+mod supported_countries;
+
 use anyhow::{anyhow, Result};
 use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use http_client::HttpClient;
 use serde::{Deserialize, Serialize};
 
+pub use supported_countries::*;
+
 pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 
 pub async fn stream_generate_content(

crates/google_ai/src/supported_countries.rs 🔗

@@ -0,0 +1,232 @@
+use std::collections::HashSet;
+use std::sync::LazyLock;
+
+/// Returns whether the given country code is supported by Google Gemini.
+///
+/// https://ai.google.dev/gemini-api/docs/available-regions
+pub fn is_supported_country(country_code: &str) -> bool {
+    SUPPORTED_COUNTRIES.contains(&country_code)
+}
+
+/// The list of country codes supported by Google Gemini.
+///
+/// https://ai.google.dev/gemini-api/docs/available-regions
+static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
+    vec![
+        "DZ", // Algeria
+        "AS", // American Samoa
+        "AO", // Angola
+        "AI", // Anguilla
+        "AQ", // Antarctica
+        "AG", // Antigua and Barbuda
+        "AR", // Argentina
+        "AM", // Armenia
+        "AW", // Aruba
+        "AU", // Australia
+        "AT", // Austria
+        "AZ", // Azerbaijan
+        "BS", // The Bahamas
+        "BH", // Bahrain
+        "BD", // Bangladesh
+        "BB", // Barbados
+        "BE", // Belgium
+        "BZ", // Belize
+        "BJ", // Benin
+        "BM", // Bermuda
+        "BT", // Bhutan
+        "BO", // Bolivia
+        "BW", // Botswana
+        "BR", // Brazil
+        "IO", // British Indian Ocean Territory
+        "VG", // British Virgin Islands
+        "BN", // Brunei
+        "BG", // Bulgaria
+        "BF", // Burkina Faso
+        "BI", // Burundi
+        "CV", // Cabo Verde
+        "KH", // Cambodia
+        "CM", // Cameroon
+        "CA", // Canada
+        "BQ", // Caribbean Netherlands
+        "KY", // Cayman Islands
+        "CF", // Central African Republic
+        "TD", // Chad
+        "CL", // Chile
+        "CX", // Christmas Island
+        "CC", // Cocos (Keeling) Islands
+        "CO", // Colombia
+        "KM", // Comoros
+        "CK", // Cook Islands
+        "CI", // Côte d'Ivoire
+        "CR", // Costa Rica
+        "HR", // Croatia
+        "CW", // Curaçao
+        "CZ", // Czech Republic
+        "CD", // Democratic Republic of the Congo
+        "DK", // Denmark
+        "DJ", // Djibouti
+        "DM", // Dominica
+        "DO", // Dominican Republic
+        "EC", // Ecuador
+        "EG", // Egypt
+        "SV", // El Salvador
+        "GQ", // Equatorial Guinea
+        "ER", // Eritrea
+        "EE", // Estonia
+        "SZ", // Eswatini
+        "ET", // Ethiopia
+        "FK", // Falkland Islands (Islas Malvinas)
+        "FJ", // Fiji
+        "FI", // Finland
+        "FR", // France
+        "GA", // Gabon
+        "GM", // The Gambia
+        "GE", // Georgia
+        "DE", // Germany
+        "GH", // Ghana
+        "GI", // Gibraltar
+        "GR", // Greece
+        "GD", // Grenada
+        "GU", // Guam
+        "GT", // Guatemala
+        "GG", // Guernsey
+        "GN", // Guinea
+        "GW", // Guinea-Bissau
+        "GY", // Guyana
+        "HT", // Haiti
+        "HM", // Heard Island and McDonald Islands
+        "HN", // Honduras
+        "HU", // Hungary
+        "IS", // Iceland
+        "IN", // India
+        "ID", // Indonesia
+        "IQ", // Iraq
+        "IE", // Ireland
+        "IM", // Isle of Man
+        "IL", // Israel
+        "IT", // Italy
+        "JM", // Jamaica
+        "JP", // Japan
+        "JE", // Jersey
+        "JO", // Jordan
+        "KZ", // Kazakhstan
+        "KE", // Kenya
+        "KI", // Kiribati
+        "KG", // Kyrgyzstan
+        "KW", // Kuwait
+        "LA", // Laos
+        "LV", // Latvia
+        "LB", // Lebanon
+        "LS", // Lesotho
+        "LR", // Liberia
+        "LY", // Libya
+        "LI", // Liechtenstein
+        "LT", // Lithuania
+        "LU", // Luxembourg
+        "MG", // Madagascar
+        "MW", // Malawi
+        "MY", // Malaysia
+        "MV", // Maldives
+        "ML", // Mali
+        "MT", // Malta
+        "MH", // Marshall Islands
+        "MR", // Mauritania
+        "MU", // Mauritius
+        "MX", // Mexico
+        "FM", // Micronesia
+        "MN", // Mongolia
+        "MS", // Montserrat
+        "MA", // Morocco
+        "MZ", // Mozambique
+        "NA", // Namibia
+        "NR", // Nauru
+        "NP", // Nepal
+        "NL", // Netherlands
+        "NC", // New Caledonia
+        "NZ", // New Zealand
+        "NI", // Nicaragua
+        "NE", // Niger
+        "NG", // Nigeria
+        "NU", // Niue
+        "NF", // Norfolk Island
+        "MP", // Northern Mariana Islands
+        "NO", // Norway
+        "OM", // Oman
+        "PK", // Pakistan
+        "PW", // Palau
+        "PS", // Palestine
+        "PA", // Panama
+        "PG", // Papua New Guinea
+        "PY", // Paraguay
+        "PE", // Peru
+        "PH", // Philippines
+        "PN", // Pitcairn Islands
+        "PL", // Poland
+        "PT", // Portugal
+        "PR", // Puerto Rico
+        "QA", // Qatar
+        "CY", // Republic of Cyprus
+        "CG", // Republic of the Congo
+        "RO", // Romania
+        "RW", // Rwanda
+        "BL", // Saint Barthélemy
+        "KN", // Saint Kitts and Nevis
+        "LC", // Saint Lucia
+        "PM", // Saint Pierre and Miquelon
+        "VC", // Saint Vincent and the Grenadines
+        "SH", // Saint Helena, Ascension and Tristan da Cunha
+        "WS", // Samoa
+        "ST", // São Tomé and Príncipe
+        "SA", // Saudi Arabia
+        "SN", // Senegal
+        "SC", // Seychelles
+        "SL", // Sierra Leone
+        "SG", // Singapore
+        "SK", // Slovakia
+        "SI", // Slovenia
+        "SB", // Solomon Islands
+        "SO", // Somalia
+        "ZA", // South Africa
+        "GS", // South Georgia and the South Sandwich Islands
+        "KR", // South Korea
+        "SS", // South Sudan
+        "ES", // Spain
+        "LK", // Sri Lanka
+        "SD", // Sudan
+        "SR", // Suriname
+        "SE", // Sweden
+        "CH", // Switzerland
+        "TW", // Taiwan
+        "TJ", // Tajikistan
+        "TZ", // Tanzania
+        "TH", // Thailand
+        "TL", // Timor-Leste
+        "TG", // Togo
+        "TK", // Tokelau
+        "TO", // Tonga
+        "TT", // Trinidad and Tobago
+        "TN", // Tunisia
+        "TR", // Türkiye
+        "TM", // Turkmenistan
+        "TC", // Turks and Caicos Islands
+        "TV", // Tuvalu
+        "UG", // Uganda
+        "GB", // United Kingdom
+        "AE", // United Arab Emirates
+        "US", // United States
+        "UM", // United States Minor Outlying Islands
+        "VI", // U.S. Virgin Islands
+        "UY", // Uruguay
+        "UZ", // Uzbekistan
+        "VU", // Vanuatu
+        "VE", // Venezuela
+        "VN", // Vietnam
+        "WF", // Wallis and Futuna
+        "EH", // Western Sahara
+        "YE", // Yemen
+        "ZM", // Zambia
+        "ZW", // Zimbabwe
+    ]
+    .into_iter()
+    .collect()
+});

crates/open_ai/src/open_ai.rs 🔗

@@ -1,3 +1,5 @@
+mod supported_countries;
+
 use anyhow::{anyhow, Context, Result};
 use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@@ -7,6 +9,8 @@ use serde_json::Value;
 use std::{convert::TryFrom, future::Future, time::Duration};
 use strum::EnumIter;
 
+pub use supported_countries::*;
+
 pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 
 fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {

crates/open_ai/src/supported_countries.rs 🔗

@@ -0,0 +1,207 @@
+use std::collections::HashSet;
+use std::sync::LazyLock;
+
+/// Returns whether the given country code is supported by OpenAI.
+///
+/// https://platform.openai.com/docs/supported-countries
+pub fn is_supported_country(country_code: &str) -> bool {
+    SUPPORTED_COUNTRIES.contains(&country_code)
+}
+
+/// The list of country codes supported by OpenAI.
+///
+/// https://platform.openai.com/docs/supported-countries
+static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
+    vec![
+        "AL", // Albania
+        "DZ", // Algeria
+        "AF", // Afghanistan
+        "AD", // Andorra
+        "AO", // Angola
+        "AG", // Antigua and Barbuda
+        "AR", // Argentina
+        "AM", // Armenia
+        "AU", // Australia
+        "AT", // Austria
+        "AZ", // Azerbaijan
+        "BS", // Bahamas
+        "BH", // Bahrain
+        "BD", // Bangladesh
+        "BB", // Barbados
+        "BE", // Belgium
+        "BZ", // Belize
+        "BJ", // Benin
+        "BT", // Bhutan
+        "BO", // Bolivia
+        "BA", // Bosnia and Herzegovina
+        "BW", // Botswana
+        "BR", // Brazil
+        "BN", // Brunei
+        "BG", // Bulgaria
+        "BF", // Burkina Faso
+        "BI", // Burundi
+        "CV", // Cabo Verde
+        "KH", // Cambodia
+        "CM", // Cameroon
+        "CA", // Canada
+        "CF", // Central African Republic
+        "TD", // Chad
+        "CL", // Chile
+        "CO", // Colombia
+        "KM", // Comoros
+        "CG", // Congo (Brazzaville)
+        "CD", // Congo (DRC)
+        "CR", // Costa Rica
+        "CI", // Côte d'Ivoire
+        "HR", // Croatia
+        "CY", // Cyprus
+        "CZ", // Czechia (Czech Republic)
+        "DK", // Denmark
+        "DJ", // Djibouti
+        "DM", // Dominica
+        "DO", // Dominican Republic
+        "EC", // Ecuador
+        "EG", // Egypt
+        "SV", // El Salvador
+        "GQ", // Equatorial Guinea
+        "ER", // Eritrea
+        "EE", // Estonia
+        "SZ", // Eswatini (Swaziland)
+        "ET", // Ethiopia
+        "FJ", // Fiji
+        "FI", // Finland
+        "FR", // France
+        "GA", // Gabon
+        "GM", // Gambia
+        "GE", // Georgia
+        "DE", // Germany
+        "GH", // Ghana
+        "GR", // Greece
+        "GD", // Grenada
+        "GT", // Guatemala
+        "GN", // Guinea
+        "GW", // Guinea-Bissau
+        "GY", // Guyana
+        "HT", // Haiti
+        "VA", // Holy See (Vatican City)
+        "HN", // Honduras
+        "HU", // Hungary
+        "IS", // Iceland
+        "IN", // India
+        "ID", // Indonesia
+        "IQ", // Iraq
+        "IE", // Ireland
+        "IL", // Israel
+        "IT", // Italy
+        "JM", // Jamaica
+        "JP", // Japan
+        "JO", // Jordan
+        "KZ", // Kazakhstan
+        "KE", // Kenya
+        "KI", // Kiribati
+        "KW", // Kuwait
+        "KG", // Kyrgyzstan
+        "LA", // Laos
+        "LV", // Latvia
+        "LB", // Lebanon
+        "LS", // Lesotho
+        "LR", // Liberia
+        "LY", // Libya
+        "LI", // Liechtenstein
+        "LT", // Lithuania
+        "LU", // Luxembourg
+        "MG", // Madagascar
+        "MW", // Malawi
+        "MY", // Malaysia
+        "MV", // Maldives
+        "ML", // Mali
+        "MT", // Malta
+        "MH", // Marshall Islands
+        "MR", // Mauritania
+        "MU", // Mauritius
+        "MX", // Mexico
+        "FM", // Micronesia
+        "MD", // Moldova
+        "MC", // Monaco
+        "MN", // Mongolia
+        "ME", // Montenegro
+        "MA", // Morocco
+        "MZ", // Mozambique
+        "MM", // Myanmar
+        "NA", // Namibia
+        "NR", // Nauru
+        "NP", // Nepal
+        "NL", // Netherlands
+        "NZ", // New Zealand
+        "NI", // Nicaragua
+        "NE", // Niger
+        "NG", // Nigeria
+        "MK", // North Macedonia
+        "NO", // Norway
+        "OM", // Oman
+        "PK", // Pakistan
+        "PW", // Palau
+        "PS", // Palestine
+        "PA", // Panama
+        "PG", // Papua New Guinea
+        "PY", // Paraguay
+        "PE", // Peru
+        "PH", // Philippines
+        "PL", // Poland
+        "PT", // Portugal
+        "QA", // Qatar
+        "RO", // Romania
+        "RW", // Rwanda
+        "KN", // Saint Kitts and Nevis
+        "LC", // Saint Lucia
+        "VC", // Saint Vincent and the Grenadines
+        "WS", // Samoa
+        "SM", // San Marino
+        "ST", // Sao Tome and Principe
+        "SA", // Saudi Arabia
+        "SN", // Senegal
+        "RS", // Serbia
+        "SC", // Seychelles
+        "SL", // Sierra Leone
+        "SG", // Singapore
+        "SK", // Slovakia
+        "SI", // Slovenia
+        "SB", // Solomon Islands
+        "SO", // Somalia
+        "ZA", // South Africa
+        "KR", // South Korea
+        "SS", // South Sudan
+        "ES", // Spain
+        "LK", // Sri Lanka
+        "SR", // Suriname
+        "SE", // Sweden
+        "CH", // Switzerland
+        "SD", // Sudan
+        "TW", // Taiwan
+        "TJ", // Tajikistan
+        "TZ", // Tanzania
+        "TH", // Thailand
+        "TL", // Timor-Leste (East Timor)
+        "TG", // Togo
+        "TO", // Tonga
+        "TT", // Trinidad and Tobago
+        "TN", // Tunisia
+        "TR", // Turkey
+        "TM", // Turkmenistan
+        "TV", // Tuvalu
+        "UG", // Uganda
+        "UA", // Ukraine (with certain exceptions)
+        "AE", // United Arab Emirates
+        "GB", // United Kingdom
+        "US", // United States of America
+        "UY", // Uruguay
+        "UZ", // Uzbekistan
+        "VU", // Vanuatu
+        "VN", // Vietnam
+        "YE", // Yemen
+        "ZM", // Zambia
+        "ZW", // Zimbabwe
+    ]
+    .into_iter()
+    .collect()
+});

typos.toml 🔗

@@ -6,6 +6,12 @@ extend-exclude = [
     # File suffixes aren't typos
     "assets/icons/file_icons/file_types.json",
     "crates/extensions_ui/src/extension_suggest.rs",
+
+    # Some countries codes are flagged as typos.
+    "crates/anthropic/src/supported_countries.rs",
+    "crates/google_ai/src/supported_countries.rs",
+    "crates/open_ai/src/supported_countries.rs",
+
     # Stripe IDs are flagged as typos.
     "crates/collab/src/db/tests/processed_stripe_event_tests.rs",
     # Not our typos