cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   6use cloud_llm_client::{
   7    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
   8    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
   9    EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
  10    PlanV1, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
  11    SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME,
  12    ZED_VERSION_HEADER_NAME,
  13};
  14use futures::{
  15    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
  16};
  17use google_ai::GoogleModelMode;
  18use gpui::{
  19    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  20};
  21use http_client::http::{HeaderMap, HeaderValue};
  22use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  23use language_model::{
  24    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  25    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  26    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  27    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  28    LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
  29    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  30};
  31use release_channel::AppVersion;
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35use smol::io::{AsyncReadExt, BufReader};
  36use std::pin::Pin;
  37use std::str::FromStr as _;
  38use std::sync::Arc;
  39use std::time::Duration;
  40use thiserror::Error;
  41use ui::{TintColor, prelude::*};
  42use util::{ResultExt as _, maybe};
  43
  44use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  45use crate::provider::google::{GoogleEventMapper, into_google};
  46use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  47
  48const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  49const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  50
  51#[derive(Default, Clone, Debug, PartialEq)]
  52pub struct ZedDotDevSettings {
  53    pub available_models: Vec<AvailableModel>,
  54}
  55
  56#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  57#[serde(rename_all = "lowercase")]
  58pub enum AvailableProvider {
  59    Anthropic,
  60    OpenAi,
  61    Google,
  62}
  63
  64#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  65pub struct AvailableModel {
  66    /// The provider of the language model.
  67    pub provider: AvailableProvider,
  68    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  69    pub name: String,
  70    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  71    pub display_name: Option<String>,
  72    /// The size of the context window, indicating the maximum number of tokens the model can process.
  73    pub max_tokens: usize,
  74    /// The maximum number of output tokens allowed by the model.
  75    pub max_output_tokens: Option<u64>,
  76    /// The maximum number of completion tokens allowed by the model (o1-* only)
  77    pub max_completion_tokens: Option<u64>,
  78    /// Override this model with a different Anthropic model for tool calls.
  79    pub tool_override: Option<String>,
  80    /// Indicates whether this custom model supports caching.
  81    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  82    /// The default temperature to use for this model.
  83    pub default_temperature: Option<f32>,
  84    /// Any extra beta headers to provide when using the model.
  85    #[serde(default)]
  86    pub extra_beta_headers: Vec<String>,
  87    /// The model's mode (e.g. thinking)
  88    pub mode: Option<ModelMode>,
  89}
  90
  91#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  92#[serde(tag = "type", rename_all = "lowercase")]
  93pub enum ModelMode {
  94    #[default]
  95    Default,
  96    Thinking {
  97        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  98        budget_tokens: Option<u32>,
  99    },
 100}
 101
 102impl From<ModelMode> for AnthropicModelMode {
 103    fn from(value: ModelMode) -> Self {
 104        match value {
 105            ModelMode::Default => AnthropicModelMode::Default,
 106            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 107        }
 108    }
 109}
 110
 111pub struct CloudLanguageModelProvider {
 112    client: Arc<Client>,
 113    state: gpui::Entity<State>,
 114    _maintain_client_status: Task<()>,
 115}
 116
 117pub struct State {
 118    client: Arc<Client>,
 119    llm_api_token: LlmApiToken,
 120    user_store: Entity<UserStore>,
 121    status: client::Status,
 122    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 123    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 124    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 125    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 126    _fetch_models_task: Task<()>,
 127    _settings_subscription: Subscription,
 128    _llm_token_subscription: Subscription,
 129}
 130
 131impl State {
 132    fn new(
 133        client: Arc<Client>,
 134        user_store: Entity<UserStore>,
 135        status: client::Status,
 136        cx: &mut Context<Self>,
 137    ) -> Self {
 138        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 139        let mut current_user = user_store.read(cx).watch_current_user();
 140        Self {
 141            client: client.clone(),
 142            llm_api_token: LlmApiToken::default(),
 143            user_store,
 144            status,
 145            models: Vec::new(),
 146            default_model: None,
 147            default_fast_model: None,
 148            recommended_models: Vec::new(),
 149            _fetch_models_task: cx.spawn(async move |this, cx| {
 150                maybe!(async move {
 151                    let (client, llm_api_token) = this
 152                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 153
 154                    while current_user.borrow().is_none() {
 155                        current_user.next().await;
 156                    }
 157
 158                    let response =
 159                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 160                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 161                    anyhow::Ok(())
 162                })
 163                .await
 164                .context("failed to fetch Zed models")
 165                .log_err();
 166            }),
 167            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 168                cx.notify();
 169            }),
 170            _llm_token_subscription: cx.subscribe(
 171                &refresh_llm_token_listener,
 172                move |this, _listener, _event, cx| {
 173                    let client = this.client.clone();
 174                    let llm_api_token = this.llm_api_token.clone();
 175                    cx.spawn(async move |this, cx| {
 176                        llm_api_token.refresh(&client).await?;
 177                        let response = Self::fetch_models(client, llm_api_token).await?;
 178                        this.update(cx, |this, cx| {
 179                            this.update_models(response, cx);
 180                        })
 181                    })
 182                    .detach_and_log_err(cx);
 183                },
 184            ),
 185        }
 186    }
 187
 188    fn is_signed_out(&self, cx: &App) -> bool {
 189        self.user_store.read(cx).current_user().is_none()
 190    }
 191
 192    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 193        let client = self.client.clone();
 194        cx.spawn(async move |state, cx| {
 195            client.sign_in_with_optional_connect(true, cx).await?;
 196            state.update(cx, |_, cx| cx.notify())
 197        })
 198    }
 199    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 200        let mut models = Vec::new();
 201
 202        for model in response.models {
 203            models.push(Arc::new(model.clone()));
 204
 205            // Right now we represent thinking variants of models as separate models on the client,
 206            // so we need to insert variants for any model that supports thinking.
 207            if model.supports_thinking {
 208                models.push(Arc::new(cloud_llm_client::LanguageModel {
 209                    id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 210                    display_name: format!("{} Thinking", model.display_name),
 211                    ..model
 212                }));
 213            }
 214        }
 215
 216        self.default_model = models
 217            .iter()
 218            .find(|model| {
 219                response
 220                    .default_model
 221                    .as_ref()
 222                    .is_some_and(|default_model_id| &model.id == default_model_id)
 223            })
 224            .cloned();
 225        self.default_fast_model = models
 226            .iter()
 227            .find(|model| {
 228                response
 229                    .default_fast_model
 230                    .as_ref()
 231                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 232            })
 233            .cloned();
 234        self.recommended_models = response
 235            .recommended_models
 236            .iter()
 237            .filter_map(|id| models.iter().find(|model| &model.id == id))
 238            .cloned()
 239            .collect();
 240        self.models = models;
 241        cx.notify();
 242    }
 243
 244    async fn fetch_models(
 245        client: Arc<Client>,
 246        llm_api_token: LlmApiToken,
 247    ) -> Result<ListModelsResponse> {
 248        let http_client = &client.http_client();
 249        let token = llm_api_token.acquire(&client).await?;
 250
 251        let request = http_client::Request::builder()
 252            .method(Method::GET)
 253            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 254            .header("Authorization", format!("Bearer {token}"))
 255            .body(AsyncBody::empty())?;
 256        let mut response = http_client
 257            .send(request)
 258            .await
 259            .context("failed to send list models request")?;
 260
 261        if response.status().is_success() {
 262            let mut body = String::new();
 263            response.body_mut().read_to_string(&mut body).await?;
 264            Ok(serde_json::from_str(&body)?)
 265        } else {
 266            let mut body = String::new();
 267            response.body_mut().read_to_string(&mut body).await?;
 268            anyhow::bail!(
 269                "error listing models.\nStatus: {:?}\nBody: {body}",
 270                response.status(),
 271            );
 272        }
 273    }
 274}
 275
 276impl CloudLanguageModelProvider {
 277    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 278        let mut status_rx = client.status();
 279        let status = *status_rx.borrow();
 280
 281        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 282
 283        let state_ref = state.downgrade();
 284        let maintain_client_status = cx.spawn(async move |cx| {
 285            while let Some(status) = status_rx.next().await {
 286                if let Some(this) = state_ref.upgrade() {
 287                    _ = this.update(cx, |this, cx| {
 288                        if this.status != status {
 289                            this.status = status;
 290                            cx.notify();
 291                        }
 292                    });
 293                } else {
 294                    break;
 295                }
 296            }
 297        });
 298
 299        Self {
 300            client,
 301            state,
 302            _maintain_client_status: maintain_client_status,
 303        }
 304    }
 305
 306    fn create_language_model(
 307        &self,
 308        model: Arc<cloud_llm_client::LanguageModel>,
 309        llm_api_token: LlmApiToken,
 310    ) -> Arc<dyn LanguageModel> {
 311        Arc::new(CloudLanguageModel {
 312            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 313            model,
 314            llm_api_token,
 315            client: self.client.clone(),
 316            request_limiter: RateLimiter::new(4),
 317        })
 318    }
 319}
 320
 321impl LanguageModelProviderState for CloudLanguageModelProvider {
 322    type ObservableEntity = State;
 323
 324    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 325        Some(self.state.clone())
 326    }
 327}
 328
 329impl LanguageModelProvider for CloudLanguageModelProvider {
 330    fn id(&self) -> LanguageModelProviderId {
 331        PROVIDER_ID
 332    }
 333
 334    fn name(&self) -> LanguageModelProviderName {
 335        PROVIDER_NAME
 336    }
 337
 338    fn icon(&self) -> IconName {
 339        IconName::AiZed
 340    }
 341
 342    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 343        let default_model = self.state.read(cx).default_model.clone()?;
 344        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 345        Some(self.create_language_model(default_model, llm_api_token))
 346    }
 347
 348    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 349        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 350        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 351        Some(self.create_language_model(default_fast_model, llm_api_token))
 352    }
 353
 354    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 355        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 356        self.state
 357            .read(cx)
 358            .recommended_models
 359            .iter()
 360            .cloned()
 361            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 362            .collect()
 363    }
 364
 365    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 366        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 367        self.state
 368            .read(cx)
 369            .models
 370            .iter()
 371            .cloned()
 372            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 373            .collect()
 374    }
 375
 376    fn is_authenticated(&self, cx: &App) -> bool {
 377        let state = self.state.read(cx);
 378        !state.is_signed_out(cx)
 379    }
 380
 381    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 382        Task::ready(Ok(()))
 383    }
 384
 385    fn configuration_view(
 386        &self,
 387        _target_agent: language_model::ConfigurationViewTargetAgent,
 388        _: &mut Window,
 389        cx: &mut App,
 390    ) -> AnyView {
 391        cx.new(|_| ConfigurationView::new(self.state.clone()))
 392            .into()
 393    }
 394
 395    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 396        Task::ready(Ok(()))
 397    }
 398}
 399
 400pub struct CloudLanguageModel {
 401    id: LanguageModelId,
 402    model: Arc<cloud_llm_client::LanguageModel>,
 403    llm_api_token: LlmApiToken,
 404    client: Arc<Client>,
 405    request_limiter: RateLimiter,
 406}
 407
 408struct PerformLlmCompletionResponse {
 409    response: Response<AsyncBody>,
 410    usage: Option<ModelRequestUsage>,
 411    tool_use_limit_reached: bool,
 412    includes_status_messages: bool,
 413}
 414
 415impl CloudLanguageModel {
 416    async fn perform_llm_completion(
 417        client: Arc<Client>,
 418        llm_api_token: LlmApiToken,
 419        app_version: Option<SemanticVersion>,
 420        body: CompletionBody,
 421    ) -> Result<PerformLlmCompletionResponse> {
 422        let http_client = &client.http_client();
 423
 424        let mut token = llm_api_token.acquire(&client).await?;
 425        let mut refreshed_token = false;
 426
 427        loop {
 428            let request_builder = http_client::Request::builder()
 429                .method(Method::POST)
 430                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 431            let request_builder = if let Some(app_version) = app_version {
 432                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 433            } else {
 434                request_builder
 435            };
 436
 437            let request = request_builder
 438                .header("Content-Type", "application/json")
 439                .header("Authorization", format!("Bearer {token}"))
 440                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 441                .body(serde_json::to_string(&body)?.into())?;
 442            let mut response = http_client.send(request).await?;
 443            let status = response.status();
 444            if status.is_success() {
 445                let includes_status_messages = response
 446                    .headers()
 447                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 448                    .is_some();
 449
 450                let tool_use_limit_reached = response
 451                    .headers()
 452                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 453                    .is_some();
 454
 455                let usage = if includes_status_messages {
 456                    None
 457                } else {
 458                    ModelRequestUsage::from_headers(response.headers()).ok()
 459                };
 460
 461                return Ok(PerformLlmCompletionResponse {
 462                    response,
 463                    usage,
 464                    includes_status_messages,
 465                    tool_use_limit_reached,
 466                });
 467            }
 468
 469            if !refreshed_token
 470                && response
 471                    .headers()
 472                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 473                    .is_some()
 474            {
 475                token = llm_api_token.refresh(&client).await?;
 476                refreshed_token = true;
 477                continue;
 478            }
 479
 480            if status == StatusCode::FORBIDDEN
 481                && response
 482                    .headers()
 483                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 484                    .is_some()
 485            {
 486                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 487                    .headers()
 488                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 489                    .and_then(|resource| resource.to_str().ok())
 490                    && let Some(plan) = response
 491                        .headers()
 492                        .get(CURRENT_PLAN_HEADER_NAME)
 493                        .and_then(|plan| plan.to_str().ok())
 494                        .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
 495                        .map(Plan::V1)
 496                {
 497                    return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 498                }
 499            } else if status == StatusCode::PAYMENT_REQUIRED {
 500                return Err(anyhow!(PaymentRequiredError));
 501            }
 502
 503            let mut body = String::new();
 504            let headers = response.headers().clone();
 505            response.body_mut().read_to_string(&mut body).await?;
 506            return Err(anyhow!(ApiError {
 507                status,
 508                body,
 509                headers
 510            }));
 511        }
 512    }
 513}
 514
 515#[derive(Debug, Error)]
 516#[error("cloud language model request failed with status {status}: {body}")]
 517struct ApiError {
 518    status: StatusCode,
 519    body: String,
 520    headers: HeaderMap<HeaderValue>,
 521}
 522
 523/// Represents error responses from Zed's cloud API.
 524///
 525/// Example JSON for an upstream HTTP error:
 526/// ```json
 527/// {
 528///   "code": "upstream_http_error",
 529///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 530///   "upstream_status": 503
 531/// }
 532/// ```
 533#[derive(Debug, serde::Deserialize)]
 534struct CloudApiError {
 535    code: String,
 536    message: String,
 537    #[serde(default)]
 538    #[serde(deserialize_with = "deserialize_optional_status_code")]
 539    upstream_status: Option<StatusCode>,
 540    #[serde(default)]
 541    retry_after: Option<f64>,
 542}
 543
 544fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 545where
 546    D: serde::Deserializer<'de>,
 547{
 548    let opt: Option<u16> = Option::deserialize(deserializer)?;
 549    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 550}
 551
 552impl From<ApiError> for LanguageModelCompletionError {
 553    fn from(error: ApiError) -> Self {
 554        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 555            if cloud_error.code.starts_with("upstream_http_") {
 556                let status = if let Some(status) = cloud_error.upstream_status {
 557                    status
 558                } else if cloud_error.code.ends_with("_error") {
 559                    error.status
 560                } else {
 561                    // If there's a status code in the code string (e.g. "upstream_http_429")
 562                    // then use that; otherwise, see if the JSON contains a status code.
 563                    cloud_error
 564                        .code
 565                        .strip_prefix("upstream_http_")
 566                        .and_then(|code_str| code_str.parse::<u16>().ok())
 567                        .and_then(|code| StatusCode::from_u16(code).ok())
 568                        .unwrap_or(error.status)
 569                };
 570
 571                return LanguageModelCompletionError::UpstreamProviderError {
 572                    message: cloud_error.message,
 573                    status,
 574                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 575                };
 576            }
 577
 578            return LanguageModelCompletionError::from_http_status(
 579                PROVIDER_NAME,
 580                error.status,
 581                cloud_error.message,
 582                None,
 583            );
 584        }
 585
 586        let retry_after = None;
 587        LanguageModelCompletionError::from_http_status(
 588            PROVIDER_NAME,
 589            error.status,
 590            error.body,
 591            retry_after,
 592        )
 593    }
 594}
 595
 596impl LanguageModel for CloudLanguageModel {
 597    fn id(&self) -> LanguageModelId {
 598        self.id.clone()
 599    }
 600
 601    fn name(&self) -> LanguageModelName {
 602        LanguageModelName::from(self.model.display_name.clone())
 603    }
 604
 605    fn provider_id(&self) -> LanguageModelProviderId {
 606        PROVIDER_ID
 607    }
 608
 609    fn provider_name(&self) -> LanguageModelProviderName {
 610        PROVIDER_NAME
 611    }
 612
 613    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 614        use cloud_llm_client::LanguageModelProvider::*;
 615        match self.model.provider {
 616            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 617            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 618            Google => language_model::GOOGLE_PROVIDER_ID,
 619        }
 620    }
 621
 622    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 623        use cloud_llm_client::LanguageModelProvider::*;
 624        match self.model.provider {
 625            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 626            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 627            Google => language_model::GOOGLE_PROVIDER_NAME,
 628        }
 629    }
 630
 631    fn supports_tools(&self) -> bool {
 632        self.model.supports_tools
 633    }
 634
 635    fn supports_images(&self) -> bool {
 636        self.model.supports_images
 637    }
 638
 639    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 640        match choice {
 641            LanguageModelToolChoice::Auto
 642            | LanguageModelToolChoice::Any
 643            | LanguageModelToolChoice::None => true,
 644        }
 645    }
 646
 647    fn supports_burn_mode(&self) -> bool {
 648        self.model.supports_max_mode
 649    }
 650
 651    fn telemetry_id(&self) -> String {
 652        format!("zed.dev/{}", self.model.id)
 653    }
 654
 655    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 656        match self.model.provider {
 657            cloud_llm_client::LanguageModelProvider::Anthropic
 658            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 659                LanguageModelToolSchemaFormat::JsonSchema
 660            }
 661            cloud_llm_client::LanguageModelProvider::Google => {
 662                LanguageModelToolSchemaFormat::JsonSchemaSubset
 663            }
 664        }
 665    }
 666
 667    fn max_token_count(&self) -> u64 {
 668        self.model.max_token_count as u64
 669    }
 670
 671    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 672        self.model
 673            .max_token_count_in_max_mode
 674            .filter(|_| self.model.supports_max_mode)
 675            .map(|max_token_count| max_token_count as u64)
 676    }
 677
 678    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 679        match &self.model.provider {
 680            cloud_llm_client::LanguageModelProvider::Anthropic => {
 681                Some(LanguageModelCacheConfiguration {
 682                    min_total_token: 2_048,
 683                    should_speculate: true,
 684                    max_cache_anchors: 4,
 685                })
 686            }
 687            cloud_llm_client::LanguageModelProvider::OpenAi
 688            | cloud_llm_client::LanguageModelProvider::Google => None,
 689        }
 690    }
 691
 692    fn count_tokens(
 693        &self,
 694        request: LanguageModelRequest,
 695        cx: &App,
 696    ) -> BoxFuture<'static, Result<u64>> {
 697        match self.model.provider {
 698            cloud_llm_client::LanguageModelProvider::Anthropic => {
 699                count_anthropic_tokens(request, cx)
 700            }
 701            cloud_llm_client::LanguageModelProvider::OpenAi => {
 702                let model = match open_ai::Model::from_id(&self.model.id.0) {
 703                    Ok(model) => model,
 704                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 705                };
 706                count_open_ai_tokens(request, model, cx)
 707            }
 708            cloud_llm_client::LanguageModelProvider::Google => {
 709                let client = self.client.clone();
 710                let llm_api_token = self.llm_api_token.clone();
 711                let model_id = self.model.id.to_string();
 712                let generate_content_request =
 713                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 714                async move {
 715                    let http_client = &client.http_client();
 716                    let token = llm_api_token.acquire(&client).await?;
 717
 718                    let request_body = CountTokensBody {
 719                        provider: cloud_llm_client::LanguageModelProvider::Google,
 720                        model: model_id,
 721                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 722                            generate_content_request,
 723                        })?,
 724                    };
 725                    let request = http_client::Request::builder()
 726                        .method(Method::POST)
 727                        .uri(
 728                            http_client
 729                                .build_zed_llm_url("/count_tokens", &[])?
 730                                .as_ref(),
 731                        )
 732                        .header("Content-Type", "application/json")
 733                        .header("Authorization", format!("Bearer {token}"))
 734                        .body(serde_json::to_string(&request_body)?.into())?;
 735                    let mut response = http_client.send(request).await?;
 736                    let status = response.status();
 737                    let headers = response.headers().clone();
 738                    let mut response_body = String::new();
 739                    response
 740                        .body_mut()
 741                        .read_to_string(&mut response_body)
 742                        .await?;
 743
 744                    if status.is_success() {
 745                        let response_body: CountTokensResponse =
 746                            serde_json::from_str(&response_body)?;
 747
 748                        Ok(response_body.tokens as u64)
 749                    } else {
 750                        Err(anyhow!(ApiError {
 751                            status,
 752                            body: response_body,
 753                            headers
 754                        }))
 755                    }
 756                }
 757                .boxed()
 758            }
 759        }
 760    }
 761
 762    fn stream_completion(
 763        &self,
 764        request: LanguageModelRequest,
 765        cx: &AsyncApp,
 766    ) -> BoxFuture<
 767        'static,
 768        Result<
 769            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 770            LanguageModelCompletionError,
 771        >,
 772    > {
 773        let thread_id = request.thread_id.clone();
 774        let prompt_id = request.prompt_id.clone();
 775        let intent = request.intent;
 776        let mode = request.mode;
 777        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 778        let thinking_allowed = request.thinking_allowed;
 779        match self.model.provider {
 780            cloud_llm_client::LanguageModelProvider::Anthropic => {
 781                let request = into_anthropic(
 782                    request,
 783                    self.model.id.to_string(),
 784                    1.0,
 785                    self.model.max_output_tokens as u64,
 786                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 787                        AnthropicModelMode::Thinking {
 788                            budget_tokens: Some(4_096),
 789                        }
 790                    } else {
 791                        AnthropicModelMode::Default
 792                    },
 793                );
 794                let client = self.client.clone();
 795                let llm_api_token = self.llm_api_token.clone();
 796                let future = self.request_limiter.stream(async move {
 797                    let PerformLlmCompletionResponse {
 798                        response,
 799                        usage,
 800                        includes_status_messages,
 801                        tool_use_limit_reached,
 802                    } = Self::perform_llm_completion(
 803                        client.clone(),
 804                        llm_api_token,
 805                        app_version,
 806                        CompletionBody {
 807                            thread_id,
 808                            prompt_id,
 809                            intent,
 810                            mode,
 811                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 812                            model: request.model.clone(),
 813                            provider_request: serde_json::to_value(&request)
 814                                .map_err(|e| anyhow!(e))?,
 815                        },
 816                    )
 817                    .await
 818                    .map_err(|err| match err.downcast::<ApiError>() {
 819                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 820                        Err(err) => anyhow!(err),
 821                    })?;
 822
 823                    let mut mapper = AnthropicEventMapper::new();
 824                    Ok(map_cloud_completion_events(
 825                        Box::pin(
 826                            response_lines(response, includes_status_messages)
 827                                .chain(usage_updated_event(usage))
 828                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 829                        ),
 830                        move |event| mapper.map_event(event),
 831                    ))
 832                });
 833                async move { Ok(future.await?.boxed()) }.boxed()
 834            }
 835            cloud_llm_client::LanguageModelProvider::OpenAi => {
 836                let client = self.client.clone();
 837                let model = match open_ai::Model::from_id(&self.model.id.0) {
 838                    Ok(model) => model,
 839                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 840                };
 841                let request = into_open_ai(
 842                    request,
 843                    model.id(),
 844                    model.supports_parallel_tool_calls(),
 845                    model.supports_prompt_cache_key(),
 846                    None,
 847                    None,
 848                );
 849                let llm_api_token = self.llm_api_token.clone();
 850                let future = self.request_limiter.stream(async move {
 851                    let PerformLlmCompletionResponse {
 852                        response,
 853                        usage,
 854                        includes_status_messages,
 855                        tool_use_limit_reached,
 856                    } = Self::perform_llm_completion(
 857                        client.clone(),
 858                        llm_api_token,
 859                        app_version,
 860                        CompletionBody {
 861                            thread_id,
 862                            prompt_id,
 863                            intent,
 864                            mode,
 865                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 866                            model: request.model.clone(),
 867                            provider_request: serde_json::to_value(&request)
 868                                .map_err(|e| anyhow!(e))?,
 869                        },
 870                    )
 871                    .await?;
 872
 873                    let mut mapper = OpenAiEventMapper::new();
 874                    Ok(map_cloud_completion_events(
 875                        Box::pin(
 876                            response_lines(response, includes_status_messages)
 877                                .chain(usage_updated_event(usage))
 878                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 879                        ),
 880                        move |event| mapper.map_event(event),
 881                    ))
 882                });
 883                async move { Ok(future.await?.boxed()) }.boxed()
 884            }
 885            cloud_llm_client::LanguageModelProvider::Google => {
 886                let client = self.client.clone();
 887                let request =
 888                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 889                let llm_api_token = self.llm_api_token.clone();
 890                let future = self.request_limiter.stream(async move {
 891                    let PerformLlmCompletionResponse {
 892                        response,
 893                        usage,
 894                        includes_status_messages,
 895                        tool_use_limit_reached,
 896                    } = Self::perform_llm_completion(
 897                        client.clone(),
 898                        llm_api_token,
 899                        app_version,
 900                        CompletionBody {
 901                            thread_id,
 902                            prompt_id,
 903                            intent,
 904                            mode,
 905                            provider: cloud_llm_client::LanguageModelProvider::Google,
 906                            model: request.model.model_id.clone(),
 907                            provider_request: serde_json::to_value(&request)
 908                                .map_err(|e| anyhow!(e))?,
 909                        },
 910                    )
 911                    .await?;
 912
 913                    let mut mapper = GoogleEventMapper::new();
 914                    Ok(map_cloud_completion_events(
 915                        Box::pin(
 916                            response_lines(response, includes_status_messages)
 917                                .chain(usage_updated_event(usage))
 918                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 919                        ),
 920                        move |event| mapper.map_event(event),
 921                    ))
 922                });
 923                async move { Ok(future.await?.boxed()) }.boxed()
 924            }
 925        }
 926    }
 927}
 928
 929fn map_cloud_completion_events<T, F>(
 930    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 931    mut map_callback: F,
 932) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 933where
 934    T: DeserializeOwned + 'static,
 935    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 936        + Send
 937        + 'static,
 938{
 939    stream
 940        .flat_map(move |event| {
 941            futures::stream::iter(match event {
 942                Err(error) => {
 943                    vec![Err(LanguageModelCompletionError::from(error))]
 944                }
 945                Ok(CompletionEvent::Status(event)) => {
 946                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
 947                }
 948                Ok(CompletionEvent::Event(event)) => map_callback(event),
 949            })
 950        })
 951        .boxed()
 952}
 953
 954fn usage_updated_event<T>(
 955    usage: Option<ModelRequestUsage>,
 956) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 957    futures::stream::iter(usage.map(|usage| {
 958        Ok(CompletionEvent::Status(
 959            CompletionRequestStatus::UsageUpdated {
 960                amount: usage.amount as usize,
 961                limit: usage.limit,
 962            },
 963        ))
 964    }))
 965}
 966
 967fn tool_use_limit_reached_event<T>(
 968    tool_use_limit_reached: bool,
 969) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 970    futures::stream::iter(tool_use_limit_reached.then(|| {
 971        Ok(CompletionEvent::Status(
 972            CompletionRequestStatus::ToolUseLimitReached,
 973        ))
 974    }))
 975}
 976
 977fn response_lines<T: DeserializeOwned>(
 978    response: Response<AsyncBody>,
 979    includes_status_messages: bool,
 980) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 981    futures::stream::try_unfold(
 982        (String::new(), BufReader::new(response.into_body())),
 983        move |(mut line, mut body)| async move {
 984            match body.read_line(&mut line).await {
 985                Ok(0) => Ok(None),
 986                Ok(_) => {
 987                    let event = if includes_status_messages {
 988                        serde_json::from_str::<CompletionEvent<T>>(&line)?
 989                    } else {
 990                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 991                    };
 992
 993                    line.clear();
 994                    Ok(Some((event, (line, body))))
 995                }
 996                Err(e) => Err(e.into()),
 997            }
 998        },
 999    )
1000}
1001
1002#[derive(IntoElement, RegisterComponent)]
1003struct ZedAiConfiguration {
1004    is_connected: bool,
1005    plan: Option<Plan>,
1006    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1007    eligible_for_trial: bool,
1008    account_too_young: bool,
1009    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1010}
1011
1012impl RenderOnce for ZedAiConfiguration {
1013    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1014        let young_account_banner = YoungAccountBanner;
1015
1016        let is_pro = self.plan.is_some_and(|plan| {
1017            matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1018        });
1019        let subscription_text = match (self.plan, self.subscription_period) {
1020            (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1021                "You have access to Zed's hosted models through your Pro subscription."
1022            }
1023            (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1024                "You have access to Zed's hosted models through your Pro trial."
1025            }
1026            (Some(Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1027                "You have basic access to Zed's hosted models through the Free plan."
1028            }
1029            _ => {
1030                if self.eligible_for_trial {
1031                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1032                } else {
1033                    "Subscribe for access to Zed's hosted models."
1034                }
1035            }
1036        };
1037
1038        let manage_subscription_buttons = if is_pro {
1039            Button::new("manage_settings", "Manage Subscription")
1040                .full_width()
1041                .style(ButtonStyle::Tinted(TintColor::Accent))
1042                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1043                .into_any_element()
1044        } else if self.plan.is_none() || self.eligible_for_trial {
1045            Button::new("start_trial", "Start 14-day Free Pro Trial")
1046                .full_width()
1047                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1048                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1049                .into_any_element()
1050        } else {
1051            Button::new("upgrade", "Upgrade to Pro")
1052                .full_width()
1053                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1054                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1055                .into_any_element()
1056        };
1057
1058        if !self.is_connected {
1059            return v_flex()
1060                .gap_2()
1061                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1062                .child(
1063                    Button::new("sign_in", "Sign In to use Zed AI")
1064                        .icon_color(Color::Muted)
1065                        .icon(IconName::Github)
1066                        .icon_size(IconSize::Small)
1067                        .icon_position(IconPosition::Start)
1068                        .full_width()
1069                        .on_click({
1070                            let callback = self.sign_in_callback.clone();
1071                            move |_, window, cx| (callback)(window, cx)
1072                        }),
1073                );
1074        }
1075
1076        v_flex().gap_2().w_full().map(|this| {
1077            if self.account_too_young {
1078                this.child(young_account_banner).child(
1079                    Button::new("upgrade", "Upgrade to Pro")
1080                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1081                        .full_width()
1082                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1083                )
1084            } else {
1085                this.text_sm()
1086                    .child(subscription_text)
1087                    .child(manage_subscription_buttons)
1088            }
1089        })
1090    }
1091}
1092
1093struct ConfigurationView {
1094    state: Entity<State>,
1095    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1096}
1097
1098impl ConfigurationView {
1099    fn new(state: Entity<State>) -> Self {
1100        let sign_in_callback = Arc::new({
1101            let state = state.clone();
1102            move |_window: &mut Window, cx: &mut App| {
1103                state.update(cx, |state, cx| {
1104                    state.authenticate(cx).detach_and_log_err(cx);
1105                });
1106            }
1107        });
1108
1109        Self {
1110            state,
1111            sign_in_callback,
1112        }
1113    }
1114}
1115
1116impl Render for ConfigurationView {
1117    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1118        let state = self.state.read(cx);
1119        let user_store = state.user_store.read(cx);
1120
1121        ZedAiConfiguration {
1122            is_connected: !state.is_signed_out(cx),
1123            plan: user_store.plan(),
1124            subscription_period: user_store.subscription_period(),
1125            eligible_for_trial: user_store.trial_started_at().is_none(),
1126            account_too_young: user_store.account_too_young(),
1127            sign_in_callback: self.sign_in_callback.clone(),
1128        }
1129    }
1130}
1131
1132impl Component for ZedAiConfiguration {
1133    fn name() -> &'static str {
1134        "AI Configuration Content"
1135    }
1136
1137    fn sort_name() -> &'static str {
1138        "AI Configuration Content"
1139    }
1140
1141    fn scope() -> ComponentScope {
1142        ComponentScope::Onboarding
1143    }
1144
1145    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1146        fn configuration(
1147            is_connected: bool,
1148            plan: Option<Plan>,
1149            eligible_for_trial: bool,
1150            account_too_young: bool,
1151        ) -> AnyElement {
1152            ZedAiConfiguration {
1153                is_connected,
1154                plan,
1155                subscription_period: plan
1156                    .is_some()
1157                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1158                eligible_for_trial,
1159                account_too_young,
1160                sign_in_callback: Arc::new(|_, _| {}),
1161            }
1162            .into_any_element()
1163        }
1164
1165        Some(
1166            v_flex()
1167                .p_4()
1168                .gap_4()
1169                .children(vec![
1170                    single_example("Not connected", configuration(false, None, false, false)),
1171                    single_example(
1172                        "Accept Terms of Service",
1173                        configuration(true, None, true, false),
1174                    ),
1175                    single_example(
1176                        "No Plan - Not eligible for trial",
1177                        configuration(true, None, false, false),
1178                    ),
1179                    single_example(
1180                        "No Plan - Eligible for trial",
1181                        configuration(true, None, true, false),
1182                    ),
1183                    single_example(
1184                        "Free Plan",
1185                        configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1186                    ),
1187                    single_example(
1188                        "Zed Pro Trial Plan",
1189                        configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1190                    ),
1191                    single_example(
1192                        "Zed Pro Plan",
1193                        configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1194                    ),
1195                ])
1196                .into_any_element(),
1197        )
1198    }
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203    use super::*;
1204    use http_client::http::{HeaderMap, StatusCode};
1205    use language_model::LanguageModelCompletionError;
1206
1207    #[test]
1208    fn test_api_error_conversion_with_upstream_http_error() {
1209        // upstream_http_error with 503 status should become ServerOverloaded
1210        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1211
1212        let api_error = ApiError {
1213            status: StatusCode::INTERNAL_SERVER_ERROR,
1214            body: error_body.to_string(),
1215            headers: HeaderMap::new(),
1216        };
1217
1218        let completion_error: LanguageModelCompletionError = api_error.into();
1219
1220        match completion_error {
1221            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1222                assert_eq!(
1223                    message,
1224                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1225                );
1226            }
1227            _ => panic!(
1228                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1229                completion_error
1230            ),
1231        }
1232
1233        // upstream_http_error with 500 status should become ApiInternalServerError
1234        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1235
1236        let api_error = ApiError {
1237            status: StatusCode::INTERNAL_SERVER_ERROR,
1238            body: error_body.to_string(),
1239            headers: HeaderMap::new(),
1240        };
1241
1242        let completion_error: LanguageModelCompletionError = api_error.into();
1243
1244        match completion_error {
1245            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1246                assert_eq!(
1247                    message,
1248                    "Received an error from the OpenAI API: internal server error"
1249                );
1250            }
1251            _ => panic!(
1252                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1253                completion_error
1254            ),
1255        }
1256
1257        // upstream_http_error with 429 status should become RateLimitExceeded
1258        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1259
1260        let api_error = ApiError {
1261            status: StatusCode::INTERNAL_SERVER_ERROR,
1262            body: error_body.to_string(),
1263            headers: HeaderMap::new(),
1264        };
1265
1266        let completion_error: LanguageModelCompletionError = api_error.into();
1267
1268        match completion_error {
1269            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1270                assert_eq!(
1271                    message,
1272                    "Received an error from the Google API: rate limit exceeded"
1273                );
1274            }
1275            _ => panic!(
1276                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1277                completion_error
1278            ),
1279        }
1280
1281        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1282        let error_body = "Regular internal server error";
1283
1284        let api_error = ApiError {
1285            status: StatusCode::INTERNAL_SERVER_ERROR,
1286            body: error_body.to_string(),
1287            headers: HeaderMap::new(),
1288        };
1289
1290        let completion_error: LanguageModelCompletionError = api_error.into();
1291
1292        match completion_error {
1293            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1294                assert_eq!(provider, PROVIDER_NAME);
1295                assert_eq!(message, "Regular internal server error");
1296            }
1297            _ => panic!(
1298                "Expected ApiInternalServerError for regular 500, got: {:?}",
1299                completion_error
1300            ),
1301        }
1302
1303        // upstream_http_429 format should be converted to UpstreamProviderError
1304        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1305
1306        let api_error = ApiError {
1307            status: StatusCode::INTERNAL_SERVER_ERROR,
1308            body: error_body.to_string(),
1309            headers: HeaderMap::new(),
1310        };
1311
1312        let completion_error: LanguageModelCompletionError = api_error.into();
1313
1314        match completion_error {
1315            LanguageModelCompletionError::UpstreamProviderError {
1316                message,
1317                status,
1318                retry_after,
1319            } => {
1320                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1321                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1322                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1323            }
1324            _ => panic!(
1325                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1326                completion_error
1327            ),
1328        }
1329
1330        // Invalid JSON in error body should fall back to regular error handling
1331        let error_body = "Not JSON at all";
1332
1333        let api_error = ApiError {
1334            status: StatusCode::INTERNAL_SERVER_ERROR,
1335            body: error_body.to_string(),
1336            headers: HeaderMap::new(),
1337        };
1338
1339        let completion_error: LanguageModelCompletionError = api_error.into();
1340
1341        match completion_error {
1342            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1343                assert_eq!(provider, PROVIDER_NAME);
1344            }
1345            _ => panic!(
1346                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1347                completion_error
1348            ),
1349        }
1350    }
1351}