cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, UserStore, zed_urls};
   6use cloud_api_types::{OrganizationId, Plan};
   7use cloud_llm_client::{
   8    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   9    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  10    CountTokensBody, CountTokensResponse, ListModelsResponse,
  11    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  12};
  13use futures::{
  14    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  15    future::BoxFuture,
  16    stream::{self, BoxStream},
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  25    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  26    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::collections::VecDeque;
  39use std::pin::Pin;
  40use std::str::FromStr;
  41use std::sync::Arc;
  42use std::task::Poll;
  43use std::time::Duration;
  44use thiserror::Error;
  45use ui::{TintColor, prelude::*};
  46
  47use crate::provider::anthropic::{
  48    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  49};
  50use crate::provider::google::{GoogleEventMapper, into_google};
  51use crate::provider::open_ai::{
  52    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  53    into_open_ai_response,
  54};
  55use crate::provider::x_ai::count_xai_tokens;
  56
  57const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  58const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  59
  60#[derive(Default, Clone, Debug, PartialEq)]
  61pub struct ZedDotDevSettings {
  62    pub available_models: Vec<AvailableModel>,
  63}
  64#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  65#[serde(tag = "type", rename_all = "lowercase")]
  66pub enum ModelMode {
  67    #[default]
  68    Default,
  69    Thinking {
  70        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  71        budget_tokens: Option<u32>,
  72    },
  73}
  74
  75impl From<ModelMode> for AnthropicModelMode {
  76    fn from(value: ModelMode) -> Self {
  77        match value {
  78            ModelMode::Default => AnthropicModelMode::Default,
  79            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  80        }
  81    }
  82}
  83
  84pub struct CloudLanguageModelProvider {
  85    client: Arc<Client>,
  86    state: Entity<State>,
  87    _maintain_client_status: Task<()>,
  88}
  89
  90pub struct State {
  91    client: Arc<Client>,
  92    llm_api_token: LlmApiToken,
  93    user_store: Entity<UserStore>,
  94    status: client::Status,
  95    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  96    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  97    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  98    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  99    _user_store_subscription: Subscription,
 100    _settings_subscription: Subscription,
 101    _llm_token_subscription: Subscription,
 102}
 103
 104impl State {
 105    fn new(
 106        client: Arc<Client>,
 107        user_store: Entity<UserStore>,
 108        status: client::Status,
 109        cx: &mut Context<Self>,
 110    ) -> Self {
 111        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 112        Self {
 113            client: client.clone(),
 114            llm_api_token: LlmApiToken::default(),
 115            user_store: user_store.clone(),
 116            status,
 117            models: Vec::new(),
 118            default_model: None,
 119            default_fast_model: None,
 120            recommended_models: Vec::new(),
 121            _user_store_subscription: cx.subscribe(
 122                &user_store,
 123                move |this, _user_store, event, cx| match event {
 124                    client::user::Event::PrivateUserInfoUpdated => {
 125                        let status = *client.status().borrow();
 126                        if status.is_signed_out() {
 127                            return;
 128                        }
 129
 130                        let client = this.client.clone();
 131                        let llm_api_token = this.llm_api_token.clone();
 132                        let organization_id = this
 133                            .user_store
 134                            .read(cx)
 135                            .current_organization()
 136                            .map(|organization| organization.id.clone());
 137                        cx.spawn(async move |this, cx| {
 138                            let response =
 139                                Self::fetch_models(client, llm_api_token, organization_id).await?;
 140                            this.update(cx, |this, cx| this.update_models(response, cx))
 141                        })
 142                        .detach_and_log_err(cx);
 143                    }
 144                    _ => {}
 145                },
 146            ),
 147            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 148                cx.notify();
 149            }),
 150            _llm_token_subscription: cx.subscribe(
 151                &refresh_llm_token_listener,
 152                move |this, _listener, _event, cx| {
 153                    let client = this.client.clone();
 154                    let llm_api_token = this.llm_api_token.clone();
 155                    let organization_id = this
 156                        .user_store
 157                        .read(cx)
 158                        .current_organization()
 159                        .map(|o| o.id.clone());
 160                    cx.spawn(async move |this, cx| {
 161                        llm_api_token
 162                            .refresh(&client, organization_id.clone())
 163                            .await?;
 164                        let response =
 165                            Self::fetch_models(client, llm_api_token, organization_id).await?;
 166                        this.update(cx, |this, cx| {
 167                            this.update_models(response, cx);
 168                        })
 169                    })
 170                    .detach_and_log_err(cx);
 171                },
 172            ),
 173        }
 174    }
 175
 176    fn is_signed_out(&self, cx: &App) -> bool {
 177        self.user_store.read(cx).current_user().is_none()
 178    }
 179
 180    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 181        let client = self.client.clone();
 182        cx.spawn(async move |state, cx| {
 183            client.sign_in_with_optional_connect(true, cx).await?;
 184            state.update(cx, |_, cx| cx.notify())
 185        })
 186    }
 187
 188    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 189        let mut models = Vec::new();
 190
 191        for model in response.models {
 192            models.push(Arc::new(model.clone()));
 193        }
 194
 195        self.default_model = models
 196            .iter()
 197            .find(|model| {
 198                response
 199                    .default_model
 200                    .as_ref()
 201                    .is_some_and(|default_model_id| &model.id == default_model_id)
 202            })
 203            .cloned();
 204        self.default_fast_model = models
 205            .iter()
 206            .find(|model| {
 207                response
 208                    .default_fast_model
 209                    .as_ref()
 210                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 211            })
 212            .cloned();
 213        self.recommended_models = response
 214            .recommended_models
 215            .iter()
 216            .filter_map(|id| models.iter().find(|model| &model.id == id))
 217            .cloned()
 218            .collect();
 219        self.models = models;
 220        cx.notify();
 221    }
 222
 223    async fn fetch_models(
 224        client: Arc<Client>,
 225        llm_api_token: LlmApiToken,
 226        organization_id: Option<OrganizationId>,
 227    ) -> Result<ListModelsResponse> {
 228        let http_client = &client.http_client();
 229        let token = llm_api_token.acquire(&client, organization_id).await?;
 230
 231        let request = http_client::Request::builder()
 232            .method(Method::GET)
 233            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 234            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 235            .header("Authorization", format!("Bearer {token}"))
 236            .body(AsyncBody::empty())?;
 237        let mut response = http_client
 238            .send(request)
 239            .await
 240            .context("failed to send list models request")?;
 241
 242        if response.status().is_success() {
 243            let mut body = String::new();
 244            response.body_mut().read_to_string(&mut body).await?;
 245            Ok(serde_json::from_str(&body)?)
 246        } else {
 247            let mut body = String::new();
 248            response.body_mut().read_to_string(&mut body).await?;
 249            anyhow::bail!(
 250                "error listing models.\nStatus: {:?}\nBody: {body}",
 251                response.status(),
 252            );
 253        }
 254    }
 255}
 256
 257impl CloudLanguageModelProvider {
 258    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 259        let mut status_rx = client.status();
 260        let status = *status_rx.borrow();
 261
 262        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 263
 264        let state_ref = state.downgrade();
 265        let maintain_client_status = cx.spawn(async move |cx| {
 266            while let Some(status) = status_rx.next().await {
 267                if let Some(this) = state_ref.upgrade() {
 268                    _ = this.update(cx, |this, cx| {
 269                        if this.status != status {
 270                            this.status = status;
 271                            cx.notify();
 272                        }
 273                    });
 274                } else {
 275                    break;
 276                }
 277            }
 278        });
 279
 280        Self {
 281            client,
 282            state,
 283            _maintain_client_status: maintain_client_status,
 284        }
 285    }
 286
 287    fn create_language_model(
 288        &self,
 289        model: Arc<cloud_llm_client::LanguageModel>,
 290        llm_api_token: LlmApiToken,
 291        user_store: Entity<UserStore>,
 292    ) -> Arc<dyn LanguageModel> {
 293        Arc::new(CloudLanguageModel {
 294            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 295            model,
 296            llm_api_token,
 297            user_store,
 298            client: self.client.clone(),
 299            request_limiter: RateLimiter::new(4),
 300        })
 301    }
 302}
 303
 304impl LanguageModelProviderState for CloudLanguageModelProvider {
 305    type ObservableEntity = State;
 306
 307    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 308        Some(self.state.clone())
 309    }
 310}
 311
 312impl LanguageModelProvider for CloudLanguageModelProvider {
 313    fn id(&self) -> LanguageModelProviderId {
 314        PROVIDER_ID
 315    }
 316
 317    fn name(&self) -> LanguageModelProviderName {
 318        PROVIDER_NAME
 319    }
 320
 321    fn icon(&self) -> IconOrSvg {
 322        IconOrSvg::Icon(IconName::AiZed)
 323    }
 324
 325    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 326        let state = self.state.read(cx);
 327        let default_model = state.default_model.clone()?;
 328        let llm_api_token = state.llm_api_token.clone();
 329        let user_store = state.user_store.clone();
 330        Some(self.create_language_model(default_model, llm_api_token, user_store))
 331    }
 332
 333    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 334        let state = self.state.read(cx);
 335        let default_fast_model = state.default_fast_model.clone()?;
 336        let llm_api_token = state.llm_api_token.clone();
 337        let user_store = state.user_store.clone();
 338        Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
 339    }
 340
 341    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 342        let state = self.state.read(cx);
 343        let llm_api_token = state.llm_api_token.clone();
 344        let user_store = state.user_store.clone();
 345        state
 346            .recommended_models
 347            .iter()
 348            .cloned()
 349            .map(|model| {
 350                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 351            })
 352            .collect()
 353    }
 354
 355    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 356        let state = self.state.read(cx);
 357        let llm_api_token = state.llm_api_token.clone();
 358        let user_store = state.user_store.clone();
 359        state
 360            .models
 361            .iter()
 362            .cloned()
 363            .map(|model| {
 364                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 365            })
 366            .collect()
 367    }
 368
 369    fn is_authenticated(&self, cx: &App) -> bool {
 370        let state = self.state.read(cx);
 371        !state.is_signed_out(cx)
 372    }
 373
 374    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 375        Task::ready(Ok(()))
 376    }
 377
 378    fn configuration_view(
 379        &self,
 380        _target_agent: language_model::ConfigurationViewTargetAgent,
 381        _: &mut Window,
 382        cx: &mut App,
 383    ) -> AnyView {
 384        cx.new(|_| ConfigurationView::new(self.state.clone()))
 385            .into()
 386    }
 387
 388    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 389        Task::ready(Ok(()))
 390    }
 391}
 392
 393pub struct CloudLanguageModel {
 394    id: LanguageModelId,
 395    model: Arc<cloud_llm_client::LanguageModel>,
 396    llm_api_token: LlmApiToken,
 397    user_store: Entity<UserStore>,
 398    client: Arc<Client>,
 399    request_limiter: RateLimiter,
 400}
 401
 402struct PerformLlmCompletionResponse {
 403    response: Response<AsyncBody>,
 404    includes_status_messages: bool,
 405}
 406
 407impl CloudLanguageModel {
 408    async fn perform_llm_completion(
 409        client: Arc<Client>,
 410        llm_api_token: LlmApiToken,
 411        organization_id: Option<OrganizationId>,
 412        app_version: Option<Version>,
 413        body: CompletionBody,
 414    ) -> Result<PerformLlmCompletionResponse> {
 415        let http_client = &client.http_client();
 416
 417        let mut token = llm_api_token
 418            .acquire(&client, organization_id.clone())
 419            .await?;
 420        let mut refreshed_token = false;
 421
 422        loop {
 423            let request = http_client::Request::builder()
 424                .method(Method::POST)
 425                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 426                .when_some(app_version.as_ref(), |builder, app_version| {
 427                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 428                })
 429                .header("Content-Type", "application/json")
 430                .header("Authorization", format!("Bearer {token}"))
 431                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 432                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 433                .body(serde_json::to_string(&body)?.into())?;
 434
 435            let mut response = http_client.send(request).await?;
 436            let status = response.status();
 437            if status.is_success() {
 438                let includes_status_messages = response
 439                    .headers()
 440                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 441                    .is_some();
 442
 443                return Ok(PerformLlmCompletionResponse {
 444                    response,
 445                    includes_status_messages,
 446                });
 447            }
 448
 449            if !refreshed_token && response.needs_llm_token_refresh() {
 450                token = llm_api_token
 451                    .refresh(&client, organization_id.clone())
 452                    .await?;
 453                refreshed_token = true;
 454                continue;
 455            }
 456
 457            if status == StatusCode::PAYMENT_REQUIRED {
 458                return Err(anyhow!(PaymentRequiredError));
 459            }
 460
 461            let mut body = String::new();
 462            let headers = response.headers().clone();
 463            response.body_mut().read_to_string(&mut body).await?;
 464            return Err(anyhow!(ApiError {
 465                status,
 466                body,
 467                headers
 468            }));
 469        }
 470    }
 471}
 472
 473#[derive(Debug, Error)]
 474#[error("cloud language model request failed with status {status}: {body}")]
 475struct ApiError {
 476    status: StatusCode,
 477    body: String,
 478    headers: HeaderMap<HeaderValue>,
 479}
 480
 481/// Represents error responses from Zed's cloud API.
 482///
 483/// Example JSON for an upstream HTTP error:
 484/// ```json
 485/// {
 486///   "code": "upstream_http_error",
 487///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 488///   "upstream_status": 503
 489/// }
 490/// ```
 491#[derive(Debug, serde::Deserialize)]
 492struct CloudApiError {
 493    code: String,
 494    message: String,
 495    #[serde(default)]
 496    #[serde(deserialize_with = "deserialize_optional_status_code")]
 497    upstream_status: Option<StatusCode>,
 498    #[serde(default)]
 499    retry_after: Option<f64>,
 500}
 501
 502fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 503where
 504    D: serde::Deserializer<'de>,
 505{
 506    let opt: Option<u16> = Option::deserialize(deserializer)?;
 507    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 508}
 509
 510impl From<ApiError> for LanguageModelCompletionError {
 511    fn from(error: ApiError) -> Self {
 512        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 513            if cloud_error.code.starts_with("upstream_http_") {
 514                let status = if let Some(status) = cloud_error.upstream_status {
 515                    status
 516                } else if cloud_error.code.ends_with("_error") {
 517                    error.status
 518                } else {
 519                    // If there's a status code in the code string (e.g. "upstream_http_429")
 520                    // then use that; otherwise, see if the JSON contains a status code.
 521                    cloud_error
 522                        .code
 523                        .strip_prefix("upstream_http_")
 524                        .and_then(|code_str| code_str.parse::<u16>().ok())
 525                        .and_then(|code| StatusCode::from_u16(code).ok())
 526                        .unwrap_or(error.status)
 527                };
 528
 529                return LanguageModelCompletionError::UpstreamProviderError {
 530                    message: cloud_error.message,
 531                    status,
 532                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 533                };
 534            }
 535
 536            return LanguageModelCompletionError::from_http_status(
 537                PROVIDER_NAME,
 538                error.status,
 539                cloud_error.message,
 540                None,
 541            );
 542        }
 543
 544        let retry_after = None;
 545        LanguageModelCompletionError::from_http_status(
 546            PROVIDER_NAME,
 547            error.status,
 548            error.body,
 549            retry_after,
 550        )
 551    }
 552}
 553
 554impl LanguageModel for CloudLanguageModel {
 555    fn id(&self) -> LanguageModelId {
 556        self.id.clone()
 557    }
 558
 559    fn name(&self) -> LanguageModelName {
 560        LanguageModelName::from(self.model.display_name.clone())
 561    }
 562
 563    fn provider_id(&self) -> LanguageModelProviderId {
 564        PROVIDER_ID
 565    }
 566
 567    fn provider_name(&self) -> LanguageModelProviderName {
 568        PROVIDER_NAME
 569    }
 570
 571    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 572        use cloud_llm_client::LanguageModelProvider::*;
 573        match self.model.provider {
 574            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 575            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 576            Google => language_model::GOOGLE_PROVIDER_ID,
 577            XAi => language_model::X_AI_PROVIDER_ID,
 578        }
 579    }
 580
 581    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 582        use cloud_llm_client::LanguageModelProvider::*;
 583        match self.model.provider {
 584            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 585            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 586            Google => language_model::GOOGLE_PROVIDER_NAME,
 587            XAi => language_model::X_AI_PROVIDER_NAME,
 588        }
 589    }
 590
 591    fn is_latest(&self) -> bool {
 592        self.model.is_latest
 593    }
 594
 595    fn supports_tools(&self) -> bool {
 596        self.model.supports_tools
 597    }
 598
 599    fn supports_images(&self) -> bool {
 600        self.model.supports_images
 601    }
 602
 603    fn supports_thinking(&self) -> bool {
 604        self.model.supports_thinking
 605    }
 606
 607    fn supports_fast_mode(&self) -> bool {
 608        self.model.supports_fast_mode
 609    }
 610
 611    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 612        self.model
 613            .supported_effort_levels
 614            .iter()
 615            .map(|effort_level| LanguageModelEffortLevel {
 616                name: effort_level.name.clone().into(),
 617                value: effort_level.value.clone().into(),
 618                is_default: effort_level.is_default.unwrap_or(false),
 619            })
 620            .collect()
 621    }
 622
 623    fn supports_streaming_tools(&self) -> bool {
 624        self.model.supports_streaming_tools
 625    }
 626
 627    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 628        match choice {
 629            LanguageModelToolChoice::Auto
 630            | LanguageModelToolChoice::Any
 631            | LanguageModelToolChoice::None => true,
 632        }
 633    }
 634
 635    fn supports_split_token_display(&self) -> bool {
 636        use cloud_llm_client::LanguageModelProvider::*;
 637        matches!(self.model.provider, OpenAi)
 638    }
 639
 640    fn telemetry_id(&self) -> String {
 641        format!("zed.dev/{}", self.model.id)
 642    }
 643
 644    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 645        match self.model.provider {
 646            cloud_llm_client::LanguageModelProvider::Anthropic
 647            | cloud_llm_client::LanguageModelProvider::OpenAi
 648            | cloud_llm_client::LanguageModelProvider::XAi => {
 649                LanguageModelToolSchemaFormat::JsonSchema
 650            }
 651            cloud_llm_client::LanguageModelProvider::Google => {
 652                LanguageModelToolSchemaFormat::JsonSchemaSubset
 653            }
 654        }
 655    }
 656
 657    fn max_token_count(&self) -> u64 {
 658        self.model.max_token_count as u64
 659    }
 660
 661    fn max_output_tokens(&self) -> Option<u64> {
 662        Some(self.model.max_output_tokens as u64)
 663    }
 664
 665    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 666        match &self.model.provider {
 667            cloud_llm_client::LanguageModelProvider::Anthropic => {
 668                Some(LanguageModelCacheConfiguration {
 669                    min_total_token: 2_048,
 670                    should_speculate: true,
 671                    max_cache_anchors: 4,
 672                })
 673            }
 674            cloud_llm_client::LanguageModelProvider::OpenAi
 675            | cloud_llm_client::LanguageModelProvider::XAi
 676            | cloud_llm_client::LanguageModelProvider::Google => None,
 677        }
 678    }
 679
 680    fn count_tokens(
 681        &self,
 682        request: LanguageModelRequest,
 683        cx: &App,
 684    ) -> BoxFuture<'static, Result<u64>> {
 685        match self.model.provider {
 686            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 687                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 688                .boxed(),
 689            cloud_llm_client::LanguageModelProvider::OpenAi => {
 690                let model = match open_ai::Model::from_id(&self.model.id.0) {
 691                    Ok(model) => model,
 692                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 693                };
 694                count_open_ai_tokens(request, model, cx)
 695            }
 696            cloud_llm_client::LanguageModelProvider::XAi => {
 697                let model = match x_ai::Model::from_id(&self.model.id.0) {
 698                    Ok(model) => model,
 699                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 700                };
 701                count_xai_tokens(request, model, cx)
 702            }
 703            cloud_llm_client::LanguageModelProvider::Google => {
 704                let client = self.client.clone();
 705                let llm_api_token = self.llm_api_token.clone();
 706                let organization_id = self
 707                    .user_store
 708                    .read(cx)
 709                    .current_organization()
 710                    .map(|o| o.id.clone());
 711                let model_id = self.model.id.to_string();
 712                let generate_content_request =
 713                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 714                async move {
 715                    let http_client = &client.http_client();
 716                    let token = llm_api_token.acquire(&client, organization_id).await?;
 717
 718                    let request_body = CountTokensBody {
 719                        provider: cloud_llm_client::LanguageModelProvider::Google,
 720                        model: model_id,
 721                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 722                            generate_content_request,
 723                        })?,
 724                    };
 725                    let request = http_client::Request::builder()
 726                        .method(Method::POST)
 727                        .uri(
 728                            http_client
 729                                .build_zed_llm_url("/count_tokens", &[])?
 730                                .as_ref(),
 731                        )
 732                        .header("Content-Type", "application/json")
 733                        .header("Authorization", format!("Bearer {token}"))
 734                        .body(serde_json::to_string(&request_body)?.into())?;
 735                    let mut response = http_client.send(request).await?;
 736                    let status = response.status();
 737                    let headers = response.headers().clone();
 738                    let mut response_body = String::new();
 739                    response
 740                        .body_mut()
 741                        .read_to_string(&mut response_body)
 742                        .await?;
 743
 744                    if status.is_success() {
 745                        let response_body: CountTokensResponse =
 746                            serde_json::from_str(&response_body)?;
 747
 748                        Ok(response_body.tokens as u64)
 749                    } else {
 750                        Err(anyhow!(ApiError {
 751                            status,
 752                            body: response_body,
 753                            headers
 754                        }))
 755                    }
 756                }
 757                .boxed()
 758            }
 759        }
 760    }
 761
 762    fn stream_completion(
 763        &self,
 764        request: LanguageModelRequest,
 765        cx: &AsyncApp,
 766    ) -> BoxFuture<
 767        'static,
 768        Result<
 769            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 770            LanguageModelCompletionError,
 771        >,
 772    > {
 773        let thread_id = request.thread_id.clone();
 774        let prompt_id = request.prompt_id.clone();
 775        let intent = request.intent;
 776        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 777        let user_store = self.user_store.clone();
 778        let organization_id = cx.update(|cx| {
 779            user_store
 780                .read(cx)
 781                .current_organization()
 782                .map(|o| o.id.clone())
 783        });
 784        let thinking_allowed = request.thinking_allowed;
 785        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 786        let provider_name = provider_name(&self.model.provider);
 787        match self.model.provider {
 788            cloud_llm_client::LanguageModelProvider::Anthropic => {
 789                let effort = request
 790                    .thinking_effort
 791                    .as_ref()
 792                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 793
 794                let mut request = into_anthropic(
 795                    request,
 796                    self.model.id.to_string(),
 797                    1.0,
 798                    self.model.max_output_tokens as u64,
 799                    if enable_thinking {
 800                        AnthropicModelMode::Thinking {
 801                            budget_tokens: Some(4_096),
 802                        }
 803                    } else {
 804                        AnthropicModelMode::Default
 805                    },
 806                );
 807
 808                if enable_thinking && effort.is_some() {
 809                    request.thinking = Some(anthropic::Thinking::Adaptive);
 810                    request.output_config = Some(anthropic::OutputConfig { effort });
 811                }
 812
 813                let client = self.client.clone();
 814                let llm_api_token = self.llm_api_token.clone();
 815                let organization_id = organization_id.clone();
 816                let future = self.request_limiter.stream(async move {
 817                    let PerformLlmCompletionResponse {
 818                        response,
 819                        includes_status_messages,
 820                    } = Self::perform_llm_completion(
 821                        client.clone(),
 822                        llm_api_token,
 823                        organization_id,
 824                        app_version,
 825                        CompletionBody {
 826                            thread_id,
 827                            prompt_id,
 828                            intent,
 829                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 830                            model: request.model.clone(),
 831                            provider_request: serde_json::to_value(&request)
 832                                .map_err(|e| anyhow!(e))?,
 833                        },
 834                    )
 835                    .await
 836                    .map_err(|err| match err.downcast::<ApiError>() {
 837                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 838                        Err(err) => anyhow!(err),
 839                    })?;
 840
 841                    let mut mapper = AnthropicEventMapper::new();
 842                    Ok(map_cloud_completion_events(
 843                        Box::pin(response_lines(response, includes_status_messages)),
 844                        &provider_name,
 845                        move |event| mapper.map_event(event),
 846                    ))
 847                });
 848                async move { Ok(future.await?.boxed()) }.boxed()
 849            }
 850            cloud_llm_client::LanguageModelProvider::OpenAi => {
 851                let client = self.client.clone();
 852                let llm_api_token = self.llm_api_token.clone();
 853                let organization_id = organization_id.clone();
 854                let effort = request
 855                    .thinking_effort
 856                    .as_ref()
 857                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 858
 859                let mut request = into_open_ai_response(
 860                    request,
 861                    &self.model.id.0,
 862                    self.model.supports_parallel_tool_calls,
 863                    true,
 864                    None,
 865                    None,
 866                );
 867
 868                if enable_thinking && let Some(effort) = effort {
 869                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
 870                        effort,
 871                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
 872                    });
 873                }
 874
 875                let future = self.request_limiter.stream(async move {
 876                    let PerformLlmCompletionResponse {
 877                        response,
 878                        includes_status_messages,
 879                    } = Self::perform_llm_completion(
 880                        client.clone(),
 881                        llm_api_token,
 882                        organization_id,
 883                        app_version,
 884                        CompletionBody {
 885                            thread_id,
 886                            prompt_id,
 887                            intent,
 888                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 889                            model: request.model.clone(),
 890                            provider_request: serde_json::to_value(&request)
 891                                .map_err(|e| anyhow!(e))?,
 892                        },
 893                    )
 894                    .await?;
 895
 896                    let mut mapper = OpenAiResponseEventMapper::new();
 897                    Ok(map_cloud_completion_events(
 898                        Box::pin(response_lines(response, includes_status_messages)),
 899                        &provider_name,
 900                        move |event| mapper.map_event(event),
 901                    ))
 902                });
 903                async move { Ok(future.await?.boxed()) }.boxed()
 904            }
 905            cloud_llm_client::LanguageModelProvider::XAi => {
 906                let client = self.client.clone();
 907                let request = into_open_ai(
 908                    request,
 909                    &self.model.id.0,
 910                    self.model.supports_parallel_tool_calls,
 911                    false,
 912                    None,
 913                    None,
 914                );
 915                let llm_api_token = self.llm_api_token.clone();
 916                let organization_id = organization_id.clone();
 917                let future = self.request_limiter.stream(async move {
 918                    let PerformLlmCompletionResponse {
 919                        response,
 920                        includes_status_messages,
 921                    } = Self::perform_llm_completion(
 922                        client.clone(),
 923                        llm_api_token,
 924                        organization_id,
 925                        app_version,
 926                        CompletionBody {
 927                            thread_id,
 928                            prompt_id,
 929                            intent,
 930                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 931                            model: request.model.clone(),
 932                            provider_request: serde_json::to_value(&request)
 933                                .map_err(|e| anyhow!(e))?,
 934                        },
 935                    )
 936                    .await?;
 937
 938                    let mut mapper = OpenAiEventMapper::new();
 939                    Ok(map_cloud_completion_events(
 940                        Box::pin(response_lines(response, includes_status_messages)),
 941                        &provider_name,
 942                        move |event| mapper.map_event(event),
 943                    ))
 944                });
 945                async move { Ok(future.await?.boxed()) }.boxed()
 946            }
 947            cloud_llm_client::LanguageModelProvider::Google => {
 948                let client = self.client.clone();
 949                let request =
 950                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 951                let llm_api_token = self.llm_api_token.clone();
 952                let future = self.request_limiter.stream(async move {
 953                    let PerformLlmCompletionResponse {
 954                        response,
 955                        includes_status_messages,
 956                    } = Self::perform_llm_completion(
 957                        client.clone(),
 958                        llm_api_token,
 959                        organization_id,
 960                        app_version,
 961                        CompletionBody {
 962                            thread_id,
 963                            prompt_id,
 964                            intent,
 965                            provider: cloud_llm_client::LanguageModelProvider::Google,
 966                            model: request.model.model_id.clone(),
 967                            provider_request: serde_json::to_value(&request)
 968                                .map_err(|e| anyhow!(e))?,
 969                        },
 970                    )
 971                    .await?;
 972
 973                    let mut mapper = GoogleEventMapper::new();
 974                    Ok(map_cloud_completion_events(
 975                        Box::pin(response_lines(response, includes_status_messages)),
 976                        &provider_name,
 977                        move |event| mapper.map_event(event),
 978                    ))
 979                });
 980                async move { Ok(future.await?.boxed()) }.boxed()
 981            }
 982        }
 983    }
 984}
 985
 986fn map_cloud_completion_events<T, F>(
 987    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 988    provider: &LanguageModelProviderName,
 989    mut map_callback: F,
 990) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 991where
 992    T: DeserializeOwned + 'static,
 993    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 994        + Send
 995        + 'static,
 996{
 997    let provider = provider.clone();
 998    let mut stream = stream.fuse();
 999
1000    let mut saw_stream_ended = false;
1001
1002    let mut done = false;
1003    let mut pending = VecDeque::new();
1004
1005    stream::poll_fn(move |cx| {
1006        loop {
1007            if let Some(item) = pending.pop_front() {
1008                return Poll::Ready(Some(item));
1009            }
1010
1011            if done {
1012                return Poll::Ready(None);
1013            }
1014
1015            match stream.poll_next_unpin(cx) {
1016                Poll::Ready(Some(event)) => {
1017                    let items = match event {
1018                        Err(error) => {
1019                            vec![Err(LanguageModelCompletionError::from(error))]
1020                        }
1021                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
1022                            saw_stream_ended = true;
1023                            vec![]
1024                        }
1025                        Ok(CompletionEvent::Status(status)) => {
1026                            LanguageModelCompletionEvent::from_completion_request_status(
1027                                status,
1028                                provider.clone(),
1029                            )
1030                            .transpose()
1031                            .map(|event| vec![event])
1032                            .unwrap_or_default()
1033                        }
1034                        Ok(CompletionEvent::Event(event)) => map_callback(event),
1035                    };
1036                    pending.extend(items);
1037                }
1038                Poll::Ready(None) => {
1039                    done = true;
1040
1041                    if !saw_stream_ended {
1042                        return Poll::Ready(Some(Err(
1043                            LanguageModelCompletionError::StreamEndedUnexpectedly {
1044                                provider: provider.clone(),
1045                            },
1046                        )));
1047                    }
1048                }
1049                Poll::Pending => return Poll::Pending,
1050            }
1051        }
1052    })
1053    .boxed()
1054}
1055
1056fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1057    match provider {
1058        cloud_llm_client::LanguageModelProvider::Anthropic => {
1059            language_model::ANTHROPIC_PROVIDER_NAME
1060        }
1061        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1062        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1063        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1064    }
1065}
1066
1067fn response_lines<T: DeserializeOwned>(
1068    response: Response<AsyncBody>,
1069    includes_status_messages: bool,
1070) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1071    futures::stream::try_unfold(
1072        (String::new(), BufReader::new(response.into_body())),
1073        move |(mut line, mut body)| async move {
1074            match body.read_line(&mut line).await {
1075                Ok(0) => Ok(None),
1076                Ok(_) => {
1077                    let event = if includes_status_messages {
1078                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1079                    } else {
1080                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1081                    };
1082
1083                    line.clear();
1084                    Ok(Some((event, (line, body))))
1085                }
1086                Err(e) => Err(e.into()),
1087            }
1088        },
1089    )
1090}
1091
1092#[derive(IntoElement, RegisterComponent)]
1093struct ZedAiConfiguration {
1094    is_connected: bool,
1095    plan: Option<Plan>,
1096    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1097    eligible_for_trial: bool,
1098    account_too_young: bool,
1099    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1100}
1101
1102impl RenderOnce for ZedAiConfiguration {
1103    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1104        let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
1105        let subscription_text = match (self.plan, self.subscription_period) {
1106            (Some(Plan::ZedPro), Some(_)) => {
1107                "You have access to Zed's hosted models through your Pro subscription."
1108            }
1109            (Some(Plan::ZedProTrial), Some(_)) => {
1110                "You have access to Zed's hosted models through your Pro trial."
1111            }
1112            (Some(Plan::ZedFree), Some(_)) => {
1113                if self.eligible_for_trial {
1114                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1115                } else {
1116                    "Subscribe for access to Zed's hosted models."
1117                }
1118            }
1119            _ => {
1120                if self.eligible_for_trial {
1121                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1122                } else {
1123                    "Subscribe for access to Zed's hosted models."
1124                }
1125            }
1126        };
1127
1128        let manage_subscription_buttons = if is_pro {
1129            Button::new("manage_settings", "Manage Subscription")
1130                .full_width()
1131                .style(ButtonStyle::Tinted(TintColor::Accent))
1132                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1133                .into_any_element()
1134        } else if self.plan.is_none() || self.eligible_for_trial {
1135            Button::new("start_trial", "Start 14-day Free Pro Trial")
1136                .full_width()
1137                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1138                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1139                .into_any_element()
1140        } else {
1141            Button::new("upgrade", "Upgrade to Pro")
1142                .full_width()
1143                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1144                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1145                .into_any_element()
1146        };
1147
1148        if !self.is_connected {
1149            return v_flex()
1150                .gap_2()
1151                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1152                .child(
1153                    Button::new("sign_in", "Sign In to use Zed AI")
1154                        .icon_color(Color::Muted)
1155                        .icon(IconName::Github)
1156                        .icon_size(IconSize::Small)
1157                        .icon_position(IconPosition::Start)
1158                        .full_width()
1159                        .on_click({
1160                            let callback = self.sign_in_callback.clone();
1161                            move |_, window, cx| (callback)(window, cx)
1162                        }),
1163                );
1164        }
1165
1166        v_flex().gap_2().w_full().map(|this| {
1167            if self.account_too_young {
1168                this.child(YoungAccountBanner).child(
1169                    Button::new("upgrade", "Upgrade to Pro")
1170                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1171                        .full_width()
1172                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1173                )
1174            } else {
1175                this.text_sm()
1176                    .child(subscription_text)
1177                    .child(manage_subscription_buttons)
1178            }
1179        })
1180    }
1181}
1182
1183struct ConfigurationView {
1184    state: Entity<State>,
1185    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1186}
1187
1188impl ConfigurationView {
1189    fn new(state: Entity<State>) -> Self {
1190        let sign_in_callback = Arc::new({
1191            let state = state.clone();
1192            move |_window: &mut Window, cx: &mut App| {
1193                state.update(cx, |state, cx| {
1194                    state.authenticate(cx).detach_and_log_err(cx);
1195                });
1196            }
1197        });
1198
1199        Self {
1200            state,
1201            sign_in_callback,
1202        }
1203    }
1204}
1205
1206impl Render for ConfigurationView {
1207    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1208        let state = self.state.read(cx);
1209        let user_store = state.user_store.read(cx);
1210
1211        ZedAiConfiguration {
1212            is_connected: !state.is_signed_out(cx),
1213            plan: user_store.plan(),
1214            subscription_period: user_store.subscription_period(),
1215            eligible_for_trial: user_store.trial_started_at().is_none(),
1216            account_too_young: user_store.account_too_young(),
1217            sign_in_callback: self.sign_in_callback.clone(),
1218        }
1219    }
1220}
1221
1222impl Component for ZedAiConfiguration {
1223    fn name() -> &'static str {
1224        "AI Configuration Content"
1225    }
1226
1227    fn sort_name() -> &'static str {
1228        "AI Configuration Content"
1229    }
1230
1231    fn scope() -> ComponentScope {
1232        ComponentScope::Onboarding
1233    }
1234
1235    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1236        fn configuration(
1237            is_connected: bool,
1238            plan: Option<Plan>,
1239            eligible_for_trial: bool,
1240            account_too_young: bool,
1241        ) -> AnyElement {
1242            ZedAiConfiguration {
1243                is_connected,
1244                plan,
1245                subscription_period: plan
1246                    .is_some()
1247                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1248                eligible_for_trial,
1249                account_too_young,
1250                sign_in_callback: Arc::new(|_, _| {}),
1251            }
1252            .into_any_element()
1253        }
1254
1255        Some(
1256            v_flex()
1257                .p_4()
1258                .gap_4()
1259                .children(vec![
1260                    single_example("Not connected", configuration(false, None, false, false)),
1261                    single_example(
1262                        "Accept Terms of Service",
1263                        configuration(true, None, true, false),
1264                    ),
1265                    single_example(
1266                        "No Plan - Not eligible for trial",
1267                        configuration(true, None, false, false),
1268                    ),
1269                    single_example(
1270                        "No Plan - Eligible for trial",
1271                        configuration(true, None, true, false),
1272                    ),
1273                    single_example(
1274                        "Free Plan",
1275                        configuration(true, Some(Plan::ZedFree), true, false),
1276                    ),
1277                    single_example(
1278                        "Zed Pro Trial Plan",
1279                        configuration(true, Some(Plan::ZedProTrial), true, false),
1280                    ),
1281                    single_example(
1282                        "Zed Pro Plan",
1283                        configuration(true, Some(Plan::ZedPro), true, false),
1284                    ),
1285                ])
1286                .into_any_element(),
1287        )
1288    }
1289}
1290
1291#[cfg(test)]
1292mod tests {
1293    use super::*;
1294    use http_client::http::{HeaderMap, StatusCode};
1295    use language_model::LanguageModelCompletionError;
1296
1297    #[test]
1298    fn test_api_error_conversion_with_upstream_http_error() {
1299        // upstream_http_error with 503 status should become ServerOverloaded
1300        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1301
1302        let api_error = ApiError {
1303            status: StatusCode::INTERNAL_SERVER_ERROR,
1304            body: error_body.to_string(),
1305            headers: HeaderMap::new(),
1306        };
1307
1308        let completion_error: LanguageModelCompletionError = api_error.into();
1309
1310        match completion_error {
1311            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1312                assert_eq!(
1313                    message,
1314                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1315                );
1316            }
1317            _ => panic!(
1318                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1319                completion_error
1320            ),
1321        }
1322
1323        // upstream_http_error with 500 status should become ApiInternalServerError
1324        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1325
1326        let api_error = ApiError {
1327            status: StatusCode::INTERNAL_SERVER_ERROR,
1328            body: error_body.to_string(),
1329            headers: HeaderMap::new(),
1330        };
1331
1332        let completion_error: LanguageModelCompletionError = api_error.into();
1333
1334        match completion_error {
1335            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1336                assert_eq!(
1337                    message,
1338                    "Received an error from the OpenAI API: internal server error"
1339                );
1340            }
1341            _ => panic!(
1342                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1343                completion_error
1344            ),
1345        }
1346
1347        // upstream_http_error with 429 status should become RateLimitExceeded
1348        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1349
1350        let api_error = ApiError {
1351            status: StatusCode::INTERNAL_SERVER_ERROR,
1352            body: error_body.to_string(),
1353            headers: HeaderMap::new(),
1354        };
1355
1356        let completion_error: LanguageModelCompletionError = api_error.into();
1357
1358        match completion_error {
1359            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1360                assert_eq!(
1361                    message,
1362                    "Received an error from the Google API: rate limit exceeded"
1363                );
1364            }
1365            _ => panic!(
1366                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1367                completion_error
1368            ),
1369        }
1370
1371        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1372        let error_body = "Regular internal server error";
1373
1374        let api_error = ApiError {
1375            status: StatusCode::INTERNAL_SERVER_ERROR,
1376            body: error_body.to_string(),
1377            headers: HeaderMap::new(),
1378        };
1379
1380        let completion_error: LanguageModelCompletionError = api_error.into();
1381
1382        match completion_error {
1383            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1384                assert_eq!(provider, PROVIDER_NAME);
1385                assert_eq!(message, "Regular internal server error");
1386            }
1387            _ => panic!(
1388                "Expected ApiInternalServerError for regular 500, got: {:?}",
1389                completion_error
1390            ),
1391        }
1392
1393        // upstream_http_429 format should be converted to UpstreamProviderError
1394        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1395
1396        let api_error = ApiError {
1397            status: StatusCode::INTERNAL_SERVER_ERROR,
1398            body: error_body.to_string(),
1399            headers: HeaderMap::new(),
1400        };
1401
1402        let completion_error: LanguageModelCompletionError = api_error.into();
1403
1404        match completion_error {
1405            LanguageModelCompletionError::UpstreamProviderError {
1406                message,
1407                status,
1408                retry_after,
1409            } => {
1410                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1411                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1412                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1413            }
1414            _ => panic!(
1415                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1416                completion_error
1417            ),
1418        }
1419
1420        // Invalid JSON in error body should fall back to regular error handling
1421        let error_body = "Not JSON at all";
1422
1423        let api_error = ApiError {
1424            status: StatusCode::INTERNAL_SERVER_ERROR,
1425            body: error_body.to_string(),
1426            headers: HeaderMap::new(),
1427        };
1428
1429        let completion_error: LanguageModelCompletionError = api_error.into();
1430
1431        match completion_error {
1432            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1433                assert_eq!(provider, PROVIDER_NAME);
1434            }
1435            _ => panic!(
1436                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1437                completion_error
1438            ),
1439        }
1440    }
1441}