cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   6use cloud_llm_client::{
   7    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
   8    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
   9    EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
  10    PlanV1, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
  11    SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME,
  12    ZED_VERSION_HEADER_NAME,
  13};
  14use feature_flags::{BillingV2FeatureFlag, FeatureFlagAppExt};
  15use futures::{
  16    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{
  20    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  21};
  22use http_client::http::{HeaderMap, HeaderValue};
  23use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  24use language_model::{
  25    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  26    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  27    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  28    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  29    LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
  30    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  31};
  32use release_channel::AppVersion;
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize, de::DeserializeOwned};
  35use settings::SettingsStore;
  36use smol::io::{AsyncReadExt, BufReader};
  37use std::pin::Pin;
  38use std::str::FromStr as _;
  39use std::sync::Arc;
  40use std::time::Duration;
  41use thiserror::Error;
  42use ui::{TintColor, prelude::*};
  43use util::{ResultExt as _, maybe};
  44
  45use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  46use crate::provider::google::{GoogleEventMapper, into_google};
  47use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  48
  49const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  50const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  51
  52#[derive(Default, Clone, Debug, PartialEq)]
  53pub struct ZedDotDevSettings {
  54    pub available_models: Vec<AvailableModel>,
  55}
  56
  57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  58#[serde(rename_all = "lowercase")]
  59pub enum AvailableProvider {
  60    Anthropic,
  61    OpenAi,
  62    Google,
  63}
  64
  65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66pub struct AvailableModel {
  67    /// The provider of the language model.
  68    pub provider: AvailableProvider,
  69    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  70    pub name: String,
  71    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  72    pub display_name: Option<String>,
  73    /// The size of the context window, indicating the maximum number of tokens the model can process.
  74    pub max_tokens: usize,
  75    /// The maximum number of output tokens allowed by the model.
  76    pub max_output_tokens: Option<u64>,
  77    /// The maximum number of completion tokens allowed by the model (o1-* only)
  78    pub max_completion_tokens: Option<u64>,
  79    /// Override this model with a different Anthropic model for tool calls.
  80    pub tool_override: Option<String>,
  81    /// Indicates whether this custom model supports caching.
  82    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  83    /// The default temperature to use for this model.
  84    pub default_temperature: Option<f32>,
  85    /// Any extra beta headers to provide when using the model.
  86    #[serde(default)]
  87    pub extra_beta_headers: Vec<String>,
  88    /// The model's mode (e.g. thinking)
  89    pub mode: Option<ModelMode>,
  90}
  91
  92#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  93#[serde(tag = "type", rename_all = "lowercase")]
  94pub enum ModelMode {
  95    #[default]
  96    Default,
  97    Thinking {
  98        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  99        budget_tokens: Option<u32>,
 100    },
 101}
 102
 103impl From<ModelMode> for AnthropicModelMode {
 104    fn from(value: ModelMode) -> Self {
 105        match value {
 106            ModelMode::Default => AnthropicModelMode::Default,
 107            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 108        }
 109    }
 110}
 111
 112pub struct CloudLanguageModelProvider {
 113    client: Arc<Client>,
 114    state: gpui::Entity<State>,
 115    _maintain_client_status: Task<()>,
 116}
 117
 118pub struct State {
 119    client: Arc<Client>,
 120    llm_api_token: LlmApiToken,
 121    user_store: Entity<UserStore>,
 122    status: client::Status,
 123    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 124    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 125    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 126    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 127    _fetch_models_task: Task<()>,
 128    _settings_subscription: Subscription,
 129    _llm_token_subscription: Subscription,
 130}
 131
 132impl State {
 133    fn new(
 134        client: Arc<Client>,
 135        user_store: Entity<UserStore>,
 136        status: client::Status,
 137        cx: &mut Context<Self>,
 138    ) -> Self {
 139        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 140        let mut current_user = user_store.read(cx).watch_current_user();
 141        Self {
 142            client: client.clone(),
 143            llm_api_token: LlmApiToken::default(),
 144            user_store,
 145            status,
 146            models: Vec::new(),
 147            default_model: None,
 148            default_fast_model: None,
 149            recommended_models: Vec::new(),
 150            _fetch_models_task: cx.spawn(async move |this, cx| {
 151                maybe!(async move {
 152                    let (client, llm_api_token) = this
 153                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 154
 155                    while current_user.borrow().is_none() {
 156                        current_user.next().await;
 157                    }
 158
 159                    let response =
 160                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 161                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 162                    anyhow::Ok(())
 163                })
 164                .await
 165                .context("failed to fetch Zed models")
 166                .log_err();
 167            }),
 168            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 169                cx.notify();
 170            }),
 171            _llm_token_subscription: cx.subscribe(
 172                &refresh_llm_token_listener,
 173                move |this, _listener, _event, cx| {
 174                    let client = this.client.clone();
 175                    let llm_api_token = this.llm_api_token.clone();
 176                    cx.spawn(async move |this, cx| {
 177                        llm_api_token.refresh(&client).await?;
 178                        let response = Self::fetch_models(client, llm_api_token).await?;
 179                        this.update(cx, |this, cx| {
 180                            this.update_models(response, cx);
 181                        })
 182                    })
 183                    .detach_and_log_err(cx);
 184                },
 185            ),
 186        }
 187    }
 188
 189    fn is_signed_out(&self, cx: &App) -> bool {
 190        self.user_store.read(cx).current_user().is_none()
 191    }
 192
 193    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 194        let client = self.client.clone();
 195        cx.spawn(async move |state, cx| {
 196            client.sign_in_with_optional_connect(true, cx).await?;
 197            state.update(cx, |_, cx| cx.notify())
 198        })
 199    }
 200    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 201        let mut models = Vec::new();
 202
 203        for model in response.models {
 204            models.push(Arc::new(model.clone()));
 205
 206            // Right now we represent thinking variants of models as separate models on the client,
 207            // so we need to insert variants for any model that supports thinking.
 208            if model.supports_thinking {
 209                models.push(Arc::new(cloud_llm_client::LanguageModel {
 210                    id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 211                    display_name: format!("{} Thinking", model.display_name),
 212                    ..model
 213                }));
 214            }
 215        }
 216
 217        self.default_model = models
 218            .iter()
 219            .find(|model| {
 220                response
 221                    .default_model
 222                    .as_ref()
 223                    .is_some_and(|default_model_id| &model.id == default_model_id)
 224            })
 225            .cloned();
 226        self.default_fast_model = models
 227            .iter()
 228            .find(|model| {
 229                response
 230                    .default_fast_model
 231                    .as_ref()
 232                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 233            })
 234            .cloned();
 235        self.recommended_models = response
 236            .recommended_models
 237            .iter()
 238            .filter_map(|id| models.iter().find(|model| &model.id == id))
 239            .cloned()
 240            .collect();
 241        self.models = models;
 242        cx.notify();
 243    }
 244
 245    async fn fetch_models(
 246        client: Arc<Client>,
 247        llm_api_token: LlmApiToken,
 248    ) -> Result<ListModelsResponse> {
 249        let http_client = &client.http_client();
 250        let token = llm_api_token.acquire(&client).await?;
 251
 252        let request = http_client::Request::builder()
 253            .method(Method::GET)
 254            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 255            .header("Authorization", format!("Bearer {token}"))
 256            .body(AsyncBody::empty())?;
 257        let mut response = http_client
 258            .send(request)
 259            .await
 260            .context("failed to send list models request")?;
 261
 262        if response.status().is_success() {
 263            let mut body = String::new();
 264            response.body_mut().read_to_string(&mut body).await?;
 265            Ok(serde_json::from_str(&body)?)
 266        } else {
 267            let mut body = String::new();
 268            response.body_mut().read_to_string(&mut body).await?;
 269            anyhow::bail!(
 270                "error listing models.\nStatus: {:?}\nBody: {body}",
 271                response.status(),
 272            );
 273        }
 274    }
 275}
 276
 277impl CloudLanguageModelProvider {
 278    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 279        let mut status_rx = client.status();
 280        let status = *status_rx.borrow();
 281
 282        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 283
 284        let state_ref = state.downgrade();
 285        let maintain_client_status = cx.spawn(async move |cx| {
 286            while let Some(status) = status_rx.next().await {
 287                if let Some(this) = state_ref.upgrade() {
 288                    _ = this.update(cx, |this, cx| {
 289                        if this.status != status {
 290                            this.status = status;
 291                            cx.notify();
 292                        }
 293                    });
 294                } else {
 295                    break;
 296                }
 297            }
 298        });
 299
 300        Self {
 301            client,
 302            state,
 303            _maintain_client_status: maintain_client_status,
 304        }
 305    }
 306
 307    fn create_language_model(
 308        &self,
 309        model: Arc<cloud_llm_client::LanguageModel>,
 310        llm_api_token: LlmApiToken,
 311    ) -> Arc<dyn LanguageModel> {
 312        Arc::new(CloudLanguageModel {
 313            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 314            model,
 315            llm_api_token,
 316            client: self.client.clone(),
 317            request_limiter: RateLimiter::new(4),
 318        })
 319    }
 320}
 321
 322impl LanguageModelProviderState for CloudLanguageModelProvider {
 323    type ObservableEntity = State;
 324
 325    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 326        Some(self.state.clone())
 327    }
 328}
 329
 330impl LanguageModelProvider for CloudLanguageModelProvider {
 331    fn id(&self) -> LanguageModelProviderId {
 332        PROVIDER_ID
 333    }
 334
 335    fn name(&self) -> LanguageModelProviderName {
 336        PROVIDER_NAME
 337    }
 338
 339    fn icon(&self) -> IconName {
 340        IconName::AiZed
 341    }
 342
 343    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 344        let default_model = self.state.read(cx).default_model.clone()?;
 345        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 346        Some(self.create_language_model(default_model, llm_api_token))
 347    }
 348
 349    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 350        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 351        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 352        Some(self.create_language_model(default_fast_model, llm_api_token))
 353    }
 354
 355    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 356        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 357        self.state
 358            .read(cx)
 359            .recommended_models
 360            .iter()
 361            .cloned()
 362            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 363            .collect()
 364    }
 365
 366    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 367        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 368        self.state
 369            .read(cx)
 370            .models
 371            .iter()
 372            .cloned()
 373            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 374            .collect()
 375    }
 376
 377    fn is_authenticated(&self, cx: &App) -> bool {
 378        let state = self.state.read(cx);
 379        !state.is_signed_out(cx)
 380    }
 381
 382    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 383        Task::ready(Ok(()))
 384    }
 385
 386    fn configuration_view(
 387        &self,
 388        _target_agent: language_model::ConfigurationViewTargetAgent,
 389        _: &mut Window,
 390        cx: &mut App,
 391    ) -> AnyView {
 392        cx.new(|_| ConfigurationView::new(self.state.clone()))
 393            .into()
 394    }
 395
 396    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 397        Task::ready(Ok(()))
 398    }
 399}
 400
 401pub struct CloudLanguageModel {
 402    id: LanguageModelId,
 403    model: Arc<cloud_llm_client::LanguageModel>,
 404    llm_api_token: LlmApiToken,
 405    client: Arc<Client>,
 406    request_limiter: RateLimiter,
 407}
 408
 409struct PerformLlmCompletionResponse {
 410    response: Response<AsyncBody>,
 411    usage: Option<ModelRequestUsage>,
 412    tool_use_limit_reached: bool,
 413    includes_status_messages: bool,
 414}
 415
 416impl CloudLanguageModel {
 417    async fn perform_llm_completion(
 418        client: Arc<Client>,
 419        llm_api_token: LlmApiToken,
 420        app_version: Option<SemanticVersion>,
 421        body: CompletionBody,
 422    ) -> Result<PerformLlmCompletionResponse> {
 423        let http_client = &client.http_client();
 424
 425        let mut token = llm_api_token.acquire(&client).await?;
 426        let mut refreshed_token = false;
 427
 428        loop {
 429            let request_builder = http_client::Request::builder()
 430                .method(Method::POST)
 431                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 432            let request_builder = if let Some(app_version) = app_version {
 433                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 434            } else {
 435                request_builder
 436            };
 437
 438            let request = request_builder
 439                .header("Content-Type", "application/json")
 440                .header("Authorization", format!("Bearer {token}"))
 441                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 442                .body(serde_json::to_string(&body)?.into())?;
 443            let mut response = http_client.send(request).await?;
 444            let status = response.status();
 445            if status.is_success() {
 446                let includes_status_messages = response
 447                    .headers()
 448                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 449                    .is_some();
 450
 451                let tool_use_limit_reached = response
 452                    .headers()
 453                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 454                    .is_some();
 455
 456                let usage = if includes_status_messages {
 457                    None
 458                } else {
 459                    ModelRequestUsage::from_headers(response.headers()).ok()
 460                };
 461
 462                return Ok(PerformLlmCompletionResponse {
 463                    response,
 464                    usage,
 465                    includes_status_messages,
 466                    tool_use_limit_reached,
 467                });
 468            }
 469
 470            if !refreshed_token
 471                && response
 472                    .headers()
 473                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 474                    .is_some()
 475            {
 476                token = llm_api_token.refresh(&client).await?;
 477                refreshed_token = true;
 478                continue;
 479            }
 480
 481            if status == StatusCode::FORBIDDEN
 482                && response
 483                    .headers()
 484                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 485                    .is_some()
 486            {
 487                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 488                    .headers()
 489                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 490                    .and_then(|resource| resource.to_str().ok())
 491                    && let Some(plan) = response
 492                        .headers()
 493                        .get(CURRENT_PLAN_HEADER_NAME)
 494                        .and_then(|plan| plan.to_str().ok())
 495                        .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
 496                        .map(Plan::V1)
 497                {
 498                    return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 499                }
 500            } else if status == StatusCode::PAYMENT_REQUIRED {
 501                return Err(anyhow!(PaymentRequiredError));
 502            }
 503
 504            let mut body = String::new();
 505            let headers = response.headers().clone();
 506            response.body_mut().read_to_string(&mut body).await?;
 507            return Err(anyhow!(ApiError {
 508                status,
 509                body,
 510                headers
 511            }));
 512        }
 513    }
 514}
 515
 516#[derive(Debug, Error)]
 517#[error("cloud language model request failed with status {status}: {body}")]
 518struct ApiError {
 519    status: StatusCode,
 520    body: String,
 521    headers: HeaderMap<HeaderValue>,
 522}
 523
 524/// Represents error responses from Zed's cloud API.
 525///
 526/// Example JSON for an upstream HTTP error:
 527/// ```json
 528/// {
 529///   "code": "upstream_http_error",
 530///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 531///   "upstream_status": 503
 532/// }
 533/// ```
 534#[derive(Debug, serde::Deserialize)]
 535struct CloudApiError {
 536    code: String,
 537    message: String,
 538    #[serde(default)]
 539    #[serde(deserialize_with = "deserialize_optional_status_code")]
 540    upstream_status: Option<StatusCode>,
 541    #[serde(default)]
 542    retry_after: Option<f64>,
 543}
 544
 545fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 546where
 547    D: serde::Deserializer<'de>,
 548{
 549    let opt: Option<u16> = Option::deserialize(deserializer)?;
 550    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 551}
 552
 553impl From<ApiError> for LanguageModelCompletionError {
 554    fn from(error: ApiError) -> Self {
 555        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 556            if cloud_error.code.starts_with("upstream_http_") {
 557                let status = if let Some(status) = cloud_error.upstream_status {
 558                    status
 559                } else if cloud_error.code.ends_with("_error") {
 560                    error.status
 561                } else {
 562                    // If there's a status code in the code string (e.g. "upstream_http_429")
 563                    // then use that; otherwise, see if the JSON contains a status code.
 564                    cloud_error
 565                        .code
 566                        .strip_prefix("upstream_http_")
 567                        .and_then(|code_str| code_str.parse::<u16>().ok())
 568                        .and_then(|code| StatusCode::from_u16(code).ok())
 569                        .unwrap_or(error.status)
 570                };
 571
 572                return LanguageModelCompletionError::UpstreamProviderError {
 573                    message: cloud_error.message,
 574                    status,
 575                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 576                };
 577            }
 578
 579            return LanguageModelCompletionError::from_http_status(
 580                PROVIDER_NAME,
 581                error.status,
 582                cloud_error.message,
 583                None,
 584            );
 585        }
 586
 587        let retry_after = None;
 588        LanguageModelCompletionError::from_http_status(
 589            PROVIDER_NAME,
 590            error.status,
 591            error.body,
 592            retry_after,
 593        )
 594    }
 595}
 596
 597impl LanguageModel for CloudLanguageModel {
 598    fn id(&self) -> LanguageModelId {
 599        self.id.clone()
 600    }
 601
 602    fn name(&self) -> LanguageModelName {
 603        LanguageModelName::from(self.model.display_name.clone())
 604    }
 605
 606    fn provider_id(&self) -> LanguageModelProviderId {
 607        PROVIDER_ID
 608    }
 609
 610    fn provider_name(&self) -> LanguageModelProviderName {
 611        PROVIDER_NAME
 612    }
 613
 614    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 615        use cloud_llm_client::LanguageModelProvider::*;
 616        match self.model.provider {
 617            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 618            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 619            Google => language_model::GOOGLE_PROVIDER_ID,
 620        }
 621    }
 622
 623    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 624        use cloud_llm_client::LanguageModelProvider::*;
 625        match self.model.provider {
 626            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 627            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 628            Google => language_model::GOOGLE_PROVIDER_NAME,
 629        }
 630    }
 631
 632    fn supports_tools(&self) -> bool {
 633        self.model.supports_tools
 634    }
 635
 636    fn supports_images(&self) -> bool {
 637        self.model.supports_images
 638    }
 639
 640    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 641        match choice {
 642            LanguageModelToolChoice::Auto
 643            | LanguageModelToolChoice::Any
 644            | LanguageModelToolChoice::None => true,
 645        }
 646    }
 647
 648    fn supports_burn_mode(&self) -> bool {
 649        self.model.supports_max_mode
 650    }
 651
 652    fn telemetry_id(&self) -> String {
 653        format!("zed.dev/{}", self.model.id)
 654    }
 655
 656    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 657        match self.model.provider {
 658            cloud_llm_client::LanguageModelProvider::Anthropic
 659            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 660                LanguageModelToolSchemaFormat::JsonSchema
 661            }
 662            cloud_llm_client::LanguageModelProvider::Google => {
 663                LanguageModelToolSchemaFormat::JsonSchemaSubset
 664            }
 665        }
 666    }
 667
 668    fn max_token_count(&self) -> u64 {
 669        self.model.max_token_count as u64
 670    }
 671
 672    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 673        self.model
 674            .max_token_count_in_max_mode
 675            .filter(|_| self.model.supports_max_mode)
 676            .map(|max_token_count| max_token_count as u64)
 677    }
 678
 679    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 680        match &self.model.provider {
 681            cloud_llm_client::LanguageModelProvider::Anthropic => {
 682                Some(LanguageModelCacheConfiguration {
 683                    min_total_token: 2_048,
 684                    should_speculate: true,
 685                    max_cache_anchors: 4,
 686                })
 687            }
 688            cloud_llm_client::LanguageModelProvider::OpenAi
 689            | cloud_llm_client::LanguageModelProvider::Google => None,
 690        }
 691    }
 692
 693    fn count_tokens(
 694        &self,
 695        request: LanguageModelRequest,
 696        cx: &App,
 697    ) -> BoxFuture<'static, Result<u64>> {
 698        match self.model.provider {
 699            cloud_llm_client::LanguageModelProvider::Anthropic => {
 700                count_anthropic_tokens(request, cx)
 701            }
 702            cloud_llm_client::LanguageModelProvider::OpenAi => {
 703                let model = match open_ai::Model::from_id(&self.model.id.0) {
 704                    Ok(model) => model,
 705                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 706                };
 707                count_open_ai_tokens(request, model, cx)
 708            }
 709            cloud_llm_client::LanguageModelProvider::Google => {
 710                let client = self.client.clone();
 711                let llm_api_token = self.llm_api_token.clone();
 712                let model_id = self.model.id.to_string();
 713                let generate_content_request =
 714                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 715                async move {
 716                    let http_client = &client.http_client();
 717                    let token = llm_api_token.acquire(&client).await?;
 718
 719                    let request_body = CountTokensBody {
 720                        provider: cloud_llm_client::LanguageModelProvider::Google,
 721                        model: model_id,
 722                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 723                            generate_content_request,
 724                        })?,
 725                    };
 726                    let request = http_client::Request::builder()
 727                        .method(Method::POST)
 728                        .uri(
 729                            http_client
 730                                .build_zed_llm_url("/count_tokens", &[])?
 731                                .as_ref(),
 732                        )
 733                        .header("Content-Type", "application/json")
 734                        .header("Authorization", format!("Bearer {token}"))
 735                        .body(serde_json::to_string(&request_body)?.into())?;
 736                    let mut response = http_client.send(request).await?;
 737                    let status = response.status();
 738                    let headers = response.headers().clone();
 739                    let mut response_body = String::new();
 740                    response
 741                        .body_mut()
 742                        .read_to_string(&mut response_body)
 743                        .await?;
 744
 745                    if status.is_success() {
 746                        let response_body: CountTokensResponse =
 747                            serde_json::from_str(&response_body)?;
 748
 749                        Ok(response_body.tokens as u64)
 750                    } else {
 751                        Err(anyhow!(ApiError {
 752                            status,
 753                            body: response_body,
 754                            headers
 755                        }))
 756                    }
 757                }
 758                .boxed()
 759            }
 760        }
 761    }
 762
 763    fn stream_completion(
 764        &self,
 765        request: LanguageModelRequest,
 766        cx: &AsyncApp,
 767    ) -> BoxFuture<
 768        'static,
 769        Result<
 770            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 771            LanguageModelCompletionError,
 772        >,
 773    > {
 774        let thread_id = request.thread_id.clone();
 775        let prompt_id = request.prompt_id.clone();
 776        let intent = request.intent;
 777        let mode = request.mode;
 778        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 779        let thinking_allowed = request.thinking_allowed;
 780        match self.model.provider {
 781            cloud_llm_client::LanguageModelProvider::Anthropic => {
 782                let request = into_anthropic(
 783                    request,
 784                    self.model.id.to_string(),
 785                    1.0,
 786                    self.model.max_output_tokens as u64,
 787                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 788                        AnthropicModelMode::Thinking {
 789                            budget_tokens: Some(4_096),
 790                        }
 791                    } else {
 792                        AnthropicModelMode::Default
 793                    },
 794                );
 795                let client = self.client.clone();
 796                let llm_api_token = self.llm_api_token.clone();
 797                let future = self.request_limiter.stream(async move {
 798                    let PerformLlmCompletionResponse {
 799                        response,
 800                        usage,
 801                        includes_status_messages,
 802                        tool_use_limit_reached,
 803                    } = Self::perform_llm_completion(
 804                        client.clone(),
 805                        llm_api_token,
 806                        app_version,
 807                        CompletionBody {
 808                            thread_id,
 809                            prompt_id,
 810                            intent,
 811                            mode,
 812                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 813                            model: request.model.clone(),
 814                            provider_request: serde_json::to_value(&request)
 815                                .map_err(|e| anyhow!(e))?,
 816                        },
 817                    )
 818                    .await
 819                    .map_err(|err| match err.downcast::<ApiError>() {
 820                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 821                        Err(err) => anyhow!(err),
 822                    })?;
 823
 824                    let mut mapper = AnthropicEventMapper::new();
 825                    Ok(map_cloud_completion_events(
 826                        Box::pin(
 827                            response_lines(response, includes_status_messages)
 828                                .chain(usage_updated_event(usage))
 829                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 830                        ),
 831                        move |event| mapper.map_event(event),
 832                    ))
 833                });
 834                async move { Ok(future.await?.boxed()) }.boxed()
 835            }
 836            cloud_llm_client::LanguageModelProvider::OpenAi => {
 837                let client = self.client.clone();
 838                let model = match open_ai::Model::from_id(&self.model.id.0) {
 839                    Ok(model) => model,
 840                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 841                };
 842                let request = into_open_ai(
 843                    request,
 844                    model.id(),
 845                    model.supports_parallel_tool_calls(),
 846                    model.supports_prompt_cache_key(),
 847                    None,
 848                    None,
 849                );
 850                let llm_api_token = self.llm_api_token.clone();
 851                let future = self.request_limiter.stream(async move {
 852                    let PerformLlmCompletionResponse {
 853                        response,
 854                        usage,
 855                        includes_status_messages,
 856                        tool_use_limit_reached,
 857                    } = Self::perform_llm_completion(
 858                        client.clone(),
 859                        llm_api_token,
 860                        app_version,
 861                        CompletionBody {
 862                            thread_id,
 863                            prompt_id,
 864                            intent,
 865                            mode,
 866                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 867                            model: request.model.clone(),
 868                            provider_request: serde_json::to_value(&request)
 869                                .map_err(|e| anyhow!(e))?,
 870                        },
 871                    )
 872                    .await?;
 873
 874                    let mut mapper = OpenAiEventMapper::new();
 875                    Ok(map_cloud_completion_events(
 876                        Box::pin(
 877                            response_lines(response, includes_status_messages)
 878                                .chain(usage_updated_event(usage))
 879                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 880                        ),
 881                        move |event| mapper.map_event(event),
 882                    ))
 883                });
 884                async move { Ok(future.await?.boxed()) }.boxed()
 885            }
 886            cloud_llm_client::LanguageModelProvider::Google => {
 887                let client = self.client.clone();
 888                let request =
 889                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 890                let llm_api_token = self.llm_api_token.clone();
 891                let future = self.request_limiter.stream(async move {
 892                    let PerformLlmCompletionResponse {
 893                        response,
 894                        usage,
 895                        includes_status_messages,
 896                        tool_use_limit_reached,
 897                    } = Self::perform_llm_completion(
 898                        client.clone(),
 899                        llm_api_token,
 900                        app_version,
 901                        CompletionBody {
 902                            thread_id,
 903                            prompt_id,
 904                            intent,
 905                            mode,
 906                            provider: cloud_llm_client::LanguageModelProvider::Google,
 907                            model: request.model.model_id.clone(),
 908                            provider_request: serde_json::to_value(&request)
 909                                .map_err(|e| anyhow!(e))?,
 910                        },
 911                    )
 912                    .await?;
 913
 914                    let mut mapper = GoogleEventMapper::new();
 915                    Ok(map_cloud_completion_events(
 916                        Box::pin(
 917                            response_lines(response, includes_status_messages)
 918                                .chain(usage_updated_event(usage))
 919                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 920                        ),
 921                        move |event| mapper.map_event(event),
 922                    ))
 923                });
 924                async move { Ok(future.await?.boxed()) }.boxed()
 925            }
 926        }
 927    }
 928}
 929
 930fn map_cloud_completion_events<T, F>(
 931    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 932    mut map_callback: F,
 933) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 934where
 935    T: DeserializeOwned + 'static,
 936    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 937        + Send
 938        + 'static,
 939{
 940    stream
 941        .flat_map(move |event| {
 942            futures::stream::iter(match event {
 943                Err(error) => {
 944                    vec![Err(LanguageModelCompletionError::from(error))]
 945                }
 946                Ok(CompletionEvent::Status(event)) => {
 947                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
 948                }
 949                Ok(CompletionEvent::Event(event)) => map_callback(event),
 950            })
 951        })
 952        .boxed()
 953}
 954
 955fn usage_updated_event<T>(
 956    usage: Option<ModelRequestUsage>,
 957) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 958    futures::stream::iter(usage.map(|usage| {
 959        Ok(CompletionEvent::Status(
 960            CompletionRequestStatus::UsageUpdated {
 961                amount: usage.amount as usize,
 962                limit: usage.limit,
 963            },
 964        ))
 965    }))
 966}
 967
 968fn tool_use_limit_reached_event<T>(
 969    tool_use_limit_reached: bool,
 970) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 971    futures::stream::iter(tool_use_limit_reached.then(|| {
 972        Ok(CompletionEvent::Status(
 973            CompletionRequestStatus::ToolUseLimitReached,
 974        ))
 975    }))
 976}
 977
 978fn response_lines<T: DeserializeOwned>(
 979    response: Response<AsyncBody>,
 980    includes_status_messages: bool,
 981) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 982    futures::stream::try_unfold(
 983        (String::new(), BufReader::new(response.into_body())),
 984        move |(mut line, mut body)| async move {
 985            match body.read_line(&mut line).await {
 986                Ok(0) => Ok(None),
 987                Ok(_) => {
 988                    let event = if includes_status_messages {
 989                        serde_json::from_str::<CompletionEvent<T>>(&line)?
 990                    } else {
 991                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 992                    };
 993
 994                    line.clear();
 995                    Ok(Some((event, (line, body))))
 996                }
 997                Err(e) => Err(e.into()),
 998            }
 999        },
1000    )
1001}
1002
1003#[derive(IntoElement, RegisterComponent)]
1004struct ZedAiConfiguration {
1005    is_connected: bool,
1006    plan: Option<Plan>,
1007    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1008    eligible_for_trial: bool,
1009    account_too_young: bool,
1010    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1011}
1012
1013impl RenderOnce for ZedAiConfiguration {
1014    fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
1015        let is_pro = self.plan.is_some_and(|plan| {
1016            matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1017        });
1018        let is_free_v2 = self
1019            .plan
1020            .is_some_and(|plan| plan == Plan::V2(PlanV2::ZedFree));
1021        let subscription_text = match (self.plan, self.subscription_period) {
1022            (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1023                "You have access to Zed's hosted models through your Pro subscription."
1024            }
1025            (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1026                "You have access to Zed's hosted models through your Pro trial."
1027            }
1028            (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1029                "You have basic access to Zed's hosted models through the Free plan."
1030            }
1031            (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1032                if self.eligible_for_trial {
1033                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1034                } else {
1035                    "Subscribe for access to Zed's hosted models."
1036                }
1037            }
1038            _ => {
1039                if self.eligible_for_trial {
1040                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1041                } else {
1042                    "Subscribe for access to Zed's hosted models."
1043                }
1044            }
1045        };
1046
1047        let manage_subscription_buttons = if is_pro {
1048            Button::new("manage_settings", "Manage Subscription")
1049                .full_width()
1050                .style(ButtonStyle::Tinted(TintColor::Accent))
1051                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1052                .into_any_element()
1053        } else if self.plan.is_none() || self.eligible_for_trial {
1054            Button::new("start_trial", "Start 14-day Free Pro Trial")
1055                .full_width()
1056                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1057                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1058                .into_any_element()
1059        } else {
1060            Button::new("upgrade", "Upgrade to Pro")
1061                .full_width()
1062                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1063                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1064                .into_any_element()
1065        };
1066
1067        if !self.is_connected {
1068            return v_flex()
1069                .gap_2()
1070                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1071                .child(
1072                    Button::new("sign_in", "Sign In to use Zed AI")
1073                        .icon_color(Color::Muted)
1074                        .icon(IconName::Github)
1075                        .icon_size(IconSize::Small)
1076                        .icon_position(IconPosition::Start)
1077                        .full_width()
1078                        .on_click({
1079                            let callback = self.sign_in_callback.clone();
1080                            move |_, window, cx| (callback)(window, cx)
1081                        }),
1082                );
1083        }
1084
1085        v_flex().gap_2().w_full().map(|this| {
1086            if self.account_too_young {
1087                this.child(YoungAccountBanner::new(
1088                    is_free_v2 || cx.has_flag::<BillingV2FeatureFlag>(),
1089                ))
1090                .child(
1091                    Button::new("upgrade", "Upgrade to Pro")
1092                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1093                        .full_width()
1094                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1095                )
1096            } else {
1097                this.text_sm()
1098                    .child(subscription_text)
1099                    .child(manage_subscription_buttons)
1100            }
1101        })
1102    }
1103}
1104
1105struct ConfigurationView {
1106    state: Entity<State>,
1107    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1108}
1109
1110impl ConfigurationView {
1111    fn new(state: Entity<State>) -> Self {
1112        let sign_in_callback = Arc::new({
1113            let state = state.clone();
1114            move |_window: &mut Window, cx: &mut App| {
1115                state.update(cx, |state, cx| {
1116                    state.authenticate(cx).detach_and_log_err(cx);
1117                });
1118            }
1119        });
1120
1121        Self {
1122            state,
1123            sign_in_callback,
1124        }
1125    }
1126}
1127
1128impl Render for ConfigurationView {
1129    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1130        let state = self.state.read(cx);
1131        let user_store = state.user_store.read(cx);
1132
1133        ZedAiConfiguration {
1134            is_connected: !state.is_signed_out(cx),
1135            plan: user_store.plan(),
1136            subscription_period: user_store.subscription_period(),
1137            eligible_for_trial: user_store.trial_started_at().is_none(),
1138            account_too_young: user_store.account_too_young(),
1139            sign_in_callback: self.sign_in_callback.clone(),
1140        }
1141    }
1142}
1143
1144impl Component for ZedAiConfiguration {
1145    fn name() -> &'static str {
1146        "AI Configuration Content"
1147    }
1148
1149    fn sort_name() -> &'static str {
1150        "AI Configuration Content"
1151    }
1152
1153    fn scope() -> ComponentScope {
1154        ComponentScope::Onboarding
1155    }
1156
1157    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1158        fn configuration(
1159            is_connected: bool,
1160            plan: Option<Plan>,
1161            eligible_for_trial: bool,
1162            account_too_young: bool,
1163        ) -> AnyElement {
1164            ZedAiConfiguration {
1165                is_connected,
1166                plan,
1167                subscription_period: plan
1168                    .is_some()
1169                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1170                eligible_for_trial,
1171                account_too_young,
1172                sign_in_callback: Arc::new(|_, _| {}),
1173            }
1174            .into_any_element()
1175        }
1176
1177        Some(
1178            v_flex()
1179                .p_4()
1180                .gap_4()
1181                .children(vec![
1182                    single_example("Not connected", configuration(false, None, false, false)),
1183                    single_example(
1184                        "Accept Terms of Service",
1185                        configuration(true, None, true, false),
1186                    ),
1187                    single_example(
1188                        "No Plan - Not eligible for trial",
1189                        configuration(true, None, false, false),
1190                    ),
1191                    single_example(
1192                        "No Plan - Eligible for trial",
1193                        configuration(true, None, true, false),
1194                    ),
1195                    single_example(
1196                        "Free Plan",
1197                        configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1198                    ),
1199                    single_example(
1200                        "Zed Pro Trial Plan",
1201                        configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1202                    ),
1203                    single_example(
1204                        "Zed Pro Plan",
1205                        configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1206                    ),
1207                ])
1208                .into_any_element(),
1209        )
1210    }
1211}
1212
1213#[cfg(test)]
1214mod tests {
1215    use super::*;
1216    use http_client::http::{HeaderMap, StatusCode};
1217    use language_model::LanguageModelCompletionError;
1218
1219    #[test]
1220    fn test_api_error_conversion_with_upstream_http_error() {
1221        // upstream_http_error with 503 status should become ServerOverloaded
1222        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1223
1224        let api_error = ApiError {
1225            status: StatusCode::INTERNAL_SERVER_ERROR,
1226            body: error_body.to_string(),
1227            headers: HeaderMap::new(),
1228        };
1229
1230        let completion_error: LanguageModelCompletionError = api_error.into();
1231
1232        match completion_error {
1233            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1234                assert_eq!(
1235                    message,
1236                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1237                );
1238            }
1239            _ => panic!(
1240                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1241                completion_error
1242            ),
1243        }
1244
1245        // upstream_http_error with 500 status should become ApiInternalServerError
1246        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1247
1248        let api_error = ApiError {
1249            status: StatusCode::INTERNAL_SERVER_ERROR,
1250            body: error_body.to_string(),
1251            headers: HeaderMap::new(),
1252        };
1253
1254        let completion_error: LanguageModelCompletionError = api_error.into();
1255
1256        match completion_error {
1257            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1258                assert_eq!(
1259                    message,
1260                    "Received an error from the OpenAI API: internal server error"
1261                );
1262            }
1263            _ => panic!(
1264                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1265                completion_error
1266            ),
1267        }
1268
1269        // upstream_http_error with 429 status should become RateLimitExceeded
1270        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1271
1272        let api_error = ApiError {
1273            status: StatusCode::INTERNAL_SERVER_ERROR,
1274            body: error_body.to_string(),
1275            headers: HeaderMap::new(),
1276        };
1277
1278        let completion_error: LanguageModelCompletionError = api_error.into();
1279
1280        match completion_error {
1281            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1282                assert_eq!(
1283                    message,
1284                    "Received an error from the Google API: rate limit exceeded"
1285                );
1286            }
1287            _ => panic!(
1288                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1289                completion_error
1290            ),
1291        }
1292
1293        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1294        let error_body = "Regular internal server error";
1295
1296        let api_error = ApiError {
1297            status: StatusCode::INTERNAL_SERVER_ERROR,
1298            body: error_body.to_string(),
1299            headers: HeaderMap::new(),
1300        };
1301
1302        let completion_error: LanguageModelCompletionError = api_error.into();
1303
1304        match completion_error {
1305            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1306                assert_eq!(provider, PROVIDER_NAME);
1307                assert_eq!(message, "Regular internal server error");
1308            }
1309            _ => panic!(
1310                "Expected ApiInternalServerError for regular 500, got: {:?}",
1311                completion_error
1312            ),
1313        }
1314
1315        // upstream_http_429 format should be converted to UpstreamProviderError
1316        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1317
1318        let api_error = ApiError {
1319            status: StatusCode::INTERNAL_SERVER_ERROR,
1320            body: error_body.to_string(),
1321            headers: HeaderMap::new(),
1322        };
1323
1324        let completion_error: LanguageModelCompletionError = api_error.into();
1325
1326        match completion_error {
1327            LanguageModelCompletionError::UpstreamProviderError {
1328                message,
1329                status,
1330                retry_after,
1331            } => {
1332                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1333                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1334                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1335            }
1336            _ => panic!(
1337                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1338                completion_error
1339            ),
1340        }
1341
1342        // Invalid JSON in error body should fall back to regular error handling
1343        let error_body = "Not JSON at all";
1344
1345        let api_error = ApiError {
1346            status: StatusCode::INTERNAL_SERVER_ERROR,
1347            body: error_body.to_string(),
1348            headers: HeaderMap::new(),
1349        };
1350
1351        let completion_error: LanguageModelCompletionError = api_error.into();
1352
1353        match completion_error {
1354            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1355                assert_eq!(provider, PROVIDER_NAME);
1356            }
1357            _ => panic!(
1358                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1359                completion_error
1360            ),
1361        }
1362    }
1363}