cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, UserStore, zed_urls};
   6use cloud_api_types::{OrganizationId, Plan};
   7use cloud_llm_client::{
   8    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   9    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  10    CountTokensBody, CountTokensResponse, ListModelsResponse,
  11    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  12};
  13use futures::{
  14    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  15    future::BoxFuture,
  16    stream::{self, BoxStream},
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  25    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  26    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::collections::VecDeque;
  39use std::pin::Pin;
  40use std::str::FromStr;
  41use std::sync::Arc;
  42use std::task::Poll;
  43use std::time::Duration;
  44use thiserror::Error;
  45use ui::{TintColor, prelude::*};
  46use util::{ResultExt as _, maybe};
  47
  48use crate::provider::anthropic::{
  49    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  50};
  51use crate::provider::google::{GoogleEventMapper, into_google};
  52use crate::provider::open_ai::{
  53    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  54    into_open_ai_response,
  55};
  56use crate::provider::x_ai::count_xai_tokens;
  57
  58const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  59const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  60
  61#[derive(Default, Clone, Debug, PartialEq)]
  62pub struct ZedDotDevSettings {
  63    pub available_models: Vec<AvailableModel>,
  64}
  65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66#[serde(tag = "type", rename_all = "lowercase")]
  67pub enum ModelMode {
  68    #[default]
  69    Default,
  70    Thinking {
  71        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  72        budget_tokens: Option<u32>,
  73    },
  74}
  75
  76impl From<ModelMode> for AnthropicModelMode {
  77    fn from(value: ModelMode) -> Self {
  78        match value {
  79            ModelMode::Default => AnthropicModelMode::Default,
  80            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  81        }
  82    }
  83}
  84
  85pub struct CloudLanguageModelProvider {
  86    client: Arc<Client>,
  87    state: Entity<State>,
  88    _maintain_client_status: Task<()>,
  89}
  90
  91pub struct State {
  92    client: Arc<Client>,
  93    llm_api_token: LlmApiToken,
  94    user_store: Entity<UserStore>,
  95    status: client::Status,
  96    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  97    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  98    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  99    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 100    _fetch_models_task: Task<()>,
 101    _settings_subscription: Subscription,
 102    _llm_token_subscription: Subscription,
 103}
 104
 105impl State {
 106    fn new(
 107        client: Arc<Client>,
 108        user_store: Entity<UserStore>,
 109        status: client::Status,
 110        cx: &mut Context<Self>,
 111    ) -> Self {
 112        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 113        let mut current_user = user_store.read(cx).watch_current_user();
 114        Self {
 115            client: client.clone(),
 116            llm_api_token: LlmApiToken::default(),
 117            user_store,
 118            status,
 119            models: Vec::new(),
 120            default_model: None,
 121            default_fast_model: None,
 122            recommended_models: Vec::new(),
 123            _fetch_models_task: cx.spawn(async move |this, cx| {
 124                maybe!(async move {
 125                    let (client, llm_api_token, organization_id) =
 126                        this.read_with(cx, |this, cx| {
 127                            (
 128                                client.clone(),
 129                                this.llm_api_token.clone(),
 130                                this.user_store
 131                                    .read(cx)
 132                                    .current_organization()
 133                                    .map(|o| o.id.clone()),
 134                            )
 135                        })?;
 136
 137                    while current_user.borrow().is_none() {
 138                        current_user.next().await;
 139                    }
 140
 141                    let response =
 142                        Self::fetch_models(client.clone(), llm_api_token.clone(), organization_id)
 143                            .await?;
 144                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 145                    anyhow::Ok(())
 146                })
 147                .await
 148                .context("failed to fetch Zed models")
 149                .log_err();
 150            }),
 151            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 152                cx.notify();
 153            }),
 154            _llm_token_subscription: cx.subscribe(
 155                &refresh_llm_token_listener,
 156                move |this, _listener, _event, cx| {
 157                    let client = this.client.clone();
 158                    let llm_api_token = this.llm_api_token.clone();
 159                    let organization_id = this
 160                        .user_store
 161                        .read(cx)
 162                        .current_organization()
 163                        .map(|o| o.id.clone());
 164                    cx.spawn(async move |this, cx| {
 165                        llm_api_token
 166                            .refresh(&client, organization_id.clone())
 167                            .await?;
 168                        let response =
 169                            Self::fetch_models(client, llm_api_token, organization_id).await?;
 170                        this.update(cx, |this, cx| {
 171                            this.update_models(response, cx);
 172                        })
 173                    })
 174                    .detach_and_log_err(cx);
 175                },
 176            ),
 177        }
 178    }
 179
 180    fn is_signed_out(&self, cx: &App) -> bool {
 181        self.user_store.read(cx).current_user().is_none()
 182    }
 183
 184    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 185        let client = self.client.clone();
 186        cx.spawn(async move |state, cx| {
 187            client.sign_in_with_optional_connect(true, cx).await?;
 188            state.update(cx, |_, cx| cx.notify())
 189        })
 190    }
 191
 192    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 193        let mut models = Vec::new();
 194
 195        for model in response.models {
 196            models.push(Arc::new(model.clone()));
 197        }
 198
 199        self.default_model = models
 200            .iter()
 201            .find(|model| {
 202                response
 203                    .default_model
 204                    .as_ref()
 205                    .is_some_and(|default_model_id| &model.id == default_model_id)
 206            })
 207            .cloned();
 208        self.default_fast_model = models
 209            .iter()
 210            .find(|model| {
 211                response
 212                    .default_fast_model
 213                    .as_ref()
 214                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 215            })
 216            .cloned();
 217        self.recommended_models = response
 218            .recommended_models
 219            .iter()
 220            .filter_map(|id| models.iter().find(|model| &model.id == id))
 221            .cloned()
 222            .collect();
 223        self.models = models;
 224        cx.notify();
 225    }
 226
 227    async fn fetch_models(
 228        client: Arc<Client>,
 229        llm_api_token: LlmApiToken,
 230        organization_id: Option<OrganizationId>,
 231    ) -> Result<ListModelsResponse> {
 232        let http_client = &client.http_client();
 233        let token = llm_api_token.acquire(&client, organization_id).await?;
 234
 235        let request = http_client::Request::builder()
 236            .method(Method::GET)
 237            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 238            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 239            .header("Authorization", format!("Bearer {token}"))
 240            .body(AsyncBody::empty())?;
 241        let mut response = http_client
 242            .send(request)
 243            .await
 244            .context("failed to send list models request")?;
 245
 246        if response.status().is_success() {
 247            let mut body = String::new();
 248            response.body_mut().read_to_string(&mut body).await?;
 249            Ok(serde_json::from_str(&body)?)
 250        } else {
 251            let mut body = String::new();
 252            response.body_mut().read_to_string(&mut body).await?;
 253            anyhow::bail!(
 254                "error listing models.\nStatus: {:?}\nBody: {body}",
 255                response.status(),
 256            );
 257        }
 258    }
 259}
 260
 261impl CloudLanguageModelProvider {
 262    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 263        let mut status_rx = client.status();
 264        let status = *status_rx.borrow();
 265
 266        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 267
 268        let state_ref = state.downgrade();
 269        let maintain_client_status = cx.spawn(async move |cx| {
 270            while let Some(status) = status_rx.next().await {
 271                if let Some(this) = state_ref.upgrade() {
 272                    _ = this.update(cx, |this, cx| {
 273                        if this.status != status {
 274                            this.status = status;
 275                            cx.notify();
 276                        }
 277                    });
 278                } else {
 279                    break;
 280                }
 281            }
 282        });
 283
 284        Self {
 285            client,
 286            state,
 287            _maintain_client_status: maintain_client_status,
 288        }
 289    }
 290
 291    fn create_language_model(
 292        &self,
 293        model: Arc<cloud_llm_client::LanguageModel>,
 294        llm_api_token: LlmApiToken,
 295        user_store: Entity<UserStore>,
 296    ) -> Arc<dyn LanguageModel> {
 297        Arc::new(CloudLanguageModel {
 298            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 299            model,
 300            llm_api_token,
 301            user_store,
 302            client: self.client.clone(),
 303            request_limiter: RateLimiter::new(4),
 304        })
 305    }
 306}
 307
 308impl LanguageModelProviderState for CloudLanguageModelProvider {
 309    type ObservableEntity = State;
 310
 311    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 312        Some(self.state.clone())
 313    }
 314}
 315
 316impl LanguageModelProvider for CloudLanguageModelProvider {
 317    fn id(&self) -> LanguageModelProviderId {
 318        PROVIDER_ID
 319    }
 320
 321    fn name(&self) -> LanguageModelProviderName {
 322        PROVIDER_NAME
 323    }
 324
 325    fn icon(&self) -> IconOrSvg {
 326        IconOrSvg::Icon(IconName::AiZed)
 327    }
 328
 329    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 330        let state = self.state.read(cx);
 331        let default_model = state.default_model.clone()?;
 332        let llm_api_token = state.llm_api_token.clone();
 333        let user_store = state.user_store.clone();
 334        Some(self.create_language_model(default_model, llm_api_token, user_store))
 335    }
 336
 337    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 338        let state = self.state.read(cx);
 339        let default_fast_model = state.default_fast_model.clone()?;
 340        let llm_api_token = state.llm_api_token.clone();
 341        let user_store = state.user_store.clone();
 342        Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
 343    }
 344
 345    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 346        let state = self.state.read(cx);
 347        let llm_api_token = state.llm_api_token.clone();
 348        let user_store = state.user_store.clone();
 349        state
 350            .recommended_models
 351            .iter()
 352            .cloned()
 353            .map(|model| {
 354                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 355            })
 356            .collect()
 357    }
 358
 359    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 360        let state = self.state.read(cx);
 361        let llm_api_token = state.llm_api_token.clone();
 362        let user_store = state.user_store.clone();
 363        state
 364            .models
 365            .iter()
 366            .cloned()
 367            .map(|model| {
 368                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 369            })
 370            .collect()
 371    }
 372
 373    fn is_authenticated(&self, cx: &App) -> bool {
 374        let state = self.state.read(cx);
 375        !state.is_signed_out(cx)
 376    }
 377
 378    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 379        Task::ready(Ok(()))
 380    }
 381
 382    fn configuration_view(
 383        &self,
 384        _target_agent: language_model::ConfigurationViewTargetAgent,
 385        _: &mut Window,
 386        cx: &mut App,
 387    ) -> AnyView {
 388        cx.new(|_| ConfigurationView::new(self.state.clone()))
 389            .into()
 390    }
 391
 392    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 393        Task::ready(Ok(()))
 394    }
 395}
 396
 397pub struct CloudLanguageModel {
 398    id: LanguageModelId,
 399    model: Arc<cloud_llm_client::LanguageModel>,
 400    llm_api_token: LlmApiToken,
 401    user_store: Entity<UserStore>,
 402    client: Arc<Client>,
 403    request_limiter: RateLimiter,
 404}
 405
 406struct PerformLlmCompletionResponse {
 407    response: Response<AsyncBody>,
 408    includes_status_messages: bool,
 409}
 410
 411impl CloudLanguageModel {
 412    async fn perform_llm_completion(
 413        client: Arc<Client>,
 414        llm_api_token: LlmApiToken,
 415        organization_id: Option<OrganizationId>,
 416        app_version: Option<Version>,
 417        body: CompletionBody,
 418    ) -> Result<PerformLlmCompletionResponse> {
 419        let http_client = &client.http_client();
 420
 421        let mut token = llm_api_token
 422            .acquire(&client, organization_id.clone())
 423            .await?;
 424        let mut refreshed_token = false;
 425
 426        loop {
 427            let request = http_client::Request::builder()
 428                .method(Method::POST)
 429                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 430                .when_some(app_version.as_ref(), |builder, app_version| {
 431                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 432                })
 433                .header("Content-Type", "application/json")
 434                .header("Authorization", format!("Bearer {token}"))
 435                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 436                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 437                .body(serde_json::to_string(&body)?.into())?;
 438
 439            let mut response = http_client.send(request).await?;
 440            let status = response.status();
 441            if status.is_success() {
 442                let includes_status_messages = response
 443                    .headers()
 444                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 445                    .is_some();
 446
 447                return Ok(PerformLlmCompletionResponse {
 448                    response,
 449                    includes_status_messages,
 450                });
 451            }
 452
 453            if !refreshed_token && response.needs_llm_token_refresh() {
 454                token = llm_api_token
 455                    .refresh(&client, organization_id.clone())
 456                    .await?;
 457                refreshed_token = true;
 458                continue;
 459            }
 460
 461            if status == StatusCode::PAYMENT_REQUIRED {
 462                return Err(anyhow!(PaymentRequiredError));
 463            }
 464
 465            let mut body = String::new();
 466            let headers = response.headers().clone();
 467            response.body_mut().read_to_string(&mut body).await?;
 468            return Err(anyhow!(ApiError {
 469                status,
 470                body,
 471                headers
 472            }));
 473        }
 474    }
 475}
 476
 477#[derive(Debug, Error)]
 478#[error("cloud language model request failed with status {status}: {body}")]
 479struct ApiError {
 480    status: StatusCode,
 481    body: String,
 482    headers: HeaderMap<HeaderValue>,
 483}
 484
 485/// Represents error responses from Zed's cloud API.
 486///
 487/// Example JSON for an upstream HTTP error:
 488/// ```json
 489/// {
 490///   "code": "upstream_http_error",
 491///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 492///   "upstream_status": 503
 493/// }
 494/// ```
 495#[derive(Debug, serde::Deserialize)]
 496struct CloudApiError {
 497    code: String,
 498    message: String,
 499    #[serde(default)]
 500    #[serde(deserialize_with = "deserialize_optional_status_code")]
 501    upstream_status: Option<StatusCode>,
 502    #[serde(default)]
 503    retry_after: Option<f64>,
 504}
 505
 506fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 507where
 508    D: serde::Deserializer<'de>,
 509{
 510    let opt: Option<u16> = Option::deserialize(deserializer)?;
 511    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 512}
 513
 514impl From<ApiError> for LanguageModelCompletionError {
 515    fn from(error: ApiError) -> Self {
 516        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 517            if cloud_error.code.starts_with("upstream_http_") {
 518                let status = if let Some(status) = cloud_error.upstream_status {
 519                    status
 520                } else if cloud_error.code.ends_with("_error") {
 521                    error.status
 522                } else {
 523                    // If there's a status code in the code string (e.g. "upstream_http_429")
 524                    // then use that; otherwise, see if the JSON contains a status code.
 525                    cloud_error
 526                        .code
 527                        .strip_prefix("upstream_http_")
 528                        .and_then(|code_str| code_str.parse::<u16>().ok())
 529                        .and_then(|code| StatusCode::from_u16(code).ok())
 530                        .unwrap_or(error.status)
 531                };
 532
 533                return LanguageModelCompletionError::UpstreamProviderError {
 534                    message: cloud_error.message,
 535                    status,
 536                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 537                };
 538            }
 539
 540            return LanguageModelCompletionError::from_http_status(
 541                PROVIDER_NAME,
 542                error.status,
 543                cloud_error.message,
 544                None,
 545            );
 546        }
 547
 548        let retry_after = None;
 549        LanguageModelCompletionError::from_http_status(
 550            PROVIDER_NAME,
 551            error.status,
 552            error.body,
 553            retry_after,
 554        )
 555    }
 556}
 557
 558impl LanguageModel for CloudLanguageModel {
 559    fn id(&self) -> LanguageModelId {
 560        self.id.clone()
 561    }
 562
 563    fn name(&self) -> LanguageModelName {
 564        LanguageModelName::from(self.model.display_name.clone())
 565    }
 566
 567    fn provider_id(&self) -> LanguageModelProviderId {
 568        PROVIDER_ID
 569    }
 570
 571    fn provider_name(&self) -> LanguageModelProviderName {
 572        PROVIDER_NAME
 573    }
 574
 575    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 576        use cloud_llm_client::LanguageModelProvider::*;
 577        match self.model.provider {
 578            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 579            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 580            Google => language_model::GOOGLE_PROVIDER_ID,
 581            XAi => language_model::X_AI_PROVIDER_ID,
 582        }
 583    }
 584
 585    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 586        use cloud_llm_client::LanguageModelProvider::*;
 587        match self.model.provider {
 588            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 589            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 590            Google => language_model::GOOGLE_PROVIDER_NAME,
 591            XAi => language_model::X_AI_PROVIDER_NAME,
 592        }
 593    }
 594
 595    fn is_latest(&self) -> bool {
 596        self.model.is_latest
 597    }
 598
 599    fn supports_tools(&self) -> bool {
 600        self.model.supports_tools
 601    }
 602
 603    fn supports_images(&self) -> bool {
 604        self.model.supports_images
 605    }
 606
 607    fn supports_thinking(&self) -> bool {
 608        self.model.supports_thinking
 609    }
 610
 611    fn supports_fast_mode(&self) -> bool {
 612        self.model.supports_fast_mode
 613    }
 614
 615    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 616        self.model
 617            .supported_effort_levels
 618            .iter()
 619            .map(|effort_level| LanguageModelEffortLevel {
 620                name: effort_level.name.clone().into(),
 621                value: effort_level.value.clone().into(),
 622                is_default: effort_level.is_default.unwrap_or(false),
 623            })
 624            .collect()
 625    }
 626
 627    fn supports_streaming_tools(&self) -> bool {
 628        self.model.supports_streaming_tools
 629    }
 630
 631    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 632        match choice {
 633            LanguageModelToolChoice::Auto
 634            | LanguageModelToolChoice::Any
 635            | LanguageModelToolChoice::None => true,
 636        }
 637    }
 638
 639    fn supports_split_token_display(&self) -> bool {
 640        use cloud_llm_client::LanguageModelProvider::*;
 641        matches!(self.model.provider, OpenAi)
 642    }
 643
 644    fn telemetry_id(&self) -> String {
 645        format!("zed.dev/{}", self.model.id)
 646    }
 647
 648    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 649        match self.model.provider {
 650            cloud_llm_client::LanguageModelProvider::Anthropic
 651            | cloud_llm_client::LanguageModelProvider::OpenAi
 652            | cloud_llm_client::LanguageModelProvider::XAi => {
 653                LanguageModelToolSchemaFormat::JsonSchema
 654            }
 655            cloud_llm_client::LanguageModelProvider::Google => {
 656                LanguageModelToolSchemaFormat::JsonSchemaSubset
 657            }
 658        }
 659    }
 660
 661    fn max_token_count(&self) -> u64 {
 662        self.model.max_token_count as u64
 663    }
 664
 665    fn max_output_tokens(&self) -> Option<u64> {
 666        Some(self.model.max_output_tokens as u64)
 667    }
 668
 669    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 670        match &self.model.provider {
 671            cloud_llm_client::LanguageModelProvider::Anthropic => {
 672                Some(LanguageModelCacheConfiguration {
 673                    min_total_token: 2_048,
 674                    should_speculate: true,
 675                    max_cache_anchors: 4,
 676                })
 677            }
 678            cloud_llm_client::LanguageModelProvider::OpenAi
 679            | cloud_llm_client::LanguageModelProvider::XAi
 680            | cloud_llm_client::LanguageModelProvider::Google => None,
 681        }
 682    }
 683
 684    fn count_tokens(
 685        &self,
 686        request: LanguageModelRequest,
 687        cx: &App,
 688    ) -> BoxFuture<'static, Result<u64>> {
 689        match self.model.provider {
 690            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 691                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 692                .boxed(),
 693            cloud_llm_client::LanguageModelProvider::OpenAi => {
 694                let model = match open_ai::Model::from_id(&self.model.id.0) {
 695                    Ok(model) => model,
 696                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 697                };
 698                count_open_ai_tokens(request, model, cx)
 699            }
 700            cloud_llm_client::LanguageModelProvider::XAi => {
 701                let model = match x_ai::Model::from_id(&self.model.id.0) {
 702                    Ok(model) => model,
 703                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 704                };
 705                count_xai_tokens(request, model, cx)
 706            }
 707            cloud_llm_client::LanguageModelProvider::Google => {
 708                let client = self.client.clone();
 709                let llm_api_token = self.llm_api_token.clone();
 710                let organization_id = self
 711                    .user_store
 712                    .read(cx)
 713                    .current_organization()
 714                    .map(|o| o.id.clone());
 715                let model_id = self.model.id.to_string();
 716                let generate_content_request =
 717                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 718                async move {
 719                    let http_client = &client.http_client();
 720                    let token = llm_api_token.acquire(&client, organization_id).await?;
 721
 722                    let request_body = CountTokensBody {
 723                        provider: cloud_llm_client::LanguageModelProvider::Google,
 724                        model: model_id,
 725                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 726                            generate_content_request,
 727                        })?,
 728                    };
 729                    let request = http_client::Request::builder()
 730                        .method(Method::POST)
 731                        .uri(
 732                            http_client
 733                                .build_zed_llm_url("/count_tokens", &[])?
 734                                .as_ref(),
 735                        )
 736                        .header("Content-Type", "application/json")
 737                        .header("Authorization", format!("Bearer {token}"))
 738                        .body(serde_json::to_string(&request_body)?.into())?;
 739                    let mut response = http_client.send(request).await?;
 740                    let status = response.status();
 741                    let headers = response.headers().clone();
 742                    let mut response_body = String::new();
 743                    response
 744                        .body_mut()
 745                        .read_to_string(&mut response_body)
 746                        .await?;
 747
 748                    if status.is_success() {
 749                        let response_body: CountTokensResponse =
 750                            serde_json::from_str(&response_body)?;
 751
 752                        Ok(response_body.tokens as u64)
 753                    } else {
 754                        Err(anyhow!(ApiError {
 755                            status,
 756                            body: response_body,
 757                            headers
 758                        }))
 759                    }
 760                }
 761                .boxed()
 762            }
 763        }
 764    }
 765
 766    fn stream_completion(
 767        &self,
 768        request: LanguageModelRequest,
 769        cx: &AsyncApp,
 770    ) -> BoxFuture<
 771        'static,
 772        Result<
 773            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 774            LanguageModelCompletionError,
 775        >,
 776    > {
 777        let thread_id = request.thread_id.clone();
 778        let prompt_id = request.prompt_id.clone();
 779        let intent = request.intent;
 780        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 781        let user_store = self.user_store.clone();
 782        let organization_id = cx.update(|cx| {
 783            user_store
 784                .read(cx)
 785                .current_organization()
 786                .map(|o| o.id.clone())
 787        });
 788        let thinking_allowed = request.thinking_allowed;
 789        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 790        let provider_name = provider_name(&self.model.provider);
 791        match self.model.provider {
 792            cloud_llm_client::LanguageModelProvider::Anthropic => {
 793                let effort = request
 794                    .thinking_effort
 795                    .as_ref()
 796                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 797
 798                let mut request = into_anthropic(
 799                    request,
 800                    self.model.id.to_string(),
 801                    1.0,
 802                    self.model.max_output_tokens as u64,
 803                    if enable_thinking {
 804                        AnthropicModelMode::Thinking {
 805                            budget_tokens: Some(4_096),
 806                        }
 807                    } else {
 808                        AnthropicModelMode::Default
 809                    },
 810                );
 811
 812                if enable_thinking && effort.is_some() {
 813                    request.thinking = Some(anthropic::Thinking::Adaptive);
 814                    request.output_config = Some(anthropic::OutputConfig { effort });
 815                }
 816
 817                let client = self.client.clone();
 818                let llm_api_token = self.llm_api_token.clone();
 819                let organization_id = organization_id.clone();
 820                let future = self.request_limiter.stream(async move {
 821                    let PerformLlmCompletionResponse {
 822                        response,
 823                        includes_status_messages,
 824                    } = Self::perform_llm_completion(
 825                        client.clone(),
 826                        llm_api_token,
 827                        organization_id,
 828                        app_version,
 829                        CompletionBody {
 830                            thread_id,
 831                            prompt_id,
 832                            intent,
 833                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 834                            model: request.model.clone(),
 835                            provider_request: serde_json::to_value(&request)
 836                                .map_err(|e| anyhow!(e))?,
 837                        },
 838                    )
 839                    .await
 840                    .map_err(|err| match err.downcast::<ApiError>() {
 841                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 842                        Err(err) => anyhow!(err),
 843                    })?;
 844
 845                    let mut mapper = AnthropicEventMapper::new();
 846                    Ok(map_cloud_completion_events(
 847                        Box::pin(response_lines(response, includes_status_messages)),
 848                        &provider_name,
 849                        move |event| mapper.map_event(event),
 850                    ))
 851                });
 852                async move { Ok(future.await?.boxed()) }.boxed()
 853            }
 854            cloud_llm_client::LanguageModelProvider::OpenAi => {
 855                let client = self.client.clone();
 856                let llm_api_token = self.llm_api_token.clone();
 857                let organization_id = organization_id.clone();
 858                let effort = request
 859                    .thinking_effort
 860                    .as_ref()
 861                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 862
 863                let mut request = into_open_ai_response(
 864                    request,
 865                    &self.model.id.0,
 866                    self.model.supports_parallel_tool_calls,
 867                    true,
 868                    None,
 869                    None,
 870                );
 871
 872                if enable_thinking && let Some(effort) = effort {
 873                    request.reasoning = Some(open_ai::responses::ReasoningConfig { effort });
 874                }
 875
 876                let future = self.request_limiter.stream(async move {
 877                    let PerformLlmCompletionResponse {
 878                        response,
 879                        includes_status_messages,
 880                    } = Self::perform_llm_completion(
 881                        client.clone(),
 882                        llm_api_token,
 883                        organization_id,
 884                        app_version,
 885                        CompletionBody {
 886                            thread_id,
 887                            prompt_id,
 888                            intent,
 889                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 890                            model: request.model.clone(),
 891                            provider_request: serde_json::to_value(&request)
 892                                .map_err(|e| anyhow!(e))?,
 893                        },
 894                    )
 895                    .await?;
 896
 897                    let mut mapper = OpenAiResponseEventMapper::new();
 898                    Ok(map_cloud_completion_events(
 899                        Box::pin(response_lines(response, includes_status_messages)),
 900                        &provider_name,
 901                        move |event| mapper.map_event(event),
 902                    ))
 903                });
 904                async move { Ok(future.await?.boxed()) }.boxed()
 905            }
 906            cloud_llm_client::LanguageModelProvider::XAi => {
 907                let client = self.client.clone();
 908                let request = into_open_ai(
 909                    request,
 910                    &self.model.id.0,
 911                    self.model.supports_parallel_tool_calls,
 912                    false,
 913                    None,
 914                    None,
 915                );
 916                let llm_api_token = self.llm_api_token.clone();
 917                let organization_id = organization_id.clone();
 918                let future = self.request_limiter.stream(async move {
 919                    let PerformLlmCompletionResponse {
 920                        response,
 921                        includes_status_messages,
 922                    } = Self::perform_llm_completion(
 923                        client.clone(),
 924                        llm_api_token,
 925                        organization_id,
 926                        app_version,
 927                        CompletionBody {
 928                            thread_id,
 929                            prompt_id,
 930                            intent,
 931                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 932                            model: request.model.clone(),
 933                            provider_request: serde_json::to_value(&request)
 934                                .map_err(|e| anyhow!(e))?,
 935                        },
 936                    )
 937                    .await?;
 938
 939                    let mut mapper = OpenAiEventMapper::new();
 940                    Ok(map_cloud_completion_events(
 941                        Box::pin(response_lines(response, includes_status_messages)),
 942                        &provider_name,
 943                        move |event| mapper.map_event(event),
 944                    ))
 945                });
 946                async move { Ok(future.await?.boxed()) }.boxed()
 947            }
 948            cloud_llm_client::LanguageModelProvider::Google => {
 949                let client = self.client.clone();
 950                let request =
 951                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 952                let llm_api_token = self.llm_api_token.clone();
 953                let future = self.request_limiter.stream(async move {
 954                    let PerformLlmCompletionResponse {
 955                        response,
 956                        includes_status_messages,
 957                    } = Self::perform_llm_completion(
 958                        client.clone(),
 959                        llm_api_token,
 960                        organization_id,
 961                        app_version,
 962                        CompletionBody {
 963                            thread_id,
 964                            prompt_id,
 965                            intent,
 966                            provider: cloud_llm_client::LanguageModelProvider::Google,
 967                            model: request.model.model_id.clone(),
 968                            provider_request: serde_json::to_value(&request)
 969                                .map_err(|e| anyhow!(e))?,
 970                        },
 971                    )
 972                    .await?;
 973
 974                    let mut mapper = GoogleEventMapper::new();
 975                    Ok(map_cloud_completion_events(
 976                        Box::pin(response_lines(response, includes_status_messages)),
 977                        &provider_name,
 978                        move |event| mapper.map_event(event),
 979                    ))
 980                });
 981                async move { Ok(future.await?.boxed()) }.boxed()
 982            }
 983        }
 984    }
 985}
 986
 987fn map_cloud_completion_events<T, F>(
 988    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 989    provider: &LanguageModelProviderName,
 990    mut map_callback: F,
 991) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 992where
 993    T: DeserializeOwned + 'static,
 994    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 995        + Send
 996        + 'static,
 997{
 998    let provider = provider.clone();
 999    let mut stream = stream.fuse();
1000
1001    let mut saw_stream_ended = false;
1002
1003    let mut done = false;
1004    let mut pending = VecDeque::new();
1005
1006    stream::poll_fn(move |cx| {
1007        loop {
1008            if let Some(item) = pending.pop_front() {
1009                return Poll::Ready(Some(item));
1010            }
1011
1012            if done {
1013                return Poll::Ready(None);
1014            }
1015
1016            match stream.poll_next_unpin(cx) {
1017                Poll::Ready(Some(event)) => {
1018                    let items = match event {
1019                        Err(error) => {
1020                            vec![Err(LanguageModelCompletionError::from(error))]
1021                        }
1022                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
1023                            saw_stream_ended = true;
1024                            vec![]
1025                        }
1026                        Ok(CompletionEvent::Status(status)) => {
1027                            LanguageModelCompletionEvent::from_completion_request_status(
1028                                status,
1029                                provider.clone(),
1030                            )
1031                            .transpose()
1032                            .map(|event| vec![event])
1033                            .unwrap_or_default()
1034                        }
1035                        Ok(CompletionEvent::Event(event)) => map_callback(event),
1036                    };
1037                    pending.extend(items);
1038                }
1039                Poll::Ready(None) => {
1040                    done = true;
1041
1042                    if !saw_stream_ended {
1043                        return Poll::Ready(Some(Err(
1044                            LanguageModelCompletionError::StreamEndedUnexpectedly {
1045                                provider: provider.clone(),
1046                            },
1047                        )));
1048                    }
1049                }
1050                Poll::Pending => return Poll::Pending,
1051            }
1052        }
1053    })
1054    .boxed()
1055}
1056
1057fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1058    match provider {
1059        cloud_llm_client::LanguageModelProvider::Anthropic => {
1060            language_model::ANTHROPIC_PROVIDER_NAME
1061        }
1062        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1063        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1064        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1065    }
1066}
1067
1068fn response_lines<T: DeserializeOwned>(
1069    response: Response<AsyncBody>,
1070    includes_status_messages: bool,
1071) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1072    futures::stream::try_unfold(
1073        (String::new(), BufReader::new(response.into_body())),
1074        move |(mut line, mut body)| async move {
1075            match body.read_line(&mut line).await {
1076                Ok(0) => Ok(None),
1077                Ok(_) => {
1078                    let event = if includes_status_messages {
1079                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1080                    } else {
1081                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1082                    };
1083
1084                    line.clear();
1085                    Ok(Some((event, (line, body))))
1086                }
1087                Err(e) => Err(e.into()),
1088            }
1089        },
1090    )
1091}
1092
1093#[derive(IntoElement, RegisterComponent)]
1094struct ZedAiConfiguration {
1095    is_connected: bool,
1096    plan: Option<Plan>,
1097    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1098    eligible_for_trial: bool,
1099    account_too_young: bool,
1100    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1101}
1102
1103impl RenderOnce for ZedAiConfiguration {
1104    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1105        let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
1106        let subscription_text = match (self.plan, self.subscription_period) {
1107            (Some(Plan::ZedPro), Some(_)) => {
1108                "You have access to Zed's hosted models through your Pro subscription."
1109            }
1110            (Some(Plan::ZedProTrial), Some(_)) => {
1111                "You have access to Zed's hosted models through your Pro trial."
1112            }
1113            (Some(Plan::ZedFree), Some(_)) => {
1114                if self.eligible_for_trial {
1115                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1116                } else {
1117                    "Subscribe for access to Zed's hosted models."
1118                }
1119            }
1120            _ => {
1121                if self.eligible_for_trial {
1122                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1123                } else {
1124                    "Subscribe for access to Zed's hosted models."
1125                }
1126            }
1127        };
1128
1129        let manage_subscription_buttons = if is_pro {
1130            Button::new("manage_settings", "Manage Subscription")
1131                .full_width()
1132                .style(ButtonStyle::Tinted(TintColor::Accent))
1133                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1134                .into_any_element()
1135        } else if self.plan.is_none() || self.eligible_for_trial {
1136            Button::new("start_trial", "Start 14-day Free Pro Trial")
1137                .full_width()
1138                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1139                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1140                .into_any_element()
1141        } else {
1142            Button::new("upgrade", "Upgrade to Pro")
1143                .full_width()
1144                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1145                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1146                .into_any_element()
1147        };
1148
1149        if !self.is_connected {
1150            return v_flex()
1151                .gap_2()
1152                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1153                .child(
1154                    Button::new("sign_in", "Sign In to use Zed AI")
1155                        .icon_color(Color::Muted)
1156                        .icon(IconName::Github)
1157                        .icon_size(IconSize::Small)
1158                        .icon_position(IconPosition::Start)
1159                        .full_width()
1160                        .on_click({
1161                            let callback = self.sign_in_callback.clone();
1162                            move |_, window, cx| (callback)(window, cx)
1163                        }),
1164                );
1165        }
1166
1167        v_flex().gap_2().w_full().map(|this| {
1168            if self.account_too_young {
1169                this.child(YoungAccountBanner).child(
1170                    Button::new("upgrade", "Upgrade to Pro")
1171                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1172                        .full_width()
1173                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1174                )
1175            } else {
1176                this.text_sm()
1177                    .child(subscription_text)
1178                    .child(manage_subscription_buttons)
1179            }
1180        })
1181    }
1182}
1183
1184struct ConfigurationView {
1185    state: Entity<State>,
1186    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1187}
1188
1189impl ConfigurationView {
1190    fn new(state: Entity<State>) -> Self {
1191        let sign_in_callback = Arc::new({
1192            let state = state.clone();
1193            move |_window: &mut Window, cx: &mut App| {
1194                state.update(cx, |state, cx| {
1195                    state.authenticate(cx).detach_and_log_err(cx);
1196                });
1197            }
1198        });
1199
1200        Self {
1201            state,
1202            sign_in_callback,
1203        }
1204    }
1205}
1206
1207impl Render for ConfigurationView {
1208    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1209        let state = self.state.read(cx);
1210        let user_store = state.user_store.read(cx);
1211
1212        ZedAiConfiguration {
1213            is_connected: !state.is_signed_out(cx),
1214            plan: user_store.plan(),
1215            subscription_period: user_store.subscription_period(),
1216            eligible_for_trial: user_store.trial_started_at().is_none(),
1217            account_too_young: user_store.account_too_young(),
1218            sign_in_callback: self.sign_in_callback.clone(),
1219        }
1220    }
1221}
1222
1223impl Component for ZedAiConfiguration {
1224    fn name() -> &'static str {
1225        "AI Configuration Content"
1226    }
1227
1228    fn sort_name() -> &'static str {
1229        "AI Configuration Content"
1230    }
1231
1232    fn scope() -> ComponentScope {
1233        ComponentScope::Onboarding
1234    }
1235
1236    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1237        fn configuration(
1238            is_connected: bool,
1239            plan: Option<Plan>,
1240            eligible_for_trial: bool,
1241            account_too_young: bool,
1242        ) -> AnyElement {
1243            ZedAiConfiguration {
1244                is_connected,
1245                plan,
1246                subscription_period: plan
1247                    .is_some()
1248                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1249                eligible_for_trial,
1250                account_too_young,
1251                sign_in_callback: Arc::new(|_, _| {}),
1252            }
1253            .into_any_element()
1254        }
1255
1256        Some(
1257            v_flex()
1258                .p_4()
1259                .gap_4()
1260                .children(vec![
1261                    single_example("Not connected", configuration(false, None, false, false)),
1262                    single_example(
1263                        "Accept Terms of Service",
1264                        configuration(true, None, true, false),
1265                    ),
1266                    single_example(
1267                        "No Plan - Not eligible for trial",
1268                        configuration(true, None, false, false),
1269                    ),
1270                    single_example(
1271                        "No Plan - Eligible for trial",
1272                        configuration(true, None, true, false),
1273                    ),
1274                    single_example(
1275                        "Free Plan",
1276                        configuration(true, Some(Plan::ZedFree), true, false),
1277                    ),
1278                    single_example(
1279                        "Zed Pro Trial Plan",
1280                        configuration(true, Some(Plan::ZedProTrial), true, false),
1281                    ),
1282                    single_example(
1283                        "Zed Pro Plan",
1284                        configuration(true, Some(Plan::ZedPro), true, false),
1285                    ),
1286                ])
1287                .into_any_element(),
1288        )
1289    }
1290}
1291
1292#[cfg(test)]
1293mod tests {
1294    use super::*;
1295    use http_client::http::{HeaderMap, StatusCode};
1296    use language_model::LanguageModelCompletionError;
1297
1298    #[test]
1299    fn test_api_error_conversion_with_upstream_http_error() {
1300        // upstream_http_error with 503 status should become ServerOverloaded
1301        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1302
1303        let api_error = ApiError {
1304            status: StatusCode::INTERNAL_SERVER_ERROR,
1305            body: error_body.to_string(),
1306            headers: HeaderMap::new(),
1307        };
1308
1309        let completion_error: LanguageModelCompletionError = api_error.into();
1310
1311        match completion_error {
1312            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1313                assert_eq!(
1314                    message,
1315                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1316                );
1317            }
1318            _ => panic!(
1319                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1320                completion_error
1321            ),
1322        }
1323
1324        // upstream_http_error with 500 status should become ApiInternalServerError
1325        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1326
1327        let api_error = ApiError {
1328            status: StatusCode::INTERNAL_SERVER_ERROR,
1329            body: error_body.to_string(),
1330            headers: HeaderMap::new(),
1331        };
1332
1333        let completion_error: LanguageModelCompletionError = api_error.into();
1334
1335        match completion_error {
1336            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1337                assert_eq!(
1338                    message,
1339                    "Received an error from the OpenAI API: internal server error"
1340                );
1341            }
1342            _ => panic!(
1343                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1344                completion_error
1345            ),
1346        }
1347
1348        // upstream_http_error with 429 status should become RateLimitExceeded
1349        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1350
1351        let api_error = ApiError {
1352            status: StatusCode::INTERNAL_SERVER_ERROR,
1353            body: error_body.to_string(),
1354            headers: HeaderMap::new(),
1355        };
1356
1357        let completion_error: LanguageModelCompletionError = api_error.into();
1358
1359        match completion_error {
1360            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1361                assert_eq!(
1362                    message,
1363                    "Received an error from the Google API: rate limit exceeded"
1364                );
1365            }
1366            _ => panic!(
1367                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1368                completion_error
1369            ),
1370        }
1371
1372        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1373        let error_body = "Regular internal server error";
1374
1375        let api_error = ApiError {
1376            status: StatusCode::INTERNAL_SERVER_ERROR,
1377            body: error_body.to_string(),
1378            headers: HeaderMap::new(),
1379        };
1380
1381        let completion_error: LanguageModelCompletionError = api_error.into();
1382
1383        match completion_error {
1384            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1385                assert_eq!(provider, PROVIDER_NAME);
1386                assert_eq!(message, "Regular internal server error");
1387            }
1388            _ => panic!(
1389                "Expected ApiInternalServerError for regular 500, got: {:?}",
1390                completion_error
1391            ),
1392        }
1393
1394        // upstream_http_429 format should be converted to UpstreamProviderError
1395        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1396
1397        let api_error = ApiError {
1398            status: StatusCode::INTERNAL_SERVER_ERROR,
1399            body: error_body.to_string(),
1400            headers: HeaderMap::new(),
1401        };
1402
1403        let completion_error: LanguageModelCompletionError = api_error.into();
1404
1405        match completion_error {
1406            LanguageModelCompletionError::UpstreamProviderError {
1407                message,
1408                status,
1409                retry_after,
1410            } => {
1411                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1412                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1413                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1414            }
1415            _ => panic!(
1416                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1417                completion_error
1418            ),
1419        }
1420
1421        // Invalid JSON in error body should fall back to regular error handling
1422        let error_body = "Not JSON at all";
1423
1424        let api_error = ApiError {
1425            status: StatusCode::INTERNAL_SERVER_ERROR,
1426            body: error_body.to_string(),
1427            headers: HeaderMap::new(),
1428        };
1429
1430        let completion_error: LanguageModelCompletionError = api_error.into();
1431
1432        match completion_error {
1433            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1434                assert_eq!(provider, PROVIDER_NAME);
1435            }
1436            _ => panic!(
1437                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1438                completion_error
1439            ),
1440        }
1441    }
1442}