cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, UserStore, zed_urls};
   6use cloud_api_types::Plan;
   7use cloud_llm_client::{
   8    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   9    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  10    CountTokensBody, CountTokensResponse, ListModelsResponse,
  11    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  12};
  13use futures::{
  14    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  15    future::BoxFuture,
  16    stream::{self, BoxStream},
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  25    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  26    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::collections::VecDeque;
  39use std::pin::Pin;
  40use std::str::FromStr;
  41use std::sync::Arc;
  42use std::task::Poll;
  43use std::time::Duration;
  44use thiserror::Error;
  45use ui::{TintColor, prelude::*};
  46use util::{ResultExt as _, maybe};
  47
  48use crate::provider::anthropic::{
  49    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  50};
  51use crate::provider::google::{GoogleEventMapper, into_google};
  52use crate::provider::open_ai::{
  53    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  54    into_open_ai_response,
  55};
  56use crate::provider::x_ai::count_xai_tokens;
  57
  58const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  59const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  60
  61#[derive(Default, Clone, Debug, PartialEq)]
  62pub struct ZedDotDevSettings {
  63    pub available_models: Vec<AvailableModel>,
  64}
  65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66#[serde(tag = "type", rename_all = "lowercase")]
  67pub enum ModelMode {
  68    #[default]
  69    Default,
  70    Thinking {
  71        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  72        budget_tokens: Option<u32>,
  73    },
  74}
  75
  76impl From<ModelMode> for AnthropicModelMode {
  77    fn from(value: ModelMode) -> Self {
  78        match value {
  79            ModelMode::Default => AnthropicModelMode::Default,
  80            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  81        }
  82    }
  83}
  84
  85pub struct CloudLanguageModelProvider {
  86    client: Arc<Client>,
  87    state: Entity<State>,
  88    _maintain_client_status: Task<()>,
  89}
  90
  91pub struct State {
  92    client: Arc<Client>,
  93    llm_api_token: LlmApiToken,
  94    user_store: Entity<UserStore>,
  95    status: client::Status,
  96    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  97    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  98    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  99    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 100    _fetch_models_task: Task<()>,
 101    _settings_subscription: Subscription,
 102    _llm_token_subscription: Subscription,
 103}
 104
 105impl State {
 106    fn new(
 107        client: Arc<Client>,
 108        user_store: Entity<UserStore>,
 109        status: client::Status,
 110        cx: &mut Context<Self>,
 111    ) -> Self {
 112        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 113        let mut current_user = user_store.read(cx).watch_current_user();
 114        Self {
 115            client: client.clone(),
 116            llm_api_token: LlmApiToken::default(),
 117            user_store,
 118            status,
 119            models: Vec::new(),
 120            default_model: None,
 121            default_fast_model: None,
 122            recommended_models: Vec::new(),
 123            _fetch_models_task: cx.spawn(async move |this, cx| {
 124                maybe!(async move {
 125                    let (client, llm_api_token) = this
 126                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 127
 128                    while current_user.borrow().is_none() {
 129                        current_user.next().await;
 130                    }
 131
 132                    let response =
 133                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 134                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 135                    anyhow::Ok(())
 136                })
 137                .await
 138                .context("failed to fetch Zed models")
 139                .log_err();
 140            }),
 141            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 142                cx.notify();
 143            }),
 144            _llm_token_subscription: cx.subscribe(
 145                &refresh_llm_token_listener,
 146                move |this, _listener, _event, cx| {
 147                    let client = this.client.clone();
 148                    let llm_api_token = this.llm_api_token.clone();
 149                    cx.spawn(async move |this, cx| {
 150                        llm_api_token.refresh(&client).await?;
 151                        let response = Self::fetch_models(client, llm_api_token).await?;
 152                        this.update(cx, |this, cx| {
 153                            this.update_models(response, cx);
 154                        })
 155                    })
 156                    .detach_and_log_err(cx);
 157                },
 158            ),
 159        }
 160    }
 161
 162    fn is_signed_out(&self, cx: &App) -> bool {
 163        self.user_store.read(cx).current_user().is_none()
 164    }
 165
 166    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 167        let client = self.client.clone();
 168        cx.spawn(async move |state, cx| {
 169            client.sign_in_with_optional_connect(true, cx).await?;
 170            state.update(cx, |_, cx| cx.notify())
 171        })
 172    }
 173
 174    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 175        let mut models = Vec::new();
 176
 177        for model in response.models {
 178            models.push(Arc::new(model.clone()));
 179        }
 180
 181        self.default_model = models
 182            .iter()
 183            .find(|model| {
 184                response
 185                    .default_model
 186                    .as_ref()
 187                    .is_some_and(|default_model_id| &model.id == default_model_id)
 188            })
 189            .cloned();
 190        self.default_fast_model = models
 191            .iter()
 192            .find(|model| {
 193                response
 194                    .default_fast_model
 195                    .as_ref()
 196                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 197            })
 198            .cloned();
 199        self.recommended_models = response
 200            .recommended_models
 201            .iter()
 202            .filter_map(|id| models.iter().find(|model| &model.id == id))
 203            .cloned()
 204            .collect();
 205        self.models = models;
 206        cx.notify();
 207    }
 208
 209    async fn fetch_models(
 210        client: Arc<Client>,
 211        llm_api_token: LlmApiToken,
 212    ) -> Result<ListModelsResponse> {
 213        let http_client = &client.http_client();
 214        let token = llm_api_token.acquire(&client).await?;
 215
 216        let request = http_client::Request::builder()
 217            .method(Method::GET)
 218            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 219            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 220            .header("Authorization", format!("Bearer {token}"))
 221            .body(AsyncBody::empty())?;
 222        let mut response = http_client
 223            .send(request)
 224            .await
 225            .context("failed to send list models request")?;
 226
 227        if response.status().is_success() {
 228            let mut body = String::new();
 229            response.body_mut().read_to_string(&mut body).await?;
 230            Ok(serde_json::from_str(&body)?)
 231        } else {
 232            let mut body = String::new();
 233            response.body_mut().read_to_string(&mut body).await?;
 234            anyhow::bail!(
 235                "error listing models.\nStatus: {:?}\nBody: {body}",
 236                response.status(),
 237            );
 238        }
 239    }
 240}
 241
 242impl CloudLanguageModelProvider {
 243    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 244        let mut status_rx = client.status();
 245        let status = *status_rx.borrow();
 246
 247        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 248
 249        let state_ref = state.downgrade();
 250        let maintain_client_status = cx.spawn(async move |cx| {
 251            while let Some(status) = status_rx.next().await {
 252                if let Some(this) = state_ref.upgrade() {
 253                    _ = this.update(cx, |this, cx| {
 254                        if this.status != status {
 255                            this.status = status;
 256                            cx.notify();
 257                        }
 258                    });
 259                } else {
 260                    break;
 261                }
 262            }
 263        });
 264
 265        Self {
 266            client,
 267            state,
 268            _maintain_client_status: maintain_client_status,
 269        }
 270    }
 271
 272    fn create_language_model(
 273        &self,
 274        model: Arc<cloud_llm_client::LanguageModel>,
 275        llm_api_token: LlmApiToken,
 276    ) -> Arc<dyn LanguageModel> {
 277        Arc::new(CloudLanguageModel {
 278            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 279            model,
 280            llm_api_token,
 281            client: self.client.clone(),
 282            request_limiter: RateLimiter::new(4),
 283        })
 284    }
 285}
 286
 287impl LanguageModelProviderState for CloudLanguageModelProvider {
 288    type ObservableEntity = State;
 289
 290    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 291        Some(self.state.clone())
 292    }
 293}
 294
 295impl LanguageModelProvider for CloudLanguageModelProvider {
 296    fn id(&self) -> LanguageModelProviderId {
 297        PROVIDER_ID
 298    }
 299
 300    fn name(&self) -> LanguageModelProviderName {
 301        PROVIDER_NAME
 302    }
 303
 304    fn icon(&self) -> IconOrSvg {
 305        IconOrSvg::Icon(IconName::AiZed)
 306    }
 307
 308    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 309        let default_model = self.state.read(cx).default_model.clone()?;
 310        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 311        Some(self.create_language_model(default_model, llm_api_token))
 312    }
 313
 314    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 315        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 316        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 317        Some(self.create_language_model(default_fast_model, llm_api_token))
 318    }
 319
 320    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 321        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 322        self.state
 323            .read(cx)
 324            .recommended_models
 325            .iter()
 326            .cloned()
 327            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 328            .collect()
 329    }
 330
 331    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 332        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 333        self.state
 334            .read(cx)
 335            .models
 336            .iter()
 337            .cloned()
 338            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 339            .collect()
 340    }
 341
 342    fn is_authenticated(&self, cx: &App) -> bool {
 343        let state = self.state.read(cx);
 344        !state.is_signed_out(cx)
 345    }
 346
 347    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 348        Task::ready(Ok(()))
 349    }
 350
 351    fn configuration_view(
 352        &self,
 353        _target_agent: language_model::ConfigurationViewTargetAgent,
 354        _: &mut Window,
 355        cx: &mut App,
 356    ) -> AnyView {
 357        cx.new(|_| ConfigurationView::new(self.state.clone()))
 358            .into()
 359    }
 360
 361    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 362        Task::ready(Ok(()))
 363    }
 364}
 365
 366pub struct CloudLanguageModel {
 367    id: LanguageModelId,
 368    model: Arc<cloud_llm_client::LanguageModel>,
 369    llm_api_token: LlmApiToken,
 370    client: Arc<Client>,
 371    request_limiter: RateLimiter,
 372}
 373
 374struct PerformLlmCompletionResponse {
 375    response: Response<AsyncBody>,
 376    includes_status_messages: bool,
 377}
 378
 379impl CloudLanguageModel {
 380    async fn perform_llm_completion(
 381        client: Arc<Client>,
 382        llm_api_token: LlmApiToken,
 383        app_version: Option<Version>,
 384        body: CompletionBody,
 385    ) -> Result<PerformLlmCompletionResponse> {
 386        let http_client = &client.http_client();
 387
 388        let mut token = llm_api_token.acquire(&client).await?;
 389        let mut refreshed_token = false;
 390
 391        loop {
 392            let request = http_client::Request::builder()
 393                .method(Method::POST)
 394                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 395                .when_some(app_version.as_ref(), |builder, app_version| {
 396                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 397                })
 398                .header("Content-Type", "application/json")
 399                .header("Authorization", format!("Bearer {token}"))
 400                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 401                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 402                .body(serde_json::to_string(&body)?.into())?;
 403
 404            let mut response = http_client.send(request).await?;
 405            let status = response.status();
 406            if status.is_success() {
 407                let includes_status_messages = response
 408                    .headers()
 409                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 410                    .is_some();
 411
 412                return Ok(PerformLlmCompletionResponse {
 413                    response,
 414                    includes_status_messages,
 415                });
 416            }
 417
 418            if !refreshed_token && response.needs_llm_token_refresh() {
 419                token = llm_api_token.refresh(&client).await?;
 420                refreshed_token = true;
 421                continue;
 422            }
 423
 424            if status == StatusCode::PAYMENT_REQUIRED {
 425                return Err(anyhow!(PaymentRequiredError));
 426            }
 427
 428            let mut body = String::new();
 429            let headers = response.headers().clone();
 430            response.body_mut().read_to_string(&mut body).await?;
 431            return Err(anyhow!(ApiError {
 432                status,
 433                body,
 434                headers
 435            }));
 436        }
 437    }
 438}
 439
 440#[derive(Debug, Error)]
 441#[error("cloud language model request failed with status {status}: {body}")]
 442struct ApiError {
 443    status: StatusCode,
 444    body: String,
 445    headers: HeaderMap<HeaderValue>,
 446}
 447
 448/// Represents error responses from Zed's cloud API.
 449///
 450/// Example JSON for an upstream HTTP error:
 451/// ```json
 452/// {
 453///   "code": "upstream_http_error",
 454///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 455///   "upstream_status": 503
 456/// }
 457/// ```
 458#[derive(Debug, serde::Deserialize)]
 459struct CloudApiError {
 460    code: String,
 461    message: String,
 462    #[serde(default)]
 463    #[serde(deserialize_with = "deserialize_optional_status_code")]
 464    upstream_status: Option<StatusCode>,
 465    #[serde(default)]
 466    retry_after: Option<f64>,
 467}
 468
 469fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 470where
 471    D: serde::Deserializer<'de>,
 472{
 473    let opt: Option<u16> = Option::deserialize(deserializer)?;
 474    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 475}
 476
 477impl From<ApiError> for LanguageModelCompletionError {
 478    fn from(error: ApiError) -> Self {
 479        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 480            if cloud_error.code.starts_with("upstream_http_") {
 481                let status = if let Some(status) = cloud_error.upstream_status {
 482                    status
 483                } else if cloud_error.code.ends_with("_error") {
 484                    error.status
 485                } else {
 486                    // If there's a status code in the code string (e.g. "upstream_http_429")
 487                    // then use that; otherwise, see if the JSON contains a status code.
 488                    cloud_error
 489                        .code
 490                        .strip_prefix("upstream_http_")
 491                        .and_then(|code_str| code_str.parse::<u16>().ok())
 492                        .and_then(|code| StatusCode::from_u16(code).ok())
 493                        .unwrap_or(error.status)
 494                };
 495
 496                return LanguageModelCompletionError::UpstreamProviderError {
 497                    message: cloud_error.message,
 498                    status,
 499                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 500                };
 501            }
 502
 503            return LanguageModelCompletionError::from_http_status(
 504                PROVIDER_NAME,
 505                error.status,
 506                cloud_error.message,
 507                None,
 508            );
 509        }
 510
 511        let retry_after = None;
 512        LanguageModelCompletionError::from_http_status(
 513            PROVIDER_NAME,
 514            error.status,
 515            error.body,
 516            retry_after,
 517        )
 518    }
 519}
 520
 521impl LanguageModel for CloudLanguageModel {
 522    fn id(&self) -> LanguageModelId {
 523        self.id.clone()
 524    }
 525
 526    fn name(&self) -> LanguageModelName {
 527        LanguageModelName::from(self.model.display_name.clone())
 528    }
 529
 530    fn provider_id(&self) -> LanguageModelProviderId {
 531        PROVIDER_ID
 532    }
 533
 534    fn provider_name(&self) -> LanguageModelProviderName {
 535        PROVIDER_NAME
 536    }
 537
 538    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 539        use cloud_llm_client::LanguageModelProvider::*;
 540        match self.model.provider {
 541            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 542            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 543            Google => language_model::GOOGLE_PROVIDER_ID,
 544            XAi => language_model::X_AI_PROVIDER_ID,
 545        }
 546    }
 547
 548    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 549        use cloud_llm_client::LanguageModelProvider::*;
 550        match self.model.provider {
 551            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 552            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 553            Google => language_model::GOOGLE_PROVIDER_NAME,
 554            XAi => language_model::X_AI_PROVIDER_NAME,
 555        }
 556    }
 557
 558    fn is_latest(&self) -> bool {
 559        self.model.is_latest
 560    }
 561
 562    fn supports_tools(&self) -> bool {
 563        self.model.supports_tools
 564    }
 565
 566    fn supports_images(&self) -> bool {
 567        self.model.supports_images
 568    }
 569
 570    fn supports_thinking(&self) -> bool {
 571        self.model.supports_thinking
 572    }
 573
 574    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 575        self.model
 576            .supported_effort_levels
 577            .iter()
 578            .map(|effort_level| LanguageModelEffortLevel {
 579                name: effort_level.name.clone().into(),
 580                value: effort_level.value.clone().into(),
 581                is_default: effort_level.is_default.unwrap_or(false),
 582            })
 583            .collect()
 584    }
 585
 586    fn supports_streaming_tools(&self) -> bool {
 587        self.model.supports_streaming_tools
 588    }
 589
 590    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 591        match choice {
 592            LanguageModelToolChoice::Auto
 593            | LanguageModelToolChoice::Any
 594            | LanguageModelToolChoice::None => true,
 595        }
 596    }
 597
 598    fn supports_split_token_display(&self) -> bool {
 599        use cloud_llm_client::LanguageModelProvider::*;
 600        matches!(self.model.provider, OpenAi)
 601    }
 602
 603    fn telemetry_id(&self) -> String {
 604        format!("zed.dev/{}", self.model.id)
 605    }
 606
 607    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 608        match self.model.provider {
 609            cloud_llm_client::LanguageModelProvider::Anthropic
 610            | cloud_llm_client::LanguageModelProvider::OpenAi
 611            | cloud_llm_client::LanguageModelProvider::XAi => {
 612                LanguageModelToolSchemaFormat::JsonSchema
 613            }
 614            cloud_llm_client::LanguageModelProvider::Google => {
 615                LanguageModelToolSchemaFormat::JsonSchemaSubset
 616            }
 617        }
 618    }
 619
 620    fn max_token_count(&self) -> u64 {
 621        self.model.max_token_count as u64
 622    }
 623
 624    fn max_output_tokens(&self) -> Option<u64> {
 625        Some(self.model.max_output_tokens as u64)
 626    }
 627
 628    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 629        match &self.model.provider {
 630            cloud_llm_client::LanguageModelProvider::Anthropic => {
 631                Some(LanguageModelCacheConfiguration {
 632                    min_total_token: 2_048,
 633                    should_speculate: true,
 634                    max_cache_anchors: 4,
 635                })
 636            }
 637            cloud_llm_client::LanguageModelProvider::OpenAi
 638            | cloud_llm_client::LanguageModelProvider::XAi
 639            | cloud_llm_client::LanguageModelProvider::Google => None,
 640        }
 641    }
 642
 643    fn count_tokens(
 644        &self,
 645        request: LanguageModelRequest,
 646        cx: &App,
 647    ) -> BoxFuture<'static, Result<u64>> {
 648        match self.model.provider {
 649            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 650                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 651                .boxed(),
 652            cloud_llm_client::LanguageModelProvider::OpenAi => {
 653                let model = match open_ai::Model::from_id(&self.model.id.0) {
 654                    Ok(model) => model,
 655                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 656                };
 657                count_open_ai_tokens(request, model, cx)
 658            }
 659            cloud_llm_client::LanguageModelProvider::XAi => {
 660                let model = match x_ai::Model::from_id(&self.model.id.0) {
 661                    Ok(model) => model,
 662                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 663                };
 664                count_xai_tokens(request, model, cx)
 665            }
 666            cloud_llm_client::LanguageModelProvider::Google => {
 667                let client = self.client.clone();
 668                let llm_api_token = self.llm_api_token.clone();
 669                let model_id = self.model.id.to_string();
 670                let generate_content_request =
 671                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 672                async move {
 673                    let http_client = &client.http_client();
 674                    let token = llm_api_token.acquire(&client).await?;
 675
 676                    let request_body = CountTokensBody {
 677                        provider: cloud_llm_client::LanguageModelProvider::Google,
 678                        model: model_id,
 679                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 680                            generate_content_request,
 681                        })?,
 682                    };
 683                    let request = http_client::Request::builder()
 684                        .method(Method::POST)
 685                        .uri(
 686                            http_client
 687                                .build_zed_llm_url("/count_tokens", &[])?
 688                                .as_ref(),
 689                        )
 690                        .header("Content-Type", "application/json")
 691                        .header("Authorization", format!("Bearer {token}"))
 692                        .body(serde_json::to_string(&request_body)?.into())?;
 693                    let mut response = http_client.send(request).await?;
 694                    let status = response.status();
 695                    let headers = response.headers().clone();
 696                    let mut response_body = String::new();
 697                    response
 698                        .body_mut()
 699                        .read_to_string(&mut response_body)
 700                        .await?;
 701
 702                    if status.is_success() {
 703                        let response_body: CountTokensResponse =
 704                            serde_json::from_str(&response_body)?;
 705
 706                        Ok(response_body.tokens as u64)
 707                    } else {
 708                        Err(anyhow!(ApiError {
 709                            status,
 710                            body: response_body,
 711                            headers
 712                        }))
 713                    }
 714                }
 715                .boxed()
 716            }
 717        }
 718    }
 719
 720    fn stream_completion(
 721        &self,
 722        request: LanguageModelRequest,
 723        cx: &AsyncApp,
 724    ) -> BoxFuture<
 725        'static,
 726        Result<
 727            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 728            LanguageModelCompletionError,
 729        >,
 730    > {
 731        let thread_id = request.thread_id.clone();
 732        let prompt_id = request.prompt_id.clone();
 733        let intent = request.intent;
 734        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 735        let thinking_allowed = request.thinking_allowed;
 736        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 737        let provider_name = provider_name(&self.model.provider);
 738        match self.model.provider {
 739            cloud_llm_client::LanguageModelProvider::Anthropic => {
 740                let effort = request
 741                    .thinking_effort
 742                    .as_ref()
 743                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 744
 745                let mut request = into_anthropic(
 746                    request,
 747                    self.model.id.to_string(),
 748                    1.0,
 749                    self.model.max_output_tokens as u64,
 750                    if enable_thinking {
 751                        AnthropicModelMode::Thinking {
 752                            budget_tokens: Some(4_096),
 753                        }
 754                    } else {
 755                        AnthropicModelMode::Default
 756                    },
 757                );
 758
 759                if enable_thinking && effort.is_some() {
 760                    request.thinking = Some(anthropic::Thinking::Adaptive);
 761                    request.output_config = Some(anthropic::OutputConfig { effort });
 762                }
 763
 764                let client = self.client.clone();
 765                let llm_api_token = self.llm_api_token.clone();
 766                let future = self.request_limiter.stream(async move {
 767                    let PerformLlmCompletionResponse {
 768                        response,
 769                        includes_status_messages,
 770                    } = Self::perform_llm_completion(
 771                        client.clone(),
 772                        llm_api_token,
 773                        app_version,
 774                        CompletionBody {
 775                            thread_id,
 776                            prompt_id,
 777                            intent,
 778                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 779                            model: request.model.clone(),
 780                            provider_request: serde_json::to_value(&request)
 781                                .map_err(|e| anyhow!(e))?,
 782                        },
 783                    )
 784                    .await
 785                    .map_err(|err| match err.downcast::<ApiError>() {
 786                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 787                        Err(err) => anyhow!(err),
 788                    })?;
 789
 790                    let mut mapper = AnthropicEventMapper::new();
 791                    Ok(map_cloud_completion_events(
 792                        Box::pin(response_lines(response, includes_status_messages)),
 793                        &provider_name,
 794                        move |event| mapper.map_event(event),
 795                    ))
 796                });
 797                async move { Ok(future.await?.boxed()) }.boxed()
 798            }
 799            cloud_llm_client::LanguageModelProvider::OpenAi => {
 800                let client = self.client.clone();
 801                let llm_api_token = self.llm_api_token.clone();
 802                let effort = request
 803                    .thinking_effort
 804                    .as_ref()
 805                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 806
 807                let mut request = into_open_ai_response(
 808                    request,
 809                    &self.model.id.0,
 810                    self.model.supports_parallel_tool_calls,
 811                    true,
 812                    None,
 813                    None,
 814                );
 815
 816                if enable_thinking && let Some(effort) = effort {
 817                    request.reasoning = Some(open_ai::responses::ReasoningConfig { effort });
 818                }
 819
 820                let future = self.request_limiter.stream(async move {
 821                    let PerformLlmCompletionResponse {
 822                        response,
 823                        includes_status_messages,
 824                    } = Self::perform_llm_completion(
 825                        client.clone(),
 826                        llm_api_token,
 827                        app_version,
 828                        CompletionBody {
 829                            thread_id,
 830                            prompt_id,
 831                            intent,
 832                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 833                            model: request.model.clone(),
 834                            provider_request: serde_json::to_value(&request)
 835                                .map_err(|e| anyhow!(e))?,
 836                        },
 837                    )
 838                    .await?;
 839
 840                    let mut mapper = OpenAiResponseEventMapper::new();
 841                    Ok(map_cloud_completion_events(
 842                        Box::pin(response_lines(response, includes_status_messages)),
 843                        &provider_name,
 844                        move |event| mapper.map_event(event),
 845                    ))
 846                });
 847                async move { Ok(future.await?.boxed()) }.boxed()
 848            }
 849            cloud_llm_client::LanguageModelProvider::XAi => {
 850                let client = self.client.clone();
 851                let request = into_open_ai(
 852                    request,
 853                    &self.model.id.0,
 854                    self.model.supports_parallel_tool_calls,
 855                    false,
 856                    None,
 857                    None,
 858                );
 859                let llm_api_token = self.llm_api_token.clone();
 860                let future = self.request_limiter.stream(async move {
 861                    let PerformLlmCompletionResponse {
 862                        response,
 863                        includes_status_messages,
 864                    } = Self::perform_llm_completion(
 865                        client.clone(),
 866                        llm_api_token,
 867                        app_version,
 868                        CompletionBody {
 869                            thread_id,
 870                            prompt_id,
 871                            intent,
 872                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 873                            model: request.model.clone(),
 874                            provider_request: serde_json::to_value(&request)
 875                                .map_err(|e| anyhow!(e))?,
 876                        },
 877                    )
 878                    .await?;
 879
 880                    let mut mapper = OpenAiEventMapper::new();
 881                    Ok(map_cloud_completion_events(
 882                        Box::pin(response_lines(response, includes_status_messages)),
 883                        &provider_name,
 884                        move |event| mapper.map_event(event),
 885                    ))
 886                });
 887                async move { Ok(future.await?.boxed()) }.boxed()
 888            }
 889            cloud_llm_client::LanguageModelProvider::Google => {
 890                let client = self.client.clone();
 891                let request =
 892                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 893                let llm_api_token = self.llm_api_token.clone();
 894                let future = self.request_limiter.stream(async move {
 895                    let PerformLlmCompletionResponse {
 896                        response,
 897                        includes_status_messages,
 898                    } = Self::perform_llm_completion(
 899                        client.clone(),
 900                        llm_api_token,
 901                        app_version,
 902                        CompletionBody {
 903                            thread_id,
 904                            prompt_id,
 905                            intent,
 906                            provider: cloud_llm_client::LanguageModelProvider::Google,
 907                            model: request.model.model_id.clone(),
 908                            provider_request: serde_json::to_value(&request)
 909                                .map_err(|e| anyhow!(e))?,
 910                        },
 911                    )
 912                    .await?;
 913
 914                    let mut mapper = GoogleEventMapper::new();
 915                    Ok(map_cloud_completion_events(
 916                        Box::pin(response_lines(response, includes_status_messages)),
 917                        &provider_name,
 918                        move |event| mapper.map_event(event),
 919                    ))
 920                });
 921                async move { Ok(future.await?.boxed()) }.boxed()
 922            }
 923        }
 924    }
 925}
 926
 927fn map_cloud_completion_events<T, F>(
 928    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 929    provider: &LanguageModelProviderName,
 930    mut map_callback: F,
 931) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 932where
 933    T: DeserializeOwned + 'static,
 934    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 935        + Send
 936        + 'static,
 937{
 938    let provider = provider.clone();
 939    let mut stream = stream.fuse();
 940
 941    let mut saw_stream_ended = false;
 942
 943    let mut done = false;
 944    let mut pending = VecDeque::new();
 945
 946    stream::poll_fn(move |cx| {
 947        loop {
 948            if let Some(item) = pending.pop_front() {
 949                return Poll::Ready(Some(item));
 950            }
 951
 952            if done {
 953                return Poll::Ready(None);
 954            }
 955
 956            match stream.poll_next_unpin(cx) {
 957                Poll::Ready(Some(event)) => {
 958                    let items = match event {
 959                        Err(error) => {
 960                            vec![Err(LanguageModelCompletionError::from(error))]
 961                        }
 962                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
 963                            saw_stream_ended = true;
 964                            vec![]
 965                        }
 966                        Ok(CompletionEvent::Status(status)) => {
 967                            LanguageModelCompletionEvent::from_completion_request_status(
 968                                status,
 969                                provider.clone(),
 970                            )
 971                            .transpose()
 972                            .map(|event| vec![event])
 973                            .unwrap_or_default()
 974                        }
 975                        Ok(CompletionEvent::Event(event)) => map_callback(event),
 976                    };
 977                    pending.extend(items);
 978                }
 979                Poll::Ready(None) => {
 980                    done = true;
 981
 982                    if !saw_stream_ended {
 983                        return Poll::Ready(Some(Err(
 984                            LanguageModelCompletionError::StreamEndedUnexpectedly {
 985                                provider: provider.clone(),
 986                            },
 987                        )));
 988                    }
 989                }
 990                Poll::Pending => return Poll::Pending,
 991            }
 992        }
 993    })
 994    .boxed()
 995}
 996
 997fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
 998    match provider {
 999        cloud_llm_client::LanguageModelProvider::Anthropic => {
1000            language_model::ANTHROPIC_PROVIDER_NAME
1001        }
1002        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1003        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1004        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1005    }
1006}
1007
1008fn response_lines<T: DeserializeOwned>(
1009    response: Response<AsyncBody>,
1010    includes_status_messages: bool,
1011) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1012    futures::stream::try_unfold(
1013        (String::new(), BufReader::new(response.into_body())),
1014        move |(mut line, mut body)| async move {
1015            match body.read_line(&mut line).await {
1016                Ok(0) => Ok(None),
1017                Ok(_) => {
1018                    let event = if includes_status_messages {
1019                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1020                    } else {
1021                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1022                    };
1023
1024                    line.clear();
1025                    Ok(Some((event, (line, body))))
1026                }
1027                Err(e) => Err(e.into()),
1028            }
1029        },
1030    )
1031}
1032
1033#[derive(IntoElement, RegisterComponent)]
1034struct ZedAiConfiguration {
1035    is_connected: bool,
1036    plan: Option<Plan>,
1037    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1038    eligible_for_trial: bool,
1039    account_too_young: bool,
1040    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1041}
1042
1043impl RenderOnce for ZedAiConfiguration {
1044    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1045        let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
1046        let subscription_text = match (self.plan, self.subscription_period) {
1047            (Some(Plan::ZedPro), Some(_)) => {
1048                "You have access to Zed's hosted models through your Pro subscription."
1049            }
1050            (Some(Plan::ZedProTrial), Some(_)) => {
1051                "You have access to Zed's hosted models through your Pro trial."
1052            }
1053            (Some(Plan::ZedFree), Some(_)) => {
1054                if self.eligible_for_trial {
1055                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1056                } else {
1057                    "Subscribe for access to Zed's hosted models."
1058                }
1059            }
1060            _ => {
1061                if self.eligible_for_trial {
1062                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1063                } else {
1064                    "Subscribe for access to Zed's hosted models."
1065                }
1066            }
1067        };
1068
1069        let manage_subscription_buttons = if is_pro {
1070            Button::new("manage_settings", "Manage Subscription")
1071                .full_width()
1072                .style(ButtonStyle::Tinted(TintColor::Accent))
1073                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1074                .into_any_element()
1075        } else if self.plan.is_none() || self.eligible_for_trial {
1076            Button::new("start_trial", "Start 14-day Free Pro Trial")
1077                .full_width()
1078                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1079                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1080                .into_any_element()
1081        } else {
1082            Button::new("upgrade", "Upgrade to Pro")
1083                .full_width()
1084                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1085                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1086                .into_any_element()
1087        };
1088
1089        if !self.is_connected {
1090            return v_flex()
1091                .gap_2()
1092                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1093                .child(
1094                    Button::new("sign_in", "Sign In to use Zed AI")
1095                        .icon_color(Color::Muted)
1096                        .icon(IconName::Github)
1097                        .icon_size(IconSize::Small)
1098                        .icon_position(IconPosition::Start)
1099                        .full_width()
1100                        .on_click({
1101                            let callback = self.sign_in_callback.clone();
1102                            move |_, window, cx| (callback)(window, cx)
1103                        }),
1104                );
1105        }
1106
1107        v_flex().gap_2().w_full().map(|this| {
1108            if self.account_too_young {
1109                this.child(YoungAccountBanner).child(
1110                    Button::new("upgrade", "Upgrade to Pro")
1111                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1112                        .full_width()
1113                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1114                )
1115            } else {
1116                this.text_sm()
1117                    .child(subscription_text)
1118                    .child(manage_subscription_buttons)
1119            }
1120        })
1121    }
1122}
1123
1124struct ConfigurationView {
1125    state: Entity<State>,
1126    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1127}
1128
1129impl ConfigurationView {
1130    fn new(state: Entity<State>) -> Self {
1131        let sign_in_callback = Arc::new({
1132            let state = state.clone();
1133            move |_window: &mut Window, cx: &mut App| {
1134                state.update(cx, |state, cx| {
1135                    state.authenticate(cx).detach_and_log_err(cx);
1136                });
1137            }
1138        });
1139
1140        Self {
1141            state,
1142            sign_in_callback,
1143        }
1144    }
1145}
1146
1147impl Render for ConfigurationView {
1148    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1149        let state = self.state.read(cx);
1150        let user_store = state.user_store.read(cx);
1151
1152        ZedAiConfiguration {
1153            is_connected: !state.is_signed_out(cx),
1154            plan: user_store.plan(),
1155            subscription_period: user_store.subscription_period(),
1156            eligible_for_trial: user_store.trial_started_at().is_none(),
1157            account_too_young: user_store.account_too_young(),
1158            sign_in_callback: self.sign_in_callback.clone(),
1159        }
1160    }
1161}
1162
1163impl Component for ZedAiConfiguration {
1164    fn name() -> &'static str {
1165        "AI Configuration Content"
1166    }
1167
1168    fn sort_name() -> &'static str {
1169        "AI Configuration Content"
1170    }
1171
1172    fn scope() -> ComponentScope {
1173        ComponentScope::Onboarding
1174    }
1175
1176    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1177        fn configuration(
1178            is_connected: bool,
1179            plan: Option<Plan>,
1180            eligible_for_trial: bool,
1181            account_too_young: bool,
1182        ) -> AnyElement {
1183            ZedAiConfiguration {
1184                is_connected,
1185                plan,
1186                subscription_period: plan
1187                    .is_some()
1188                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1189                eligible_for_trial,
1190                account_too_young,
1191                sign_in_callback: Arc::new(|_, _| {}),
1192            }
1193            .into_any_element()
1194        }
1195
1196        Some(
1197            v_flex()
1198                .p_4()
1199                .gap_4()
1200                .children(vec![
1201                    single_example("Not connected", configuration(false, None, false, false)),
1202                    single_example(
1203                        "Accept Terms of Service",
1204                        configuration(true, None, true, false),
1205                    ),
1206                    single_example(
1207                        "No Plan - Not eligible for trial",
1208                        configuration(true, None, false, false),
1209                    ),
1210                    single_example(
1211                        "No Plan - Eligible for trial",
1212                        configuration(true, None, true, false),
1213                    ),
1214                    single_example(
1215                        "Free Plan",
1216                        configuration(true, Some(Plan::ZedFree), true, false),
1217                    ),
1218                    single_example(
1219                        "Zed Pro Trial Plan",
1220                        configuration(true, Some(Plan::ZedProTrial), true, false),
1221                    ),
1222                    single_example(
1223                        "Zed Pro Plan",
1224                        configuration(true, Some(Plan::ZedPro), true, false),
1225                    ),
1226                ])
1227                .into_any_element(),
1228        )
1229    }
1230}
1231
1232#[cfg(test)]
1233mod tests {
1234    use super::*;
1235    use http_client::http::{HeaderMap, StatusCode};
1236    use language_model::LanguageModelCompletionError;
1237
1238    #[test]
1239    fn test_api_error_conversion_with_upstream_http_error() {
1240        // upstream_http_error with 503 status should become ServerOverloaded
1241        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1242
1243        let api_error = ApiError {
1244            status: StatusCode::INTERNAL_SERVER_ERROR,
1245            body: error_body.to_string(),
1246            headers: HeaderMap::new(),
1247        };
1248
1249        let completion_error: LanguageModelCompletionError = api_error.into();
1250
1251        match completion_error {
1252            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1253                assert_eq!(
1254                    message,
1255                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1256                );
1257            }
1258            _ => panic!(
1259                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1260                completion_error
1261            ),
1262        }
1263
1264        // upstream_http_error with 500 status should become ApiInternalServerError
1265        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1266
1267        let api_error = ApiError {
1268            status: StatusCode::INTERNAL_SERVER_ERROR,
1269            body: error_body.to_string(),
1270            headers: HeaderMap::new(),
1271        };
1272
1273        let completion_error: LanguageModelCompletionError = api_error.into();
1274
1275        match completion_error {
1276            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1277                assert_eq!(
1278                    message,
1279                    "Received an error from the OpenAI API: internal server error"
1280                );
1281            }
1282            _ => panic!(
1283                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1284                completion_error
1285            ),
1286        }
1287
1288        // upstream_http_error with 429 status should become RateLimitExceeded
1289        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1290
1291        let api_error = ApiError {
1292            status: StatusCode::INTERNAL_SERVER_ERROR,
1293            body: error_body.to_string(),
1294            headers: HeaderMap::new(),
1295        };
1296
1297        let completion_error: LanguageModelCompletionError = api_error.into();
1298
1299        match completion_error {
1300            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1301                assert_eq!(
1302                    message,
1303                    "Received an error from the Google API: rate limit exceeded"
1304                );
1305            }
1306            _ => panic!(
1307                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1308                completion_error
1309            ),
1310        }
1311
1312        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1313        let error_body = "Regular internal server error";
1314
1315        let api_error = ApiError {
1316            status: StatusCode::INTERNAL_SERVER_ERROR,
1317            body: error_body.to_string(),
1318            headers: HeaderMap::new(),
1319        };
1320
1321        let completion_error: LanguageModelCompletionError = api_error.into();
1322
1323        match completion_error {
1324            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1325                assert_eq!(provider, PROVIDER_NAME);
1326                assert_eq!(message, "Regular internal server error");
1327            }
1328            _ => panic!(
1329                "Expected ApiInternalServerError for regular 500, got: {:?}",
1330                completion_error
1331            ),
1332        }
1333
1334        // upstream_http_429 format should be converted to UpstreamProviderError
1335        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1336
1337        let api_error = ApiError {
1338            status: StatusCode::INTERNAL_SERVER_ERROR,
1339            body: error_body.to_string(),
1340            headers: HeaderMap::new(),
1341        };
1342
1343        let completion_error: LanguageModelCompletionError = api_error.into();
1344
1345        match completion_error {
1346            LanguageModelCompletionError::UpstreamProviderError {
1347                message,
1348                status,
1349                retry_after,
1350            } => {
1351                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1352                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1353                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1354            }
1355            _ => panic!(
1356                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1357                completion_error
1358            ),
1359        }
1360
1361        // Invalid JSON in error body should fall back to regular error handling
1362        let error_body = "Not JSON at all";
1363
1364        let api_error = ApiError {
1365            status: StatusCode::INTERNAL_SERVER_ERROR,
1366            body: error_body.to_string(),
1367            headers: HeaderMap::new(),
1368        };
1369
1370        let completion_error: LanguageModelCompletionError = api_error.into();
1371
1372        match completion_error {
1373            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1374                assert_eq!(provider, PROVIDER_NAME);
1375            }
1376            _ => panic!(
1377                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1378                completion_error
1379            ),
1380        }
1381    }
1382}