cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use client::{
   5    Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore,
   6    global_llm_token as global_llm_api_token, zed_urls,
   7};
   8use cloud_api_types::{OrganizationId, Plan};
   9use cloud_llm_client::{
  10    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
  11    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  12    CountTokensBody, CountTokensResponse, ListModelsResponse,
  13    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  14};
  15use futures::{
  16    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  17    future::BoxFuture,
  18    stream::{self, BoxStream},
  19};
  20use google_ai::GoogleModelMode;
  21use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  22use http_client::http::{HeaderMap, HeaderValue};
  23use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  24use language_model::{
  25    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, AuthenticateError, GOOGLE_PROVIDER_ID,
  26    GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  27    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  28    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  29    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  30    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
  31    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
  32    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
  33};
  34use release_channel::AppVersion;
  35use schemars::JsonSchema;
  36use semver::Version;
  37use serde::{Deserialize, Serialize, de::DeserializeOwned};
  38use settings::SettingsStore;
  39pub use settings::ZedDotDevAvailableModel as AvailableModel;
  40pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  41use smol::io::{AsyncReadExt, BufReader};
  42use std::collections::VecDeque;
  43use std::pin::Pin;
  44use std::str::FromStr;
  45use std::sync::Arc;
  46use std::task::Poll;
  47use std::time::Duration;
  48use thiserror::Error;
  49use ui::{TintColor, prelude::*};
  50
  51use crate::provider::anthropic::{
  52    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  53};
  54use crate::provider::google::{GoogleEventMapper, into_google};
  55use crate::provider::open_ai::{
  56    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  57    into_open_ai_response,
  58};
  59use crate::provider::x_ai::count_xai_tokens;
  60
  61const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
  62const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
  63
  64#[derive(Default, Clone, Debug, PartialEq)]
  65pub struct ZedDotDevSettings {
  66    pub available_models: Vec<AvailableModel>,
  67}
  68#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  69#[serde(tag = "type", rename_all = "lowercase")]
  70pub enum ModelMode {
  71    #[default]
  72    Default,
  73    Thinking {
  74        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  75        budget_tokens: Option<u32>,
  76    },
  77}
  78
  79impl From<ModelMode> for AnthropicModelMode {
  80    fn from(value: ModelMode) -> Self {
  81        match value {
  82            ModelMode::Default => AnthropicModelMode::Default,
  83            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  84        }
  85    }
  86}
  87
  88pub struct CloudLanguageModelProvider {
  89    client: Arc<Client>,
  90    state: Entity<State>,
  91    _maintain_client_status: Task<()>,
  92}
  93
  94pub struct State {
  95    client: Arc<Client>,
  96    llm_api_token: LlmApiToken,
  97    user_store: Entity<UserStore>,
  98    status: client::Status,
  99    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 100    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 101    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 102    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 103    _user_store_subscription: Subscription,
 104    _settings_subscription: Subscription,
 105    _llm_token_subscription: Subscription,
 106}
 107
 108impl State {
 109    fn new(
 110        client: Arc<Client>,
 111        user_store: Entity<UserStore>,
 112        status: client::Status,
 113        cx: &mut Context<Self>,
 114    ) -> Self {
 115        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 116        let llm_api_token = global_llm_api_token(cx);
 117        Self {
 118            client: client.clone(),
 119            llm_api_token,
 120            user_store: user_store.clone(),
 121            status,
 122            models: Vec::new(),
 123            default_model: None,
 124            default_fast_model: None,
 125            recommended_models: Vec::new(),
 126            _user_store_subscription: cx.subscribe(
 127                &user_store,
 128                move |this, _user_store, event, cx| match event {
 129                    client::user::Event::PrivateUserInfoUpdated => {
 130                        let status = *client.status().borrow();
 131                        if status.is_signed_out() {
 132                            return;
 133                        }
 134
 135                        let client = this.client.clone();
 136                        let llm_api_token = this.llm_api_token.clone();
 137                        let organization_id = this
 138                            .user_store
 139                            .read(cx)
 140                            .current_organization()
 141                            .map(|organization| organization.id.clone());
 142                        cx.spawn(async move |this, cx| {
 143                            let response =
 144                                Self::fetch_models(client, llm_api_token, organization_id).await?;
 145                            this.update(cx, |this, cx| this.update_models(response, cx))
 146                        })
 147                        .detach_and_log_err(cx);
 148                    }
 149                    _ => {}
 150                },
 151            ),
 152            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 153                cx.notify();
 154            }),
 155            _llm_token_subscription: cx.subscribe(
 156                &refresh_llm_token_listener,
 157                move |this, _listener, _event, cx| {
 158                    let client = this.client.clone();
 159                    let llm_api_token = this.llm_api_token.clone();
 160                    let organization_id = this
 161                        .user_store
 162                        .read(cx)
 163                        .current_organization()
 164                        .map(|organization| organization.id.clone());
 165                    cx.spawn(async move |this, cx| {
 166                        let response =
 167                            Self::fetch_models(client, llm_api_token, organization_id).await?;
 168                        this.update(cx, |this, cx| {
 169                            this.update_models(response, cx);
 170                        })
 171                    })
 172                    .detach_and_log_err(cx);
 173                },
 174            ),
 175        }
 176    }
 177
 178    fn is_signed_out(&self, cx: &App) -> bool {
 179        self.user_store.read(cx).current_user().is_none()
 180    }
 181
 182    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 183        let client = self.client.clone();
 184        cx.spawn(async move |state, cx| {
 185            client.sign_in_with_optional_connect(true, cx).await?;
 186            state.update(cx, |_, cx| cx.notify())
 187        })
 188    }
 189
 190    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 191        let mut models = Vec::new();
 192
 193        for model in response.models {
 194            models.push(Arc::new(model.clone()));
 195        }
 196
 197        self.default_model = models
 198            .iter()
 199            .find(|model| {
 200                response
 201                    .default_model
 202                    .as_ref()
 203                    .is_some_and(|default_model_id| &model.id == default_model_id)
 204            })
 205            .cloned();
 206        self.default_fast_model = models
 207            .iter()
 208            .find(|model| {
 209                response
 210                    .default_fast_model
 211                    .as_ref()
 212                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 213            })
 214            .cloned();
 215        self.recommended_models = response
 216            .recommended_models
 217            .iter()
 218            .filter_map(|id| models.iter().find(|model| &model.id == id))
 219            .cloned()
 220            .collect();
 221        self.models = models;
 222        cx.notify();
 223    }
 224
 225    async fn fetch_models(
 226        client: Arc<Client>,
 227        llm_api_token: LlmApiToken,
 228        organization_id: Option<OrganizationId>,
 229    ) -> Result<ListModelsResponse> {
 230        let http_client = &client.http_client();
 231        let token = client
 232            .acquire_llm_token(&llm_api_token, organization_id)
 233            .await?;
 234
 235        let request = http_client::Request::builder()
 236            .method(Method::GET)
 237            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 238            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 239            .header("Authorization", format!("Bearer {token}"))
 240            .body(AsyncBody::empty())?;
 241        let mut response = http_client
 242            .send(request)
 243            .await
 244            .context("failed to send list models request")?;
 245
 246        if response.status().is_success() {
 247            let mut body = String::new();
 248            response.body_mut().read_to_string(&mut body).await?;
 249            Ok(serde_json::from_str(&body)?)
 250        } else {
 251            let mut body = String::new();
 252            response.body_mut().read_to_string(&mut body).await?;
 253            anyhow::bail!(
 254                "error listing models.\nStatus: {:?}\nBody: {body}",
 255                response.status(),
 256            );
 257        }
 258    }
 259}
 260
 261impl CloudLanguageModelProvider {
 262    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 263        let mut status_rx = client.status();
 264        let status = *status_rx.borrow();
 265
 266        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 267
 268        let state_ref = state.downgrade();
 269        let maintain_client_status = cx.spawn(async move |cx| {
 270            while let Some(status) = status_rx.next().await {
 271                if let Some(this) = state_ref.upgrade() {
 272                    _ = this.update(cx, |this, cx| {
 273                        if this.status != status {
 274                            this.status = status;
 275                            cx.notify();
 276                        }
 277                    });
 278                } else {
 279                    break;
 280                }
 281            }
 282        });
 283
 284        Self {
 285            client,
 286            state,
 287            _maintain_client_status: maintain_client_status,
 288        }
 289    }
 290
 291    fn create_language_model(
 292        &self,
 293        model: Arc<cloud_llm_client::LanguageModel>,
 294        llm_api_token: LlmApiToken,
 295        user_store: Entity<UserStore>,
 296    ) -> Arc<dyn LanguageModel> {
 297        Arc::new(CloudLanguageModel {
 298            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 299            model,
 300            llm_api_token,
 301            user_store,
 302            client: self.client.clone(),
 303            request_limiter: RateLimiter::new(4),
 304        })
 305    }
 306}
 307
 308impl LanguageModelProviderState for CloudLanguageModelProvider {
 309    type ObservableEntity = State;
 310
 311    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 312        Some(self.state.clone())
 313    }
 314}
 315
 316impl LanguageModelProvider for CloudLanguageModelProvider {
 317    fn id(&self) -> LanguageModelProviderId {
 318        PROVIDER_ID
 319    }
 320
 321    fn name(&self) -> LanguageModelProviderName {
 322        PROVIDER_NAME
 323    }
 324
 325    fn icon(&self) -> IconOrSvg {
 326        IconOrSvg::Icon(IconName::AiZed)
 327    }
 328
 329    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 330        let state = self.state.read(cx);
 331        let default_model = state.default_model.clone()?;
 332        let llm_api_token = state.llm_api_token.clone();
 333        let user_store = state.user_store.clone();
 334        Some(self.create_language_model(default_model, llm_api_token, user_store))
 335    }
 336
 337    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 338        let state = self.state.read(cx);
 339        let default_fast_model = state.default_fast_model.clone()?;
 340        let llm_api_token = state.llm_api_token.clone();
 341        let user_store = state.user_store.clone();
 342        Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
 343    }
 344
 345    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 346        let state = self.state.read(cx);
 347        let llm_api_token = state.llm_api_token.clone();
 348        let user_store = state.user_store.clone();
 349        state
 350            .recommended_models
 351            .iter()
 352            .cloned()
 353            .map(|model| {
 354                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 355            })
 356            .collect()
 357    }
 358
 359    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 360        let state = self.state.read(cx);
 361        let llm_api_token = state.llm_api_token.clone();
 362        let user_store = state.user_store.clone();
 363        state
 364            .models
 365            .iter()
 366            .cloned()
 367            .map(|model| {
 368                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 369            })
 370            .collect()
 371    }
 372
 373    fn is_authenticated(&self, cx: &App) -> bool {
 374        let state = self.state.read(cx);
 375        !state.is_signed_out(cx)
 376    }
 377
 378    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 379        Task::ready(Ok(()))
 380    }
 381
 382    fn configuration_view(
 383        &self,
 384        _target_agent: language_model::ConfigurationViewTargetAgent,
 385        _: &mut Window,
 386        cx: &mut App,
 387    ) -> AnyView {
 388        cx.new(|_| ConfigurationView::new(self.state.clone()))
 389            .into()
 390    }
 391
 392    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 393        Task::ready(Ok(()))
 394    }
 395}
 396
 397pub struct CloudLanguageModel {
 398    id: LanguageModelId,
 399    model: Arc<cloud_llm_client::LanguageModel>,
 400    llm_api_token: LlmApiToken,
 401    user_store: Entity<UserStore>,
 402    client: Arc<Client>,
 403    request_limiter: RateLimiter,
 404}
 405
 406struct PerformLlmCompletionResponse {
 407    response: Response<AsyncBody>,
 408    includes_status_messages: bool,
 409}
 410
 411impl CloudLanguageModel {
 412    async fn perform_llm_completion(
 413        client: Arc<Client>,
 414        llm_api_token: LlmApiToken,
 415        organization_id: Option<OrganizationId>,
 416        app_version: Option<Version>,
 417        body: CompletionBody,
 418    ) -> Result<PerformLlmCompletionResponse> {
 419        let http_client = &client.http_client();
 420
 421        let mut token = client
 422            .acquire_llm_token(&llm_api_token, organization_id.clone())
 423            .await?;
 424        let mut refreshed_token = false;
 425
 426        loop {
 427            let request = http_client::Request::builder()
 428                .method(Method::POST)
 429                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 430                .when_some(app_version.as_ref(), |builder, app_version| {
 431                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 432                })
 433                .header("Content-Type", "application/json")
 434                .header("Authorization", format!("Bearer {token}"))
 435                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 436                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 437                .body(serde_json::to_string(&body)?.into())?;
 438
 439            let mut response = http_client.send(request).await?;
 440            let status = response.status();
 441            if status.is_success() {
 442                let includes_status_messages = response
 443                    .headers()
 444                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 445                    .is_some();
 446
 447                return Ok(PerformLlmCompletionResponse {
 448                    response,
 449                    includes_status_messages,
 450                });
 451            }
 452
 453            if !refreshed_token && response.needs_llm_token_refresh() {
 454                token = client
 455                    .refresh_llm_token(&llm_api_token, organization_id.clone())
 456                    .await?;
 457                refreshed_token = true;
 458                continue;
 459            }
 460
 461            if status == StatusCode::PAYMENT_REQUIRED {
 462                return Err(anyhow!(PaymentRequiredError));
 463            }
 464
 465            let mut body = String::new();
 466            let headers = response.headers().clone();
 467            response.body_mut().read_to_string(&mut body).await?;
 468            return Err(anyhow!(ApiError {
 469                status,
 470                body,
 471                headers
 472            }));
 473        }
 474    }
 475}
 476
 477#[derive(Debug, Error)]
 478#[error("cloud language model request failed with status {status}: {body}")]
 479struct ApiError {
 480    status: StatusCode,
 481    body: String,
 482    headers: HeaderMap<HeaderValue>,
 483}
 484
 485/// Represents error responses from Zed's cloud API.
 486///
 487/// Example JSON for an upstream HTTP error:
 488/// ```json
 489/// {
 490///   "code": "upstream_http_error",
 491///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 492///   "upstream_status": 503
 493/// }
 494/// ```
 495#[derive(Debug, serde::Deserialize)]
 496struct CloudApiError {
 497    code: String,
 498    message: String,
 499    #[serde(default)]
 500    #[serde(deserialize_with = "deserialize_optional_status_code")]
 501    upstream_status: Option<StatusCode>,
 502    #[serde(default)]
 503    retry_after: Option<f64>,
 504}
 505
 506fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 507where
 508    D: serde::Deserializer<'de>,
 509{
 510    let opt: Option<u16> = Option::deserialize(deserializer)?;
 511    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 512}
 513
 514impl From<ApiError> for LanguageModelCompletionError {
 515    fn from(error: ApiError) -> Self {
 516        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 517            if cloud_error.code.starts_with("upstream_http_") {
 518                let status = if let Some(status) = cloud_error.upstream_status {
 519                    status
 520                } else if cloud_error.code.ends_with("_error") {
 521                    error.status
 522                } else {
 523                    // If there's a status code in the code string (e.g. "upstream_http_429")
 524                    // then use that; otherwise, see if the JSON contains a status code.
 525                    cloud_error
 526                        .code
 527                        .strip_prefix("upstream_http_")
 528                        .and_then(|code_str| code_str.parse::<u16>().ok())
 529                        .and_then(|code| StatusCode::from_u16(code).ok())
 530                        .unwrap_or(error.status)
 531                };
 532
 533                return LanguageModelCompletionError::UpstreamProviderError {
 534                    message: cloud_error.message,
 535                    status,
 536                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 537                };
 538            }
 539
 540            return LanguageModelCompletionError::from_http_status(
 541                PROVIDER_NAME,
 542                error.status,
 543                cloud_error.message,
 544                None,
 545            );
 546        }
 547
 548        let retry_after = None;
 549        LanguageModelCompletionError::from_http_status(
 550            PROVIDER_NAME,
 551            error.status,
 552            error.body,
 553            retry_after,
 554        )
 555    }
 556}
 557
 558impl LanguageModel for CloudLanguageModel {
 559    fn id(&self) -> LanguageModelId {
 560        self.id.clone()
 561    }
 562
 563    fn name(&self) -> LanguageModelName {
 564        LanguageModelName::from(self.model.display_name.clone())
 565    }
 566
 567    fn provider_id(&self) -> LanguageModelProviderId {
 568        PROVIDER_ID
 569    }
 570
 571    fn provider_name(&self) -> LanguageModelProviderName {
 572        PROVIDER_NAME
 573    }
 574
 575    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 576        use cloud_llm_client::LanguageModelProvider::*;
 577        match self.model.provider {
 578            Anthropic => ANTHROPIC_PROVIDER_ID,
 579            OpenAi => OPEN_AI_PROVIDER_ID,
 580            Google => GOOGLE_PROVIDER_ID,
 581            XAi => X_AI_PROVIDER_ID,
 582        }
 583    }
 584
 585    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 586        use cloud_llm_client::LanguageModelProvider::*;
 587        match self.model.provider {
 588            Anthropic => ANTHROPIC_PROVIDER_NAME,
 589            OpenAi => OPEN_AI_PROVIDER_NAME,
 590            Google => GOOGLE_PROVIDER_NAME,
 591            XAi => X_AI_PROVIDER_NAME,
 592        }
 593    }
 594
 595    fn is_latest(&self) -> bool {
 596        self.model.is_latest
 597    }
 598
 599    fn supports_tools(&self) -> bool {
 600        self.model.supports_tools
 601    }
 602
 603    fn supports_images(&self) -> bool {
 604        self.model.supports_images
 605    }
 606
 607    fn supports_thinking(&self) -> bool {
 608        self.model.supports_thinking
 609    }
 610
 611    fn supports_fast_mode(&self) -> bool {
 612        self.model.supports_fast_mode
 613    }
 614
 615    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 616        self.model
 617            .supported_effort_levels
 618            .iter()
 619            .map(|effort_level| LanguageModelEffortLevel {
 620                name: effort_level.name.clone().into(),
 621                value: effort_level.value.clone().into(),
 622                is_default: effort_level.is_default.unwrap_or(false),
 623            })
 624            .collect()
 625    }
 626
 627    fn supports_streaming_tools(&self) -> bool {
 628        self.model.supports_streaming_tools
 629    }
 630
 631    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 632        match choice {
 633            LanguageModelToolChoice::Auto
 634            | LanguageModelToolChoice::Any
 635            | LanguageModelToolChoice::None => true,
 636        }
 637    }
 638
 639    fn supports_split_token_display(&self) -> bool {
 640        use cloud_llm_client::LanguageModelProvider::*;
 641        matches!(self.model.provider, OpenAi | XAi)
 642    }
 643
 644    fn telemetry_id(&self) -> String {
 645        format!("zed.dev/{}", self.model.id)
 646    }
 647
 648    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 649        match self.model.provider {
 650            cloud_llm_client::LanguageModelProvider::Anthropic
 651            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 652                LanguageModelToolSchemaFormat::JsonSchema
 653            }
 654            cloud_llm_client::LanguageModelProvider::Google
 655            | cloud_llm_client::LanguageModelProvider::XAi => {
 656                LanguageModelToolSchemaFormat::JsonSchemaSubset
 657            }
 658        }
 659    }
 660
 661    fn max_token_count(&self) -> u64 {
 662        self.model.max_token_count as u64
 663    }
 664
 665    fn max_output_tokens(&self) -> Option<u64> {
 666        Some(self.model.max_output_tokens as u64)
 667    }
 668
 669    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 670        match &self.model.provider {
 671            cloud_llm_client::LanguageModelProvider::Anthropic => {
 672                Some(LanguageModelCacheConfiguration {
 673                    min_total_token: 2_048,
 674                    should_speculate: true,
 675                    max_cache_anchors: 4,
 676                })
 677            }
 678            cloud_llm_client::LanguageModelProvider::OpenAi
 679            | cloud_llm_client::LanguageModelProvider::XAi
 680            | cloud_llm_client::LanguageModelProvider::Google => None,
 681        }
 682    }
 683
 684    fn count_tokens(
 685        &self,
 686        request: LanguageModelRequest,
 687        cx: &App,
 688    ) -> BoxFuture<'static, Result<u64>> {
 689        match self.model.provider {
 690            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 691                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 692                .boxed(),
 693            cloud_llm_client::LanguageModelProvider::OpenAi => {
 694                let model = match open_ai::Model::from_id(&self.model.id.0) {
 695                    Ok(model) => model,
 696                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 697                };
 698                count_open_ai_tokens(request, model, cx)
 699            }
 700            cloud_llm_client::LanguageModelProvider::XAi => {
 701                let model = match x_ai::Model::from_id(&self.model.id.0) {
 702                    Ok(model) => model,
 703                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 704                };
 705                count_xai_tokens(request, model, cx)
 706            }
 707            cloud_llm_client::LanguageModelProvider::Google => {
 708                let client = self.client.clone();
 709                let llm_api_token = self.llm_api_token.clone();
 710                let organization_id = self
 711                    .user_store
 712                    .read(cx)
 713                    .current_organization()
 714                    .map(|organization| organization.id.clone());
 715                let model_id = self.model.id.to_string();
 716                let generate_content_request =
 717                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 718                async move {
 719                    let http_client = &client.http_client();
 720                    let token = client
 721                        .acquire_llm_token(&llm_api_token, organization_id)
 722                        .await?;
 723
 724                    let request_body = CountTokensBody {
 725                        provider: cloud_llm_client::LanguageModelProvider::Google,
 726                        model: model_id,
 727                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 728                            generate_content_request,
 729                        })?,
 730                    };
 731                    let request = http_client::Request::builder()
 732                        .method(Method::POST)
 733                        .uri(
 734                            http_client
 735                                .build_zed_llm_url("/count_tokens", &[])?
 736                                .as_ref(),
 737                        )
 738                        .header("Content-Type", "application/json")
 739                        .header("Authorization", format!("Bearer {token}"))
 740                        .body(serde_json::to_string(&request_body)?.into())?;
 741                    let mut response = http_client.send(request).await?;
 742                    let status = response.status();
 743                    let headers = response.headers().clone();
 744                    let mut response_body = String::new();
 745                    response
 746                        .body_mut()
 747                        .read_to_string(&mut response_body)
 748                        .await?;
 749
 750                    if status.is_success() {
 751                        let response_body: CountTokensResponse =
 752                            serde_json::from_str(&response_body)?;
 753
 754                        Ok(response_body.tokens as u64)
 755                    } else {
 756                        Err(anyhow!(ApiError {
 757                            status,
 758                            body: response_body,
 759                            headers
 760                        }))
 761                    }
 762                }
 763                .boxed()
 764            }
 765        }
 766    }
 767
 768    fn stream_completion(
 769        &self,
 770        request: LanguageModelRequest,
 771        cx: &AsyncApp,
 772    ) -> BoxFuture<
 773        'static,
 774        Result<
 775            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 776            LanguageModelCompletionError,
 777        >,
 778    > {
 779        let thread_id = request.thread_id.clone();
 780        let prompt_id = request.prompt_id.clone();
 781        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 782        let user_store = self.user_store.clone();
 783        let organization_id = cx.update(|cx| {
 784            user_store
 785                .read(cx)
 786                .current_organization()
 787                .map(|organization| organization.id.clone())
 788        });
 789        let thinking_allowed = request.thinking_allowed;
 790        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 791        let provider_name = provider_name(&self.model.provider);
 792        match self.model.provider {
 793            cloud_llm_client::LanguageModelProvider::Anthropic => {
 794                let effort = request
 795                    .thinking_effort
 796                    .as_ref()
 797                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 798
 799                let mut request = into_anthropic(
 800                    request,
 801                    self.model.id.to_string(),
 802                    1.0,
 803                    self.model.max_output_tokens as u64,
 804                    if enable_thinking {
 805                        AnthropicModelMode::Thinking {
 806                            budget_tokens: Some(4_096),
 807                        }
 808                    } else {
 809                        AnthropicModelMode::Default
 810                    },
 811                );
 812
 813                if enable_thinking && effort.is_some() {
 814                    request.thinking = Some(anthropic::Thinking::Adaptive);
 815                    request.output_config = Some(anthropic::OutputConfig { effort });
 816                }
 817
 818                let client = self.client.clone();
 819                let llm_api_token = self.llm_api_token.clone();
 820                let organization_id = organization_id.clone();
 821                let future = self.request_limiter.stream(async move {
 822                    let PerformLlmCompletionResponse {
 823                        response,
 824                        includes_status_messages,
 825                    } = Self::perform_llm_completion(
 826                        client.clone(),
 827                        llm_api_token,
 828                        organization_id,
 829                        app_version,
 830                        CompletionBody {
 831                            thread_id,
 832                            prompt_id,
 833                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 834                            model: request.model.clone(),
 835                            provider_request: serde_json::to_value(&request)
 836                                .map_err(|e| anyhow!(e))?,
 837                        },
 838                    )
 839                    .await
 840                    .map_err(|err| match err.downcast::<ApiError>() {
 841                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 842                        Err(err) => anyhow!(err),
 843                    })?;
 844
 845                    let mut mapper = AnthropicEventMapper::new();
 846                    Ok(map_cloud_completion_events(
 847                        Box::pin(response_lines(response, includes_status_messages)),
 848                        &provider_name,
 849                        move |event| mapper.map_event(event),
 850                    ))
 851                });
 852                async move { Ok(future.await?.boxed()) }.boxed()
 853            }
 854            cloud_llm_client::LanguageModelProvider::OpenAi => {
 855                let client = self.client.clone();
 856                let llm_api_token = self.llm_api_token.clone();
 857                let organization_id = organization_id.clone();
 858                let effort = request
 859                    .thinking_effort
 860                    .as_ref()
 861                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 862
 863                let mut request = into_open_ai_response(
 864                    request,
 865                    &self.model.id.0,
 866                    self.model.supports_parallel_tool_calls,
 867                    true,
 868                    None,
 869                    None,
 870                );
 871
 872                if enable_thinking && let Some(effort) = effort {
 873                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
 874                        effort,
 875                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
 876                    });
 877                }
 878
 879                let future = self.request_limiter.stream(async move {
 880                    let PerformLlmCompletionResponse {
 881                        response,
 882                        includes_status_messages,
 883                    } = Self::perform_llm_completion(
 884                        client.clone(),
 885                        llm_api_token,
 886                        organization_id,
 887                        app_version,
 888                        CompletionBody {
 889                            thread_id,
 890                            prompt_id,
 891                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 892                            model: request.model.clone(),
 893                            provider_request: serde_json::to_value(&request)
 894                                .map_err(|e| anyhow!(e))?,
 895                        },
 896                    )
 897                    .await?;
 898
 899                    let mut mapper = OpenAiResponseEventMapper::new();
 900                    Ok(map_cloud_completion_events(
 901                        Box::pin(response_lines(response, includes_status_messages)),
 902                        &provider_name,
 903                        move |event| mapper.map_event(event),
 904                    ))
 905                });
 906                async move { Ok(future.await?.boxed()) }.boxed()
 907            }
 908            cloud_llm_client::LanguageModelProvider::XAi => {
 909                let client = self.client.clone();
 910                let request = into_open_ai(
 911                    request,
 912                    &self.model.id.0,
 913                    self.model.supports_parallel_tool_calls,
 914                    false,
 915                    None,
 916                    None,
 917                );
 918                let llm_api_token = self.llm_api_token.clone();
 919                let organization_id = organization_id.clone();
 920                let future = self.request_limiter.stream(async move {
 921                    let PerformLlmCompletionResponse {
 922                        response,
 923                        includes_status_messages,
 924                    } = Self::perform_llm_completion(
 925                        client.clone(),
 926                        llm_api_token,
 927                        organization_id,
 928                        app_version,
 929                        CompletionBody {
 930                            thread_id,
 931                            prompt_id,
 932                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 933                            model: request.model.clone(),
 934                            provider_request: serde_json::to_value(&request)
 935                                .map_err(|e| anyhow!(e))?,
 936                        },
 937                    )
 938                    .await?;
 939
 940                    let mut mapper = OpenAiEventMapper::new();
 941                    Ok(map_cloud_completion_events(
 942                        Box::pin(response_lines(response, includes_status_messages)),
 943                        &provider_name,
 944                        move |event| mapper.map_event(event),
 945                    ))
 946                });
 947                async move { Ok(future.await?.boxed()) }.boxed()
 948            }
 949            cloud_llm_client::LanguageModelProvider::Google => {
 950                let client = self.client.clone();
 951                let request =
 952                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 953                let llm_api_token = self.llm_api_token.clone();
 954                let future = self.request_limiter.stream(async move {
 955                    let PerformLlmCompletionResponse {
 956                        response,
 957                        includes_status_messages,
 958                    } = Self::perform_llm_completion(
 959                        client.clone(),
 960                        llm_api_token,
 961                        organization_id,
 962                        app_version,
 963                        CompletionBody {
 964                            thread_id,
 965                            prompt_id,
 966                            provider: cloud_llm_client::LanguageModelProvider::Google,
 967                            model: request.model.model_id.clone(),
 968                            provider_request: serde_json::to_value(&request)
 969                                .map_err(|e| anyhow!(e))?,
 970                        },
 971                    )
 972                    .await?;
 973
 974                    let mut mapper = GoogleEventMapper::new();
 975                    Ok(map_cloud_completion_events(
 976                        Box::pin(response_lines(response, includes_status_messages)),
 977                        &provider_name,
 978                        move |event| mapper.map_event(event),
 979                    ))
 980                });
 981                async move { Ok(future.await?.boxed()) }.boxed()
 982            }
 983        }
 984    }
 985}
 986
 987fn map_cloud_completion_events<T, F>(
 988    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 989    provider: &LanguageModelProviderName,
 990    mut map_callback: F,
 991) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 992where
 993    T: DeserializeOwned + 'static,
 994    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 995        + Send
 996        + 'static,
 997{
 998    let provider = provider.clone();
 999    let mut stream = stream.fuse();
1000
1001    let mut saw_stream_ended = false;
1002
1003    let mut done = false;
1004    let mut pending = VecDeque::new();
1005
1006    stream::poll_fn(move |cx| {
1007        loop {
1008            if let Some(item) = pending.pop_front() {
1009                return Poll::Ready(Some(item));
1010            }
1011
1012            if done {
1013                return Poll::Ready(None);
1014            }
1015
1016            match stream.poll_next_unpin(cx) {
1017                Poll::Ready(Some(event)) => {
1018                    let items = match event {
1019                        Err(error) => {
1020                            vec![Err(LanguageModelCompletionError::from(error))]
1021                        }
1022                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
1023                            saw_stream_ended = true;
1024                            vec![]
1025                        }
1026                        Ok(CompletionEvent::Status(status)) => {
1027                            LanguageModelCompletionEvent::from_completion_request_status(
1028                                status,
1029                                provider.clone(),
1030                            )
1031                            .transpose()
1032                            .map(|event| vec![event])
1033                            .unwrap_or_default()
1034                        }
1035                        Ok(CompletionEvent::Event(event)) => map_callback(event),
1036                    };
1037                    pending.extend(items);
1038                }
1039                Poll::Ready(None) => {
1040                    done = true;
1041
1042                    if !saw_stream_ended {
1043                        return Poll::Ready(Some(Err(
1044                            LanguageModelCompletionError::StreamEndedUnexpectedly {
1045                                provider: provider.clone(),
1046                            },
1047                        )));
1048                    }
1049                }
1050                Poll::Pending => return Poll::Pending,
1051            }
1052        }
1053    })
1054    .boxed()
1055}
1056
1057fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1058    match provider {
1059        cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
1060        cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
1061        cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
1062        cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
1063    }
1064}
1065
1066fn response_lines<T: DeserializeOwned>(
1067    response: Response<AsyncBody>,
1068    includes_status_messages: bool,
1069) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1070    futures::stream::try_unfold(
1071        (String::new(), BufReader::new(response.into_body())),
1072        move |(mut line, mut body)| async move {
1073            match body.read_line(&mut line).await {
1074                Ok(0) => Ok(None),
1075                Ok(_) => {
1076                    let event = if includes_status_messages {
1077                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1078                    } else {
1079                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1080                    };
1081
1082                    line.clear();
1083                    Ok(Some((event, (line, body))))
1084                }
1085                Err(e) => Err(e.into()),
1086            }
1087        },
1088    )
1089}
1090
1091#[derive(IntoElement, RegisterComponent)]
1092struct ZedAiConfiguration {
1093    is_connected: bool,
1094    plan: Option<Plan>,
1095    eligible_for_trial: bool,
1096    account_too_young: bool,
1097    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1098}
1099
1100impl RenderOnce for ZedAiConfiguration {
1101    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1102        let (subscription_text, has_paid_plan) = match self.plan {
1103            Some(Plan::ZedPro) => (
1104                "You have access to Zed's hosted models through your Pro subscription.",
1105                true,
1106            ),
1107            Some(Plan::ZedProTrial) => (
1108                "You have access to Zed's hosted models through your Pro trial.",
1109                false,
1110            ),
1111            Some(Plan::ZedStudent) => (
1112                "You have access to Zed's hosted models through your Student subscription.",
1113                true,
1114            ),
1115            Some(Plan::ZedBusiness) => (
1116                "You have access to Zed's hosted models through your Organization.",
1117                true,
1118            ),
1119            Some(Plan::ZedFree) | None => (
1120                if self.eligible_for_trial {
1121                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1122                } else {
1123                    "Subscribe for access to Zed's hosted models."
1124                },
1125                false,
1126            ),
1127        };
1128
1129        let manage_subscription_buttons = if has_paid_plan {
1130            Button::new("manage_settings", "Manage Subscription")
1131                .full_width()
1132                .label_size(LabelSize::Small)
1133                .style(ButtonStyle::Tinted(TintColor::Accent))
1134                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1135                .into_any_element()
1136        } else if self.plan.is_none() || self.eligible_for_trial {
1137            Button::new("start_trial", "Start 14-day Free Pro Trial")
1138                .full_width()
1139                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1140                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1141                .into_any_element()
1142        } else {
1143            Button::new("upgrade", "Upgrade to Pro")
1144                .full_width()
1145                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1146                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1147                .into_any_element()
1148        };
1149
1150        if !self.is_connected {
1151            return v_flex()
1152                .gap_2()
1153                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1154                .child(
1155                    Button::new("sign_in", "Sign In to use Zed AI")
1156                        .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
1157                        .full_width()
1158                        .on_click({
1159                            let callback = self.sign_in_callback.clone();
1160                            move |_, window, cx| (callback)(window, cx)
1161                        }),
1162                );
1163        }
1164
1165        v_flex().gap_2().w_full().map(|this| {
1166            if self.account_too_young {
1167                this.child(YoungAccountBanner).child(
1168                    Button::new("upgrade", "Upgrade to Pro")
1169                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1170                        .full_width()
1171                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1172                )
1173            } else {
1174                this.text_sm()
1175                    .child(subscription_text)
1176                    .child(manage_subscription_buttons)
1177            }
1178        })
1179    }
1180}
1181
1182struct ConfigurationView {
1183    state: Entity<State>,
1184    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1185}
1186
1187impl ConfigurationView {
1188    fn new(state: Entity<State>) -> Self {
1189        let sign_in_callback = Arc::new({
1190            let state = state.clone();
1191            move |_window: &mut Window, cx: &mut App| {
1192                state.update(cx, |state, cx| {
1193                    state.authenticate(cx).detach_and_log_err(cx);
1194                });
1195            }
1196        });
1197
1198        Self {
1199            state,
1200            sign_in_callback,
1201        }
1202    }
1203}
1204
1205impl Render for ConfigurationView {
1206    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1207        let state = self.state.read(cx);
1208        let user_store = state.user_store.read(cx);
1209
1210        ZedAiConfiguration {
1211            is_connected: !state.is_signed_out(cx),
1212            plan: user_store.plan(),
1213            eligible_for_trial: user_store.trial_started_at().is_none(),
1214            account_too_young: user_store.account_too_young(),
1215            sign_in_callback: self.sign_in_callback.clone(),
1216        }
1217    }
1218}
1219
1220impl Component for ZedAiConfiguration {
1221    fn name() -> &'static str {
1222        "AI Configuration Content"
1223    }
1224
1225    fn sort_name() -> &'static str {
1226        "AI Configuration Content"
1227    }
1228
1229    fn scope() -> ComponentScope {
1230        ComponentScope::Onboarding
1231    }
1232
1233    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1234        fn configuration(
1235            is_connected: bool,
1236            plan: Option<Plan>,
1237            eligible_for_trial: bool,
1238            account_too_young: bool,
1239        ) -> AnyElement {
1240            ZedAiConfiguration {
1241                is_connected,
1242                plan,
1243                eligible_for_trial,
1244                account_too_young,
1245                sign_in_callback: Arc::new(|_, _| {}),
1246            }
1247            .into_any_element()
1248        }
1249
1250        Some(
1251            v_flex()
1252                .p_4()
1253                .gap_4()
1254                .children(vec![
1255                    single_example("Not connected", configuration(false, None, false, false)),
1256                    single_example(
1257                        "Accept Terms of Service",
1258                        configuration(true, None, true, false),
1259                    ),
1260                    single_example(
1261                        "No Plan - Not eligible for trial",
1262                        configuration(true, None, false, false),
1263                    ),
1264                    single_example(
1265                        "No Plan - Eligible for trial",
1266                        configuration(true, None, true, false),
1267                    ),
1268                    single_example(
1269                        "Free Plan",
1270                        configuration(true, Some(Plan::ZedFree), true, false),
1271                    ),
1272                    single_example(
1273                        "Zed Pro Trial Plan",
1274                        configuration(true, Some(Plan::ZedProTrial), true, false),
1275                    ),
1276                    single_example(
1277                        "Zed Pro Plan",
1278                        configuration(true, Some(Plan::ZedPro), true, false),
1279                    ),
1280                ])
1281                .into_any_element(),
1282        )
1283    }
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288    use super::*;
1289    use http_client::http::{HeaderMap, StatusCode};
1290    use language_model::LanguageModelCompletionError;
1291
1292    #[test]
1293    fn test_api_error_conversion_with_upstream_http_error() {
1294        // upstream_http_error with 503 status should become ServerOverloaded
1295        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1296
1297        let api_error = ApiError {
1298            status: StatusCode::INTERNAL_SERVER_ERROR,
1299            body: error_body.to_string(),
1300            headers: HeaderMap::new(),
1301        };
1302
1303        let completion_error: LanguageModelCompletionError = api_error.into();
1304
1305        match completion_error {
1306            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1307                assert_eq!(
1308                    message,
1309                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1310                );
1311            }
1312            _ => panic!(
1313                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1314                completion_error
1315            ),
1316        }
1317
1318        // upstream_http_error with 500 status should become ApiInternalServerError
1319        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1320
1321        let api_error = ApiError {
1322            status: StatusCode::INTERNAL_SERVER_ERROR,
1323            body: error_body.to_string(),
1324            headers: HeaderMap::new(),
1325        };
1326
1327        let completion_error: LanguageModelCompletionError = api_error.into();
1328
1329        match completion_error {
1330            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1331                assert_eq!(
1332                    message,
1333                    "Received an error from the OpenAI API: internal server error"
1334                );
1335            }
1336            _ => panic!(
1337                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1338                completion_error
1339            ),
1340        }
1341
1342        // upstream_http_error with 429 status should become RateLimitExceeded
1343        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1344
1345        let api_error = ApiError {
1346            status: StatusCode::INTERNAL_SERVER_ERROR,
1347            body: error_body.to_string(),
1348            headers: HeaderMap::new(),
1349        };
1350
1351        let completion_error: LanguageModelCompletionError = api_error.into();
1352
1353        match completion_error {
1354            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1355                assert_eq!(
1356                    message,
1357                    "Received an error from the Google API: rate limit exceeded"
1358                );
1359            }
1360            _ => panic!(
1361                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1362                completion_error
1363            ),
1364        }
1365
1366        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1367        let error_body = "Regular internal server error";
1368
1369        let api_error = ApiError {
1370            status: StatusCode::INTERNAL_SERVER_ERROR,
1371            body: error_body.to_string(),
1372            headers: HeaderMap::new(),
1373        };
1374
1375        let completion_error: LanguageModelCompletionError = api_error.into();
1376
1377        match completion_error {
1378            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1379                assert_eq!(provider, PROVIDER_NAME);
1380                assert_eq!(message, "Regular internal server error");
1381            }
1382            _ => panic!(
1383                "Expected ApiInternalServerError for regular 500, got: {:?}",
1384                completion_error
1385            ),
1386        }
1387
1388        // upstream_http_429 format should be converted to UpstreamProviderError
1389        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1390
1391        let api_error = ApiError {
1392            status: StatusCode::INTERNAL_SERVER_ERROR,
1393            body: error_body.to_string(),
1394            headers: HeaderMap::new(),
1395        };
1396
1397        let completion_error: LanguageModelCompletionError = api_error.into();
1398
1399        match completion_error {
1400            LanguageModelCompletionError::UpstreamProviderError {
1401                message,
1402                status,
1403                retry_after,
1404            } => {
1405                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1406                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1407                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1408            }
1409            _ => panic!(
1410                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1411                completion_error
1412            ),
1413        }
1414
1415        // Invalid JSON in error body should fall back to regular error handling
1416        let error_body = "Not JSON at all";
1417
1418        let api_error = ApiError {
1419            status: StatusCode::INTERNAL_SERVER_ERROR,
1420            body: error_body.to_string(),
1421            headers: HeaderMap::new(),
1422        };
1423
1424        let completion_error: LanguageModelCompletionError = api_error.into();
1425
1426        match completion_error {
1427            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1428                assert_eq!(provider, PROVIDER_NAME);
1429            }
1430            _ => panic!(
1431                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1432                completion_error
1433            ),
1434        }
1435    }
1436}