cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   6use cloud_llm_client::{
   7    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME,
   8    CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
   9    CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
  10    MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, PlanV1, PlanV2,
  11    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  12    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  13};
  14use feature_flags::{FeatureFlagAppExt as _, OpenAiResponsesApiFeatureFlag};
  15use futures::{
  16    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  25    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  26    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  27    LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::pin::Pin;
  39use std::str::FromStr as _;
  40use std::sync::Arc;
  41use std::time::Duration;
  42use thiserror::Error;
  43use ui::{TintColor, prelude::*};
  44use util::{ResultExt as _, maybe};
  45
  46use crate::provider::anthropic::{
  47    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  48};
  49use crate::provider::google::{GoogleEventMapper, into_google};
  50use crate::provider::open_ai::{
  51    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  52    into_open_ai_response,
  53};
  54use crate::provider::x_ai::count_xai_tokens;
  55
  56const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  57const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  58
  59#[derive(Default, Clone, Debug, PartialEq)]
  60pub struct ZedDotDevSettings {
  61    pub available_models: Vec<AvailableModel>,
  62}
  63#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  64#[serde(tag = "type", rename_all = "lowercase")]
  65pub enum ModelMode {
  66    #[default]
  67    Default,
  68    Thinking {
  69        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  70        budget_tokens: Option<u32>,
  71    },
  72}
  73
  74impl From<ModelMode> for AnthropicModelMode {
  75    fn from(value: ModelMode) -> Self {
  76        match value {
  77            ModelMode::Default => AnthropicModelMode::Default,
  78            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  79        }
  80    }
  81}
  82
  83pub struct CloudLanguageModelProvider {
  84    client: Arc<Client>,
  85    state: Entity<State>,
  86    _maintain_client_status: Task<()>,
  87}
  88
  89pub struct State {
  90    client: Arc<Client>,
  91    llm_api_token: LlmApiToken,
  92    user_store: Entity<UserStore>,
  93    status: client::Status,
  94    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  95    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  96    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  97    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  98    _fetch_models_task: Task<()>,
  99    _settings_subscription: Subscription,
 100    _llm_token_subscription: Subscription,
 101}
 102
 103impl State {
 104    fn new(
 105        client: Arc<Client>,
 106        user_store: Entity<UserStore>,
 107        status: client::Status,
 108        cx: &mut Context<Self>,
 109    ) -> Self {
 110        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 111        let mut current_user = user_store.read(cx).watch_current_user();
 112        Self {
 113            client: client.clone(),
 114            llm_api_token: LlmApiToken::default(),
 115            user_store,
 116            status,
 117            models: Vec::new(),
 118            default_model: None,
 119            default_fast_model: None,
 120            recommended_models: Vec::new(),
 121            _fetch_models_task: cx.spawn(async move |this, cx| {
 122                maybe!(async move {
 123                    let (client, llm_api_token) = this
 124                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 125
 126                    while current_user.borrow().is_none() {
 127                        current_user.next().await;
 128                    }
 129
 130                    let response =
 131                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 132                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 133                    anyhow::Ok(())
 134                })
 135                .await
 136                .context("failed to fetch Zed models")
 137                .log_err();
 138            }),
 139            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 140                cx.notify();
 141            }),
 142            _llm_token_subscription: cx.subscribe(
 143                &refresh_llm_token_listener,
 144                move |this, _listener, _event, cx| {
 145                    let client = this.client.clone();
 146                    let llm_api_token = this.llm_api_token.clone();
 147                    cx.spawn(async move |this, cx| {
 148                        llm_api_token.refresh(&client).await?;
 149                        let response = Self::fetch_models(client, llm_api_token).await?;
 150                        this.update(cx, |this, cx| {
 151                            this.update_models(response, cx);
 152                        })
 153                    })
 154                    .detach_and_log_err(cx);
 155                },
 156            ),
 157        }
 158    }
 159
 160    fn is_signed_out(&self, cx: &App) -> bool {
 161        self.user_store.read(cx).current_user().is_none()
 162    }
 163
 164    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 165        let client = self.client.clone();
 166        cx.spawn(async move |state, cx| {
 167            client.sign_in_with_optional_connect(true, cx).await?;
 168            state.update(cx, |_, cx| cx.notify())
 169        })
 170    }
 171    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 172        let mut models = Vec::new();
 173
 174        for model in response.models {
 175            models.push(Arc::new(model.clone()));
 176
 177            // Right now we represent thinking variants of models as separate models on the client,
 178            // so we need to insert variants for any model that supports thinking.
 179            if model.supports_thinking {
 180                models.push(Arc::new(cloud_llm_client::LanguageModel {
 181                    id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 182                    display_name: format!("{} Thinking", model.display_name),
 183                    ..model
 184                }));
 185            }
 186        }
 187
 188        self.default_model = models
 189            .iter()
 190            .find(|model| {
 191                response
 192                    .default_model
 193                    .as_ref()
 194                    .is_some_and(|default_model_id| &model.id == default_model_id)
 195            })
 196            .cloned();
 197        self.default_fast_model = models
 198            .iter()
 199            .find(|model| {
 200                response
 201                    .default_fast_model
 202                    .as_ref()
 203                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 204            })
 205            .cloned();
 206        self.recommended_models = response
 207            .recommended_models
 208            .iter()
 209            .filter_map(|id| models.iter().find(|model| &model.id == id))
 210            .cloned()
 211            .collect();
 212        self.models = models;
 213        cx.notify();
 214    }
 215
 216    async fn fetch_models(
 217        client: Arc<Client>,
 218        llm_api_token: LlmApiToken,
 219    ) -> Result<ListModelsResponse> {
 220        let http_client = &client.http_client();
 221        let token = llm_api_token.acquire(&client).await?;
 222
 223        let request = http_client::Request::builder()
 224            .method(Method::GET)
 225            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 226            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 227            .header("Authorization", format!("Bearer {token}"))
 228            .body(AsyncBody::empty())?;
 229        let mut response = http_client
 230            .send(request)
 231            .await
 232            .context("failed to send list models request")?;
 233
 234        if response.status().is_success() {
 235            let mut body = String::new();
 236            response.body_mut().read_to_string(&mut body).await?;
 237            Ok(serde_json::from_str(&body)?)
 238        } else {
 239            let mut body = String::new();
 240            response.body_mut().read_to_string(&mut body).await?;
 241            anyhow::bail!(
 242                "error listing models.\nStatus: {:?}\nBody: {body}",
 243                response.status(),
 244            );
 245        }
 246    }
 247}
 248
 249impl CloudLanguageModelProvider {
 250    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 251        let mut status_rx = client.status();
 252        let status = *status_rx.borrow();
 253
 254        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 255
 256        let state_ref = state.downgrade();
 257        let maintain_client_status = cx.spawn(async move |cx| {
 258            while let Some(status) = status_rx.next().await {
 259                if let Some(this) = state_ref.upgrade() {
 260                    _ = this.update(cx, |this, cx| {
 261                        if this.status != status {
 262                            this.status = status;
 263                            cx.notify();
 264                        }
 265                    });
 266                } else {
 267                    break;
 268                }
 269            }
 270        });
 271
 272        Self {
 273            client,
 274            state,
 275            _maintain_client_status: maintain_client_status,
 276        }
 277    }
 278
 279    fn create_language_model(
 280        &self,
 281        model: Arc<cloud_llm_client::LanguageModel>,
 282        llm_api_token: LlmApiToken,
 283    ) -> Arc<dyn LanguageModel> {
 284        Arc::new(CloudLanguageModel {
 285            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 286            model,
 287            llm_api_token,
 288            client: self.client.clone(),
 289            request_limiter: RateLimiter::new(4),
 290        })
 291    }
 292}
 293
 294impl LanguageModelProviderState for CloudLanguageModelProvider {
 295    type ObservableEntity = State;
 296
 297    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 298        Some(self.state.clone())
 299    }
 300}
 301
 302impl LanguageModelProvider for CloudLanguageModelProvider {
 303    fn id(&self) -> LanguageModelProviderId {
 304        PROVIDER_ID
 305    }
 306
 307    fn name(&self) -> LanguageModelProviderName {
 308        PROVIDER_NAME
 309    }
 310
 311    fn icon(&self) -> IconOrSvg {
 312        IconOrSvg::Icon(IconName::AiZed)
 313    }
 314
 315    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 316        let default_model = self.state.read(cx).default_model.clone()?;
 317        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 318        Some(self.create_language_model(default_model, llm_api_token))
 319    }
 320
 321    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 322        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 323        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 324        Some(self.create_language_model(default_fast_model, llm_api_token))
 325    }
 326
 327    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 328        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 329        self.state
 330            .read(cx)
 331            .recommended_models
 332            .iter()
 333            .cloned()
 334            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 335            .collect()
 336    }
 337
 338    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 339        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 340        self.state
 341            .read(cx)
 342            .models
 343            .iter()
 344            .cloned()
 345            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 346            .collect()
 347    }
 348
 349    fn is_authenticated(&self, cx: &App) -> bool {
 350        let state = self.state.read(cx);
 351        !state.is_signed_out(cx)
 352    }
 353
 354    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 355        Task::ready(Ok(()))
 356    }
 357
 358    fn configuration_view(
 359        &self,
 360        _target_agent: language_model::ConfigurationViewTargetAgent,
 361        _: &mut Window,
 362        cx: &mut App,
 363    ) -> AnyView {
 364        cx.new(|_| ConfigurationView::new(self.state.clone()))
 365            .into()
 366    }
 367
 368    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 369        Task::ready(Ok(()))
 370    }
 371}
 372
 373pub struct CloudLanguageModel {
 374    id: LanguageModelId,
 375    model: Arc<cloud_llm_client::LanguageModel>,
 376    llm_api_token: LlmApiToken,
 377    client: Arc<Client>,
 378    request_limiter: RateLimiter,
 379}
 380
 381struct PerformLlmCompletionResponse {
 382    response: Response<AsyncBody>,
 383    usage: Option<ModelRequestUsage>,
 384    tool_use_limit_reached: bool,
 385    includes_status_messages: bool,
 386}
 387
 388impl CloudLanguageModel {
 389    async fn perform_llm_completion(
 390        client: Arc<Client>,
 391        llm_api_token: LlmApiToken,
 392        app_version: Option<Version>,
 393        body: CompletionBody,
 394    ) -> Result<PerformLlmCompletionResponse> {
 395        let http_client = &client.http_client();
 396
 397        let mut token = llm_api_token.acquire(&client).await?;
 398        let mut refreshed_token = false;
 399
 400        loop {
 401            let request = http_client::Request::builder()
 402                .method(Method::POST)
 403                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 404                .when_some(app_version.as_ref(), |builder, app_version| {
 405                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 406                })
 407                .header("Content-Type", "application/json")
 408                .header("Authorization", format!("Bearer {token}"))
 409                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 410                .body(serde_json::to_string(&body)?.into())?;
 411
 412            let mut response = http_client.send(request).await?;
 413            let status = response.status();
 414            if status.is_success() {
 415                let includes_status_messages = response
 416                    .headers()
 417                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 418                    .is_some();
 419
 420                let tool_use_limit_reached = response
 421                    .headers()
 422                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 423                    .is_some();
 424
 425                let usage = if includes_status_messages {
 426                    None
 427                } else {
 428                    ModelRequestUsage::from_headers(response.headers()).ok()
 429                };
 430
 431                return Ok(PerformLlmCompletionResponse {
 432                    response,
 433                    usage,
 434                    includes_status_messages,
 435                    tool_use_limit_reached,
 436                });
 437            }
 438
 439            if !refreshed_token
 440                && response
 441                    .headers()
 442                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 443                    .is_some()
 444            {
 445                token = llm_api_token.refresh(&client).await?;
 446                refreshed_token = true;
 447                continue;
 448            }
 449
 450            if status == StatusCode::FORBIDDEN
 451                && response
 452                    .headers()
 453                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 454                    .is_some()
 455            {
 456                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 457                    .headers()
 458                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 459                    .and_then(|resource| resource.to_str().ok())
 460                    && let Some(plan) = response
 461                        .headers()
 462                        .get(CURRENT_PLAN_HEADER_NAME)
 463                        .and_then(|plan| plan.to_str().ok())
 464                        .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
 465                        .map(Plan::V1)
 466                {
 467                    return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 468                }
 469            } else if status == StatusCode::PAYMENT_REQUIRED {
 470                return Err(anyhow!(PaymentRequiredError));
 471            }
 472
 473            let mut body = String::new();
 474            let headers = response.headers().clone();
 475            response.body_mut().read_to_string(&mut body).await?;
 476            return Err(anyhow!(ApiError {
 477                status,
 478                body,
 479                headers
 480            }));
 481        }
 482    }
 483}
 484
 485#[derive(Debug, Error)]
 486#[error("cloud language model request failed with status {status}: {body}")]
 487struct ApiError {
 488    status: StatusCode,
 489    body: String,
 490    headers: HeaderMap<HeaderValue>,
 491}
 492
 493/// Represents error responses from Zed's cloud API.
 494///
 495/// Example JSON for an upstream HTTP error:
 496/// ```json
 497/// {
 498///   "code": "upstream_http_error",
 499///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 500///   "upstream_status": 503
 501/// }
 502/// ```
 503#[derive(Debug, serde::Deserialize)]
 504struct CloudApiError {
 505    code: String,
 506    message: String,
 507    #[serde(default)]
 508    #[serde(deserialize_with = "deserialize_optional_status_code")]
 509    upstream_status: Option<StatusCode>,
 510    #[serde(default)]
 511    retry_after: Option<f64>,
 512}
 513
 514fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 515where
 516    D: serde::Deserializer<'de>,
 517{
 518    let opt: Option<u16> = Option::deserialize(deserializer)?;
 519    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 520}
 521
 522impl From<ApiError> for LanguageModelCompletionError {
 523    fn from(error: ApiError) -> Self {
 524        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 525            if cloud_error.code.starts_with("upstream_http_") {
 526                let status = if let Some(status) = cloud_error.upstream_status {
 527                    status
 528                } else if cloud_error.code.ends_with("_error") {
 529                    error.status
 530                } else {
 531                    // If there's a status code in the code string (e.g. "upstream_http_429")
 532                    // then use that; otherwise, see if the JSON contains a status code.
 533                    cloud_error
 534                        .code
 535                        .strip_prefix("upstream_http_")
 536                        .and_then(|code_str| code_str.parse::<u16>().ok())
 537                        .and_then(|code| StatusCode::from_u16(code).ok())
 538                        .unwrap_or(error.status)
 539                };
 540
 541                return LanguageModelCompletionError::UpstreamProviderError {
 542                    message: cloud_error.message,
 543                    status,
 544                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 545                };
 546            }
 547
 548            return LanguageModelCompletionError::from_http_status(
 549                PROVIDER_NAME,
 550                error.status,
 551                cloud_error.message,
 552                None,
 553            );
 554        }
 555
 556        let retry_after = None;
 557        LanguageModelCompletionError::from_http_status(
 558            PROVIDER_NAME,
 559            error.status,
 560            error.body,
 561            retry_after,
 562        )
 563    }
 564}
 565
 566impl LanguageModel for CloudLanguageModel {
 567    fn id(&self) -> LanguageModelId {
 568        self.id.clone()
 569    }
 570
 571    fn name(&self) -> LanguageModelName {
 572        LanguageModelName::from(self.model.display_name.clone())
 573    }
 574
 575    fn provider_id(&self) -> LanguageModelProviderId {
 576        PROVIDER_ID
 577    }
 578
 579    fn provider_name(&self) -> LanguageModelProviderName {
 580        PROVIDER_NAME
 581    }
 582
 583    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 584        use cloud_llm_client::LanguageModelProvider::*;
 585        match self.model.provider {
 586            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 587            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 588            Google => language_model::GOOGLE_PROVIDER_ID,
 589            XAi => language_model::X_AI_PROVIDER_ID,
 590        }
 591    }
 592
 593    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 594        use cloud_llm_client::LanguageModelProvider::*;
 595        match self.model.provider {
 596            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 597            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 598            Google => language_model::GOOGLE_PROVIDER_NAME,
 599            XAi => language_model::X_AI_PROVIDER_NAME,
 600        }
 601    }
 602
 603    fn supports_tools(&self) -> bool {
 604        self.model.supports_tools
 605    }
 606
 607    fn supports_images(&self) -> bool {
 608        self.model.supports_images
 609    }
 610
 611    fn supports_streaming_tools(&self) -> bool {
 612        self.model.supports_streaming_tools
 613    }
 614
 615    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 616        match choice {
 617            LanguageModelToolChoice::Auto
 618            | LanguageModelToolChoice::Any
 619            | LanguageModelToolChoice::None => true,
 620        }
 621    }
 622
 623    fn supports_burn_mode(&self) -> bool {
 624        self.model.supports_max_mode
 625    }
 626
 627    fn telemetry_id(&self) -> String {
 628        format!("zed.dev/{}", self.model.id)
 629    }
 630
 631    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 632        match self.model.provider {
 633            cloud_llm_client::LanguageModelProvider::Anthropic
 634            | cloud_llm_client::LanguageModelProvider::OpenAi
 635            | cloud_llm_client::LanguageModelProvider::XAi => {
 636                LanguageModelToolSchemaFormat::JsonSchema
 637            }
 638            cloud_llm_client::LanguageModelProvider::Google => {
 639                LanguageModelToolSchemaFormat::JsonSchemaSubset
 640            }
 641        }
 642    }
 643
 644    fn max_token_count(&self) -> u64 {
 645        self.model.max_token_count as u64
 646    }
 647
 648    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 649        self.model
 650            .max_token_count_in_max_mode
 651            .filter(|_| self.model.supports_max_mode)
 652            .map(|max_token_count| max_token_count as u64)
 653    }
 654
 655    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 656        match &self.model.provider {
 657            cloud_llm_client::LanguageModelProvider::Anthropic => {
 658                Some(LanguageModelCacheConfiguration {
 659                    min_total_token: 2_048,
 660                    should_speculate: true,
 661                    max_cache_anchors: 4,
 662                })
 663            }
 664            cloud_llm_client::LanguageModelProvider::OpenAi
 665            | cloud_llm_client::LanguageModelProvider::XAi
 666            | cloud_llm_client::LanguageModelProvider::Google => None,
 667        }
 668    }
 669
 670    fn count_tokens(
 671        &self,
 672        request: LanguageModelRequest,
 673        cx: &App,
 674    ) -> BoxFuture<'static, Result<u64>> {
 675        match self.model.provider {
 676            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 677                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 678                .boxed(),
 679            cloud_llm_client::LanguageModelProvider::OpenAi => {
 680                let model = match open_ai::Model::from_id(&self.model.id.0) {
 681                    Ok(model) => model,
 682                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 683                };
 684                count_open_ai_tokens(request, model, cx)
 685            }
 686            cloud_llm_client::LanguageModelProvider::XAi => {
 687                let model = match x_ai::Model::from_id(&self.model.id.0) {
 688                    Ok(model) => model,
 689                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 690                };
 691                count_xai_tokens(request, model, cx)
 692            }
 693            cloud_llm_client::LanguageModelProvider::Google => {
 694                let client = self.client.clone();
 695                let llm_api_token = self.llm_api_token.clone();
 696                let model_id = self.model.id.to_string();
 697                let generate_content_request =
 698                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 699                async move {
 700                    let http_client = &client.http_client();
 701                    let token = llm_api_token.acquire(&client).await?;
 702
 703                    let request_body = CountTokensBody {
 704                        provider: cloud_llm_client::LanguageModelProvider::Google,
 705                        model: model_id,
 706                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 707                            generate_content_request,
 708                        })?,
 709                    };
 710                    let request = http_client::Request::builder()
 711                        .method(Method::POST)
 712                        .uri(
 713                            http_client
 714                                .build_zed_llm_url("/count_tokens", &[])?
 715                                .as_ref(),
 716                        )
 717                        .header("Content-Type", "application/json")
 718                        .header("Authorization", format!("Bearer {token}"))
 719                        .body(serde_json::to_string(&request_body)?.into())?;
 720                    let mut response = http_client.send(request).await?;
 721                    let status = response.status();
 722                    let headers = response.headers().clone();
 723                    let mut response_body = String::new();
 724                    response
 725                        .body_mut()
 726                        .read_to_string(&mut response_body)
 727                        .await?;
 728
 729                    if status.is_success() {
 730                        let response_body: CountTokensResponse =
 731                            serde_json::from_str(&response_body)?;
 732
 733                        Ok(response_body.tokens as u64)
 734                    } else {
 735                        Err(anyhow!(ApiError {
 736                            status,
 737                            body: response_body,
 738                            headers
 739                        }))
 740                    }
 741                }
 742                .boxed()
 743            }
 744        }
 745    }
 746
 747    fn stream_completion(
 748        &self,
 749        request: LanguageModelRequest,
 750        cx: &AsyncApp,
 751    ) -> BoxFuture<
 752        'static,
 753        Result<
 754            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 755            LanguageModelCompletionError,
 756        >,
 757    > {
 758        let thread_id = request.thread_id.clone();
 759        let prompt_id = request.prompt_id.clone();
 760        let intent = request.intent;
 761        let mode = request.mode;
 762        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 763        let use_responses_api = cx.update(|cx| cx.has_flag::<OpenAiResponsesApiFeatureFlag>());
 764        let thinking_allowed = request.thinking_allowed;
 765        let provider_name = provider_name(&self.model.provider);
 766        match self.model.provider {
 767            cloud_llm_client::LanguageModelProvider::Anthropic => {
 768                let request = into_anthropic(
 769                    request,
 770                    self.model.id.to_string(),
 771                    1.0,
 772                    self.model.max_output_tokens as u64,
 773                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 774                        AnthropicModelMode::Thinking {
 775                            budget_tokens: Some(4_096),
 776                        }
 777                    } else {
 778                        AnthropicModelMode::Default
 779                    },
 780                );
 781                let client = self.client.clone();
 782                let llm_api_token = self.llm_api_token.clone();
 783                let future = self.request_limiter.stream(async move {
 784                    let PerformLlmCompletionResponse {
 785                        response,
 786                        usage,
 787                        includes_status_messages,
 788                        tool_use_limit_reached,
 789                    } = Self::perform_llm_completion(
 790                        client.clone(),
 791                        llm_api_token,
 792                        app_version,
 793                        CompletionBody {
 794                            thread_id,
 795                            prompt_id,
 796                            intent,
 797                            mode,
 798                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 799                            model: request.model.clone(),
 800                            provider_request: serde_json::to_value(&request)
 801                                .map_err(|e| anyhow!(e))?,
 802                        },
 803                    )
 804                    .await
 805                    .map_err(|err| match err.downcast::<ApiError>() {
 806                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 807                        Err(err) => anyhow!(err),
 808                    })?;
 809
 810                    let mut mapper = AnthropicEventMapper::new();
 811                    Ok(map_cloud_completion_events(
 812                        Box::pin(
 813                            response_lines(response, includes_status_messages)
 814                                .chain(usage_updated_event(usage))
 815                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 816                        ),
 817                        &provider_name,
 818                        move |event| mapper.map_event(event),
 819                    ))
 820                });
 821                async move { Ok(future.await?.boxed()) }.boxed()
 822            }
 823            cloud_llm_client::LanguageModelProvider::OpenAi => {
 824                let client = self.client.clone();
 825                let llm_api_token = self.llm_api_token.clone();
 826
 827                if use_responses_api {
 828                    let request = into_open_ai_response(
 829                        request,
 830                        &self.model.id.0,
 831                        self.model.supports_parallel_tool_calls,
 832                        true,
 833                        None,
 834                        None,
 835                    );
 836                    let future = self.request_limiter.stream(async move {
 837                        let PerformLlmCompletionResponse {
 838                            response,
 839                            usage,
 840                            includes_status_messages,
 841                            tool_use_limit_reached,
 842                        } = Self::perform_llm_completion(
 843                            client.clone(),
 844                            llm_api_token,
 845                            app_version,
 846                            CompletionBody {
 847                                thread_id,
 848                                prompt_id,
 849                                intent,
 850                                mode,
 851                                provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 852                                model: request.model.clone(),
 853                                provider_request: serde_json::to_value(&request)
 854                                    .map_err(|e| anyhow!(e))?,
 855                            },
 856                        )
 857                        .await?;
 858
 859                        let mut mapper = OpenAiResponseEventMapper::new();
 860                        Ok(map_cloud_completion_events(
 861                            Box::pin(
 862                                response_lines(response, includes_status_messages)
 863                                    .chain(usage_updated_event(usage))
 864                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 865                            ),
 866                            &provider_name,
 867                            move |event| mapper.map_event(event),
 868                        ))
 869                    });
 870                    async move { Ok(future.await?.boxed()) }.boxed()
 871                } else {
 872                    let request = into_open_ai(
 873                        request,
 874                        &self.model.id.0,
 875                        self.model.supports_parallel_tool_calls,
 876                        true,
 877                        None,
 878                        None,
 879                    );
 880                    let future = self.request_limiter.stream(async move {
 881                        let PerformLlmCompletionResponse {
 882                            response,
 883                            usage,
 884                            includes_status_messages,
 885                            tool_use_limit_reached,
 886                        } = Self::perform_llm_completion(
 887                            client.clone(),
 888                            llm_api_token,
 889                            app_version,
 890                            CompletionBody {
 891                                thread_id,
 892                                prompt_id,
 893                                intent,
 894                                mode,
 895                                provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 896                                model: request.model.clone(),
 897                                provider_request: serde_json::to_value(&request)
 898                                    .map_err(|e| anyhow!(e))?,
 899                            },
 900                        )
 901                        .await?;
 902
 903                        let mut mapper = OpenAiEventMapper::new();
 904                        Ok(map_cloud_completion_events(
 905                            Box::pin(
 906                                response_lines(response, includes_status_messages)
 907                                    .chain(usage_updated_event(usage))
 908                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 909                            ),
 910                            &provider_name,
 911                            move |event| mapper.map_event(event),
 912                        ))
 913                    });
 914                    async move { Ok(future.await?.boxed()) }.boxed()
 915                }
 916            }
 917            cloud_llm_client::LanguageModelProvider::XAi => {
 918                let client = self.client.clone();
 919                let request = into_open_ai(
 920                    request,
 921                    &self.model.id.0,
 922                    self.model.supports_parallel_tool_calls,
 923                    false,
 924                    None,
 925                    None,
 926                );
 927                let llm_api_token = self.llm_api_token.clone();
 928                let future = self.request_limiter.stream(async move {
 929                    let PerformLlmCompletionResponse {
 930                        response,
 931                        usage,
 932                        includes_status_messages,
 933                        tool_use_limit_reached,
 934                    } = Self::perform_llm_completion(
 935                        client.clone(),
 936                        llm_api_token,
 937                        app_version,
 938                        CompletionBody {
 939                            thread_id,
 940                            prompt_id,
 941                            intent,
 942                            mode,
 943                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 944                            model: request.model.clone(),
 945                            provider_request: serde_json::to_value(&request)
 946                                .map_err(|e| anyhow!(e))?,
 947                        },
 948                    )
 949                    .await?;
 950
 951                    let mut mapper = OpenAiEventMapper::new();
 952                    Ok(map_cloud_completion_events(
 953                        Box::pin(
 954                            response_lines(response, includes_status_messages)
 955                                .chain(usage_updated_event(usage))
 956                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 957                        ),
 958                        &provider_name,
 959                        move |event| mapper.map_event(event),
 960                    ))
 961                });
 962                async move { Ok(future.await?.boxed()) }.boxed()
 963            }
 964            cloud_llm_client::LanguageModelProvider::Google => {
 965                let client = self.client.clone();
 966                let request =
 967                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 968                let llm_api_token = self.llm_api_token.clone();
 969                let future = self.request_limiter.stream(async move {
 970                    let PerformLlmCompletionResponse {
 971                        response,
 972                        usage,
 973                        includes_status_messages,
 974                        tool_use_limit_reached,
 975                    } = Self::perform_llm_completion(
 976                        client.clone(),
 977                        llm_api_token,
 978                        app_version,
 979                        CompletionBody {
 980                            thread_id,
 981                            prompt_id,
 982                            intent,
 983                            mode,
 984                            provider: cloud_llm_client::LanguageModelProvider::Google,
 985                            model: request.model.model_id.clone(),
 986                            provider_request: serde_json::to_value(&request)
 987                                .map_err(|e| anyhow!(e))?,
 988                        },
 989                    )
 990                    .await?;
 991
 992                    let mut mapper = GoogleEventMapper::new();
 993                    Ok(map_cloud_completion_events(
 994                        Box::pin(
 995                            response_lines(response, includes_status_messages)
 996                                .chain(usage_updated_event(usage))
 997                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 998                        ),
 999                        &provider_name,
1000                        move |event| mapper.map_event(event),
1001                    ))
1002                });
1003                async move { Ok(future.await?.boxed()) }.boxed()
1004            }
1005        }
1006    }
1007}
1008
1009fn map_cloud_completion_events<T, F>(
1010    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
1011    provider: &LanguageModelProviderName,
1012    mut map_callback: F,
1013) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1014where
1015    T: DeserializeOwned + 'static,
1016    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1017        + Send
1018        + 'static,
1019{
1020    let provider = provider.clone();
1021    stream
1022        .flat_map(move |event| {
1023            futures::stream::iter(match event {
1024                Err(error) => {
1025                    vec![Err(LanguageModelCompletionError::from(error))]
1026                }
1027                Ok(CompletionEvent::Status(event)) => {
1028                    vec![
1029                        LanguageModelCompletionEvent::from_completion_request_status(
1030                            event,
1031                            provider.clone(),
1032                        ),
1033                    ]
1034                }
1035                Ok(CompletionEvent::Event(event)) => map_callback(event),
1036            })
1037        })
1038        .boxed()
1039}
1040
1041fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1042    match provider {
1043        cloud_llm_client::LanguageModelProvider::Anthropic => {
1044            language_model::ANTHROPIC_PROVIDER_NAME
1045        }
1046        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1047        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1048        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1049    }
1050}
1051
1052fn usage_updated_event<T>(
1053    usage: Option<ModelRequestUsage>,
1054) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1055    futures::stream::iter(usage.map(|usage| {
1056        Ok(CompletionEvent::Status(
1057            CompletionRequestStatus::UsageUpdated {
1058                amount: usage.amount as usize,
1059                limit: usage.limit,
1060            },
1061        ))
1062    }))
1063}
1064
1065fn tool_use_limit_reached_event<T>(
1066    tool_use_limit_reached: bool,
1067) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1068    futures::stream::iter(tool_use_limit_reached.then(|| {
1069        Ok(CompletionEvent::Status(
1070            CompletionRequestStatus::ToolUseLimitReached,
1071        ))
1072    }))
1073}
1074
1075fn response_lines<T: DeserializeOwned>(
1076    response: Response<AsyncBody>,
1077    includes_status_messages: bool,
1078) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1079    futures::stream::try_unfold(
1080        (String::new(), BufReader::new(response.into_body())),
1081        move |(mut line, mut body)| async move {
1082            match body.read_line(&mut line).await {
1083                Ok(0) => Ok(None),
1084                Ok(_) => {
1085                    let event = if includes_status_messages {
1086                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1087                    } else {
1088                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1089                    };
1090
1091                    line.clear();
1092                    Ok(Some((event, (line, body))))
1093                }
1094                Err(e) => Err(e.into()),
1095            }
1096        },
1097    )
1098}
1099
1100#[derive(IntoElement, RegisterComponent)]
1101struct ZedAiConfiguration {
1102    is_connected: bool,
1103    plan: Option<Plan>,
1104    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1105    eligible_for_trial: bool,
1106    account_too_young: bool,
1107    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1108}
1109
1110impl RenderOnce for ZedAiConfiguration {
1111    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1112        let is_pro = self.plan.is_some_and(|plan| {
1113            matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1114        });
1115        let subscription_text = match (self.plan, self.subscription_period) {
1116            (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1117                "You have access to Zed's hosted models through your Pro subscription."
1118            }
1119            (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1120                "You have access to Zed's hosted models through your Pro trial."
1121            }
1122            (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1123                "You have basic access to Zed's hosted models through the Free plan."
1124            }
1125            (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1126                if self.eligible_for_trial {
1127                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1128                } else {
1129                    "Subscribe for access to Zed's hosted models."
1130                }
1131            }
1132            _ => {
1133                if self.eligible_for_trial {
1134                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1135                } else {
1136                    "Subscribe for access to Zed's hosted models."
1137                }
1138            }
1139        };
1140
1141        let manage_subscription_buttons = if is_pro {
1142            Button::new("manage_settings", "Manage Subscription")
1143                .full_width()
1144                .style(ButtonStyle::Tinted(TintColor::Accent))
1145                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1146                .into_any_element()
1147        } else if self.plan.is_none() || self.eligible_for_trial {
1148            Button::new("start_trial", "Start 14-day Free Pro Trial")
1149                .full_width()
1150                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1151                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1152                .into_any_element()
1153        } else {
1154            Button::new("upgrade", "Upgrade to Pro")
1155                .full_width()
1156                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1157                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1158                .into_any_element()
1159        };
1160
1161        if !self.is_connected {
1162            return v_flex()
1163                .gap_2()
1164                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1165                .child(
1166                    Button::new("sign_in", "Sign In to use Zed AI")
1167                        .icon_color(Color::Muted)
1168                        .icon(IconName::Github)
1169                        .icon_size(IconSize::Small)
1170                        .icon_position(IconPosition::Start)
1171                        .full_width()
1172                        .on_click({
1173                            let callback = self.sign_in_callback.clone();
1174                            move |_, window, cx| (callback)(window, cx)
1175                        }),
1176                );
1177        }
1178
1179        v_flex().gap_2().w_full().map(|this| {
1180            if self.account_too_young {
1181                this.child(YoungAccountBanner).child(
1182                    Button::new("upgrade", "Upgrade to Pro")
1183                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1184                        .full_width()
1185                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1186                )
1187            } else {
1188                this.text_sm()
1189                    .child(subscription_text)
1190                    .child(manage_subscription_buttons)
1191            }
1192        })
1193    }
1194}
1195
1196struct ConfigurationView {
1197    state: Entity<State>,
1198    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1199}
1200
1201impl ConfigurationView {
1202    fn new(state: Entity<State>) -> Self {
1203        let sign_in_callback = Arc::new({
1204            let state = state.clone();
1205            move |_window: &mut Window, cx: &mut App| {
1206                state.update(cx, |state, cx| {
1207                    state.authenticate(cx).detach_and_log_err(cx);
1208                });
1209            }
1210        });
1211
1212        Self {
1213            state,
1214            sign_in_callback,
1215        }
1216    }
1217}
1218
1219impl Render for ConfigurationView {
1220    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1221        let state = self.state.read(cx);
1222        let user_store = state.user_store.read(cx);
1223
1224        ZedAiConfiguration {
1225            is_connected: !state.is_signed_out(cx),
1226            plan: user_store.plan(),
1227            subscription_period: user_store.subscription_period(),
1228            eligible_for_trial: user_store.trial_started_at().is_none(),
1229            account_too_young: user_store.account_too_young(),
1230            sign_in_callback: self.sign_in_callback.clone(),
1231        }
1232    }
1233}
1234
1235impl Component for ZedAiConfiguration {
1236    fn name() -> &'static str {
1237        "AI Configuration Content"
1238    }
1239
1240    fn sort_name() -> &'static str {
1241        "AI Configuration Content"
1242    }
1243
1244    fn scope() -> ComponentScope {
1245        ComponentScope::Onboarding
1246    }
1247
1248    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1249        fn configuration(
1250            is_connected: bool,
1251            plan: Option<Plan>,
1252            eligible_for_trial: bool,
1253            account_too_young: bool,
1254        ) -> AnyElement {
1255            ZedAiConfiguration {
1256                is_connected,
1257                plan,
1258                subscription_period: plan
1259                    .is_some()
1260                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1261                eligible_for_trial,
1262                account_too_young,
1263                sign_in_callback: Arc::new(|_, _| {}),
1264            }
1265            .into_any_element()
1266        }
1267
1268        Some(
1269            v_flex()
1270                .p_4()
1271                .gap_4()
1272                .children(vec![
1273                    single_example("Not connected", configuration(false, None, false, false)),
1274                    single_example(
1275                        "Accept Terms of Service",
1276                        configuration(true, None, true, false),
1277                    ),
1278                    single_example(
1279                        "No Plan - Not eligible for trial",
1280                        configuration(true, None, false, false),
1281                    ),
1282                    single_example(
1283                        "No Plan - Eligible for trial",
1284                        configuration(true, None, true, false),
1285                    ),
1286                    single_example(
1287                        "Free Plan",
1288                        configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1289                    ),
1290                    single_example(
1291                        "Zed Pro Trial Plan",
1292                        configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1293                    ),
1294                    single_example(
1295                        "Zed Pro Plan",
1296                        configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1297                    ),
1298                ])
1299                .into_any_element(),
1300        )
1301    }
1302}
1303
1304#[cfg(test)]
1305mod tests {
1306    use super::*;
1307    use http_client::http::{HeaderMap, StatusCode};
1308    use language_model::LanguageModelCompletionError;
1309
1310    #[test]
1311    fn test_api_error_conversion_with_upstream_http_error() {
1312        // upstream_http_error with 503 status should become ServerOverloaded
1313        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1314
1315        let api_error = ApiError {
1316            status: StatusCode::INTERNAL_SERVER_ERROR,
1317            body: error_body.to_string(),
1318            headers: HeaderMap::new(),
1319        };
1320
1321        let completion_error: LanguageModelCompletionError = api_error.into();
1322
1323        match completion_error {
1324            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1325                assert_eq!(
1326                    message,
1327                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1328                );
1329            }
1330            _ => panic!(
1331                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1332                completion_error
1333            ),
1334        }
1335
1336        // upstream_http_error with 500 status should become ApiInternalServerError
1337        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1338
1339        let api_error = ApiError {
1340            status: StatusCode::INTERNAL_SERVER_ERROR,
1341            body: error_body.to_string(),
1342            headers: HeaderMap::new(),
1343        };
1344
1345        let completion_error: LanguageModelCompletionError = api_error.into();
1346
1347        match completion_error {
1348            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1349                assert_eq!(
1350                    message,
1351                    "Received an error from the OpenAI API: internal server error"
1352                );
1353            }
1354            _ => panic!(
1355                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1356                completion_error
1357            ),
1358        }
1359
1360        // upstream_http_error with 429 status should become RateLimitExceeded
1361        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1362
1363        let api_error = ApiError {
1364            status: StatusCode::INTERNAL_SERVER_ERROR,
1365            body: error_body.to_string(),
1366            headers: HeaderMap::new(),
1367        };
1368
1369        let completion_error: LanguageModelCompletionError = api_error.into();
1370
1371        match completion_error {
1372            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1373                assert_eq!(
1374                    message,
1375                    "Received an error from the Google API: rate limit exceeded"
1376                );
1377            }
1378            _ => panic!(
1379                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1380                completion_error
1381            ),
1382        }
1383
1384        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1385        let error_body = "Regular internal server error";
1386
1387        let api_error = ApiError {
1388            status: StatusCode::INTERNAL_SERVER_ERROR,
1389            body: error_body.to_string(),
1390            headers: HeaderMap::new(),
1391        };
1392
1393        let completion_error: LanguageModelCompletionError = api_error.into();
1394
1395        match completion_error {
1396            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1397                assert_eq!(provider, PROVIDER_NAME);
1398                assert_eq!(message, "Regular internal server error");
1399            }
1400            _ => panic!(
1401                "Expected ApiInternalServerError for regular 500, got: {:?}",
1402                completion_error
1403            ),
1404        }
1405
1406        // upstream_http_429 format should be converted to UpstreamProviderError
1407        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1408
1409        let api_error = ApiError {
1410            status: StatusCode::INTERNAL_SERVER_ERROR,
1411            body: error_body.to_string(),
1412            headers: HeaderMap::new(),
1413        };
1414
1415        let completion_error: LanguageModelCompletionError = api_error.into();
1416
1417        match completion_error {
1418            LanguageModelCompletionError::UpstreamProviderError {
1419                message,
1420                status,
1421                retry_after,
1422            } => {
1423                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1424                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1425                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1426            }
1427            _ => panic!(
1428                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1429                completion_error
1430            ),
1431        }
1432
1433        // Invalid JSON in error body should fall back to regular error handling
1434        let error_body = "Not JSON at all";
1435
1436        let api_error = ApiError {
1437            status: StatusCode::INTERNAL_SERVER_ERROR,
1438            body: error_body.to_string(),
1439            headers: HeaderMap::new(),
1440        };
1441
1442        let completion_error: LanguageModelCompletionError = api_error.into();
1443
1444        match completion_error {
1445            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1446                assert_eq!(provider, PROVIDER_NAME);
1447            }
1448            _ => panic!(
1449                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1450                completion_error
1451            ),
1452        }
1453    }
1454}