cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, UserStore, zed_urls};
   6use cloud_api_types::Plan;
   7use cloud_llm_client::{
   8    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody,
   9    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
  10    ListModelsResponse, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  11};
  12use futures::{
  13    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  14    future::BoxFuture,
  15    stream::{self, BoxStream},
  16};
  17use google_ai::GoogleModelMode;
  18use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  19use http_client::http::{HeaderMap, HeaderValue};
  20use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  21use language_model::{
  22    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  23    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  24    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  25    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  26    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  27    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  28};
  29use release_channel::AppVersion;
  30use schemars::JsonSchema;
  31use semver::Version;
  32use serde::{Deserialize, Serialize, de::DeserializeOwned};
  33use settings::SettingsStore;
  34pub use settings::ZedDotDevAvailableModel as AvailableModel;
  35pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  36use smol::io::{AsyncReadExt, BufReader};
  37use std::collections::VecDeque;
  38use std::pin::Pin;
  39use std::str::FromStr;
  40use std::sync::Arc;
  41use std::task::Poll;
  42use std::time::Duration;
  43use thiserror::Error;
  44use ui::{TintColor, prelude::*};
  45use util::{ResultExt as _, maybe};
  46
  47use crate::provider::anthropic::{
  48    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  49};
  50use crate::provider::google::{GoogleEventMapper, into_google};
  51use crate::provider::open_ai::{
  52    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  53    into_open_ai_response,
  54};
  55use crate::provider::x_ai::count_xai_tokens;
  56
  57const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  58const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  59
  60#[derive(Default, Clone, Debug, PartialEq)]
  61pub struct ZedDotDevSettings {
  62    pub available_models: Vec<AvailableModel>,
  63}
  64#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  65#[serde(tag = "type", rename_all = "lowercase")]
  66pub enum ModelMode {
  67    #[default]
  68    Default,
  69    Thinking {
  70        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  71        budget_tokens: Option<u32>,
  72    },
  73}
  74
  75impl From<ModelMode> for AnthropicModelMode {
  76    fn from(value: ModelMode) -> Self {
  77        match value {
  78            ModelMode::Default => AnthropicModelMode::Default,
  79            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  80        }
  81    }
  82}
  83
  84pub struct CloudLanguageModelProvider {
  85    client: Arc<Client>,
  86    state: Entity<State>,
  87    _maintain_client_status: Task<()>,
  88}
  89
  90pub struct State {
  91    client: Arc<Client>,
  92    llm_api_token: LlmApiToken,
  93    user_store: Entity<UserStore>,
  94    status: client::Status,
  95    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  96    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  97    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  98    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  99    _fetch_models_task: Task<()>,
 100    _settings_subscription: Subscription,
 101    _llm_token_subscription: Subscription,
 102}
 103
 104impl State {
 105    fn new(
 106        client: Arc<Client>,
 107        user_store: Entity<UserStore>,
 108        status: client::Status,
 109        cx: &mut Context<Self>,
 110    ) -> Self {
 111        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 112        let mut current_user = user_store.read(cx).watch_current_user();
 113        Self {
 114            client: client.clone(),
 115            llm_api_token: LlmApiToken::default(),
 116            user_store,
 117            status,
 118            models: Vec::new(),
 119            default_model: None,
 120            default_fast_model: None,
 121            recommended_models: Vec::new(),
 122            _fetch_models_task: cx.spawn(async move |this, cx| {
 123                maybe!(async move {
 124                    let (client, llm_api_token) = this
 125                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 126
 127                    while current_user.borrow().is_none() {
 128                        current_user.next().await;
 129                    }
 130
 131                    let response =
 132                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 133                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 134                    anyhow::Ok(())
 135                })
 136                .await
 137                .context("failed to fetch Zed models")
 138                .log_err();
 139            }),
 140            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 141                cx.notify();
 142            }),
 143            _llm_token_subscription: cx.subscribe(
 144                &refresh_llm_token_listener,
 145                move |this, _listener, _event, cx| {
 146                    let client = this.client.clone();
 147                    let llm_api_token = this.llm_api_token.clone();
 148                    cx.spawn(async move |this, cx| {
 149                        llm_api_token.refresh(&client).await?;
 150                        let response = Self::fetch_models(client, llm_api_token).await?;
 151                        this.update(cx, |this, cx| {
 152                            this.update_models(response, cx);
 153                        })
 154                    })
 155                    .detach_and_log_err(cx);
 156                },
 157            ),
 158        }
 159    }
 160
 161    fn is_signed_out(&self, cx: &App) -> bool {
 162        self.user_store.read(cx).current_user().is_none()
 163    }
 164
 165    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 166        let client = self.client.clone();
 167        cx.spawn(async move |state, cx| {
 168            client.sign_in_with_optional_connect(true, cx).await?;
 169            state.update(cx, |_, cx| cx.notify())
 170        })
 171    }
 172
 173    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 174        let mut models = Vec::new();
 175
 176        for model in response.models {
 177            models.push(Arc::new(model.clone()));
 178        }
 179
 180        self.default_model = models
 181            .iter()
 182            .find(|model| {
 183                response
 184                    .default_model
 185                    .as_ref()
 186                    .is_some_and(|default_model_id| &model.id == default_model_id)
 187            })
 188            .cloned();
 189        self.default_fast_model = models
 190            .iter()
 191            .find(|model| {
 192                response
 193                    .default_fast_model
 194                    .as_ref()
 195                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 196            })
 197            .cloned();
 198        self.recommended_models = response
 199            .recommended_models
 200            .iter()
 201            .filter_map(|id| models.iter().find(|model| &model.id == id))
 202            .cloned()
 203            .collect();
 204        self.models = models;
 205        cx.notify();
 206    }
 207
 208    async fn fetch_models(
 209        client: Arc<Client>,
 210        llm_api_token: LlmApiToken,
 211    ) -> Result<ListModelsResponse> {
 212        let http_client = &client.http_client();
 213        let token = llm_api_token.acquire(&client).await?;
 214
 215        let request = http_client::Request::builder()
 216            .method(Method::GET)
 217            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 218            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 219            .header("Authorization", format!("Bearer {token}"))
 220            .body(AsyncBody::empty())?;
 221        let mut response = http_client
 222            .send(request)
 223            .await
 224            .context("failed to send list models request")?;
 225
 226        if response.status().is_success() {
 227            let mut body = String::new();
 228            response.body_mut().read_to_string(&mut body).await?;
 229            Ok(serde_json::from_str(&body)?)
 230        } else {
 231            let mut body = String::new();
 232            response.body_mut().read_to_string(&mut body).await?;
 233            anyhow::bail!(
 234                "error listing models.\nStatus: {:?}\nBody: {body}",
 235                response.status(),
 236            );
 237        }
 238    }
 239}
 240
 241impl CloudLanguageModelProvider {
 242    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 243        let mut status_rx = client.status();
 244        let status = *status_rx.borrow();
 245
 246        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 247
 248        let state_ref = state.downgrade();
 249        let maintain_client_status = cx.spawn(async move |cx| {
 250            while let Some(status) = status_rx.next().await {
 251                if let Some(this) = state_ref.upgrade() {
 252                    _ = this.update(cx, |this, cx| {
 253                        if this.status != status {
 254                            this.status = status;
 255                            cx.notify();
 256                        }
 257                    });
 258                } else {
 259                    break;
 260                }
 261            }
 262        });
 263
 264        Self {
 265            client,
 266            state,
 267            _maintain_client_status: maintain_client_status,
 268        }
 269    }
 270
 271    fn create_language_model(
 272        &self,
 273        model: Arc<cloud_llm_client::LanguageModel>,
 274        llm_api_token: LlmApiToken,
 275    ) -> Arc<dyn LanguageModel> {
 276        Arc::new(CloudLanguageModel {
 277            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 278            model,
 279            llm_api_token,
 280            client: self.client.clone(),
 281            request_limiter: RateLimiter::new(4),
 282        })
 283    }
 284}
 285
 286impl LanguageModelProviderState for CloudLanguageModelProvider {
 287    type ObservableEntity = State;
 288
 289    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 290        Some(self.state.clone())
 291    }
 292}
 293
 294impl LanguageModelProvider for CloudLanguageModelProvider {
 295    fn id(&self) -> LanguageModelProviderId {
 296        PROVIDER_ID
 297    }
 298
 299    fn name(&self) -> LanguageModelProviderName {
 300        PROVIDER_NAME
 301    }
 302
 303    fn icon(&self) -> IconOrSvg {
 304        IconOrSvg::Icon(IconName::AiZed)
 305    }
 306
 307    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 308        let default_model = self.state.read(cx).default_model.clone()?;
 309        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 310        Some(self.create_language_model(default_model, llm_api_token))
 311    }
 312
 313    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 314        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 315        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 316        Some(self.create_language_model(default_fast_model, llm_api_token))
 317    }
 318
 319    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 320        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 321        self.state
 322            .read(cx)
 323            .recommended_models
 324            .iter()
 325            .cloned()
 326            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 327            .collect()
 328    }
 329
 330    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 331        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 332        self.state
 333            .read(cx)
 334            .models
 335            .iter()
 336            .cloned()
 337            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 338            .collect()
 339    }
 340
 341    fn is_authenticated(&self, cx: &App) -> bool {
 342        let state = self.state.read(cx);
 343        !state.is_signed_out(cx)
 344    }
 345
 346    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 347        Task::ready(Ok(()))
 348    }
 349
 350    fn configuration_view(
 351        &self,
 352        _target_agent: language_model::ConfigurationViewTargetAgent,
 353        _: &mut Window,
 354        cx: &mut App,
 355    ) -> AnyView {
 356        cx.new(|_| ConfigurationView::new(self.state.clone()))
 357            .into()
 358    }
 359
 360    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 361        Task::ready(Ok(()))
 362    }
 363}
 364
 365pub struct CloudLanguageModel {
 366    id: LanguageModelId,
 367    model: Arc<cloud_llm_client::LanguageModel>,
 368    llm_api_token: LlmApiToken,
 369    client: Arc<Client>,
 370    request_limiter: RateLimiter,
 371}
 372
 373struct PerformLlmCompletionResponse {
 374    response: Response<AsyncBody>,
 375    includes_status_messages: bool,
 376}
 377
 378impl CloudLanguageModel {
 379    async fn perform_llm_completion(
 380        client: Arc<Client>,
 381        llm_api_token: LlmApiToken,
 382        app_version: Option<Version>,
 383        body: CompletionBody,
 384    ) -> Result<PerformLlmCompletionResponse> {
 385        let http_client = &client.http_client();
 386
 387        let mut token = llm_api_token.acquire(&client).await?;
 388        let mut refreshed_token = false;
 389
 390        loop {
 391            let request = http_client::Request::builder()
 392                .method(Method::POST)
 393                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 394                .when_some(app_version.as_ref(), |builder, app_version| {
 395                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 396                })
 397                .header("Content-Type", "application/json")
 398                .header("Authorization", format!("Bearer {token}"))
 399                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 400                // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
 401                // .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 402                .body(serde_json::to_string(&body)?.into())?;
 403
 404            let mut response = http_client.send(request).await?;
 405            let status = response.status();
 406            if status.is_success() {
 407                let includes_status_messages = response
 408                    .headers()
 409                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 410                    .is_some();
 411
 412                return Ok(PerformLlmCompletionResponse {
 413                    response,
 414                    includes_status_messages,
 415                });
 416            }
 417
 418            if !refreshed_token && response.needs_llm_token_refresh() {
 419                token = llm_api_token.refresh(&client).await?;
 420                refreshed_token = true;
 421                continue;
 422            }
 423
 424            if status == StatusCode::PAYMENT_REQUIRED {
 425                return Err(anyhow!(PaymentRequiredError));
 426            }
 427
 428            let mut body = String::new();
 429            let headers = response.headers().clone();
 430            response.body_mut().read_to_string(&mut body).await?;
 431            return Err(anyhow!(ApiError {
 432                status,
 433                body,
 434                headers
 435            }));
 436        }
 437    }
 438}
 439
 440#[derive(Debug, Error)]
 441#[error("cloud language model request failed with status {status}: {body}")]
 442struct ApiError {
 443    status: StatusCode,
 444    body: String,
 445    headers: HeaderMap<HeaderValue>,
 446}
 447
 448/// Represents error responses from Zed's cloud API.
 449///
 450/// Example JSON for an upstream HTTP error:
 451/// ```json
 452/// {
 453///   "code": "upstream_http_error",
 454///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 455///   "upstream_status": 503
 456/// }
 457/// ```
 458#[derive(Debug, serde::Deserialize)]
 459struct CloudApiError {
 460    code: String,
 461    message: String,
 462    #[serde(default)]
 463    #[serde(deserialize_with = "deserialize_optional_status_code")]
 464    upstream_status: Option<StatusCode>,
 465    #[serde(default)]
 466    retry_after: Option<f64>,
 467}
 468
 469fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 470where
 471    D: serde::Deserializer<'de>,
 472{
 473    let opt: Option<u16> = Option::deserialize(deserializer)?;
 474    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 475}
 476
 477impl From<ApiError> for LanguageModelCompletionError {
 478    fn from(error: ApiError) -> Self {
 479        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 480            if cloud_error.code.starts_with("upstream_http_") {
 481                let status = if let Some(status) = cloud_error.upstream_status {
 482                    status
 483                } else if cloud_error.code.ends_with("_error") {
 484                    error.status
 485                } else {
 486                    // If there's a status code in the code string (e.g. "upstream_http_429")
 487                    // then use that; otherwise, see if the JSON contains a status code.
 488                    cloud_error
 489                        .code
 490                        .strip_prefix("upstream_http_")
 491                        .and_then(|code_str| code_str.parse::<u16>().ok())
 492                        .and_then(|code| StatusCode::from_u16(code).ok())
 493                        .unwrap_or(error.status)
 494                };
 495
 496                return LanguageModelCompletionError::UpstreamProviderError {
 497                    message: cloud_error.message,
 498                    status,
 499                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 500                };
 501            }
 502
 503            return LanguageModelCompletionError::from_http_status(
 504                PROVIDER_NAME,
 505                error.status,
 506                cloud_error.message,
 507                None,
 508            );
 509        }
 510
 511        let retry_after = None;
 512        LanguageModelCompletionError::from_http_status(
 513            PROVIDER_NAME,
 514            error.status,
 515            error.body,
 516            retry_after,
 517        )
 518    }
 519}
 520
 521impl LanguageModel for CloudLanguageModel {
 522    fn id(&self) -> LanguageModelId {
 523        self.id.clone()
 524    }
 525
 526    fn name(&self) -> LanguageModelName {
 527        LanguageModelName::from(self.model.display_name.clone())
 528    }
 529
 530    fn provider_id(&self) -> LanguageModelProviderId {
 531        PROVIDER_ID
 532    }
 533
 534    fn provider_name(&self) -> LanguageModelProviderName {
 535        PROVIDER_NAME
 536    }
 537
 538    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 539        use cloud_llm_client::LanguageModelProvider::*;
 540        match self.model.provider {
 541            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 542            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 543            Google => language_model::GOOGLE_PROVIDER_ID,
 544            XAi => language_model::X_AI_PROVIDER_ID,
 545        }
 546    }
 547
 548    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 549        use cloud_llm_client::LanguageModelProvider::*;
 550        match self.model.provider {
 551            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 552            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 553            Google => language_model::GOOGLE_PROVIDER_NAME,
 554            XAi => language_model::X_AI_PROVIDER_NAME,
 555        }
 556    }
 557
 558    fn is_latest(&self) -> bool {
 559        self.model.is_latest
 560    }
 561
 562    fn supports_tools(&self) -> bool {
 563        self.model.supports_tools
 564    }
 565
 566    fn supports_images(&self) -> bool {
 567        self.model.supports_images
 568    }
 569
 570    fn supports_thinking(&self) -> bool {
 571        self.model.supports_thinking
 572    }
 573
 574    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 575        self.model
 576            .supported_effort_levels
 577            .iter()
 578            .map(|effort_level| LanguageModelEffortLevel {
 579                name: effort_level.name.clone().into(),
 580                value: effort_level.value.clone().into(),
 581                is_default: effort_level.is_default.unwrap_or(false),
 582            })
 583            .collect()
 584    }
 585
 586    fn supports_streaming_tools(&self) -> bool {
 587        self.model.supports_streaming_tools
 588    }
 589
 590    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 591        match choice {
 592            LanguageModelToolChoice::Auto
 593            | LanguageModelToolChoice::Any
 594            | LanguageModelToolChoice::None => true,
 595        }
 596    }
 597
 598    fn supports_split_token_display(&self) -> bool {
 599        use cloud_llm_client::LanguageModelProvider::*;
 600        matches!(self.model.provider, OpenAi)
 601    }
 602
 603    fn telemetry_id(&self) -> String {
 604        format!("zed.dev/{}", self.model.id)
 605    }
 606
 607    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 608        match self.model.provider {
 609            cloud_llm_client::LanguageModelProvider::Anthropic
 610            | cloud_llm_client::LanguageModelProvider::OpenAi
 611            | cloud_llm_client::LanguageModelProvider::XAi => {
 612                LanguageModelToolSchemaFormat::JsonSchema
 613            }
 614            cloud_llm_client::LanguageModelProvider::Google => {
 615                LanguageModelToolSchemaFormat::JsonSchemaSubset
 616            }
 617        }
 618    }
 619
 620    fn max_token_count(&self) -> u64 {
 621        self.model.max_token_count as u64
 622    }
 623
 624    fn max_output_tokens(&self) -> Option<u64> {
 625        Some(self.model.max_output_tokens as u64)
 626    }
 627
 628    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 629        match &self.model.provider {
 630            cloud_llm_client::LanguageModelProvider::Anthropic => {
 631                Some(LanguageModelCacheConfiguration {
 632                    min_total_token: 2_048,
 633                    should_speculate: true,
 634                    max_cache_anchors: 4,
 635                })
 636            }
 637            cloud_llm_client::LanguageModelProvider::OpenAi
 638            | cloud_llm_client::LanguageModelProvider::XAi
 639            | cloud_llm_client::LanguageModelProvider::Google => None,
 640        }
 641    }
 642
 643    fn count_tokens(
 644        &self,
 645        request: LanguageModelRequest,
 646        cx: &App,
 647    ) -> BoxFuture<'static, Result<u64>> {
 648        match self.model.provider {
 649            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 650                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 651                .boxed(),
 652            cloud_llm_client::LanguageModelProvider::OpenAi => {
 653                let model = match open_ai::Model::from_id(&self.model.id.0) {
 654                    Ok(model) => model,
 655                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 656                };
 657                count_open_ai_tokens(request, model, cx)
 658            }
 659            cloud_llm_client::LanguageModelProvider::XAi => {
 660                let model = match x_ai::Model::from_id(&self.model.id.0) {
 661                    Ok(model) => model,
 662                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 663                };
 664                count_xai_tokens(request, model, cx)
 665            }
 666            cloud_llm_client::LanguageModelProvider::Google => {
 667                let client = self.client.clone();
 668                let llm_api_token = self.llm_api_token.clone();
 669                let model_id = self.model.id.to_string();
 670                let generate_content_request =
 671                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 672                async move {
 673                    let http_client = &client.http_client();
 674                    let token = llm_api_token.acquire(&client).await?;
 675
 676                    let request_body = CountTokensBody {
 677                        provider: cloud_llm_client::LanguageModelProvider::Google,
 678                        model: model_id,
 679                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 680                            generate_content_request,
 681                        })?,
 682                    };
 683                    let request = http_client::Request::builder()
 684                        .method(Method::POST)
 685                        .uri(
 686                            http_client
 687                                .build_zed_llm_url("/count_tokens", &[])?
 688                                .as_ref(),
 689                        )
 690                        .header("Content-Type", "application/json")
 691                        .header("Authorization", format!("Bearer {token}"))
 692                        .body(serde_json::to_string(&request_body)?.into())?;
 693                    let mut response = http_client.send(request).await?;
 694                    let status = response.status();
 695                    let headers = response.headers().clone();
 696                    let mut response_body = String::new();
 697                    response
 698                        .body_mut()
 699                        .read_to_string(&mut response_body)
 700                        .await?;
 701
 702                    if status.is_success() {
 703                        let response_body: CountTokensResponse =
 704                            serde_json::from_str(&response_body)?;
 705
 706                        Ok(response_body.tokens as u64)
 707                    } else {
 708                        Err(anyhow!(ApiError {
 709                            status,
 710                            body: response_body,
 711                            headers
 712                        }))
 713                    }
 714                }
 715                .boxed()
 716            }
 717        }
 718    }
 719
 720    fn stream_completion(
 721        &self,
 722        request: LanguageModelRequest,
 723        cx: &AsyncApp,
 724    ) -> BoxFuture<
 725        'static,
 726        Result<
 727            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 728            LanguageModelCompletionError,
 729        >,
 730    > {
 731        let thread_id = request.thread_id.clone();
 732        let prompt_id = request.prompt_id.clone();
 733        let intent = request.intent;
 734        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 735        let thinking_allowed = request.thinking_allowed;
 736        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 737        let provider_name = provider_name(&self.model.provider);
 738        match self.model.provider {
 739            cloud_llm_client::LanguageModelProvider::Anthropic => {
 740                let effort = request
 741                    .thinking_effort
 742                    .as_ref()
 743                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 744
 745                let mut request = into_anthropic(
 746                    request,
 747                    self.model.id.to_string(),
 748                    1.0,
 749                    self.model.max_output_tokens as u64,
 750                    if enable_thinking {
 751                        AnthropicModelMode::Thinking {
 752                            budget_tokens: Some(4_096),
 753                        }
 754                    } else {
 755                        AnthropicModelMode::Default
 756                    },
 757                );
 758
 759                if enable_thinking && effort.is_some() {
 760                    request.thinking = Some(anthropic::Thinking::Adaptive);
 761                    request.output_config = Some(anthropic::OutputConfig { effort });
 762                }
 763
 764                let client = self.client.clone();
 765                let llm_api_token = self.llm_api_token.clone();
 766                let future = self.request_limiter.stream(async move {
 767                    let PerformLlmCompletionResponse {
 768                        response,
 769                        includes_status_messages,
 770                    } = Self::perform_llm_completion(
 771                        client.clone(),
 772                        llm_api_token,
 773                        app_version,
 774                        CompletionBody {
 775                            thread_id,
 776                            prompt_id,
 777                            intent,
 778                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 779                            model: request.model.clone(),
 780                            provider_request: serde_json::to_value(&request)
 781                                .map_err(|e| anyhow!(e))?,
 782                        },
 783                    )
 784                    .await
 785                    .map_err(|err| match err.downcast::<ApiError>() {
 786                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 787                        Err(err) => anyhow!(err),
 788                    })?;
 789
 790                    let mut mapper = AnthropicEventMapper::new();
 791                    Ok(map_cloud_completion_events(
 792                        Box::pin(response_lines(response, includes_status_messages)),
 793                        &provider_name,
 794                        move |event| mapper.map_event(event),
 795                    ))
 796                });
 797                async move { Ok(future.await?.boxed()) }.boxed()
 798            }
 799            cloud_llm_client::LanguageModelProvider::OpenAi => {
 800                let client = self.client.clone();
 801                let llm_api_token = self.llm_api_token.clone();
 802                let effort = request
 803                    .thinking_effort
 804                    .as_ref()
 805                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 806
 807                let mut request = into_open_ai_response(
 808                    request,
 809                    &self.model.id.0,
 810                    self.model.supports_parallel_tool_calls,
 811                    true,
 812                    None,
 813                    None,
 814                );
 815
 816                if enable_thinking && let Some(effort) = effort {
 817                    request.reasoning = Some(open_ai::responses::ReasoningConfig { effort });
 818                }
 819
 820                let future = self.request_limiter.stream(async move {
 821                    let PerformLlmCompletionResponse {
 822                        response,
 823                        includes_status_messages,
 824                    } = Self::perform_llm_completion(
 825                        client.clone(),
 826                        llm_api_token,
 827                        app_version,
 828                        CompletionBody {
 829                            thread_id,
 830                            prompt_id,
 831                            intent,
 832                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 833                            model: request.model.clone(),
 834                            provider_request: serde_json::to_value(&request)
 835                                .map_err(|e| anyhow!(e))?,
 836                        },
 837                    )
 838                    .await?;
 839
 840                    let mut mapper = OpenAiResponseEventMapper::new();
 841                    Ok(map_cloud_completion_events(
 842                        Box::pin(response_lines(response, includes_status_messages)),
 843                        &provider_name,
 844                        move |event| mapper.map_event(event),
 845                    ))
 846                });
 847                async move { Ok(future.await?.boxed()) }.boxed()
 848            }
 849            cloud_llm_client::LanguageModelProvider::XAi => {
 850                let client = self.client.clone();
 851                let request = into_open_ai(
 852                    request,
 853                    &self.model.id.0,
 854                    self.model.supports_parallel_tool_calls,
 855                    false,
 856                    None,
 857                    None,
 858                );
 859                let llm_api_token = self.llm_api_token.clone();
 860                let future = self.request_limiter.stream(async move {
 861                    let PerformLlmCompletionResponse {
 862                        response,
 863                        includes_status_messages,
 864                    } = Self::perform_llm_completion(
 865                        client.clone(),
 866                        llm_api_token,
 867                        app_version,
 868                        CompletionBody {
 869                            thread_id,
 870                            prompt_id,
 871                            intent,
 872                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 873                            model: request.model.clone(),
 874                            provider_request: serde_json::to_value(&request)
 875                                .map_err(|e| anyhow!(e))?,
 876                        },
 877                    )
 878                    .await?;
 879
 880                    let mut mapper = OpenAiEventMapper::new();
 881                    Ok(map_cloud_completion_events(
 882                        Box::pin(response_lines(response, includes_status_messages)),
 883                        &provider_name,
 884                        move |event| mapper.map_event(event),
 885                    ))
 886                });
 887                async move { Ok(future.await?.boxed()) }.boxed()
 888            }
 889            cloud_llm_client::LanguageModelProvider::Google => {
 890                let client = self.client.clone();
 891                let request =
 892                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 893                let llm_api_token = self.llm_api_token.clone();
 894                let future = self.request_limiter.stream(async move {
 895                    let PerformLlmCompletionResponse {
 896                        response,
 897                        includes_status_messages,
 898                    } = Self::perform_llm_completion(
 899                        client.clone(),
 900                        llm_api_token,
 901                        app_version,
 902                        CompletionBody {
 903                            thread_id,
 904                            prompt_id,
 905                            intent,
 906                            provider: cloud_llm_client::LanguageModelProvider::Google,
 907                            model: request.model.model_id.clone(),
 908                            provider_request: serde_json::to_value(&request)
 909                                .map_err(|e| anyhow!(e))?,
 910                        },
 911                    )
 912                    .await?;
 913
 914                    let mut mapper = GoogleEventMapper::new();
 915                    Ok(map_cloud_completion_events(
 916                        Box::pin(response_lines(response, includes_status_messages)),
 917                        &provider_name,
 918                        move |event| mapper.map_event(event),
 919                    ))
 920                });
 921                async move { Ok(future.await?.boxed()) }.boxed()
 922            }
 923        }
 924    }
 925}
 926
 927fn map_cloud_completion_events<T, F>(
 928    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 929    provider: &LanguageModelProviderName,
 930    mut map_callback: F,
 931) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 932where
 933    T: DeserializeOwned + 'static,
 934    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 935        + Send
 936        + 'static,
 937{
 938    let provider = provider.clone();
 939    let mut stream = stream.fuse();
 940
 941    // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
 942    // let mut saw_stream_ended = false;
 943
 944    let mut done = false;
 945    let mut pending = VecDeque::new();
 946
 947    stream::poll_fn(move |cx| {
 948        loop {
 949            if let Some(item) = pending.pop_front() {
 950                return Poll::Ready(Some(item));
 951            }
 952
 953            if done {
 954                return Poll::Ready(None);
 955            }
 956
 957            match stream.poll_next_unpin(cx) {
 958                Poll::Ready(Some(event)) => {
 959                    let items = match event {
 960                        Err(error) => {
 961                            vec![Err(LanguageModelCompletionError::from(error))]
 962                        }
 963                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
 964                            // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
 965                            // let mut saw_stream_ended = false;
 966                            //
 967                            // saw_stream_ended = true;
 968                            vec![]
 969                        }
 970                        Ok(CompletionEvent::Status(status)) => {
 971                            LanguageModelCompletionEvent::from_completion_request_status(
 972                                status,
 973                                provider.clone(),
 974                            )
 975                            .transpose()
 976                            .map(|event| vec![event])
 977                            .unwrap_or_default()
 978                        }
 979                        Ok(CompletionEvent::Event(event)) => map_callback(event),
 980                    };
 981                    pending.extend(items);
 982                }
 983                Poll::Ready(None) => {
 984                    done = true;
 985
 986                    // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
 987                    //
 988                    // if !saw_stream_ended {
 989                    //     return Poll::Ready(Some(Err(
 990                    //         LanguageModelCompletionError::StreamEndedUnexpectedly {
 991                    //             provider: provider.clone(),
 992                    //         },
 993                    //     )));
 994                    // }
 995                }
 996                Poll::Pending => return Poll::Pending,
 997            }
 998        }
 999    })
1000    .boxed()
1001}
1002
1003fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1004    match provider {
1005        cloud_llm_client::LanguageModelProvider::Anthropic => {
1006            language_model::ANTHROPIC_PROVIDER_NAME
1007        }
1008        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1009        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1010        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1011    }
1012}
1013
1014fn response_lines<T: DeserializeOwned>(
1015    response: Response<AsyncBody>,
1016    includes_status_messages: bool,
1017) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1018    futures::stream::try_unfold(
1019        (String::new(), BufReader::new(response.into_body())),
1020        move |(mut line, mut body)| async move {
1021            match body.read_line(&mut line).await {
1022                Ok(0) => Ok(None),
1023                Ok(_) => {
1024                    let event = if includes_status_messages {
1025                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1026                    } else {
1027                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1028                    };
1029
1030                    line.clear();
1031                    Ok(Some((event, (line, body))))
1032                }
1033                Err(e) => Err(e.into()),
1034            }
1035        },
1036    )
1037}
1038
1039#[derive(IntoElement, RegisterComponent)]
1040struct ZedAiConfiguration {
1041    is_connected: bool,
1042    plan: Option<Plan>,
1043    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1044    eligible_for_trial: bool,
1045    account_too_young: bool,
1046    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1047}
1048
1049impl RenderOnce for ZedAiConfiguration {
1050    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1051        let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
1052        let subscription_text = match (self.plan, self.subscription_period) {
1053            (Some(Plan::ZedPro), Some(_)) => {
1054                "You have access to Zed's hosted models through your Pro subscription."
1055            }
1056            (Some(Plan::ZedProTrial), Some(_)) => {
1057                "You have access to Zed's hosted models through your Pro trial."
1058            }
1059            (Some(Plan::ZedFree), Some(_)) => {
1060                if self.eligible_for_trial {
1061                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1062                } else {
1063                    "Subscribe for access to Zed's hosted models."
1064                }
1065            }
1066            _ => {
1067                if self.eligible_for_trial {
1068                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1069                } else {
1070                    "Subscribe for access to Zed's hosted models."
1071                }
1072            }
1073        };
1074
1075        let manage_subscription_buttons = if is_pro {
1076            Button::new("manage_settings", "Manage Subscription")
1077                .full_width()
1078                .style(ButtonStyle::Tinted(TintColor::Accent))
1079                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1080                .into_any_element()
1081        } else if self.plan.is_none() || self.eligible_for_trial {
1082            Button::new("start_trial", "Start 14-day Free Pro Trial")
1083                .full_width()
1084                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1085                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1086                .into_any_element()
1087        } else {
1088            Button::new("upgrade", "Upgrade to Pro")
1089                .full_width()
1090                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1091                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1092                .into_any_element()
1093        };
1094
1095        if !self.is_connected {
1096            return v_flex()
1097                .gap_2()
1098                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1099                .child(
1100                    Button::new("sign_in", "Sign In to use Zed AI")
1101                        .icon_color(Color::Muted)
1102                        .icon(IconName::Github)
1103                        .icon_size(IconSize::Small)
1104                        .icon_position(IconPosition::Start)
1105                        .full_width()
1106                        .on_click({
1107                            let callback = self.sign_in_callback.clone();
1108                            move |_, window, cx| (callback)(window, cx)
1109                        }),
1110                );
1111        }
1112
1113        v_flex().gap_2().w_full().map(|this| {
1114            if self.account_too_young {
1115                this.child(YoungAccountBanner).child(
1116                    Button::new("upgrade", "Upgrade to Pro")
1117                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1118                        .full_width()
1119                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1120                )
1121            } else {
1122                this.text_sm()
1123                    .child(subscription_text)
1124                    .child(manage_subscription_buttons)
1125            }
1126        })
1127    }
1128}
1129
1130struct ConfigurationView {
1131    state: Entity<State>,
1132    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1133}
1134
1135impl ConfigurationView {
1136    fn new(state: Entity<State>) -> Self {
1137        let sign_in_callback = Arc::new({
1138            let state = state.clone();
1139            move |_window: &mut Window, cx: &mut App| {
1140                state.update(cx, |state, cx| {
1141                    state.authenticate(cx).detach_and_log_err(cx);
1142                });
1143            }
1144        });
1145
1146        Self {
1147            state,
1148            sign_in_callback,
1149        }
1150    }
1151}
1152
1153impl Render for ConfigurationView {
1154    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1155        let state = self.state.read(cx);
1156        let user_store = state.user_store.read(cx);
1157
1158        ZedAiConfiguration {
1159            is_connected: !state.is_signed_out(cx),
1160            plan: user_store.plan(),
1161            subscription_period: user_store.subscription_period(),
1162            eligible_for_trial: user_store.trial_started_at().is_none(),
1163            account_too_young: user_store.account_too_young(),
1164            sign_in_callback: self.sign_in_callback.clone(),
1165        }
1166    }
1167}
1168
1169impl Component for ZedAiConfiguration {
1170    fn name() -> &'static str {
1171        "AI Configuration Content"
1172    }
1173
1174    fn sort_name() -> &'static str {
1175        "AI Configuration Content"
1176    }
1177
1178    fn scope() -> ComponentScope {
1179        ComponentScope::Onboarding
1180    }
1181
1182    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1183        fn configuration(
1184            is_connected: bool,
1185            plan: Option<Plan>,
1186            eligible_for_trial: bool,
1187            account_too_young: bool,
1188        ) -> AnyElement {
1189            ZedAiConfiguration {
1190                is_connected,
1191                plan,
1192                subscription_period: plan
1193                    .is_some()
1194                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1195                eligible_for_trial,
1196                account_too_young,
1197                sign_in_callback: Arc::new(|_, _| {}),
1198            }
1199            .into_any_element()
1200        }
1201
1202        Some(
1203            v_flex()
1204                .p_4()
1205                .gap_4()
1206                .children(vec![
1207                    single_example("Not connected", configuration(false, None, false, false)),
1208                    single_example(
1209                        "Accept Terms of Service",
1210                        configuration(true, None, true, false),
1211                    ),
1212                    single_example(
1213                        "No Plan - Not eligible for trial",
1214                        configuration(true, None, false, false),
1215                    ),
1216                    single_example(
1217                        "No Plan - Eligible for trial",
1218                        configuration(true, None, true, false),
1219                    ),
1220                    single_example(
1221                        "Free Plan",
1222                        configuration(true, Some(Plan::ZedFree), true, false),
1223                    ),
1224                    single_example(
1225                        "Zed Pro Trial Plan",
1226                        configuration(true, Some(Plan::ZedProTrial), true, false),
1227                    ),
1228                    single_example(
1229                        "Zed Pro Plan",
1230                        configuration(true, Some(Plan::ZedPro), true, false),
1231                    ),
1232                ])
1233                .into_any_element(),
1234        )
1235    }
1236}
1237
1238#[cfg(test)]
1239mod tests {
1240    use super::*;
1241    use http_client::http::{HeaderMap, StatusCode};
1242    use language_model::LanguageModelCompletionError;
1243
1244    #[test]
1245    fn test_api_error_conversion_with_upstream_http_error() {
1246        // upstream_http_error with 503 status should become ServerOverloaded
1247        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1248
1249        let api_error = ApiError {
1250            status: StatusCode::INTERNAL_SERVER_ERROR,
1251            body: error_body.to_string(),
1252            headers: HeaderMap::new(),
1253        };
1254
1255        let completion_error: LanguageModelCompletionError = api_error.into();
1256
1257        match completion_error {
1258            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1259                assert_eq!(
1260                    message,
1261                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1262                );
1263            }
1264            _ => panic!(
1265                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1266                completion_error
1267            ),
1268        }
1269
1270        // upstream_http_error with 500 status should become ApiInternalServerError
1271        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1272
1273        let api_error = ApiError {
1274            status: StatusCode::INTERNAL_SERVER_ERROR,
1275            body: error_body.to_string(),
1276            headers: HeaderMap::new(),
1277        };
1278
1279        let completion_error: LanguageModelCompletionError = api_error.into();
1280
1281        match completion_error {
1282            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1283                assert_eq!(
1284                    message,
1285                    "Received an error from the OpenAI API: internal server error"
1286                );
1287            }
1288            _ => panic!(
1289                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1290                completion_error
1291            ),
1292        }
1293
1294        // upstream_http_error with 429 status should become RateLimitExceeded
1295        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1296
1297        let api_error = ApiError {
1298            status: StatusCode::INTERNAL_SERVER_ERROR,
1299            body: error_body.to_string(),
1300            headers: HeaderMap::new(),
1301        };
1302
1303        let completion_error: LanguageModelCompletionError = api_error.into();
1304
1305        match completion_error {
1306            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1307                assert_eq!(
1308                    message,
1309                    "Received an error from the Google API: rate limit exceeded"
1310                );
1311            }
1312            _ => panic!(
1313                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1314                completion_error
1315            ),
1316        }
1317
1318        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1319        let error_body = "Regular internal server error";
1320
1321        let api_error = ApiError {
1322            status: StatusCode::INTERNAL_SERVER_ERROR,
1323            body: error_body.to_string(),
1324            headers: HeaderMap::new(),
1325        };
1326
1327        let completion_error: LanguageModelCompletionError = api_error.into();
1328
1329        match completion_error {
1330            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1331                assert_eq!(provider, PROVIDER_NAME);
1332                assert_eq!(message, "Regular internal server error");
1333            }
1334            _ => panic!(
1335                "Expected ApiInternalServerError for regular 500, got: {:?}",
1336                completion_error
1337            ),
1338        }
1339
1340        // upstream_http_429 format should be converted to UpstreamProviderError
1341        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1342
1343        let api_error = ApiError {
1344            status: StatusCode::INTERNAL_SERVER_ERROR,
1345            body: error_body.to_string(),
1346            headers: HeaderMap::new(),
1347        };
1348
1349        let completion_error: LanguageModelCompletionError = api_error.into();
1350
1351        match completion_error {
1352            LanguageModelCompletionError::UpstreamProviderError {
1353                message,
1354                status,
1355                retry_after,
1356            } => {
1357                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1358                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1359                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1360            }
1361            _ => panic!(
1362                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1363                completion_error
1364            ),
1365        }
1366
1367        // Invalid JSON in error body should fall back to regular error handling
1368        let error_body = "Not JSON at all";
1369
1370        let api_error = ApiError {
1371            status: StatusCode::INTERNAL_SERVER_ERROR,
1372            body: error_body.to_string(),
1373            headers: HeaderMap::new(),
1374        };
1375
1376        let completion_error: LanguageModelCompletionError = api_error.into();
1377
1378        match completion_error {
1379            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1380                assert_eq!(provider, PROVIDER_NAME);
1381            }
1382            _ => panic!(
1383                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1384                completion_error
1385            ),
1386        }
1387    }
1388}