cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use client::{Client, UserStore, zed_urls};
   5use cloud_api_types::{OrganizationId, Plan};
   6use cloud_llm_client::{
   7    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   8    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
   9    CountTokensBody, CountTokensResponse, ListModelsResponse,
  10    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  11};
  12use futures::{
  13    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  14    future::BoxFuture,
  15    stream::{self, BoxStream},
  16};
  17use google_ai::GoogleModelMode;
  18use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  19use http_client::http::{HeaderMap, HeaderValue};
  20use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  21use language_model::{
  22    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, AuthenticateError, GOOGLE_PROVIDER_ID,
  23    GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  25    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  26    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  28    OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
  29    RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
  30    ZED_CLOUD_PROVIDER_NAME,
  31};
  32use release_channel::AppVersion;
  33use schemars::JsonSchema;
  34use semver::Version;
  35use serde::{Deserialize, Serialize, de::DeserializeOwned};
  36use settings::SettingsStore;
  37pub use settings::ZedDotDevAvailableModel as AvailableModel;
  38pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  39use smol::io::{AsyncReadExt, BufReader};
  40use std::collections::VecDeque;
  41use std::pin::Pin;
  42use std::str::FromStr;
  43use std::sync::Arc;
  44use std::task::Poll;
  45use std::time::Duration;
  46use thiserror::Error;
  47use ui::{TintColor, prelude::*};
  48
  49use crate::provider::anthropic::{
  50    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  51};
  52use crate::provider::google::{GoogleEventMapper, into_google};
  53use crate::provider::open_ai::{
  54    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  55    into_open_ai_response,
  56};
  57use crate::provider::x_ai::count_xai_tokens;
  58
  59const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
  60const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
  61
  62#[derive(Default, Clone, Debug, PartialEq)]
  63pub struct ZedDotDevSettings {
  64    pub available_models: Vec<AvailableModel>,
  65}
  66#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  67#[serde(tag = "type", rename_all = "lowercase")]
  68pub enum ModelMode {
  69    #[default]
  70    Default,
  71    Thinking {
  72        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  73        budget_tokens: Option<u32>,
  74    },
  75}
  76
  77impl From<ModelMode> for AnthropicModelMode {
  78    fn from(value: ModelMode) -> Self {
  79        match value {
  80            ModelMode::Default => AnthropicModelMode::Default,
  81            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  82        }
  83    }
  84}
  85
  86pub struct CloudLanguageModelProvider {
  87    client: Arc<Client>,
  88    state: Entity<State>,
  89    _maintain_client_status: Task<()>,
  90}
  91
  92pub struct State {
  93    client: Arc<Client>,
  94    llm_api_token: LlmApiToken,
  95    user_store: Entity<UserStore>,
  96    status: client::Status,
  97    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  98    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  99    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 100    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 101    _user_store_subscription: Subscription,
 102    _settings_subscription: Subscription,
 103    _llm_token_subscription: Subscription,
 104}
 105
 106impl State {
 107    fn new(
 108        client: Arc<Client>,
 109        user_store: Entity<UserStore>,
 110        status: client::Status,
 111        cx: &mut Context<Self>,
 112    ) -> Self {
 113        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 114        let llm_api_token = LlmApiToken::global(cx);
 115        Self {
 116            client: client.clone(),
 117            llm_api_token,
 118            user_store: user_store.clone(),
 119            status,
 120            models: Vec::new(),
 121            default_model: None,
 122            default_fast_model: None,
 123            recommended_models: Vec::new(),
 124            _user_store_subscription: cx.subscribe(
 125                &user_store,
 126                move |this, _user_store, event, cx| match event {
 127                    client::user::Event::PrivateUserInfoUpdated => {
 128                        let status = *client.status().borrow();
 129                        if status.is_signed_out() {
 130                            return;
 131                        }
 132
 133                        let client = this.client.clone();
 134                        let llm_api_token = this.llm_api_token.clone();
 135                        let organization_id = this
 136                            .user_store
 137                            .read(cx)
 138                            .current_organization()
 139                            .map(|organization| organization.id.clone());
 140                        cx.spawn(async move |this, cx| {
 141                            let response =
 142                                Self::fetch_models(client, llm_api_token, organization_id).await?;
 143                            this.update(cx, |this, cx| this.update_models(response, cx))
 144                        })
 145                        .detach_and_log_err(cx);
 146                    }
 147                    _ => {}
 148                },
 149            ),
 150            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 151                cx.notify();
 152            }),
 153            _llm_token_subscription: cx.subscribe(
 154                &refresh_llm_token_listener,
 155                move |this, _listener, _event, cx| {
 156                    let client = this.client.clone();
 157                    let llm_api_token = this.llm_api_token.clone();
 158                    let organization_id = this
 159                        .user_store
 160                        .read(cx)
 161                        .current_organization()
 162                        .map(|organization| organization.id.clone());
 163                    cx.spawn(async move |this, cx| {
 164                        let response =
 165                            Self::fetch_models(client, llm_api_token, organization_id).await?;
 166                        this.update(cx, |this, cx| {
 167                            this.update_models(response, cx);
 168                        })
 169                    })
 170                    .detach_and_log_err(cx);
 171                },
 172            ),
 173        }
 174    }
 175
 176    fn is_signed_out(&self, cx: &App) -> bool {
 177        self.user_store.read(cx).current_user().is_none()
 178    }
 179
 180    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 181        let client = self.client.clone();
 182        cx.spawn(async move |state, cx| {
 183            client.sign_in_with_optional_connect(true, cx).await?;
 184            state.update(cx, |_, cx| cx.notify())
 185        })
 186    }
 187
 188    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 189        let mut models = Vec::new();
 190
 191        for model in response.models {
 192            models.push(Arc::new(model.clone()));
 193        }
 194
 195        self.default_model = models
 196            .iter()
 197            .find(|model| {
 198                response
 199                    .default_model
 200                    .as_ref()
 201                    .is_some_and(|default_model_id| &model.id == default_model_id)
 202            })
 203            .cloned();
 204        self.default_fast_model = models
 205            .iter()
 206            .find(|model| {
 207                response
 208                    .default_fast_model
 209                    .as_ref()
 210                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 211            })
 212            .cloned();
 213        self.recommended_models = response
 214            .recommended_models
 215            .iter()
 216            .filter_map(|id| models.iter().find(|model| &model.id == id))
 217            .cloned()
 218            .collect();
 219        self.models = models;
 220        cx.notify();
 221    }
 222
 223    async fn fetch_models(
 224        client: Arc<Client>,
 225        llm_api_token: LlmApiToken,
 226        organization_id: Option<OrganizationId>,
 227    ) -> Result<ListModelsResponse> {
 228        let http_client = &client.http_client();
 229        let token = llm_api_token.acquire(&client, organization_id).await?;
 230
 231        let request = http_client::Request::builder()
 232            .method(Method::GET)
 233            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 234            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 235            .header("Authorization", format!("Bearer {token}"))
 236            .body(AsyncBody::empty())?;
 237        let mut response = http_client
 238            .send(request)
 239            .await
 240            .context("failed to send list models request")?;
 241
 242        if response.status().is_success() {
 243            let mut body = String::new();
 244            response.body_mut().read_to_string(&mut body).await?;
 245            Ok(serde_json::from_str(&body)?)
 246        } else {
 247            let mut body = String::new();
 248            response.body_mut().read_to_string(&mut body).await?;
 249            anyhow::bail!(
 250                "error listing models.\nStatus: {:?}\nBody: {body}",
 251                response.status(),
 252            );
 253        }
 254    }
 255}
 256
 257impl CloudLanguageModelProvider {
 258    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 259        let mut status_rx = client.status();
 260        let status = *status_rx.borrow();
 261
 262        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 263
 264        let state_ref = state.downgrade();
 265        let maintain_client_status = cx.spawn(async move |cx| {
 266            while let Some(status) = status_rx.next().await {
 267                if let Some(this) = state_ref.upgrade() {
 268                    _ = this.update(cx, |this, cx| {
 269                        if this.status != status {
 270                            this.status = status;
 271                            cx.notify();
 272                        }
 273                    });
 274                } else {
 275                    break;
 276                }
 277            }
 278        });
 279
 280        Self {
 281            client,
 282            state,
 283            _maintain_client_status: maintain_client_status,
 284        }
 285    }
 286
 287    fn create_language_model(
 288        &self,
 289        model: Arc<cloud_llm_client::LanguageModel>,
 290        llm_api_token: LlmApiToken,
 291        user_store: Entity<UserStore>,
 292    ) -> Arc<dyn LanguageModel> {
 293        Arc::new(CloudLanguageModel {
 294            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 295            model,
 296            llm_api_token,
 297            user_store,
 298            client: self.client.clone(),
 299            request_limiter: RateLimiter::new(4),
 300        })
 301    }
 302}
 303
 304impl LanguageModelProviderState for CloudLanguageModelProvider {
 305    type ObservableEntity = State;
 306
 307    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 308        Some(self.state.clone())
 309    }
 310}
 311
 312impl LanguageModelProvider for CloudLanguageModelProvider {
 313    fn id(&self) -> LanguageModelProviderId {
 314        PROVIDER_ID
 315    }
 316
 317    fn name(&self) -> LanguageModelProviderName {
 318        PROVIDER_NAME
 319    }
 320
 321    fn icon(&self) -> IconOrSvg {
 322        IconOrSvg::Icon(IconName::AiZed)
 323    }
 324
 325    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 326        let state = self.state.read(cx);
 327        let default_model = state.default_model.clone()?;
 328        let llm_api_token = state.llm_api_token.clone();
 329        let user_store = state.user_store.clone();
 330        Some(self.create_language_model(default_model, llm_api_token, user_store))
 331    }
 332
 333    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 334        let state = self.state.read(cx);
 335        let default_fast_model = state.default_fast_model.clone()?;
 336        let llm_api_token = state.llm_api_token.clone();
 337        let user_store = state.user_store.clone();
 338        Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
 339    }
 340
 341    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 342        let state = self.state.read(cx);
 343        let llm_api_token = state.llm_api_token.clone();
 344        let user_store = state.user_store.clone();
 345        state
 346            .recommended_models
 347            .iter()
 348            .cloned()
 349            .map(|model| {
 350                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 351            })
 352            .collect()
 353    }
 354
 355    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 356        let state = self.state.read(cx);
 357        let llm_api_token = state.llm_api_token.clone();
 358        let user_store = state.user_store.clone();
 359        state
 360            .models
 361            .iter()
 362            .cloned()
 363            .map(|model| {
 364                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
 365            })
 366            .collect()
 367    }
 368
 369    fn is_authenticated(&self, cx: &App) -> bool {
 370        let state = self.state.read(cx);
 371        !state.is_signed_out(cx)
 372    }
 373
 374    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 375        Task::ready(Ok(()))
 376    }
 377
 378    fn configuration_view(
 379        &self,
 380        _target_agent: language_model::ConfigurationViewTargetAgent,
 381        _: &mut Window,
 382        cx: &mut App,
 383    ) -> AnyView {
 384        cx.new(|_| ConfigurationView::new(self.state.clone()))
 385            .into()
 386    }
 387
 388    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 389        Task::ready(Ok(()))
 390    }
 391}
 392
 393pub struct CloudLanguageModel {
 394    id: LanguageModelId,
 395    model: Arc<cloud_llm_client::LanguageModel>,
 396    llm_api_token: LlmApiToken,
 397    user_store: Entity<UserStore>,
 398    client: Arc<Client>,
 399    request_limiter: RateLimiter,
 400}
 401
 402struct PerformLlmCompletionResponse {
 403    response: Response<AsyncBody>,
 404    includes_status_messages: bool,
 405}
 406
 407impl CloudLanguageModel {
 408    async fn perform_llm_completion(
 409        client: Arc<Client>,
 410        llm_api_token: LlmApiToken,
 411        organization_id: Option<OrganizationId>,
 412        app_version: Option<Version>,
 413        body: CompletionBody,
 414    ) -> Result<PerformLlmCompletionResponse> {
 415        let http_client = &client.http_client();
 416
 417        let mut token = llm_api_token
 418            .acquire(&client, organization_id.clone())
 419            .await?;
 420        let mut refreshed_token = false;
 421
 422        loop {
 423            let request = http_client::Request::builder()
 424                .method(Method::POST)
 425                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 426                .when_some(app_version.as_ref(), |builder, app_version| {
 427                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 428                })
 429                .header("Content-Type", "application/json")
 430                .header("Authorization", format!("Bearer {token}"))
 431                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 432                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 433                .body(serde_json::to_string(&body)?.into())?;
 434
 435            let mut response = http_client.send(request).await?;
 436            let status = response.status();
 437            if status.is_success() {
 438                let includes_status_messages = response
 439                    .headers()
 440                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 441                    .is_some();
 442
 443                return Ok(PerformLlmCompletionResponse {
 444                    response,
 445                    includes_status_messages,
 446                });
 447            }
 448
 449            if !refreshed_token && response.needs_llm_token_refresh() {
 450                token = llm_api_token
 451                    .refresh(&client, organization_id.clone())
 452                    .await?;
 453                refreshed_token = true;
 454                continue;
 455            }
 456
 457            if status == StatusCode::PAYMENT_REQUIRED {
 458                return Err(anyhow!(PaymentRequiredError));
 459            }
 460
 461            let mut body = String::new();
 462            let headers = response.headers().clone();
 463            response.body_mut().read_to_string(&mut body).await?;
 464            return Err(anyhow!(ApiError {
 465                status,
 466                body,
 467                headers
 468            }));
 469        }
 470    }
 471}
 472
 473#[derive(Debug, Error)]
 474#[error("cloud language model request failed with status {status}: {body}")]
 475struct ApiError {
 476    status: StatusCode,
 477    body: String,
 478    headers: HeaderMap<HeaderValue>,
 479}
 480
 481/// Represents error responses from Zed's cloud API.
 482///
 483/// Example JSON for an upstream HTTP error:
 484/// ```json
 485/// {
 486///   "code": "upstream_http_error",
 487///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 488///   "upstream_status": 503
 489/// }
 490/// ```
 491#[derive(Debug, serde::Deserialize)]
 492struct CloudApiError {
 493    code: String,
 494    message: String,
 495    #[serde(default)]
 496    #[serde(deserialize_with = "deserialize_optional_status_code")]
 497    upstream_status: Option<StatusCode>,
 498    #[serde(default)]
 499    retry_after: Option<f64>,
 500}
 501
 502fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 503where
 504    D: serde::Deserializer<'de>,
 505{
 506    let opt: Option<u16> = Option::deserialize(deserializer)?;
 507    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 508}
 509
 510impl From<ApiError> for LanguageModelCompletionError {
 511    fn from(error: ApiError) -> Self {
 512        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 513            if cloud_error.code.starts_with("upstream_http_") {
 514                let status = if let Some(status) = cloud_error.upstream_status {
 515                    status
 516                } else if cloud_error.code.ends_with("_error") {
 517                    error.status
 518                } else {
 519                    // If there's a status code in the code string (e.g. "upstream_http_429")
 520                    // then use that; otherwise, see if the JSON contains a status code.
 521                    cloud_error
 522                        .code
 523                        .strip_prefix("upstream_http_")
 524                        .and_then(|code_str| code_str.parse::<u16>().ok())
 525                        .and_then(|code| StatusCode::from_u16(code).ok())
 526                        .unwrap_or(error.status)
 527                };
 528
 529                return LanguageModelCompletionError::UpstreamProviderError {
 530                    message: cloud_error.message,
 531                    status,
 532                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 533                };
 534            }
 535
 536            return LanguageModelCompletionError::from_http_status(
 537                PROVIDER_NAME,
 538                error.status,
 539                cloud_error.message,
 540                None,
 541            );
 542        }
 543
 544        let retry_after = None;
 545        LanguageModelCompletionError::from_http_status(
 546            PROVIDER_NAME,
 547            error.status,
 548            error.body,
 549            retry_after,
 550        )
 551    }
 552}
 553
 554impl LanguageModel for CloudLanguageModel {
 555    fn id(&self) -> LanguageModelId {
 556        self.id.clone()
 557    }
 558
 559    fn name(&self) -> LanguageModelName {
 560        LanguageModelName::from(self.model.display_name.clone())
 561    }
 562
 563    fn provider_id(&self) -> LanguageModelProviderId {
 564        PROVIDER_ID
 565    }
 566
 567    fn provider_name(&self) -> LanguageModelProviderName {
 568        PROVIDER_NAME
 569    }
 570
 571    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 572        use cloud_llm_client::LanguageModelProvider::*;
 573        match self.model.provider {
 574            Anthropic => ANTHROPIC_PROVIDER_ID,
 575            OpenAi => OPEN_AI_PROVIDER_ID,
 576            Google => GOOGLE_PROVIDER_ID,
 577            XAi => X_AI_PROVIDER_ID,
 578        }
 579    }
 580
 581    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 582        use cloud_llm_client::LanguageModelProvider::*;
 583        match self.model.provider {
 584            Anthropic => ANTHROPIC_PROVIDER_NAME,
 585            OpenAi => OPEN_AI_PROVIDER_NAME,
 586            Google => GOOGLE_PROVIDER_NAME,
 587            XAi => X_AI_PROVIDER_NAME,
 588        }
 589    }
 590
 591    fn is_latest(&self) -> bool {
 592        self.model.is_latest
 593    }
 594
 595    fn supports_tools(&self) -> bool {
 596        self.model.supports_tools
 597    }
 598
 599    fn supports_images(&self) -> bool {
 600        self.model.supports_images
 601    }
 602
 603    fn supports_thinking(&self) -> bool {
 604        self.model.supports_thinking
 605    }
 606
 607    fn supports_fast_mode(&self) -> bool {
 608        self.model.supports_fast_mode
 609    }
 610
 611    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 612        self.model
 613            .supported_effort_levels
 614            .iter()
 615            .map(|effort_level| LanguageModelEffortLevel {
 616                name: effort_level.name.clone().into(),
 617                value: effort_level.value.clone().into(),
 618                is_default: effort_level.is_default.unwrap_or(false),
 619            })
 620            .collect()
 621    }
 622
 623    fn supports_streaming_tools(&self) -> bool {
 624        self.model.supports_streaming_tools
 625    }
 626
 627    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 628        match choice {
 629            LanguageModelToolChoice::Auto
 630            | LanguageModelToolChoice::Any
 631            | LanguageModelToolChoice::None => true,
 632        }
 633    }
 634
 635    fn supports_split_token_display(&self) -> bool {
 636        use cloud_llm_client::LanguageModelProvider::*;
 637        matches!(self.model.provider, OpenAi | XAi)
 638    }
 639
 640    fn telemetry_id(&self) -> String {
 641        format!("zed.dev/{}", self.model.id)
 642    }
 643
 644    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 645        match self.model.provider {
 646            cloud_llm_client::LanguageModelProvider::Anthropic
 647            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 648                LanguageModelToolSchemaFormat::JsonSchema
 649            }
 650            cloud_llm_client::LanguageModelProvider::Google
 651            | cloud_llm_client::LanguageModelProvider::XAi => {
 652                LanguageModelToolSchemaFormat::JsonSchemaSubset
 653            }
 654        }
 655    }
 656
 657    fn max_token_count(&self) -> u64 {
 658        self.model.max_token_count as u64
 659    }
 660
 661    fn max_output_tokens(&self) -> Option<u64> {
 662        Some(self.model.max_output_tokens as u64)
 663    }
 664
 665    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 666        match &self.model.provider {
 667            cloud_llm_client::LanguageModelProvider::Anthropic => {
 668                Some(LanguageModelCacheConfiguration {
 669                    min_total_token: 2_048,
 670                    should_speculate: true,
 671                    max_cache_anchors: 4,
 672                })
 673            }
 674            cloud_llm_client::LanguageModelProvider::OpenAi
 675            | cloud_llm_client::LanguageModelProvider::XAi
 676            | cloud_llm_client::LanguageModelProvider::Google => None,
 677        }
 678    }
 679
 680    fn count_tokens(
 681        &self,
 682        request: LanguageModelRequest,
 683        cx: &App,
 684    ) -> BoxFuture<'static, Result<u64>> {
 685        match self.model.provider {
 686            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 687                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 688                .boxed(),
 689            cloud_llm_client::LanguageModelProvider::OpenAi => {
 690                let model = match open_ai::Model::from_id(&self.model.id.0) {
 691                    Ok(model) => model,
 692                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 693                };
 694                count_open_ai_tokens(request, model, cx)
 695            }
 696            cloud_llm_client::LanguageModelProvider::XAi => {
 697                let model = match x_ai::Model::from_id(&self.model.id.0) {
 698                    Ok(model) => model,
 699                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 700                };
 701                count_xai_tokens(request, model, cx)
 702            }
 703            cloud_llm_client::LanguageModelProvider::Google => {
 704                let client = self.client.clone();
 705                let llm_api_token = self.llm_api_token.clone();
 706                let organization_id = self
 707                    .user_store
 708                    .read(cx)
 709                    .current_organization()
 710                    .map(|organization| organization.id.clone());
 711                let model_id = self.model.id.to_string();
 712                let generate_content_request =
 713                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 714                async move {
 715                    let http_client = &client.http_client();
 716                    let token = llm_api_token.acquire(&client, organization_id).await?;
 717
 718                    let request_body = CountTokensBody {
 719                        provider: cloud_llm_client::LanguageModelProvider::Google,
 720                        model: model_id,
 721                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 722                            generate_content_request,
 723                        })?,
 724                    };
 725                    let request = http_client::Request::builder()
 726                        .method(Method::POST)
 727                        .uri(
 728                            http_client
 729                                .build_zed_llm_url("/count_tokens", &[])?
 730                                .as_ref(),
 731                        )
 732                        .header("Content-Type", "application/json")
 733                        .header("Authorization", format!("Bearer {token}"))
 734                        .body(serde_json::to_string(&request_body)?.into())?;
 735                    let mut response = http_client.send(request).await?;
 736                    let status = response.status();
 737                    let headers = response.headers().clone();
 738                    let mut response_body = String::new();
 739                    response
 740                        .body_mut()
 741                        .read_to_string(&mut response_body)
 742                        .await?;
 743
 744                    if status.is_success() {
 745                        let response_body: CountTokensResponse =
 746                            serde_json::from_str(&response_body)?;
 747
 748                        Ok(response_body.tokens as u64)
 749                    } else {
 750                        Err(anyhow!(ApiError {
 751                            status,
 752                            body: response_body,
 753                            headers
 754                        }))
 755                    }
 756                }
 757                .boxed()
 758            }
 759        }
 760    }
 761
 762    fn stream_completion(
 763        &self,
 764        request: LanguageModelRequest,
 765        cx: &AsyncApp,
 766    ) -> BoxFuture<
 767        'static,
 768        Result<
 769            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 770            LanguageModelCompletionError,
 771        >,
 772    > {
 773        let thread_id = request.thread_id.clone();
 774        let prompt_id = request.prompt_id.clone();
 775        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 776        let user_store = self.user_store.clone();
 777        let organization_id = cx.update(|cx| {
 778            user_store
 779                .read(cx)
 780                .current_organization()
 781                .map(|organization| organization.id.clone())
 782        });
 783        let thinking_allowed = request.thinking_allowed;
 784        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 785        let provider_name = provider_name(&self.model.provider);
 786        match self.model.provider {
 787            cloud_llm_client::LanguageModelProvider::Anthropic => {
 788                let effort = request
 789                    .thinking_effort
 790                    .as_ref()
 791                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 792
 793                let mut request = into_anthropic(
 794                    request,
 795                    self.model.id.to_string(),
 796                    1.0,
 797                    self.model.max_output_tokens as u64,
 798                    if enable_thinking {
 799                        AnthropicModelMode::Thinking {
 800                            budget_tokens: Some(4_096),
 801                        }
 802                    } else {
 803                        AnthropicModelMode::Default
 804                    },
 805                );
 806
 807                if enable_thinking && effort.is_some() {
 808                    request.thinking = Some(anthropic::Thinking::Adaptive);
 809                    request.output_config = Some(anthropic::OutputConfig { effort });
 810                }
 811
 812                let client = self.client.clone();
 813                let llm_api_token = self.llm_api_token.clone();
 814                let organization_id = organization_id.clone();
 815                let future = self.request_limiter.stream(async move {
 816                    let PerformLlmCompletionResponse {
 817                        response,
 818                        includes_status_messages,
 819                    } = Self::perform_llm_completion(
 820                        client.clone(),
 821                        llm_api_token,
 822                        organization_id,
 823                        app_version,
 824                        CompletionBody {
 825                            thread_id,
 826                            prompt_id,
 827                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 828                            model: request.model.clone(),
 829                            provider_request: serde_json::to_value(&request)
 830                                .map_err(|e| anyhow!(e))?,
 831                        },
 832                    )
 833                    .await
 834                    .map_err(|err| match err.downcast::<ApiError>() {
 835                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 836                        Err(err) => anyhow!(err),
 837                    })?;
 838
 839                    let mut mapper = AnthropicEventMapper::new();
 840                    Ok(map_cloud_completion_events(
 841                        Box::pin(response_lines(response, includes_status_messages)),
 842                        &provider_name,
 843                        move |event| mapper.map_event(event),
 844                    ))
 845                });
 846                async move { Ok(future.await?.boxed()) }.boxed()
 847            }
 848            cloud_llm_client::LanguageModelProvider::OpenAi => {
 849                let client = self.client.clone();
 850                let llm_api_token = self.llm_api_token.clone();
 851                let organization_id = organization_id.clone();
 852                let effort = request
 853                    .thinking_effort
 854                    .as_ref()
 855                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 856
 857                let mut request = into_open_ai_response(
 858                    request,
 859                    &self.model.id.0,
 860                    self.model.supports_parallel_tool_calls,
 861                    true,
 862                    None,
 863                    None,
 864                );
 865
 866                if enable_thinking && let Some(effort) = effort {
 867                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
 868                        effort,
 869                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
 870                    });
 871                }
 872
 873                let future = self.request_limiter.stream(async move {
 874                    let PerformLlmCompletionResponse {
 875                        response,
 876                        includes_status_messages,
 877                    } = Self::perform_llm_completion(
 878                        client.clone(),
 879                        llm_api_token,
 880                        organization_id,
 881                        app_version,
 882                        CompletionBody {
 883                            thread_id,
 884                            prompt_id,
 885                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 886                            model: request.model.clone(),
 887                            provider_request: serde_json::to_value(&request)
 888                                .map_err(|e| anyhow!(e))?,
 889                        },
 890                    )
 891                    .await?;
 892
 893                    let mut mapper = OpenAiResponseEventMapper::new();
 894                    Ok(map_cloud_completion_events(
 895                        Box::pin(response_lines(response, includes_status_messages)),
 896                        &provider_name,
 897                        move |event| mapper.map_event(event),
 898                    ))
 899                });
 900                async move { Ok(future.await?.boxed()) }.boxed()
 901            }
 902            cloud_llm_client::LanguageModelProvider::XAi => {
 903                let client = self.client.clone();
 904                let request = into_open_ai(
 905                    request,
 906                    &self.model.id.0,
 907                    self.model.supports_parallel_tool_calls,
 908                    false,
 909                    None,
 910                    None,
 911                );
 912                let llm_api_token = self.llm_api_token.clone();
 913                let organization_id = organization_id.clone();
 914                let future = self.request_limiter.stream(async move {
 915                    let PerformLlmCompletionResponse {
 916                        response,
 917                        includes_status_messages,
 918                    } = Self::perform_llm_completion(
 919                        client.clone(),
 920                        llm_api_token,
 921                        organization_id,
 922                        app_version,
 923                        CompletionBody {
 924                            thread_id,
 925                            prompt_id,
 926                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 927                            model: request.model.clone(),
 928                            provider_request: serde_json::to_value(&request)
 929                                .map_err(|e| anyhow!(e))?,
 930                        },
 931                    )
 932                    .await?;
 933
 934                    let mut mapper = OpenAiEventMapper::new();
 935                    Ok(map_cloud_completion_events(
 936                        Box::pin(response_lines(response, includes_status_messages)),
 937                        &provider_name,
 938                        move |event| mapper.map_event(event),
 939                    ))
 940                });
 941                async move { Ok(future.await?.boxed()) }.boxed()
 942            }
 943            cloud_llm_client::LanguageModelProvider::Google => {
 944                let client = self.client.clone();
 945                let request =
 946                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 947                let llm_api_token = self.llm_api_token.clone();
 948                let future = self.request_limiter.stream(async move {
 949                    let PerformLlmCompletionResponse {
 950                        response,
 951                        includes_status_messages,
 952                    } = Self::perform_llm_completion(
 953                        client.clone(),
 954                        llm_api_token,
 955                        organization_id,
 956                        app_version,
 957                        CompletionBody {
 958                            thread_id,
 959                            prompt_id,
 960                            provider: cloud_llm_client::LanguageModelProvider::Google,
 961                            model: request.model.model_id.clone(),
 962                            provider_request: serde_json::to_value(&request)
 963                                .map_err(|e| anyhow!(e))?,
 964                        },
 965                    )
 966                    .await?;
 967
 968                    let mut mapper = GoogleEventMapper::new();
 969                    Ok(map_cloud_completion_events(
 970                        Box::pin(response_lines(response, includes_status_messages)),
 971                        &provider_name,
 972                        move |event| mapper.map_event(event),
 973                    ))
 974                });
 975                async move { Ok(future.await?.boxed()) }.boxed()
 976            }
 977        }
 978    }
 979}
 980
 981fn map_cloud_completion_events<T, F>(
 982    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 983    provider: &LanguageModelProviderName,
 984    mut map_callback: F,
 985) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 986where
 987    T: DeserializeOwned + 'static,
 988    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 989        + Send
 990        + 'static,
 991{
 992    let provider = provider.clone();
 993    let mut stream = stream.fuse();
 994
 995    let mut saw_stream_ended = false;
 996
 997    let mut done = false;
 998    let mut pending = VecDeque::new();
 999
1000    stream::poll_fn(move |cx| {
1001        loop {
1002            if let Some(item) = pending.pop_front() {
1003                return Poll::Ready(Some(item));
1004            }
1005
1006            if done {
1007                return Poll::Ready(None);
1008            }
1009
1010            match stream.poll_next_unpin(cx) {
1011                Poll::Ready(Some(event)) => {
1012                    let items = match event {
1013                        Err(error) => {
1014                            vec![Err(LanguageModelCompletionError::from(error))]
1015                        }
1016                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
1017                            saw_stream_ended = true;
1018                            vec![]
1019                        }
1020                        Ok(CompletionEvent::Status(status)) => {
1021                            LanguageModelCompletionEvent::from_completion_request_status(
1022                                status,
1023                                provider.clone(),
1024                            )
1025                            .transpose()
1026                            .map(|event| vec![event])
1027                            .unwrap_or_default()
1028                        }
1029                        Ok(CompletionEvent::Event(event)) => map_callback(event),
1030                    };
1031                    pending.extend(items);
1032                }
1033                Poll::Ready(None) => {
1034                    done = true;
1035
1036                    if !saw_stream_ended {
1037                        return Poll::Ready(Some(Err(
1038                            LanguageModelCompletionError::StreamEndedUnexpectedly {
1039                                provider: provider.clone(),
1040                            },
1041                        )));
1042                    }
1043                }
1044                Poll::Pending => return Poll::Pending,
1045            }
1046        }
1047    })
1048    .boxed()
1049}
1050
1051fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1052    match provider {
1053        cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
1054        cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
1055        cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
1056        cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
1057    }
1058}
1059
1060fn response_lines<T: DeserializeOwned>(
1061    response: Response<AsyncBody>,
1062    includes_status_messages: bool,
1063) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1064    futures::stream::try_unfold(
1065        (String::new(), BufReader::new(response.into_body())),
1066        move |(mut line, mut body)| async move {
1067            match body.read_line(&mut line).await {
1068                Ok(0) => Ok(None),
1069                Ok(_) => {
1070                    let event = if includes_status_messages {
1071                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1072                    } else {
1073                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1074                    };
1075
1076                    line.clear();
1077                    Ok(Some((event, (line, body))))
1078                }
1079                Err(e) => Err(e.into()),
1080            }
1081        },
1082    )
1083}
1084
1085#[derive(IntoElement, RegisterComponent)]
1086struct ZedAiConfiguration {
1087    is_connected: bool,
1088    plan: Option<Plan>,
1089    eligible_for_trial: bool,
1090    account_too_young: bool,
1091    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1092}
1093
1094impl RenderOnce for ZedAiConfiguration {
1095    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1096        let (subscription_text, has_paid_plan) = match self.plan {
1097            Some(Plan::ZedPro) => (
1098                "You have access to Zed's hosted models through your Pro subscription.",
1099                true,
1100            ),
1101            Some(Plan::ZedProTrial) => (
1102                "You have access to Zed's hosted models through your Pro trial.",
1103                false,
1104            ),
1105            Some(Plan::ZedStudent) => (
1106                "You have access to Zed's hosted models through your Student subscription.",
1107                true,
1108            ),
1109            Some(Plan::ZedBusiness) => (
1110                "You have access to Zed's hosted models through your Organization.",
1111                true,
1112            ),
1113            Some(Plan::ZedFree) | None => (
1114                if self.eligible_for_trial {
1115                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1116                } else {
1117                    "Subscribe for access to Zed's hosted models."
1118                },
1119                false,
1120            ),
1121        };
1122
1123        let manage_subscription_buttons = if has_paid_plan {
1124            Button::new("manage_settings", "Manage Subscription")
1125                .full_width()
1126                .label_size(LabelSize::Small)
1127                .style(ButtonStyle::Tinted(TintColor::Accent))
1128                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1129                .into_any_element()
1130        } else if self.plan.is_none() || self.eligible_for_trial {
1131            Button::new("start_trial", "Start 14-day Free Pro Trial")
1132                .full_width()
1133                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1134                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1135                .into_any_element()
1136        } else {
1137            Button::new("upgrade", "Upgrade to Pro")
1138                .full_width()
1139                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1140                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1141                .into_any_element()
1142        };
1143
1144        if !self.is_connected {
1145            return v_flex()
1146                .gap_2()
1147                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1148                .child(
1149                    Button::new("sign_in", "Sign In to use Zed AI")
1150                        .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
1151                        .full_width()
1152                        .on_click({
1153                            let callback = self.sign_in_callback.clone();
1154                            move |_, window, cx| (callback)(window, cx)
1155                        }),
1156                );
1157        }
1158
1159        v_flex().gap_2().w_full().map(|this| {
1160            if self.account_too_young {
1161                this.child(YoungAccountBanner).child(
1162                    Button::new("upgrade", "Upgrade to Pro")
1163                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1164                        .full_width()
1165                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1166                )
1167            } else {
1168                this.text_sm()
1169                    .child(subscription_text)
1170                    .child(manage_subscription_buttons)
1171            }
1172        })
1173    }
1174}
1175
1176struct ConfigurationView {
1177    state: Entity<State>,
1178    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1179}
1180
1181impl ConfigurationView {
1182    fn new(state: Entity<State>) -> Self {
1183        let sign_in_callback = Arc::new({
1184            let state = state.clone();
1185            move |_window: &mut Window, cx: &mut App| {
1186                state.update(cx, |state, cx| {
1187                    state.authenticate(cx).detach_and_log_err(cx);
1188                });
1189            }
1190        });
1191
1192        Self {
1193            state,
1194            sign_in_callback,
1195        }
1196    }
1197}
1198
1199impl Render for ConfigurationView {
1200    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1201        let state = self.state.read(cx);
1202        let user_store = state.user_store.read(cx);
1203
1204        ZedAiConfiguration {
1205            is_connected: !state.is_signed_out(cx),
1206            plan: user_store.plan(),
1207            eligible_for_trial: user_store.trial_started_at().is_none(),
1208            account_too_young: user_store.account_too_young(),
1209            sign_in_callback: self.sign_in_callback.clone(),
1210        }
1211    }
1212}
1213
1214impl Component for ZedAiConfiguration {
1215    fn name() -> &'static str {
1216        "AI Configuration Content"
1217    }
1218
1219    fn sort_name() -> &'static str {
1220        "AI Configuration Content"
1221    }
1222
1223    fn scope() -> ComponentScope {
1224        ComponentScope::Onboarding
1225    }
1226
1227    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1228        fn configuration(
1229            is_connected: bool,
1230            plan: Option<Plan>,
1231            eligible_for_trial: bool,
1232            account_too_young: bool,
1233        ) -> AnyElement {
1234            ZedAiConfiguration {
1235                is_connected,
1236                plan,
1237                eligible_for_trial,
1238                account_too_young,
1239                sign_in_callback: Arc::new(|_, _| {}),
1240            }
1241            .into_any_element()
1242        }
1243
1244        Some(
1245            v_flex()
1246                .p_4()
1247                .gap_4()
1248                .children(vec![
1249                    single_example("Not connected", configuration(false, None, false, false)),
1250                    single_example(
1251                        "Accept Terms of Service",
1252                        configuration(true, None, true, false),
1253                    ),
1254                    single_example(
1255                        "No Plan - Not eligible for trial",
1256                        configuration(true, None, false, false),
1257                    ),
1258                    single_example(
1259                        "No Plan - Eligible for trial",
1260                        configuration(true, None, true, false),
1261                    ),
1262                    single_example(
1263                        "Free Plan",
1264                        configuration(true, Some(Plan::ZedFree), true, false),
1265                    ),
1266                    single_example(
1267                        "Zed Pro Trial Plan",
1268                        configuration(true, Some(Plan::ZedProTrial), true, false),
1269                    ),
1270                    single_example(
1271                        "Zed Pro Plan",
1272                        configuration(true, Some(Plan::ZedPro), true, false),
1273                    ),
1274                ])
1275                .into_any_element(),
1276        )
1277    }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282    use super::*;
1283    use http_client::http::{HeaderMap, StatusCode};
1284    use language_model::LanguageModelCompletionError;
1285
1286    #[test]
1287    fn test_api_error_conversion_with_upstream_http_error() {
1288        // upstream_http_error with 503 status should become ServerOverloaded
1289        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1290
1291        let api_error = ApiError {
1292            status: StatusCode::INTERNAL_SERVER_ERROR,
1293            body: error_body.to_string(),
1294            headers: HeaderMap::new(),
1295        };
1296
1297        let completion_error: LanguageModelCompletionError = api_error.into();
1298
1299        match completion_error {
1300            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1301                assert_eq!(
1302                    message,
1303                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1304                );
1305            }
1306            _ => panic!(
1307                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1308                completion_error
1309            ),
1310        }
1311
1312        // upstream_http_error with 500 status should become ApiInternalServerError
1313        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1314
1315        let api_error = ApiError {
1316            status: StatusCode::INTERNAL_SERVER_ERROR,
1317            body: error_body.to_string(),
1318            headers: HeaderMap::new(),
1319        };
1320
1321        let completion_error: LanguageModelCompletionError = api_error.into();
1322
1323        match completion_error {
1324            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1325                assert_eq!(
1326                    message,
1327                    "Received an error from the OpenAI API: internal server error"
1328                );
1329            }
1330            _ => panic!(
1331                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1332                completion_error
1333            ),
1334        }
1335
1336        // upstream_http_error with 429 status should become RateLimitExceeded
1337        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1338
1339        let api_error = ApiError {
1340            status: StatusCode::INTERNAL_SERVER_ERROR,
1341            body: error_body.to_string(),
1342            headers: HeaderMap::new(),
1343        };
1344
1345        let completion_error: LanguageModelCompletionError = api_error.into();
1346
1347        match completion_error {
1348            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1349                assert_eq!(
1350                    message,
1351                    "Received an error from the Google API: rate limit exceeded"
1352                );
1353            }
1354            _ => panic!(
1355                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1356                completion_error
1357            ),
1358        }
1359
1360        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1361        let error_body = "Regular internal server error";
1362
1363        let api_error = ApiError {
1364            status: StatusCode::INTERNAL_SERVER_ERROR,
1365            body: error_body.to_string(),
1366            headers: HeaderMap::new(),
1367        };
1368
1369        let completion_error: LanguageModelCompletionError = api_error.into();
1370
1371        match completion_error {
1372            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1373                assert_eq!(provider, PROVIDER_NAME);
1374                assert_eq!(message, "Regular internal server error");
1375            }
1376            _ => panic!(
1377                "Expected ApiInternalServerError for regular 500, got: {:?}",
1378                completion_error
1379            ),
1380        }
1381
1382        // upstream_http_429 format should be converted to UpstreamProviderError
1383        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1384
1385        let api_error = ApiError {
1386            status: StatusCode::INTERNAL_SERVER_ERROR,
1387            body: error_body.to_string(),
1388            headers: HeaderMap::new(),
1389        };
1390
1391        let completion_error: LanguageModelCompletionError = api_error.into();
1392
1393        match completion_error {
1394            LanguageModelCompletionError::UpstreamProviderError {
1395                message,
1396                status,
1397                retry_after,
1398            } => {
1399                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1400                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1401                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1402            }
1403            _ => panic!(
1404                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1405                completion_error
1406            ),
1407        }
1408
1409        // Invalid JSON in error body should fall back to regular error handling
1410        let error_body = "Not JSON at all";
1411
1412        let api_error = ApiError {
1413            status: StatusCode::INTERNAL_SERVER_ERROR,
1414            body: error_body.to_string(),
1415            headers: HeaderMap::new(),
1416        };
1417
1418        let completion_error: LanguageModelCompletionError = api_error.into();
1419
1420        match completion_error {
1421            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1422                assert_eq!(provider, PROVIDER_NAME);
1423            }
1424            _ => panic!(
1425                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1426                completion_error
1427            ),
1428        }
1429    }
1430}