cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, UserStore, zed_urls};
   6use cloud_api_types::{OrganizationId, Plan};
   7use cloud_llm_client::{
   8    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   9    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  10    CountTokensBody, CountTokensResponse, ListModelsResponse,
  11    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  12};
  13use futures::{
  14    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  15    future::BoxFuture,
  16    stream::{self, BoxStream},
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  25    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  26    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::collections::VecDeque;
  39use std::pin::Pin;
  40use std::str::FromStr;
  41use std::sync::Arc;
  42use std::task::Poll;
  43use std::time::Duration;
  44use thiserror::Error;
  45use ui::{TintColor, prelude::*};
  46
  47use crate::provider::anthropic::{
  48    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  49};
  50use crate::provider::google::{GoogleEventMapper, into_google};
  51use crate::provider::open_ai::{
  52    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  53    into_open_ai_response,
  54};
  55use crate::provider::x_ai::count_xai_tokens;
  56
  57const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  58const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  59
  60#[derive(Default, Clone, Debug, PartialEq)]
  61pub struct ZedDotDevSettings {
  62    pub available_models: Vec<AvailableModel>,
  63}
  64#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  65#[serde(tag = "type", rename_all = "lowercase")]
  66pub enum ModelMode {
  67    #[default]
  68    Default,
  69    Thinking {
  70        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  71        budget_tokens: Option<u32>,
  72    },
  73}
  74
  75impl From<ModelMode> for AnthropicModelMode {
  76    fn from(value: ModelMode) -> Self {
  77        match value {
  78            ModelMode::Default => AnthropicModelMode::Default,
  79            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  80        }
  81    }
  82}
  83
  84pub struct CloudLanguageModelProvider {
  85    client: Arc<Client>,
  86    state: Entity<State>,
  87    _maintain_client_status: Task<()>,
  88}
  89
  90pub struct State {
  91    client: Arc<Client>,
  92    llm_api_token: LlmApiToken,
  93    user_store: Entity<UserStore>,
  94    status: client::Status,
  95    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  96    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  97    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  98    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  99    _user_store_subscription: Subscription,
 100    _settings_subscription: Subscription,
 101    _llm_token_subscription: Subscription,
 102}
 103
 104impl State {
 105    fn new(
 106        client: Arc<Client>,
 107        user_store: Entity<UserStore>,
 108        status: client::Status,
 109        cx: &mut Context<Self>,
 110    ) -> Self {
 111        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 112        Self {
 113            client: client.clone(),
 114            llm_api_token: LlmApiToken::default(),
 115            user_store: user_store.clone(),
 116            status,
 117            models: Vec::new(),
 118            default_model: None,
 119            default_fast_model: None,
 120            recommended_models: Vec::new(),
 121            _user_store_subscription: cx.subscribe(
 122                &user_store,
 123                move |this, _user_store, event, cx| match event {
 124                    client::user::Event::PrivateUserInfoUpdated => {
 125                        let status = *client.status().borrow();
 126                        if status.is_signed_out() {
 127                            return;
 128                        }
 129
 130                        let client = this.client.clone();
 131                        let llm_api_token = this.llm_api_token.clone();
 132                        let organization_id = this
 133                            .user_store
 134                            .read(cx)
 135                            .current_organization()
 136                            .map(|organization| organization.id.clone());
 137                        cx.spawn(async move |this, cx| {
 138                            let response =
 139                                Self::fetch_models(client, llm_api_token, organization_id).await?;
 140                            this.update(cx, |this, cx| this.update_models(response, cx))
 141                        })
 142                        .detach_and_log_err(cx);
 143                    }
 144                    _ => {}
 145                },
 146            ),
 147            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 148                cx.notify();
 149            }),
 150            _llm_token_subscription: cx.subscribe(
 151                &refresh_llm_token_listener,
 152                move |this, _listener, _event, cx| {
 153                    let client = this.client.clone();
 154                    let llm_api_token = this.llm_api_token.clone();
 155                    let organization_id = this
 156                        .user_store
 157                        .read(cx)
 158                        .current_organization()
 159                        .map(|o| o.id.clone());
 160                    cx.spawn(async move |this, cx| {
 161                        llm_api_token
 162                            .refresh(&client, organization_id.clone())
 163                            .await?;
 164                        let response =
 165                            Self::fetch_models(client, llm_api_token, organization_id).await?;
 166                        this.update(cx, |this, cx| {
 167                            this.update_models(response, cx);
 168                        })
 169                    })
 170                    .detach_and_log_err(cx);
 171                },
 172            ),
 173        }
 174    }
 175
 176    fn is_signed_out(&self, cx: &App) -> bool {
 177        self.user_store.read(cx).current_user().is_none()
 178    }
 179
 180    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 181        let client = self.client.clone();
 182        cx.spawn(async move |state, cx| {
 183            client.sign_in_with_optional_connect(true, cx).await?;
 184            state.update(cx, |_, cx| cx.notify())
 185        })
 186    }
 187
 188    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 189        let mut models = Vec::new();
 190
 191        for model in response.models {
 192            models.push(Arc::new(model.clone()));
 193        }
 194
 195        self.default_model = models
 196            .iter()
 197            .find(|model| {
 198                response
 199                    .default_model
 200                    .as_ref()
 201                    .is_some_and(|default_model_id| &model.id == default_model_id)
 202            })
 203            .cloned();
 204        self.default_fast_model = models
 205            .iter()
 206            .find(|model| {
 207                response
 208                    .default_fast_model
 209                    .as_ref()
 210                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 211            })
 212            .cloned();
 213        self.recommended_models = response
 214            .recommended_models
 215            .iter()
 216            .filter_map(|id| models.iter().find(|model| &model.id == id))
 217            .cloned()
 218            .collect();
 219        self.models = models;
 220        cx.notify();
 221    }
 222
 223    async fn fetch_models(
 224        client: Arc<Client>,
 225        llm_api_token: LlmApiToken,
 226        organization_id: Option<OrganizationId>,
 227    ) -> Result<ListModelsResponse> {
 228        let http_client = &client.http_client();
 229        let token = llm_api_token.acquire(&client, organization_id).await?;
 230
 231        let request = http_client::Request::builder()
 232            .method(Method::GET)
 233            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 234            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 235            .header("Authorization", format!("Bearer {token}"))
 236            .body(AsyncBody::empty())?;
 237        let mut response = http_client
 238            .send(request)
 239            .await
 240            .context("failed to send list models request")?;
 241
 242        if response.status().is_success() {
 243            let mut body = String::new();
 244            response.body_mut().read_to_string(&mut body).await?;
 245            Ok(serde_json::from_str(&body)?)
 246        } else {
 247            let mut body = String::new();
 248            response.body_mut().read_to_string(&mut body).await?;
 249            anyhow::bail!(
 250                "error listing models.\nStatus: {:?}\nBody: {body}",
 251                response.status(),
 252            );
 253        }
 254    }
 255}
 256
 257impl CloudLanguageModelProvider {
 258    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 259        let mut status_rx = client.status();
 260        let status = *status_rx.borrow();
 261
 262        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 263
 264        let state_ref = state.downgrade();
 265        let maintain_client_status = cx.spawn(async move |cx| {
 266            while let Some(status) = status_rx.next().await {
 267                if let Some(this) = state_ref.upgrade() {
 268                    _ = this.update(cx, |this, cx| {
 269                        if this.status != status {
 270                            this.status = status;
 271                            cx.notify();
 272                        }
 273                    });
 274                } else {
 275                    break;
 276                }
 277            }
 278        });
 279
 280        Self {
 281            client,
 282            state,
 283            _maintain_client_status: maintain_client_status,
 284        }
 285    }
 286
 287    fn create_language_model(
 288        &self,
 289        model: Arc<cloud_llm_client::LanguageModel>,
 290        llm_api_token: LlmApiToken,
 291        user_store: Entity<UserStore>,
 292    ) -> Arc<dyn LanguageModel> {
 293        Arc::new(CloudLanguageModel {
 294            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 295            model,
 296            llm_api_token,
 297            user_store,
 298            client: self.client.clone(),
 299            request_limiter: RateLimiter::new(4),
 300        })
 301    }
 302}
 303
 304impl LanguageModelProviderState for CloudLanguageModelProvider {
 305    type ObservableEntity = State;
 306
 307    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 308        Some(self.state.clone())
 309    }
 310}
 311
 312impl LanguageModelProvider for CloudLanguageModelProvider {
 313    fn id(&self) -> LanguageModelProviderId {
 314        PROVIDER_ID
 315    }
 316
 317    fn name(&self) -> LanguageModelProviderName {
 318        PROVIDER_NAME
 319    }
 320
 321    fn icon(&self) -> IconOrSvg {
 322        IconOrSvg::Icon(IconName::AiZed)
 323    }
 324
 325    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 326        let state = self.state.read(cx);
 327        let default_model = state.default_model.clone()?;
 328        let llm_api_token = state.llm_api_token.clone();
 329        let user_store = state.user_store.clone();
 330        Some(self.create_language_model(default_model, llm_api_token, user_store))
 331    }
 332
 333    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 334        let state = self.state.read(cx);
 335        let default_fast_model = state.default_fast_model.clone()?;
 336        let llm_api_token = state.llm_api_token.clone();
 337        let user_store = state.user_store.clone();
 338        Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
 339    }
 340
 341    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 342        let state = self.state.read(cx);
 343        let llm_api_token = state.llm_api_token.clone();
 344        let user_store = state.user_store.clone();
 345        state
 346            .recommended_models
 347            .iter()
 348            .cloned()
 349            .map(|model| {
 350                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 351            })
 352            .collect()
 353    }
 354
 355    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 356        let state = self.state.read(cx);
 357        let llm_api_token = state.llm_api_token.clone();
 358        let user_store = state.user_store.clone();
 359        state
 360            .models
 361            .iter()
 362            .cloned()
 363            .map(|model| {
 364                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 365            })
 366            .collect()
 367    }
 368
 369    fn is_authenticated(&self, cx: &App) -> bool {
 370        let state = self.state.read(cx);
 371        !state.is_signed_out(cx)
 372    }
 373
 374    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 375        Task::ready(Ok(()))
 376    }
 377
 378    fn configuration_view(
 379        &self,
 380        _target_agent: language_model::ConfigurationViewTargetAgent,
 381        _: &mut Window,
 382        cx: &mut App,
 383    ) -> AnyView {
 384        cx.new(|_| ConfigurationView::new(self.state.clone()))
 385            .into()
 386    }
 387
 388    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 389        Task::ready(Ok(()))
 390    }
 391}
 392
 393pub struct CloudLanguageModel {
 394    id: LanguageModelId,
 395    model: Arc<cloud_llm_client::LanguageModel>,
 396    llm_api_token: LlmApiToken,
 397    user_store: Entity<UserStore>,
 398    client: Arc<Client>,
 399    request_limiter: RateLimiter,
 400}
 401
 402struct PerformLlmCompletionResponse {
 403    response: Response<AsyncBody>,
 404    includes_status_messages: bool,
 405}
 406
 407impl CloudLanguageModel {
 408    async fn perform_llm_completion(
 409        client: Arc<Client>,
 410        llm_api_token: LlmApiToken,
 411        organization_id: Option<OrganizationId>,
 412        app_version: Option<Version>,
 413        body: CompletionBody,
 414    ) -> Result<PerformLlmCompletionResponse> {
 415        let http_client = &client.http_client();
 416
 417        let mut token = llm_api_token
 418            .acquire(&client, organization_id.clone())
 419            .await?;
 420        let mut refreshed_token = false;
 421
 422        loop {
 423            let request = http_client::Request::builder()
 424                .method(Method::POST)
 425                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 426                .when_some(app_version.as_ref(), |builder, app_version| {
 427                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 428                })
 429                .header("Content-Type", "application/json")
 430                .header("Authorization", format!("Bearer {token}"))
 431                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 432                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 433                .body(serde_json::to_string(&body)?.into())?;
 434
 435            let mut response = http_client.send(request).await?;
 436            let status = response.status();
 437            if status.is_success() {
 438                let includes_status_messages = response
 439                    .headers()
 440                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 441                    .is_some();
 442
 443                return Ok(PerformLlmCompletionResponse {
 444                    response,
 445                    includes_status_messages,
 446                });
 447            }
 448
 449            if !refreshed_token && response.needs_llm_token_refresh() {
 450                token = llm_api_token
 451                    .refresh(&client, organization_id.clone())
 452                    .await?;
 453                refreshed_token = true;
 454                continue;
 455            }
 456
 457            if status == StatusCode::PAYMENT_REQUIRED {
 458                return Err(anyhow!(PaymentRequiredError));
 459            }
 460
 461            let mut body = String::new();
 462            let headers = response.headers().clone();
 463            response.body_mut().read_to_string(&mut body).await?;
 464            return Err(anyhow!(ApiError {
 465                status,
 466                body,
 467                headers
 468            }));
 469        }
 470    }
 471}
 472
 473#[derive(Debug, Error)]
 474#[error("cloud language model request failed with status {status}: {body}")]
 475struct ApiError {
 476    status: StatusCode,
 477    body: String,
 478    headers: HeaderMap<HeaderValue>,
 479}
 480
 481/// Represents error responses from Zed's cloud API.
 482///
 483/// Example JSON for an upstream HTTP error:
 484/// ```json
 485/// {
 486///   "code": "upstream_http_error",
 487///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 488///   "upstream_status": 503
 489/// }
 490/// ```
 491#[derive(Debug, serde::Deserialize)]
 492struct CloudApiError {
 493    code: String,
 494    message: String,
 495    #[serde(default)]
 496    #[serde(deserialize_with = "deserialize_optional_status_code")]
 497    upstream_status: Option<StatusCode>,
 498    #[serde(default)]
 499    retry_after: Option<f64>,
 500}
 501
 502fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 503where
 504    D: serde::Deserializer<'de>,
 505{
 506    let opt: Option<u16> = Option::deserialize(deserializer)?;
 507    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 508}
 509
 510impl From<ApiError> for LanguageModelCompletionError {
 511    fn from(error: ApiError) -> Self {
 512        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 513            if cloud_error.code.starts_with("upstream_http_") {
 514                let status = if let Some(status) = cloud_error.upstream_status {
 515                    status
 516                } else if cloud_error.code.ends_with("_error") {
 517                    error.status
 518                } else {
 519                    // If there's a status code in the code string (e.g. "upstream_http_429")
 520                    // then use that; otherwise, see if the JSON contains a status code.
 521                    cloud_error
 522                        .code
 523                        .strip_prefix("upstream_http_")
 524                        .and_then(|code_str| code_str.parse::<u16>().ok())
 525                        .and_then(|code| StatusCode::from_u16(code).ok())
 526                        .unwrap_or(error.status)
 527                };
 528
 529                return LanguageModelCompletionError::UpstreamProviderError {
 530                    message: cloud_error.message,
 531                    status,
 532                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 533                };
 534            }
 535
 536            return LanguageModelCompletionError::from_http_status(
 537                PROVIDER_NAME,
 538                error.status,
 539                cloud_error.message,
 540                None,
 541            );
 542        }
 543
 544        let retry_after = None;
 545        LanguageModelCompletionError::from_http_status(
 546            PROVIDER_NAME,
 547            error.status,
 548            error.body,
 549            retry_after,
 550        )
 551    }
 552}
 553
 554impl LanguageModel for CloudLanguageModel {
 555    fn id(&self) -> LanguageModelId {
 556        self.id.clone()
 557    }
 558
 559    fn name(&self) -> LanguageModelName {
 560        LanguageModelName::from(self.model.display_name.clone())
 561    }
 562
 563    fn provider_id(&self) -> LanguageModelProviderId {
 564        PROVIDER_ID
 565    }
 566
 567    fn provider_name(&self) -> LanguageModelProviderName {
 568        PROVIDER_NAME
 569    }
 570
 571    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 572        use cloud_llm_client::LanguageModelProvider::*;
 573        match self.model.provider {
 574            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 575            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 576            Google => language_model::GOOGLE_PROVIDER_ID,
 577            XAi => language_model::X_AI_PROVIDER_ID,
 578        }
 579    }
 580
 581    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 582        use cloud_llm_client::LanguageModelProvider::*;
 583        match self.model.provider {
 584            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 585            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 586            Google => language_model::GOOGLE_PROVIDER_NAME,
 587            XAi => language_model::X_AI_PROVIDER_NAME,
 588        }
 589    }
 590
 591    fn is_latest(&self) -> bool {
 592        self.model.is_latest
 593    }
 594
 595    fn supports_tools(&self) -> bool {
 596        self.model.supports_tools
 597    }
 598
 599    fn supports_images(&self) -> bool {
 600        self.model.supports_images
 601    }
 602
 603    fn supports_thinking(&self) -> bool {
 604        self.model.supports_thinking
 605    }
 606
 607    fn supports_fast_mode(&self) -> bool {
 608        self.model.supports_fast_mode
 609    }
 610
 611    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 612        self.model
 613            .supported_effort_levels
 614            .iter()
 615            .map(|effort_level| LanguageModelEffortLevel {
 616                name: effort_level.name.clone().into(),
 617                value: effort_level.value.clone().into(),
 618                is_default: effort_level.is_default.unwrap_or(false),
 619            })
 620            .collect()
 621    }
 622
 623    fn supports_streaming_tools(&self) -> bool {
 624        self.model.supports_streaming_tools
 625    }
 626
 627    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 628        match choice {
 629            LanguageModelToolChoice::Auto
 630            | LanguageModelToolChoice::Any
 631            | LanguageModelToolChoice::None => true,
 632        }
 633    }
 634
 635    fn supports_split_token_display(&self) -> bool {
 636        use cloud_llm_client::LanguageModelProvider::*;
 637        matches!(self.model.provider, OpenAi)
 638    }
 639
 640    fn telemetry_id(&self) -> String {
 641        format!("zed.dev/{}", self.model.id)
 642    }
 643
 644    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 645        match self.model.provider {
 646            cloud_llm_client::LanguageModelProvider::Anthropic
 647            | cloud_llm_client::LanguageModelProvider::OpenAi
 648            | cloud_llm_client::LanguageModelProvider::XAi => {
 649                LanguageModelToolSchemaFormat::JsonSchema
 650            }
 651            cloud_llm_client::LanguageModelProvider::Google => {
 652                LanguageModelToolSchemaFormat::JsonSchemaSubset
 653            }
 654        }
 655    }
 656
 657    fn max_token_count(&self) -> u64 {
 658        self.model.max_token_count as u64
 659    }
 660
 661    fn max_output_tokens(&self) -> Option<u64> {
 662        Some(self.model.max_output_tokens as u64)
 663    }
 664
 665    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 666        match &self.model.provider {
 667            cloud_llm_client::LanguageModelProvider::Anthropic => {
 668                Some(LanguageModelCacheConfiguration {
 669                    min_total_token: 2_048,
 670                    should_speculate: true,
 671                    max_cache_anchors: 4,
 672                })
 673            }
 674            cloud_llm_client::LanguageModelProvider::OpenAi
 675            | cloud_llm_client::LanguageModelProvider::XAi
 676            | cloud_llm_client::LanguageModelProvider::Google => None,
 677        }
 678    }
 679
 680    fn count_tokens(
 681        &self,
 682        request: LanguageModelRequest,
 683        cx: &App,
 684    ) -> BoxFuture<'static, Result<u64>> {
 685        match self.model.provider {
 686            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 687                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 688                .boxed(),
 689            cloud_llm_client::LanguageModelProvider::OpenAi => {
 690                let model = match open_ai::Model::from_id(&self.model.id.0) {
 691                    Ok(model) => model,
 692                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 693                };
 694                count_open_ai_tokens(request, model, cx)
 695            }
 696            cloud_llm_client::LanguageModelProvider::XAi => {
 697                let model = match x_ai::Model::from_id(&self.model.id.0) {
 698                    Ok(model) => model,
 699                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 700                };
 701                count_xai_tokens(request, model, cx)
 702            }
 703            cloud_llm_client::LanguageModelProvider::Google => {
 704                let client = self.client.clone();
 705                let llm_api_token = self.llm_api_token.clone();
 706                let organization_id = self
 707                    .user_store
 708                    .read(cx)
 709                    .current_organization()
 710                    .map(|o| o.id.clone());
 711                let model_id = self.model.id.to_string();
 712                let generate_content_request =
 713                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 714                async move {
 715                    let http_client = &client.http_client();
 716                    let token = llm_api_token.acquire(&client, organization_id).await?;
 717
 718                    let request_body = CountTokensBody {
 719                        provider: cloud_llm_client::LanguageModelProvider::Google,
 720                        model: model_id,
 721                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 722                            generate_content_request,
 723                        })?,
 724                    };
 725                    let request = http_client::Request::builder()
 726                        .method(Method::POST)
 727                        .uri(
 728                            http_client
 729                                .build_zed_llm_url("/count_tokens", &[])?
 730                                .as_ref(),
 731                        )
 732                        .header("Content-Type", "application/json")
 733                        .header("Authorization", format!("Bearer {token}"))
 734                        .body(serde_json::to_string(&request_body)?.into())?;
 735                    let mut response = http_client.send(request).await?;
 736                    let status = response.status();
 737                    let headers = response.headers().clone();
 738                    let mut response_body = String::new();
 739                    response
 740                        .body_mut()
 741                        .read_to_string(&mut response_body)
 742                        .await?;
 743
 744                    if status.is_success() {
 745                        let response_body: CountTokensResponse =
 746                            serde_json::from_str(&response_body)?;
 747
 748                        Ok(response_body.tokens as u64)
 749                    } else {
 750                        Err(anyhow!(ApiError {
 751                            status,
 752                            body: response_body,
 753                            headers
 754                        }))
 755                    }
 756                }
 757                .boxed()
 758            }
 759        }
 760    }
 761
 762    fn stream_completion(
 763        &self,
 764        request: LanguageModelRequest,
 765        cx: &AsyncApp,
 766    ) -> BoxFuture<
 767        'static,
 768        Result<
 769            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 770            LanguageModelCompletionError,
 771        >,
 772    > {
 773        let thread_id = request.thread_id.clone();
 774        let prompt_id = request.prompt_id.clone();
 775        let intent = request.intent;
 776        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 777        let user_store = self.user_store.clone();
 778        let organization_id = cx.update(|cx| {
 779            user_store
 780                .read(cx)
 781                .current_organization()
 782                .map(|o| o.id.clone())
 783        });
 784        let thinking_allowed = request.thinking_allowed;
 785        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 786        let provider_name = provider_name(&self.model.provider);
 787        match self.model.provider {
 788            cloud_llm_client::LanguageModelProvider::Anthropic => {
 789                let effort = request
 790                    .thinking_effort
 791                    .as_ref()
 792                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 793
 794                let mut request = into_anthropic(
 795                    request,
 796                    self.model.id.to_string(),
 797                    1.0,
 798                    self.model.max_output_tokens as u64,
 799                    if enable_thinking {
 800                        AnthropicModelMode::Thinking {
 801                            budget_tokens: Some(4_096),
 802                        }
 803                    } else {
 804                        AnthropicModelMode::Default
 805                    },
 806                );
 807
 808                if enable_thinking && effort.is_some() {
 809                    request.thinking = Some(anthropic::Thinking::Adaptive);
 810                    request.output_config = Some(anthropic::OutputConfig { effort });
 811                }
 812
 813                let client = self.client.clone();
 814                let llm_api_token = self.llm_api_token.clone();
 815                let organization_id = organization_id.clone();
 816                let future = self.request_limiter.stream(async move {
 817                    let PerformLlmCompletionResponse {
 818                        response,
 819                        includes_status_messages,
 820                    } = Self::perform_llm_completion(
 821                        client.clone(),
 822                        llm_api_token,
 823                        organization_id,
 824                        app_version,
 825                        CompletionBody {
 826                            thread_id,
 827                            prompt_id,
 828                            intent,
 829                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 830                            model: request.model.clone(),
 831                            provider_request: serde_json::to_value(&request)
 832                                .map_err(|e| anyhow!(e))?,
 833                        },
 834                    )
 835                    .await
 836                    .map_err(|err| match err.downcast::<ApiError>() {
 837                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 838                        Err(err) => anyhow!(err),
 839                    })?;
 840
 841                    let mut mapper = AnthropicEventMapper::new();
 842                    Ok(map_cloud_completion_events(
 843                        Box::pin(response_lines(response, includes_status_messages)),
 844                        &provider_name,
 845                        move |event| mapper.map_event(event),
 846                    ))
 847                });
 848                async move { Ok(future.await?.boxed()) }.boxed()
 849            }
 850            cloud_llm_client::LanguageModelProvider::OpenAi => {
 851                let client = self.client.clone();
 852                let llm_api_token = self.llm_api_token.clone();
 853                let organization_id = organization_id.clone();
 854                let effort = request
 855                    .thinking_effort
 856                    .as_ref()
 857                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 858
 859                let mut request = into_open_ai_response(
 860                    request,
 861                    &self.model.id.0,
 862                    self.model.supports_parallel_tool_calls,
 863                    true,
 864                    None,
 865                    None,
 866                );
 867
 868                if enable_thinking && let Some(effort) = effort {
 869                    request.reasoning = Some(open_ai::responses::ReasoningConfig { effort });
 870                }
 871
 872                let future = self.request_limiter.stream(async move {
 873                    let PerformLlmCompletionResponse {
 874                        response,
 875                        includes_status_messages,
 876                    } = Self::perform_llm_completion(
 877                        client.clone(),
 878                        llm_api_token,
 879                        organization_id,
 880                        app_version,
 881                        CompletionBody {
 882                            thread_id,
 883                            prompt_id,
 884                            intent,
 885                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 886                            model: request.model.clone(),
 887                            provider_request: serde_json::to_value(&request)
 888                                .map_err(|e| anyhow!(e))?,
 889                        },
 890                    )
 891                    .await?;
 892
 893                    let mut mapper = OpenAiResponseEventMapper::new();
 894                    Ok(map_cloud_completion_events(
 895                        Box::pin(response_lines(response, includes_status_messages)),
 896                        &provider_name,
 897                        move |event| mapper.map_event(event),
 898                    ))
 899                });
 900                async move { Ok(future.await?.boxed()) }.boxed()
 901            }
 902            cloud_llm_client::LanguageModelProvider::XAi => {
 903                let client = self.client.clone();
 904                let request = into_open_ai(
 905                    request,
 906                    &self.model.id.0,
 907                    self.model.supports_parallel_tool_calls,
 908                    false,
 909                    None,
 910                    None,
 911                );
 912                let llm_api_token = self.llm_api_token.clone();
 913                let organization_id = organization_id.clone();
 914                let future = self.request_limiter.stream(async move {
 915                    let PerformLlmCompletionResponse {
 916                        response,
 917                        includes_status_messages,
 918                    } = Self::perform_llm_completion(
 919                        client.clone(),
 920                        llm_api_token,
 921                        organization_id,
 922                        app_version,
 923                        CompletionBody {
 924                            thread_id,
 925                            prompt_id,
 926                            intent,
 927                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 928                            model: request.model.clone(),
 929                            provider_request: serde_json::to_value(&request)
 930                                .map_err(|e| anyhow!(e))?,
 931                        },
 932                    )
 933                    .await?;
 934
 935                    let mut mapper = OpenAiEventMapper::new();
 936                    Ok(map_cloud_completion_events(
 937                        Box::pin(response_lines(response, includes_status_messages)),
 938                        &provider_name,
 939                        move |event| mapper.map_event(event),
 940                    ))
 941                });
 942                async move { Ok(future.await?.boxed()) }.boxed()
 943            }
 944            cloud_llm_client::LanguageModelProvider::Google => {
 945                let client = self.client.clone();
 946                let request =
 947                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 948                let llm_api_token = self.llm_api_token.clone();
 949                let future = self.request_limiter.stream(async move {
 950                    let PerformLlmCompletionResponse {
 951                        response,
 952                        includes_status_messages,
 953                    } = Self::perform_llm_completion(
 954                        client.clone(),
 955                        llm_api_token,
 956                        organization_id,
 957                        app_version,
 958                        CompletionBody {
 959                            thread_id,
 960                            prompt_id,
 961                            intent,
 962                            provider: cloud_llm_client::LanguageModelProvider::Google,
 963                            model: request.model.model_id.clone(),
 964                            provider_request: serde_json::to_value(&request)
 965                                .map_err(|e| anyhow!(e))?,
 966                        },
 967                    )
 968                    .await?;
 969
 970                    let mut mapper = GoogleEventMapper::new();
 971                    Ok(map_cloud_completion_events(
 972                        Box::pin(response_lines(response, includes_status_messages)),
 973                        &provider_name,
 974                        move |event| mapper.map_event(event),
 975                    ))
 976                });
 977                async move { Ok(future.await?.boxed()) }.boxed()
 978            }
 979        }
 980    }
 981}
 982
 983fn map_cloud_completion_events<T, F>(
 984    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 985    provider: &LanguageModelProviderName,
 986    mut map_callback: F,
 987) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 988where
 989    T: DeserializeOwned + 'static,
 990    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 991        + Send
 992        + 'static,
 993{
 994    let provider = provider.clone();
 995    let mut stream = stream.fuse();
 996
 997    let mut saw_stream_ended = false;
 998
 999    let mut done = false;
1000    let mut pending = VecDeque::new();
1001
1002    stream::poll_fn(move |cx| {
1003        loop {
1004            if let Some(item) = pending.pop_front() {
1005                return Poll::Ready(Some(item));
1006            }
1007
1008            if done {
1009                return Poll::Ready(None);
1010            }
1011
1012            match stream.poll_next_unpin(cx) {
1013                Poll::Ready(Some(event)) => {
1014                    let items = match event {
1015                        Err(error) => {
1016                            vec![Err(LanguageModelCompletionError::from(error))]
1017                        }
1018                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
1019                            saw_stream_ended = true;
1020                            vec![]
1021                        }
1022                        Ok(CompletionEvent::Status(status)) => {
1023                            LanguageModelCompletionEvent::from_completion_request_status(
1024                                status,
1025                                provider.clone(),
1026                            )
1027                            .transpose()
1028                            .map(|event| vec![event])
1029                            .unwrap_or_default()
1030                        }
1031                        Ok(CompletionEvent::Event(event)) => map_callback(event),
1032                    };
1033                    pending.extend(items);
1034                }
1035                Poll::Ready(None) => {
1036                    done = true;
1037
1038                    if !saw_stream_ended {
1039                        return Poll::Ready(Some(Err(
1040                            LanguageModelCompletionError::StreamEndedUnexpectedly {
1041                                provider: provider.clone(),
1042                            },
1043                        )));
1044                    }
1045                }
1046                Poll::Pending => return Poll::Pending,
1047            }
1048        }
1049    })
1050    .boxed()
1051}
1052
1053fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1054    match provider {
1055        cloud_llm_client::LanguageModelProvider::Anthropic => {
1056            language_model::ANTHROPIC_PROVIDER_NAME
1057        }
1058        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1059        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1060        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1061    }
1062}
1063
1064fn response_lines<T: DeserializeOwned>(
1065    response: Response<AsyncBody>,
1066    includes_status_messages: bool,
1067) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1068    futures::stream::try_unfold(
1069        (String::new(), BufReader::new(response.into_body())),
1070        move |(mut line, mut body)| async move {
1071            match body.read_line(&mut line).await {
1072                Ok(0) => Ok(None),
1073                Ok(_) => {
1074                    let event = if includes_status_messages {
1075                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1076                    } else {
1077                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1078                    };
1079
1080                    line.clear();
1081                    Ok(Some((event, (line, body))))
1082                }
1083                Err(e) => Err(e.into()),
1084            }
1085        },
1086    )
1087}
1088
1089#[derive(IntoElement, RegisterComponent)]
1090struct ZedAiConfiguration {
1091    is_connected: bool,
1092    plan: Option<Plan>,
1093    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1094    eligible_for_trial: bool,
1095    account_too_young: bool,
1096    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1097}
1098
1099impl RenderOnce for ZedAiConfiguration {
1100    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1101        let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
1102        let subscription_text = match (self.plan, self.subscription_period) {
1103            (Some(Plan::ZedPro), Some(_)) => {
1104                "You have access to Zed's hosted models through your Pro subscription."
1105            }
1106            (Some(Plan::ZedProTrial), Some(_)) => {
1107                "You have access to Zed's hosted models through your Pro trial."
1108            }
1109            (Some(Plan::ZedFree), Some(_)) => {
1110                if self.eligible_for_trial {
1111                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1112                } else {
1113                    "Subscribe for access to Zed's hosted models."
1114                }
1115            }
1116            _ => {
1117                if self.eligible_for_trial {
1118                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1119                } else {
1120                    "Subscribe for access to Zed's hosted models."
1121                }
1122            }
1123        };
1124
1125        let manage_subscription_buttons = if is_pro {
1126            Button::new("manage_settings", "Manage Subscription")
1127                .full_width()
1128                .style(ButtonStyle::Tinted(TintColor::Accent))
1129                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1130                .into_any_element()
1131        } else if self.plan.is_none() || self.eligible_for_trial {
1132            Button::new("start_trial", "Start 14-day Free Pro Trial")
1133                .full_width()
1134                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1135                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1136                .into_any_element()
1137        } else {
1138            Button::new("upgrade", "Upgrade to Pro")
1139                .full_width()
1140                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1141                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1142                .into_any_element()
1143        };
1144
1145        if !self.is_connected {
1146            return v_flex()
1147                .gap_2()
1148                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1149                .child(
1150                    Button::new("sign_in", "Sign In to use Zed AI")
1151                        .icon_color(Color::Muted)
1152                        .icon(IconName::Github)
1153                        .icon_size(IconSize::Small)
1154                        .icon_position(IconPosition::Start)
1155                        .full_width()
1156                        .on_click({
1157                            let callback = self.sign_in_callback.clone();
1158                            move |_, window, cx| (callback)(window, cx)
1159                        }),
1160                );
1161        }
1162
1163        v_flex().gap_2().w_full().map(|this| {
1164            if self.account_too_young {
1165                this.child(YoungAccountBanner).child(
1166                    Button::new("upgrade", "Upgrade to Pro")
1167                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1168                        .full_width()
1169                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1170                )
1171            } else {
1172                this.text_sm()
1173                    .child(subscription_text)
1174                    .child(manage_subscription_buttons)
1175            }
1176        })
1177    }
1178}
1179
1180struct ConfigurationView {
1181    state: Entity<State>,
1182    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1183}
1184
1185impl ConfigurationView {
1186    fn new(state: Entity<State>) -> Self {
1187        let sign_in_callback = Arc::new({
1188            let state = state.clone();
1189            move |_window: &mut Window, cx: &mut App| {
1190                state.update(cx, |state, cx| {
1191                    state.authenticate(cx).detach_and_log_err(cx);
1192                });
1193            }
1194        });
1195
1196        Self {
1197            state,
1198            sign_in_callback,
1199        }
1200    }
1201}
1202
1203impl Render for ConfigurationView {
1204    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1205        let state = self.state.read(cx);
1206        let user_store = state.user_store.read(cx);
1207
1208        ZedAiConfiguration {
1209            is_connected: !state.is_signed_out(cx),
1210            plan: user_store.plan(),
1211            subscription_period: user_store.subscription_period(),
1212            eligible_for_trial: user_store.trial_started_at().is_none(),
1213            account_too_young: user_store.account_too_young(),
1214            sign_in_callback: self.sign_in_callback.clone(),
1215        }
1216    }
1217}
1218
1219impl Component for ZedAiConfiguration {
1220    fn name() -> &'static str {
1221        "AI Configuration Content"
1222    }
1223
1224    fn sort_name() -> &'static str {
1225        "AI Configuration Content"
1226    }
1227
1228    fn scope() -> ComponentScope {
1229        ComponentScope::Onboarding
1230    }
1231
1232    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1233        fn configuration(
1234            is_connected: bool,
1235            plan: Option<Plan>,
1236            eligible_for_trial: bool,
1237            account_too_young: bool,
1238        ) -> AnyElement {
1239            ZedAiConfiguration {
1240                is_connected,
1241                plan,
1242                subscription_period: plan
1243                    .is_some()
1244                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1245                eligible_for_trial,
1246                account_too_young,
1247                sign_in_callback: Arc::new(|_, _| {}),
1248            }
1249            .into_any_element()
1250        }
1251
1252        Some(
1253            v_flex()
1254                .p_4()
1255                .gap_4()
1256                .children(vec![
1257                    single_example("Not connected", configuration(false, None, false, false)),
1258                    single_example(
1259                        "Accept Terms of Service",
1260                        configuration(true, None, true, false),
1261                    ),
1262                    single_example(
1263                        "No Plan - Not eligible for trial",
1264                        configuration(true, None, false, false),
1265                    ),
1266                    single_example(
1267                        "No Plan - Eligible for trial",
1268                        configuration(true, None, true, false),
1269                    ),
1270                    single_example(
1271                        "Free Plan",
1272                        configuration(true, Some(Plan::ZedFree), true, false),
1273                    ),
1274                    single_example(
1275                        "Zed Pro Trial Plan",
1276                        configuration(true, Some(Plan::ZedProTrial), true, false),
1277                    ),
1278                    single_example(
1279                        "Zed Pro Plan",
1280                        configuration(true, Some(Plan::ZedPro), true, false),
1281                    ),
1282                ])
1283                .into_any_element(),
1284        )
1285    }
1286}
1287
1288#[cfg(test)]
1289mod tests {
1290    use super::*;
1291    use http_client::http::{HeaderMap, StatusCode};
1292    use language_model::LanguageModelCompletionError;
1293
1294    #[test]
1295    fn test_api_error_conversion_with_upstream_http_error() {
1296        // upstream_http_error with 503 status should become ServerOverloaded
1297        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1298
1299        let api_error = ApiError {
1300            status: StatusCode::INTERNAL_SERVER_ERROR,
1301            body: error_body.to_string(),
1302            headers: HeaderMap::new(),
1303        };
1304
1305        let completion_error: LanguageModelCompletionError = api_error.into();
1306
1307        match completion_error {
1308            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1309                assert_eq!(
1310                    message,
1311                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1312                );
1313            }
1314            _ => panic!(
1315                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1316                completion_error
1317            ),
1318        }
1319
1320        // upstream_http_error with 500 status should become ApiInternalServerError
1321        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1322
1323        let api_error = ApiError {
1324            status: StatusCode::INTERNAL_SERVER_ERROR,
1325            body: error_body.to_string(),
1326            headers: HeaderMap::new(),
1327        };
1328
1329        let completion_error: LanguageModelCompletionError = api_error.into();
1330
1331        match completion_error {
1332            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1333                assert_eq!(
1334                    message,
1335                    "Received an error from the OpenAI API: internal server error"
1336                );
1337            }
1338            _ => panic!(
1339                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1340                completion_error
1341            ),
1342        }
1343
1344        // upstream_http_error with 429 status should become RateLimitExceeded
1345        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1346
1347        let api_error = ApiError {
1348            status: StatusCode::INTERNAL_SERVER_ERROR,
1349            body: error_body.to_string(),
1350            headers: HeaderMap::new(),
1351        };
1352
1353        let completion_error: LanguageModelCompletionError = api_error.into();
1354
1355        match completion_error {
1356            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1357                assert_eq!(
1358                    message,
1359                    "Received an error from the Google API: rate limit exceeded"
1360                );
1361            }
1362            _ => panic!(
1363                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1364                completion_error
1365            ),
1366        }
1367
1368        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1369        let error_body = "Regular internal server error";
1370
1371        let api_error = ApiError {
1372            status: StatusCode::INTERNAL_SERVER_ERROR,
1373            body: error_body.to_string(),
1374            headers: HeaderMap::new(),
1375        };
1376
1377        let completion_error: LanguageModelCompletionError = api_error.into();
1378
1379        match completion_error {
1380            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1381                assert_eq!(provider, PROVIDER_NAME);
1382                assert_eq!(message, "Regular internal server error");
1383            }
1384            _ => panic!(
1385                "Expected ApiInternalServerError for regular 500, got: {:?}",
1386                completion_error
1387            ),
1388        }
1389
1390        // upstream_http_429 format should be converted to UpstreamProviderError
1391        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1392
1393        let api_error = ApiError {
1394            status: StatusCode::INTERNAL_SERVER_ERROR,
1395            body: error_body.to_string(),
1396            headers: HeaderMap::new(),
1397        };
1398
1399        let completion_error: LanguageModelCompletionError = api_error.into();
1400
1401        match completion_error {
1402            LanguageModelCompletionError::UpstreamProviderError {
1403                message,
1404                status,
1405                retry_after,
1406            } => {
1407                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1408                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1409                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1410            }
1411            _ => panic!(
1412                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1413                completion_error
1414            ),
1415        }
1416
1417        // Invalid JSON in error body should fall back to regular error handling
1418        let error_body = "Not JSON at all";
1419
1420        let api_error = ApiError {
1421            status: StatusCode::INTERNAL_SERVER_ERROR,
1422            body: error_body.to_string(),
1423            headers: HeaderMap::new(),
1424        };
1425
1426        let completion_error: LanguageModelCompletionError = api_error.into();
1427
1428        match completion_error {
1429            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1430                assert_eq!(provider, PROVIDER_NAME);
1431            }
1432            _ => panic!(
1433                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1434                completion_error
1435            ),
1436        }
1437    }
1438}