cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, UserStore, zed_urls};
   6use cloud_api_types::Plan;
   7use cloud_llm_client::{
   8    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   9    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  10    CountTokensBody, CountTokensResponse, ListModelsResponse,
  11    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  12};
  13use futures::{
  14    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  15    future::BoxFuture,
  16    stream::{self, BoxStream},
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  25    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  26    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::collections::VecDeque;
  39use std::pin::Pin;
  40use std::str::FromStr;
  41use std::sync::Arc;
  42use std::task::Poll;
  43use std::time::Duration;
  44use thiserror::Error;
  45use ui::{TintColor, prelude::*};
  46use util::{ResultExt as _, maybe};
  47
  48use crate::provider::anthropic::{
  49    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  50};
  51use crate::provider::google::{GoogleEventMapper, into_google};
  52use crate::provider::open_ai::{
  53    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  54    into_open_ai_response,
  55};
  56use crate::provider::x_ai::count_xai_tokens;
  57
  58const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  59const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  60
  61#[derive(Default, Clone, Debug, PartialEq)]
  62pub struct ZedDotDevSettings {
  63    pub available_models: Vec<AvailableModel>,
  64}
  65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66#[serde(tag = "type", rename_all = "lowercase")]
  67pub enum ModelMode {
  68    #[default]
  69    Default,
  70    Thinking {
  71        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  72        budget_tokens: Option<u32>,
  73    },
  74}
  75
  76impl From<ModelMode> for AnthropicModelMode {
  77    fn from(value: ModelMode) -> Self {
  78        match value {
  79            ModelMode::Default => AnthropicModelMode::Default,
  80            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  81        }
  82    }
  83}
  84
  85pub struct CloudLanguageModelProvider {
  86    client: Arc<Client>,
  87    state: Entity<State>,
  88    _maintain_client_status: Task<()>,
  89}
  90
  91pub struct State {
  92    client: Arc<Client>,
  93    llm_api_token: LlmApiToken,
  94    user_store: Entity<UserStore>,
  95    status: client::Status,
  96    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  97    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  98    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  99    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 100    _fetch_models_task: Task<()>,
 101    _settings_subscription: Subscription,
 102    _llm_token_subscription: Subscription,
 103}
 104
 105impl State {
 106    fn new(
 107        client: Arc<Client>,
 108        user_store: Entity<UserStore>,
 109        status: client::Status,
 110        cx: &mut Context<Self>,
 111    ) -> Self {
 112        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 113        let mut current_user = user_store.read(cx).watch_current_user();
 114        Self {
 115            client: client.clone(),
 116            llm_api_token: LlmApiToken::default(),
 117            user_store,
 118            status,
 119            models: Vec::new(),
 120            default_model: None,
 121            default_fast_model: None,
 122            recommended_models: Vec::new(),
 123            _fetch_models_task: cx.spawn(async move |this, cx| {
 124                maybe!(async move {
 125                    let (client, llm_api_token) = this
 126                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 127
 128                    while current_user.borrow().is_none() {
 129                        current_user.next().await;
 130                    }
 131
 132                    let response =
 133                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 134                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 135                    anyhow::Ok(())
 136                })
 137                .await
 138                .context("failed to fetch Zed models")
 139                .log_err();
 140            }),
 141            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 142                cx.notify();
 143            }),
 144            _llm_token_subscription: cx.subscribe(
 145                &refresh_llm_token_listener,
 146                move |this, _listener, _event, cx| {
 147                    let client = this.client.clone();
 148                    let llm_api_token = this.llm_api_token.clone();
 149                    cx.spawn(async move |this, cx| {
 150                        llm_api_token.refresh(&client).await?;
 151                        let response = Self::fetch_models(client, llm_api_token).await?;
 152                        this.update(cx, |this, cx| {
 153                            this.update_models(response, cx);
 154                        })
 155                    })
 156                    .detach_and_log_err(cx);
 157                },
 158            ),
 159        }
 160    }
 161
 162    fn is_signed_out(&self, cx: &App) -> bool {
 163        self.user_store.read(cx).current_user().is_none()
 164    }
 165
 166    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 167        let client = self.client.clone();
 168        cx.spawn(async move |state, cx| {
 169            client.sign_in_with_optional_connect(true, cx).await?;
 170            state.update(cx, |_, cx| cx.notify())
 171        })
 172    }
 173
 174    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 175        let mut models = Vec::new();
 176
 177        for model in response.models {
 178            models.push(Arc::new(model.clone()));
 179        }
 180
 181        self.default_model = models
 182            .iter()
 183            .find(|model| {
 184                response
 185                    .default_model
 186                    .as_ref()
 187                    .is_some_and(|default_model_id| &model.id == default_model_id)
 188            })
 189            .cloned();
 190        self.default_fast_model = models
 191            .iter()
 192            .find(|model| {
 193                response
 194                    .default_fast_model
 195                    .as_ref()
 196                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 197            })
 198            .cloned();
 199        self.recommended_models = response
 200            .recommended_models
 201            .iter()
 202            .filter_map(|id| models.iter().find(|model| &model.id == id))
 203            .cloned()
 204            .collect();
 205        self.models = models;
 206        cx.notify();
 207    }
 208
 209    async fn fetch_models(
 210        client: Arc<Client>,
 211        llm_api_token: LlmApiToken,
 212    ) -> Result<ListModelsResponse> {
 213        let http_client = &client.http_client();
 214        let token = llm_api_token.acquire(&client).await?;
 215
 216        let request = http_client::Request::builder()
 217            .method(Method::GET)
 218            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 219            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 220            .header("Authorization", format!("Bearer {token}"))
 221            .body(AsyncBody::empty())?;
 222        let mut response = http_client
 223            .send(request)
 224            .await
 225            .context("failed to send list models request")?;
 226
 227        if response.status().is_success() {
 228            let mut body = String::new();
 229            response.body_mut().read_to_string(&mut body).await?;
 230            Ok(serde_json::from_str(&body)?)
 231        } else {
 232            let mut body = String::new();
 233            response.body_mut().read_to_string(&mut body).await?;
 234            anyhow::bail!(
 235                "error listing models.\nStatus: {:?}\nBody: {body}",
 236                response.status(),
 237            );
 238        }
 239    }
 240}
 241
 242impl CloudLanguageModelProvider {
 243    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 244        let mut status_rx = client.status();
 245        let status = *status_rx.borrow();
 246
 247        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 248
 249        let state_ref = state.downgrade();
 250        let maintain_client_status = cx.spawn(async move |cx| {
 251            while let Some(status) = status_rx.next().await {
 252                if let Some(this) = state_ref.upgrade() {
 253                    _ = this.update(cx, |this, cx| {
 254                        if this.status != status {
 255                            this.status = status;
 256                            cx.notify();
 257                        }
 258                    });
 259                } else {
 260                    break;
 261                }
 262            }
 263        });
 264
 265        Self {
 266            client,
 267            state,
 268            _maintain_client_status: maintain_client_status,
 269        }
 270    }
 271
 272    fn create_language_model(
 273        &self,
 274        model: Arc<cloud_llm_client::LanguageModel>,
 275        llm_api_token: LlmApiToken,
 276    ) -> Arc<dyn LanguageModel> {
 277        Arc::new(CloudLanguageModel {
 278            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 279            model,
 280            llm_api_token,
 281            client: self.client.clone(),
 282            request_limiter: RateLimiter::new(4),
 283        })
 284    }
 285}
 286
 287impl LanguageModelProviderState for CloudLanguageModelProvider {
 288    type ObservableEntity = State;
 289
 290    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 291        Some(self.state.clone())
 292    }
 293}
 294
 295impl LanguageModelProvider for CloudLanguageModelProvider {
 296    fn id(&self) -> LanguageModelProviderId {
 297        PROVIDER_ID
 298    }
 299
 300    fn name(&self) -> LanguageModelProviderName {
 301        PROVIDER_NAME
 302    }
 303
 304    fn icon(&self) -> IconOrSvg {
 305        IconOrSvg::Icon(IconName::AiZed)
 306    }
 307
 308    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 309        let default_model = self.state.read(cx).default_model.clone()?;
 310        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 311        Some(self.create_language_model(default_model, llm_api_token))
 312    }
 313
 314    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 315        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 316        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 317        Some(self.create_language_model(default_fast_model, llm_api_token))
 318    }
 319
 320    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 321        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 322        self.state
 323            .read(cx)
 324            .recommended_models
 325            .iter()
 326            .cloned()
 327            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 328            .collect()
 329    }
 330
 331    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 332        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 333        self.state
 334            .read(cx)
 335            .models
 336            .iter()
 337            .cloned()
 338            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 339            .collect()
 340    }
 341
 342    fn is_authenticated(&self, cx: &App) -> bool {
 343        let state = self.state.read(cx);
 344        !state.is_signed_out(cx)
 345    }
 346
 347    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 348        Task::ready(Ok(()))
 349    }
 350
 351    fn configuration_view(
 352        &self,
 353        _target_agent: language_model::ConfigurationViewTargetAgent,
 354        _: &mut Window,
 355        cx: &mut App,
 356    ) -> AnyView {
 357        cx.new(|_| ConfigurationView::new(self.state.clone()))
 358            .into()
 359    }
 360
 361    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 362        Task::ready(Ok(()))
 363    }
 364}
 365
 366pub struct CloudLanguageModel {
 367    id: LanguageModelId,
 368    model: Arc<cloud_llm_client::LanguageModel>,
 369    llm_api_token: LlmApiToken,
 370    client: Arc<Client>,
 371    request_limiter: RateLimiter,
 372}
 373
 374struct PerformLlmCompletionResponse {
 375    response: Response<AsyncBody>,
 376    includes_status_messages: bool,
 377}
 378
 379impl CloudLanguageModel {
 380    async fn perform_llm_completion(
 381        client: Arc<Client>,
 382        llm_api_token: LlmApiToken,
 383        app_version: Option<Version>,
 384        body: CompletionBody,
 385    ) -> Result<PerformLlmCompletionResponse> {
 386        let http_client = &client.http_client();
 387
 388        let mut token = llm_api_token.acquire(&client).await?;
 389        let mut refreshed_token = false;
 390
 391        loop {
 392            let request = http_client::Request::builder()
 393                .method(Method::POST)
 394                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 395                .when_some(app_version.as_ref(), |builder, app_version| {
 396                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 397                })
 398                .header("Content-Type", "application/json")
 399                .header("Authorization", format!("Bearer {token}"))
 400                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 401                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 402                .body(serde_json::to_string(&body)?.into())?;
 403
 404            let mut response = http_client.send(request).await?;
 405            let status = response.status();
 406            if status.is_success() {
 407                let includes_status_messages = response
 408                    .headers()
 409                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 410                    .is_some();
 411
 412                return Ok(PerformLlmCompletionResponse {
 413                    response,
 414                    includes_status_messages,
 415                });
 416            }
 417
 418            if !refreshed_token && response.needs_llm_token_refresh() {
 419                token = llm_api_token.refresh(&client).await?;
 420                refreshed_token = true;
 421                continue;
 422            }
 423
 424            if status == StatusCode::PAYMENT_REQUIRED {
 425                return Err(anyhow!(PaymentRequiredError));
 426            }
 427
 428            let mut body = String::new();
 429            let headers = response.headers().clone();
 430            response.body_mut().read_to_string(&mut body).await?;
 431            return Err(anyhow!(ApiError {
 432                status,
 433                body,
 434                headers
 435            }));
 436        }
 437    }
 438}
 439
 440#[derive(Debug, Error)]
 441#[error("cloud language model request failed with status {status}: {body}")]
 442struct ApiError {
 443    status: StatusCode,
 444    body: String,
 445    headers: HeaderMap<HeaderValue>,
 446}
 447
 448/// Represents error responses from Zed's cloud API.
 449///
 450/// Example JSON for an upstream HTTP error:
 451/// ```json
 452/// {
 453///   "code": "upstream_http_error",
 454///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 455///   "upstream_status": 503
 456/// }
 457/// ```
 458#[derive(Debug, serde::Deserialize)]
 459struct CloudApiError {
 460    code: String,
 461    message: String,
 462    #[serde(default)]
 463    #[serde(deserialize_with = "deserialize_optional_status_code")]
 464    upstream_status: Option<StatusCode>,
 465    #[serde(default)]
 466    retry_after: Option<f64>,
 467}
 468
 469fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 470where
 471    D: serde::Deserializer<'de>,
 472{
 473    let opt: Option<u16> = Option::deserialize(deserializer)?;
 474    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 475}
 476
 477impl From<ApiError> for LanguageModelCompletionError {
 478    fn from(error: ApiError) -> Self {
 479        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 480            if cloud_error.code.starts_with("upstream_http_") {
 481                let status = if let Some(status) = cloud_error.upstream_status {
 482                    status
 483                } else if cloud_error.code.ends_with("_error") {
 484                    error.status
 485                } else {
 486                    // If there's a status code in the code string (e.g. "upstream_http_429")
 487                    // then use that; otherwise, see if the JSON contains a status code.
 488                    cloud_error
 489                        .code
 490                        .strip_prefix("upstream_http_")
 491                        .and_then(|code_str| code_str.parse::<u16>().ok())
 492                        .and_then(|code| StatusCode::from_u16(code).ok())
 493                        .unwrap_or(error.status)
 494                };
 495
 496                return LanguageModelCompletionError::UpstreamProviderError {
 497                    message: cloud_error.message,
 498                    status,
 499                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 500                };
 501            }
 502
 503            return LanguageModelCompletionError::from_http_status(
 504                PROVIDER_NAME,
 505                error.status,
 506                cloud_error.message,
 507                None,
 508            );
 509        }
 510
 511        let retry_after = None;
 512        LanguageModelCompletionError::from_http_status(
 513            PROVIDER_NAME,
 514            error.status,
 515            error.body,
 516            retry_after,
 517        )
 518    }
 519}
 520
 521impl LanguageModel for CloudLanguageModel {
 522    fn id(&self) -> LanguageModelId {
 523        self.id.clone()
 524    }
 525
 526    fn name(&self) -> LanguageModelName {
 527        LanguageModelName::from(self.model.display_name.clone())
 528    }
 529
 530    fn provider_id(&self) -> LanguageModelProviderId {
 531        PROVIDER_ID
 532    }
 533
 534    fn provider_name(&self) -> LanguageModelProviderName {
 535        PROVIDER_NAME
 536    }
 537
 538    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 539        use cloud_llm_client::LanguageModelProvider::*;
 540        match self.model.provider {
 541            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 542            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 543            Google => language_model::GOOGLE_PROVIDER_ID,
 544            XAi => language_model::X_AI_PROVIDER_ID,
 545        }
 546    }
 547
 548    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 549        use cloud_llm_client::LanguageModelProvider::*;
 550        match self.model.provider {
 551            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 552            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 553            Google => language_model::GOOGLE_PROVIDER_NAME,
 554            XAi => language_model::X_AI_PROVIDER_NAME,
 555        }
 556    }
 557
 558    fn is_latest(&self) -> bool {
 559        self.model.is_latest
 560    }
 561
 562    fn supports_tools(&self) -> bool {
 563        self.model.supports_tools
 564    }
 565
 566    fn supports_images(&self) -> bool {
 567        self.model.supports_images
 568    }
 569
 570    fn supports_thinking(&self) -> bool {
 571        self.model.supports_thinking
 572    }
 573
 574    fn supports_fast_mode(&self) -> bool {
 575        self.model.supports_fast_mode
 576    }
 577
 578    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 579        self.model
 580            .supported_effort_levels
 581            .iter()
 582            .map(|effort_level| LanguageModelEffortLevel {
 583                name: effort_level.name.clone().into(),
 584                value: effort_level.value.clone().into(),
 585                is_default: effort_level.is_default.unwrap_or(false),
 586            })
 587            .collect()
 588    }
 589
 590    fn supports_streaming_tools(&self) -> bool {
 591        self.model.supports_streaming_tools
 592    }
 593
 594    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 595        match choice {
 596            LanguageModelToolChoice::Auto
 597            | LanguageModelToolChoice::Any
 598            | LanguageModelToolChoice::None => true,
 599        }
 600    }
 601
 602    fn supports_split_token_display(&self) -> bool {
 603        use cloud_llm_client::LanguageModelProvider::*;
 604        matches!(self.model.provider, OpenAi)
 605    }
 606
 607    fn telemetry_id(&self) -> String {
 608        format!("zed.dev/{}", self.model.id)
 609    }
 610
 611    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 612        match self.model.provider {
 613            cloud_llm_client::LanguageModelProvider::Anthropic
 614            | cloud_llm_client::LanguageModelProvider::OpenAi
 615            | cloud_llm_client::LanguageModelProvider::XAi => {
 616                LanguageModelToolSchemaFormat::JsonSchema
 617            }
 618            cloud_llm_client::LanguageModelProvider::Google => {
 619                LanguageModelToolSchemaFormat::JsonSchemaSubset
 620            }
 621        }
 622    }
 623
 624    fn max_token_count(&self) -> u64 {
 625        self.model.max_token_count as u64
 626    }
 627
 628    fn max_output_tokens(&self) -> Option<u64> {
 629        Some(self.model.max_output_tokens as u64)
 630    }
 631
 632    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 633        match &self.model.provider {
 634            cloud_llm_client::LanguageModelProvider::Anthropic => {
 635                Some(LanguageModelCacheConfiguration {
 636                    min_total_token: 2_048,
 637                    should_speculate: true,
 638                    max_cache_anchors: 4,
 639                })
 640            }
 641            cloud_llm_client::LanguageModelProvider::OpenAi
 642            | cloud_llm_client::LanguageModelProvider::XAi
 643            | cloud_llm_client::LanguageModelProvider::Google => None,
 644        }
 645    }
 646
 647    fn count_tokens(
 648        &self,
 649        request: LanguageModelRequest,
 650        cx: &App,
 651    ) -> BoxFuture<'static, Result<u64>> {
 652        match self.model.provider {
 653            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 654                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 655                .boxed(),
 656            cloud_llm_client::LanguageModelProvider::OpenAi => {
 657                let model = match open_ai::Model::from_id(&self.model.id.0) {
 658                    Ok(model) => model,
 659                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 660                };
 661                count_open_ai_tokens(request, model, cx)
 662            }
 663            cloud_llm_client::LanguageModelProvider::XAi => {
 664                let model = match x_ai::Model::from_id(&self.model.id.0) {
 665                    Ok(model) => model,
 666                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 667                };
 668                count_xai_tokens(request, model, cx)
 669            }
 670            cloud_llm_client::LanguageModelProvider::Google => {
 671                let client = self.client.clone();
 672                let llm_api_token = self.llm_api_token.clone();
 673                let model_id = self.model.id.to_string();
 674                let generate_content_request =
 675                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 676                async move {
 677                    let http_client = &client.http_client();
 678                    let token = llm_api_token.acquire(&client).await?;
 679
 680                    let request_body = CountTokensBody {
 681                        provider: cloud_llm_client::LanguageModelProvider::Google,
 682                        model: model_id,
 683                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 684                            generate_content_request,
 685                        })?,
 686                    };
 687                    let request = http_client::Request::builder()
 688                        .method(Method::POST)
 689                        .uri(
 690                            http_client
 691                                .build_zed_llm_url("/count_tokens", &[])?
 692                                .as_ref(),
 693                        )
 694                        .header("Content-Type", "application/json")
 695                        .header("Authorization", format!("Bearer {token}"))
 696                        .body(serde_json::to_string(&request_body)?.into())?;
 697                    let mut response = http_client.send(request).await?;
 698                    let status = response.status();
 699                    let headers = response.headers().clone();
 700                    let mut response_body = String::new();
 701                    response
 702                        .body_mut()
 703                        .read_to_string(&mut response_body)
 704                        .await?;
 705
 706                    if status.is_success() {
 707                        let response_body: CountTokensResponse =
 708                            serde_json::from_str(&response_body)?;
 709
 710                        Ok(response_body.tokens as u64)
 711                    } else {
 712                        Err(anyhow!(ApiError {
 713                            status,
 714                            body: response_body,
 715                            headers
 716                        }))
 717                    }
 718                }
 719                .boxed()
 720            }
 721        }
 722    }
 723
 724    fn stream_completion(
 725        &self,
 726        request: LanguageModelRequest,
 727        cx: &AsyncApp,
 728    ) -> BoxFuture<
 729        'static,
 730        Result<
 731            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 732            LanguageModelCompletionError,
 733        >,
 734    > {
 735        let thread_id = request.thread_id.clone();
 736        let prompt_id = request.prompt_id.clone();
 737        let intent = request.intent;
 738        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 739        let thinking_allowed = request.thinking_allowed;
 740        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 741        let provider_name = provider_name(&self.model.provider);
 742        match self.model.provider {
 743            cloud_llm_client::LanguageModelProvider::Anthropic => {
 744                let effort = request
 745                    .thinking_effort
 746                    .as_ref()
 747                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 748
 749                let mut request = into_anthropic(
 750                    request,
 751                    self.model.id.to_string(),
 752                    1.0,
 753                    self.model.max_output_tokens as u64,
 754                    if enable_thinking {
 755                        AnthropicModelMode::Thinking {
 756                            budget_tokens: Some(4_096),
 757                        }
 758                    } else {
 759                        AnthropicModelMode::Default
 760                    },
 761                );
 762
 763                if enable_thinking && effort.is_some() {
 764                    request.thinking = Some(anthropic::Thinking::Adaptive);
 765                    request.output_config = Some(anthropic::OutputConfig { effort });
 766                }
 767
 768                let client = self.client.clone();
 769                let llm_api_token = self.llm_api_token.clone();
 770                let future = self.request_limiter.stream(async move {
 771                    let PerformLlmCompletionResponse {
 772                        response,
 773                        includes_status_messages,
 774                    } = Self::perform_llm_completion(
 775                        client.clone(),
 776                        llm_api_token,
 777                        app_version,
 778                        CompletionBody {
 779                            thread_id,
 780                            prompt_id,
 781                            intent,
 782                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 783                            model: request.model.clone(),
 784                            provider_request: serde_json::to_value(&request)
 785                                .map_err(|e| anyhow!(e))?,
 786                        },
 787                    )
 788                    .await
 789                    .map_err(|err| match err.downcast::<ApiError>() {
 790                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 791                        Err(err) => anyhow!(err),
 792                    })?;
 793
 794                    let mut mapper = AnthropicEventMapper::new();
 795                    Ok(map_cloud_completion_events(
 796                        Box::pin(response_lines(response, includes_status_messages)),
 797                        &provider_name,
 798                        move |event| mapper.map_event(event),
 799                    ))
 800                });
 801                async move { Ok(future.await?.boxed()) }.boxed()
 802            }
 803            cloud_llm_client::LanguageModelProvider::OpenAi => {
 804                let client = self.client.clone();
 805                let llm_api_token = self.llm_api_token.clone();
 806                let effort = request
 807                    .thinking_effort
 808                    .as_ref()
 809                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 810
 811                let mut request = into_open_ai_response(
 812                    request,
 813                    &self.model.id.0,
 814                    self.model.supports_parallel_tool_calls,
 815                    true,
 816                    None,
 817                    None,
 818                );
 819
 820                if enable_thinking && let Some(effort) = effort {
 821                    request.reasoning = Some(open_ai::responses::ReasoningConfig { effort });
 822                }
 823
 824                let future = self.request_limiter.stream(async move {
 825                    let PerformLlmCompletionResponse {
 826                        response,
 827                        includes_status_messages,
 828                    } = Self::perform_llm_completion(
 829                        client.clone(),
 830                        llm_api_token,
 831                        app_version,
 832                        CompletionBody {
 833                            thread_id,
 834                            prompt_id,
 835                            intent,
 836                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 837                            model: request.model.clone(),
 838                            provider_request: serde_json::to_value(&request)
 839                                .map_err(|e| anyhow!(e))?,
 840                        },
 841                    )
 842                    .await?;
 843
 844                    let mut mapper = OpenAiResponseEventMapper::new();
 845                    Ok(map_cloud_completion_events(
 846                        Box::pin(response_lines(response, includes_status_messages)),
 847                        &provider_name,
 848                        move |event| mapper.map_event(event),
 849                    ))
 850                });
 851                async move { Ok(future.await?.boxed()) }.boxed()
 852            }
 853            cloud_llm_client::LanguageModelProvider::XAi => {
 854                let client = self.client.clone();
 855                let request = into_open_ai(
 856                    request,
 857                    &self.model.id.0,
 858                    self.model.supports_parallel_tool_calls,
 859                    false,
 860                    None,
 861                    None,
 862                );
 863                let llm_api_token = self.llm_api_token.clone();
 864                let future = self.request_limiter.stream(async move {
 865                    let PerformLlmCompletionResponse {
 866                        response,
 867                        includes_status_messages,
 868                    } = Self::perform_llm_completion(
 869                        client.clone(),
 870                        llm_api_token,
 871                        app_version,
 872                        CompletionBody {
 873                            thread_id,
 874                            prompt_id,
 875                            intent,
 876                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 877                            model: request.model.clone(),
 878                            provider_request: serde_json::to_value(&request)
 879                                .map_err(|e| anyhow!(e))?,
 880                        },
 881                    )
 882                    .await?;
 883
 884                    let mut mapper = OpenAiEventMapper::new();
 885                    Ok(map_cloud_completion_events(
 886                        Box::pin(response_lines(response, includes_status_messages)),
 887                        &provider_name,
 888                        move |event| mapper.map_event(event),
 889                    ))
 890                });
 891                async move { Ok(future.await?.boxed()) }.boxed()
 892            }
 893            cloud_llm_client::LanguageModelProvider::Google => {
 894                let client = self.client.clone();
 895                let request =
 896                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 897                let llm_api_token = self.llm_api_token.clone();
 898                let future = self.request_limiter.stream(async move {
 899                    let PerformLlmCompletionResponse {
 900                        response,
 901                        includes_status_messages,
 902                    } = Self::perform_llm_completion(
 903                        client.clone(),
 904                        llm_api_token,
 905                        app_version,
 906                        CompletionBody {
 907                            thread_id,
 908                            prompt_id,
 909                            intent,
 910                            provider: cloud_llm_client::LanguageModelProvider::Google,
 911                            model: request.model.model_id.clone(),
 912                            provider_request: serde_json::to_value(&request)
 913                                .map_err(|e| anyhow!(e))?,
 914                        },
 915                    )
 916                    .await?;
 917
 918                    let mut mapper = GoogleEventMapper::new();
 919                    Ok(map_cloud_completion_events(
 920                        Box::pin(response_lines(response, includes_status_messages)),
 921                        &provider_name,
 922                        move |event| mapper.map_event(event),
 923                    ))
 924                });
 925                async move { Ok(future.await?.boxed()) }.boxed()
 926            }
 927        }
 928    }
 929}
 930
 931fn map_cloud_completion_events<T, F>(
 932    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 933    provider: &LanguageModelProviderName,
 934    mut map_callback: F,
 935) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 936where
 937    T: DeserializeOwned + 'static,
 938    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 939        + Send
 940        + 'static,
 941{
 942    let provider = provider.clone();
 943    let mut stream = stream.fuse();
 944
 945    let mut saw_stream_ended = false;
 946
 947    let mut done = false;
 948    let mut pending = VecDeque::new();
 949
 950    stream::poll_fn(move |cx| {
 951        loop {
 952            if let Some(item) = pending.pop_front() {
 953                return Poll::Ready(Some(item));
 954            }
 955
 956            if done {
 957                return Poll::Ready(None);
 958            }
 959
 960            match stream.poll_next_unpin(cx) {
 961                Poll::Ready(Some(event)) => {
 962                    let items = match event {
 963                        Err(error) => {
 964                            vec![Err(LanguageModelCompletionError::from(error))]
 965                        }
 966                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
 967                            saw_stream_ended = true;
 968                            vec![]
 969                        }
 970                        Ok(CompletionEvent::Status(status)) => {
 971                            LanguageModelCompletionEvent::from_completion_request_status(
 972                                status,
 973                                provider.clone(),
 974                            )
 975                            .transpose()
 976                            .map(|event| vec![event])
 977                            .unwrap_or_default()
 978                        }
 979                        Ok(CompletionEvent::Event(event)) => map_callback(event),
 980                    };
 981                    pending.extend(items);
 982                }
 983                Poll::Ready(None) => {
 984                    done = true;
 985
 986                    if !saw_stream_ended {
 987                        return Poll::Ready(Some(Err(
 988                            LanguageModelCompletionError::StreamEndedUnexpectedly {
 989                                provider: provider.clone(),
 990                            },
 991                        )));
 992                    }
 993                }
 994                Poll::Pending => return Poll::Pending,
 995            }
 996        }
 997    })
 998    .boxed()
 999}
1000
1001fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1002    match provider {
1003        cloud_llm_client::LanguageModelProvider::Anthropic => {
1004            language_model::ANTHROPIC_PROVIDER_NAME
1005        }
1006        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1007        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1008        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1009    }
1010}
1011
1012fn response_lines<T: DeserializeOwned>(
1013    response: Response<AsyncBody>,
1014    includes_status_messages: bool,
1015) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1016    futures::stream::try_unfold(
1017        (String::new(), BufReader::new(response.into_body())),
1018        move |(mut line, mut body)| async move {
1019            match body.read_line(&mut line).await {
1020                Ok(0) => Ok(None),
1021                Ok(_) => {
1022                    let event = if includes_status_messages {
1023                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1024                    } else {
1025                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1026                    };
1027
1028                    line.clear();
1029                    Ok(Some((event, (line, body))))
1030                }
1031                Err(e) => Err(e.into()),
1032            }
1033        },
1034    )
1035}
1036
1037#[derive(IntoElement, RegisterComponent)]
1038struct ZedAiConfiguration {
1039    is_connected: bool,
1040    plan: Option<Plan>,
1041    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1042    eligible_for_trial: bool,
1043    account_too_young: bool,
1044    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1045}
1046
1047impl RenderOnce for ZedAiConfiguration {
1048    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1049        let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
1050        let subscription_text = match (self.plan, self.subscription_period) {
1051            (Some(Plan::ZedPro), Some(_)) => {
1052                "You have access to Zed's hosted models through your Pro subscription."
1053            }
1054            (Some(Plan::ZedProTrial), Some(_)) => {
1055                "You have access to Zed's hosted models through your Pro trial."
1056            }
1057            (Some(Plan::ZedFree), Some(_)) => {
1058                if self.eligible_for_trial {
1059                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1060                } else {
1061                    "Subscribe for access to Zed's hosted models."
1062                }
1063            }
1064            _ => {
1065                if self.eligible_for_trial {
1066                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1067                } else {
1068                    "Subscribe for access to Zed's hosted models."
1069                }
1070            }
1071        };
1072
1073        let manage_subscription_buttons = if is_pro {
1074            Button::new("manage_settings", "Manage Subscription")
1075                .full_width()
1076                .style(ButtonStyle::Tinted(TintColor::Accent))
1077                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1078                .into_any_element()
1079        } else if self.plan.is_none() || self.eligible_for_trial {
1080            Button::new("start_trial", "Start 14-day Free Pro Trial")
1081                .full_width()
1082                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1083                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1084                .into_any_element()
1085        } else {
1086            Button::new("upgrade", "Upgrade to Pro")
1087                .full_width()
1088                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1089                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1090                .into_any_element()
1091        };
1092
1093        if !self.is_connected {
1094            return v_flex()
1095                .gap_2()
1096                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1097                .child(
1098                    Button::new("sign_in", "Sign In to use Zed AI")
1099                        .icon_color(Color::Muted)
1100                        .icon(IconName::Github)
1101                        .icon_size(IconSize::Small)
1102                        .icon_position(IconPosition::Start)
1103                        .full_width()
1104                        .on_click({
1105                            let callback = self.sign_in_callback.clone();
1106                            move |_, window, cx| (callback)(window, cx)
1107                        }),
1108                );
1109        }
1110
1111        v_flex().gap_2().w_full().map(|this| {
1112            if self.account_too_young {
1113                this.child(YoungAccountBanner).child(
1114                    Button::new("upgrade", "Upgrade to Pro")
1115                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1116                        .full_width()
1117                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1118                )
1119            } else {
1120                this.text_sm()
1121                    .child(subscription_text)
1122                    .child(manage_subscription_buttons)
1123            }
1124        })
1125    }
1126}
1127
1128struct ConfigurationView {
1129    state: Entity<State>,
1130    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1131}
1132
1133impl ConfigurationView {
1134    fn new(state: Entity<State>) -> Self {
1135        let sign_in_callback = Arc::new({
1136            let state = state.clone();
1137            move |_window: &mut Window, cx: &mut App| {
1138                state.update(cx, |state, cx| {
1139                    state.authenticate(cx).detach_and_log_err(cx);
1140                });
1141            }
1142        });
1143
1144        Self {
1145            state,
1146            sign_in_callback,
1147        }
1148    }
1149}
1150
1151impl Render for ConfigurationView {
1152    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1153        let state = self.state.read(cx);
1154        let user_store = state.user_store.read(cx);
1155
1156        ZedAiConfiguration {
1157            is_connected: !state.is_signed_out(cx),
1158            plan: user_store.plan(),
1159            subscription_period: user_store.subscription_period(),
1160            eligible_for_trial: user_store.trial_started_at().is_none(),
1161            account_too_young: user_store.account_too_young(),
1162            sign_in_callback: self.sign_in_callback.clone(),
1163        }
1164    }
1165}
1166
1167impl Component for ZedAiConfiguration {
1168    fn name() -> &'static str {
1169        "AI Configuration Content"
1170    }
1171
1172    fn sort_name() -> &'static str {
1173        "AI Configuration Content"
1174    }
1175
1176    fn scope() -> ComponentScope {
1177        ComponentScope::Onboarding
1178    }
1179
1180    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1181        fn configuration(
1182            is_connected: bool,
1183            plan: Option<Plan>,
1184            eligible_for_trial: bool,
1185            account_too_young: bool,
1186        ) -> AnyElement {
1187            ZedAiConfiguration {
1188                is_connected,
1189                plan,
1190                subscription_period: plan
1191                    .is_some()
1192                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1193                eligible_for_trial,
1194                account_too_young,
1195                sign_in_callback: Arc::new(|_, _| {}),
1196            }
1197            .into_any_element()
1198        }
1199
1200        Some(
1201            v_flex()
1202                .p_4()
1203                .gap_4()
1204                .children(vec![
1205                    single_example("Not connected", configuration(false, None, false, false)),
1206                    single_example(
1207                        "Accept Terms of Service",
1208                        configuration(true, None, true, false),
1209                    ),
1210                    single_example(
1211                        "No Plan - Not eligible for trial",
1212                        configuration(true, None, false, false),
1213                    ),
1214                    single_example(
1215                        "No Plan - Eligible for trial",
1216                        configuration(true, None, true, false),
1217                    ),
1218                    single_example(
1219                        "Free Plan",
1220                        configuration(true, Some(Plan::ZedFree), true, false),
1221                    ),
1222                    single_example(
1223                        "Zed Pro Trial Plan",
1224                        configuration(true, Some(Plan::ZedProTrial), true, false),
1225                    ),
1226                    single_example(
1227                        "Zed Pro Plan",
1228                        configuration(true, Some(Plan::ZedPro), true, false),
1229                    ),
1230                ])
1231                .into_any_element(),
1232        )
1233    }
1234}
1235
1236#[cfg(test)]
1237mod tests {
1238    use super::*;
1239    use http_client::http::{HeaderMap, StatusCode};
1240    use language_model::LanguageModelCompletionError;
1241
1242    #[test]
1243    fn test_api_error_conversion_with_upstream_http_error() {
1244        // upstream_http_error with 503 status should become ServerOverloaded
1245        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1246
1247        let api_error = ApiError {
1248            status: StatusCode::INTERNAL_SERVER_ERROR,
1249            body: error_body.to_string(),
1250            headers: HeaderMap::new(),
1251        };
1252
1253        let completion_error: LanguageModelCompletionError = api_error.into();
1254
1255        match completion_error {
1256            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1257                assert_eq!(
1258                    message,
1259                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1260                );
1261            }
1262            _ => panic!(
1263                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1264                completion_error
1265            ),
1266        }
1267
1268        // upstream_http_error with 500 status should become ApiInternalServerError
1269        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1270
1271        let api_error = ApiError {
1272            status: StatusCode::INTERNAL_SERVER_ERROR,
1273            body: error_body.to_string(),
1274            headers: HeaderMap::new(),
1275        };
1276
1277        let completion_error: LanguageModelCompletionError = api_error.into();
1278
1279        match completion_error {
1280            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1281                assert_eq!(
1282                    message,
1283                    "Received an error from the OpenAI API: internal server error"
1284                );
1285            }
1286            _ => panic!(
1287                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1288                completion_error
1289            ),
1290        }
1291
1292        // upstream_http_error with 429 status should become RateLimitExceeded
1293        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1294
1295        let api_error = ApiError {
1296            status: StatusCode::INTERNAL_SERVER_ERROR,
1297            body: error_body.to_string(),
1298            headers: HeaderMap::new(),
1299        };
1300
1301        let completion_error: LanguageModelCompletionError = api_error.into();
1302
1303        match completion_error {
1304            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1305                assert_eq!(
1306                    message,
1307                    "Received an error from the Google API: rate limit exceeded"
1308                );
1309            }
1310            _ => panic!(
1311                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1312                completion_error
1313            ),
1314        }
1315
1316        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1317        let error_body = "Regular internal server error";
1318
1319        let api_error = ApiError {
1320            status: StatusCode::INTERNAL_SERVER_ERROR,
1321            body: error_body.to_string(),
1322            headers: HeaderMap::new(),
1323        };
1324
1325        let completion_error: LanguageModelCompletionError = api_error.into();
1326
1327        match completion_error {
1328            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1329                assert_eq!(provider, PROVIDER_NAME);
1330                assert_eq!(message, "Regular internal server error");
1331            }
1332            _ => panic!(
1333                "Expected ApiInternalServerError for regular 500, got: {:?}",
1334                completion_error
1335            ),
1336        }
1337
1338        // upstream_http_429 format should be converted to UpstreamProviderError
1339        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1340
1341        let api_error = ApiError {
1342            status: StatusCode::INTERNAL_SERVER_ERROR,
1343            body: error_body.to_string(),
1344            headers: HeaderMap::new(),
1345        };
1346
1347        let completion_error: LanguageModelCompletionError = api_error.into();
1348
1349        match completion_error {
1350            LanguageModelCompletionError::UpstreamProviderError {
1351                message,
1352                status,
1353                retry_after,
1354            } => {
1355                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1356                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1357                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1358            }
1359            _ => panic!(
1360                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1361                completion_error
1362            ),
1363        }
1364
1365        // Invalid JSON in error body should fall back to regular error handling
1366        let error_body = "Not JSON at all";
1367
1368        let api_error = ApiError {
1369            status: StatusCode::INTERNAL_SERVER_ERROR,
1370            body: error_body.to_string(),
1371            headers: HeaderMap::new(),
1372        };
1373
1374        let completion_error: LanguageModelCompletionError = api_error.into();
1375
1376        match completion_error {
1377            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1378                assert_eq!(provider, PROVIDER_NAME);
1379            }
1380            _ => panic!(
1381                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1382                completion_error
1383            ),
1384        }
1385    }
1386}