cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, UserStore, zed_urls};
   6use cloud_api_types::Plan;
   7use cloud_llm_client::{
   8    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody,
   9    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
  10    ListModelsResponse, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  11};
  12use feature_flags::{CloudThinkingEffortFeatureFlag, FeatureFlagAppExt as _};
  13use futures::{
  14    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  15    future::BoxFuture,
  16    stream::{self, BoxStream},
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
  25    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  26    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::collections::VecDeque;
  39use std::pin::Pin;
  40use std::str::FromStr;
  41use std::sync::Arc;
  42use std::task::Poll;
  43use std::time::Duration;
  44use thiserror::Error;
  45use ui::{TintColor, prelude::*};
  46use util::{ResultExt as _, maybe};
  47
  48use crate::provider::anthropic::{
  49    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  50};
  51use crate::provider::google::{GoogleEventMapper, into_google};
  52use crate::provider::open_ai::{
  53    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  54    into_open_ai_response,
  55};
  56use crate::provider::x_ai::count_xai_tokens;
  57
  58const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  59const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  60
  61#[derive(Default, Clone, Debug, PartialEq)]
  62pub struct ZedDotDevSettings {
  63    pub available_models: Vec<AvailableModel>,
  64}
  65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66#[serde(tag = "type", rename_all = "lowercase")]
  67pub enum ModelMode {
  68    #[default]
  69    Default,
  70    Thinking {
  71        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  72        budget_tokens: Option<u32>,
  73    },
  74}
  75
  76impl From<ModelMode> for AnthropicModelMode {
  77    fn from(value: ModelMode) -> Self {
  78        match value {
  79            ModelMode::Default => AnthropicModelMode::Default,
  80            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  81        }
  82    }
  83}
  84
  85pub struct CloudLanguageModelProvider {
  86    client: Arc<Client>,
  87    state: Entity<State>,
  88    _maintain_client_status: Task<()>,
  89}
  90
  91pub struct State {
  92    client: Arc<Client>,
  93    llm_api_token: LlmApiToken,
  94    user_store: Entity<UserStore>,
  95    status: client::Status,
  96    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  97    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  98    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  99    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 100    _fetch_models_task: Task<()>,
 101    _settings_subscription: Subscription,
 102    _llm_token_subscription: Subscription,
 103}
 104
 105impl State {
 106    fn new(
 107        client: Arc<Client>,
 108        user_store: Entity<UserStore>,
 109        status: client::Status,
 110        cx: &mut Context<Self>,
 111    ) -> Self {
 112        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 113        let mut current_user = user_store.read(cx).watch_current_user();
 114        Self {
 115            client: client.clone(),
 116            llm_api_token: LlmApiToken::default(),
 117            user_store,
 118            status,
 119            models: Vec::new(),
 120            default_model: None,
 121            default_fast_model: None,
 122            recommended_models: Vec::new(),
 123            _fetch_models_task: cx.spawn(async move |this, cx| {
 124                maybe!(async move {
 125                    let (client, llm_api_token) = this
 126                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 127
 128                    while current_user.borrow().is_none() {
 129                        current_user.next().await;
 130                    }
 131
 132                    let response =
 133                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 134                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 135                    anyhow::Ok(())
 136                })
 137                .await
 138                .context("failed to fetch Zed models")
 139                .log_err();
 140            }),
 141            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 142                cx.notify();
 143            }),
 144            _llm_token_subscription: cx.subscribe(
 145                &refresh_llm_token_listener,
 146                move |this, _listener, _event, cx| {
 147                    let client = this.client.clone();
 148                    let llm_api_token = this.llm_api_token.clone();
 149                    cx.spawn(async move |this, cx| {
 150                        llm_api_token.refresh(&client).await?;
 151                        let response = Self::fetch_models(client, llm_api_token).await?;
 152                        this.update(cx, |this, cx| {
 153                            this.update_models(response, cx);
 154                        })
 155                    })
 156                    .detach_and_log_err(cx);
 157                },
 158            ),
 159        }
 160    }
 161
 162    fn is_signed_out(&self, cx: &App) -> bool {
 163        self.user_store.read(cx).current_user().is_none()
 164    }
 165
 166    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 167        let client = self.client.clone();
 168        cx.spawn(async move |state, cx| {
 169            client.sign_in_with_optional_connect(true, cx).await?;
 170            state.update(cx, |_, cx| cx.notify())
 171        })
 172    }
 173
 174    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 175        let is_thinking_effort_enabled = cx.has_flag::<CloudThinkingEffortFeatureFlag>();
 176
 177        let mut models = Vec::new();
 178
 179        for model in response.models {
 180            models.push(Arc::new(model.clone()));
 181
 182            if !is_thinking_effort_enabled {
 183                // Right now we represent thinking variants of models as separate models on the client,
 184                // so we need to insert variants for any model that supports thinking.
 185                if model.supports_thinking {
 186                    models.push(Arc::new(cloud_llm_client::LanguageModel {
 187                        id: cloud_llm_client::LanguageModelId(
 188                            format!("{}-thinking", model.id).into(),
 189                        ),
 190                        display_name: format!("{} Thinking", model.display_name),
 191                        ..model
 192                    }));
 193                }
 194            }
 195        }
 196
 197        self.default_model = models
 198            .iter()
 199            .find(|model| {
 200                response
 201                    .default_model
 202                    .as_ref()
 203                    .is_some_and(|default_model_id| &model.id == default_model_id)
 204            })
 205            .cloned();
 206        self.default_fast_model = models
 207            .iter()
 208            .find(|model| {
 209                response
 210                    .default_fast_model
 211                    .as_ref()
 212                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 213            })
 214            .cloned();
 215        self.recommended_models = response
 216            .recommended_models
 217            .iter()
 218            .filter_map(|id| models.iter().find(|model| &model.id == id))
 219            .cloned()
 220            .collect();
 221        self.models = models;
 222        cx.notify();
 223    }
 224
 225    async fn fetch_models(
 226        client: Arc<Client>,
 227        llm_api_token: LlmApiToken,
 228    ) -> Result<ListModelsResponse> {
 229        let http_client = &client.http_client();
 230        let token = llm_api_token.acquire(&client).await?;
 231
 232        let request = http_client::Request::builder()
 233            .method(Method::GET)
 234            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 235            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 236            .header("Authorization", format!("Bearer {token}"))
 237            .body(AsyncBody::empty())?;
 238        let mut response = http_client
 239            .send(request)
 240            .await
 241            .context("failed to send list models request")?;
 242
 243        if response.status().is_success() {
 244            let mut body = String::new();
 245            response.body_mut().read_to_string(&mut body).await?;
 246            Ok(serde_json::from_str(&body)?)
 247        } else {
 248            let mut body = String::new();
 249            response.body_mut().read_to_string(&mut body).await?;
 250            anyhow::bail!(
 251                "error listing models.\nStatus: {:?}\nBody: {body}",
 252                response.status(),
 253            );
 254        }
 255    }
 256}
 257
 258impl CloudLanguageModelProvider {
 259    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 260        let mut status_rx = client.status();
 261        let status = *status_rx.borrow();
 262
 263        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 264
 265        let state_ref = state.downgrade();
 266        let maintain_client_status = cx.spawn(async move |cx| {
 267            while let Some(status) = status_rx.next().await {
 268                if let Some(this) = state_ref.upgrade() {
 269                    _ = this.update(cx, |this, cx| {
 270                        if this.status != status {
 271                            this.status = status;
 272                            cx.notify();
 273                        }
 274                    });
 275                } else {
 276                    break;
 277                }
 278            }
 279        });
 280
 281        Self {
 282            client,
 283            state,
 284            _maintain_client_status: maintain_client_status,
 285        }
 286    }
 287
 288    fn create_language_model(
 289        &self,
 290        model: Arc<cloud_llm_client::LanguageModel>,
 291        llm_api_token: LlmApiToken,
 292    ) -> Arc<dyn LanguageModel> {
 293        Arc::new(CloudLanguageModel {
 294            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 295            model,
 296            llm_api_token,
 297            client: self.client.clone(),
 298            request_limiter: RateLimiter::new(4),
 299        })
 300    }
 301}
 302
 303impl LanguageModelProviderState for CloudLanguageModelProvider {
 304    type ObservableEntity = State;
 305
 306    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 307        Some(self.state.clone())
 308    }
 309}
 310
 311impl LanguageModelProvider for CloudLanguageModelProvider {
 312    fn id(&self) -> LanguageModelProviderId {
 313        PROVIDER_ID
 314    }
 315
 316    fn name(&self) -> LanguageModelProviderName {
 317        PROVIDER_NAME
 318    }
 319
 320    fn icon(&self) -> IconOrSvg {
 321        IconOrSvg::Icon(IconName::AiZed)
 322    }
 323
 324    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 325        let default_model = self.state.read(cx).default_model.clone()?;
 326        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 327        Some(self.create_language_model(default_model, llm_api_token))
 328    }
 329
 330    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 331        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 332        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 333        Some(self.create_language_model(default_fast_model, llm_api_token))
 334    }
 335
 336    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 337        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 338        self.state
 339            .read(cx)
 340            .recommended_models
 341            .iter()
 342            .cloned()
 343            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 344            .collect()
 345    }
 346
 347    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 348        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 349        self.state
 350            .read(cx)
 351            .models
 352            .iter()
 353            .cloned()
 354            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 355            .collect()
 356    }
 357
 358    fn is_authenticated(&self, cx: &App) -> bool {
 359        let state = self.state.read(cx);
 360        !state.is_signed_out(cx)
 361    }
 362
 363    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 364        Task::ready(Ok(()))
 365    }
 366
 367    fn configuration_view(
 368        &self,
 369        _target_agent: language_model::ConfigurationViewTargetAgent,
 370        _: &mut Window,
 371        cx: &mut App,
 372    ) -> AnyView {
 373        cx.new(|_| ConfigurationView::new(self.state.clone()))
 374            .into()
 375    }
 376
 377    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 378        Task::ready(Ok(()))
 379    }
 380}
 381
 382pub struct CloudLanguageModel {
 383    id: LanguageModelId,
 384    model: Arc<cloud_llm_client::LanguageModel>,
 385    llm_api_token: LlmApiToken,
 386    client: Arc<Client>,
 387    request_limiter: RateLimiter,
 388}
 389
 390struct PerformLlmCompletionResponse {
 391    response: Response<AsyncBody>,
 392    includes_status_messages: bool,
 393}
 394
 395impl CloudLanguageModel {
 396    async fn perform_llm_completion(
 397        client: Arc<Client>,
 398        llm_api_token: LlmApiToken,
 399        app_version: Option<Version>,
 400        body: CompletionBody,
 401    ) -> Result<PerformLlmCompletionResponse> {
 402        let http_client = &client.http_client();
 403
 404        let mut token = llm_api_token.acquire(&client).await?;
 405        let mut refreshed_token = false;
 406
 407        loop {
 408            let request = http_client::Request::builder()
 409                .method(Method::POST)
 410                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 411                .when_some(app_version.as_ref(), |builder, app_version| {
 412                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 413                })
 414                .header("Content-Type", "application/json")
 415                .header("Authorization", format!("Bearer {token}"))
 416                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 417                // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
 418                // .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 419                .body(serde_json::to_string(&body)?.into())?;
 420
 421            let mut response = http_client.send(request).await?;
 422            let status = response.status();
 423            if status.is_success() {
 424                let includes_status_messages = response
 425                    .headers()
 426                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 427                    .is_some();
 428
 429                return Ok(PerformLlmCompletionResponse {
 430                    response,
 431                    includes_status_messages,
 432                });
 433            }
 434
 435            if !refreshed_token && response.needs_llm_token_refresh() {
 436                token = llm_api_token.refresh(&client).await?;
 437                refreshed_token = true;
 438                continue;
 439            }
 440
 441            if status == StatusCode::PAYMENT_REQUIRED {
 442                return Err(anyhow!(PaymentRequiredError));
 443            }
 444
 445            let mut body = String::new();
 446            let headers = response.headers().clone();
 447            response.body_mut().read_to_string(&mut body).await?;
 448            return Err(anyhow!(ApiError {
 449                status,
 450                body,
 451                headers
 452            }));
 453        }
 454    }
 455}
 456
 457#[derive(Debug, Error)]
 458#[error("cloud language model request failed with status {status}: {body}")]
 459struct ApiError {
 460    status: StatusCode,
 461    body: String,
 462    headers: HeaderMap<HeaderValue>,
 463}
 464
 465/// Represents error responses from Zed's cloud API.
 466///
 467/// Example JSON for an upstream HTTP error:
 468/// ```json
 469/// {
 470///   "code": "upstream_http_error",
 471///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 472///   "upstream_status": 503
 473/// }
 474/// ```
 475#[derive(Debug, serde::Deserialize)]
 476struct CloudApiError {
 477    code: String,
 478    message: String,
 479    #[serde(default)]
 480    #[serde(deserialize_with = "deserialize_optional_status_code")]
 481    upstream_status: Option<StatusCode>,
 482    #[serde(default)]
 483    retry_after: Option<f64>,
 484}
 485
 486fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 487where
 488    D: serde::Deserializer<'de>,
 489{
 490    let opt: Option<u16> = Option::deserialize(deserializer)?;
 491    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 492}
 493
 494impl From<ApiError> for LanguageModelCompletionError {
 495    fn from(error: ApiError) -> Self {
 496        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 497            if cloud_error.code.starts_with("upstream_http_") {
 498                let status = if let Some(status) = cloud_error.upstream_status {
 499                    status
 500                } else if cloud_error.code.ends_with("_error") {
 501                    error.status
 502                } else {
 503                    // If there's a status code in the code string (e.g. "upstream_http_429")
 504                    // then use that; otherwise, see if the JSON contains a status code.
 505                    cloud_error
 506                        .code
 507                        .strip_prefix("upstream_http_")
 508                        .and_then(|code_str| code_str.parse::<u16>().ok())
 509                        .and_then(|code| StatusCode::from_u16(code).ok())
 510                        .unwrap_or(error.status)
 511                };
 512
 513                return LanguageModelCompletionError::UpstreamProviderError {
 514                    message: cloud_error.message,
 515                    status,
 516                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 517                };
 518            }
 519
 520            return LanguageModelCompletionError::from_http_status(
 521                PROVIDER_NAME,
 522                error.status,
 523                cloud_error.message,
 524                None,
 525            );
 526        }
 527
 528        let retry_after = None;
 529        LanguageModelCompletionError::from_http_status(
 530            PROVIDER_NAME,
 531            error.status,
 532            error.body,
 533            retry_after,
 534        )
 535    }
 536}
 537
 538impl LanguageModel for CloudLanguageModel {
 539    fn id(&self) -> LanguageModelId {
 540        self.id.clone()
 541    }
 542
 543    fn name(&self) -> LanguageModelName {
 544        LanguageModelName::from(self.model.display_name.clone())
 545    }
 546
 547    fn provider_id(&self) -> LanguageModelProviderId {
 548        PROVIDER_ID
 549    }
 550
 551    fn provider_name(&self) -> LanguageModelProviderName {
 552        PROVIDER_NAME
 553    }
 554
 555    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 556        use cloud_llm_client::LanguageModelProvider::*;
 557        match self.model.provider {
 558            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 559            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 560            Google => language_model::GOOGLE_PROVIDER_ID,
 561            XAi => language_model::X_AI_PROVIDER_ID,
 562        }
 563    }
 564
 565    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 566        use cloud_llm_client::LanguageModelProvider::*;
 567        match self.model.provider {
 568            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 569            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 570            Google => language_model::GOOGLE_PROVIDER_NAME,
 571            XAi => language_model::X_AI_PROVIDER_NAME,
 572        }
 573    }
 574
 575    fn is_latest(&self) -> bool {
 576        self.model.is_latest
 577    }
 578
 579    fn supports_tools(&self) -> bool {
 580        self.model.supports_tools
 581    }
 582
 583    fn supports_images(&self) -> bool {
 584        self.model.supports_images
 585    }
 586
 587    fn supports_thinking(&self) -> bool {
 588        self.model.supports_thinking
 589    }
 590
 591    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 592        self.model
 593            .supported_effort_levels
 594            .iter()
 595            .map(|effort_level| LanguageModelEffortLevel {
 596                name: effort_level.name.clone().into(),
 597                value: effort_level.value.clone().into(),
 598                is_default: effort_level.is_default.unwrap_or(false),
 599            })
 600            .collect()
 601    }
 602
 603    fn supports_streaming_tools(&self) -> bool {
 604        self.model.supports_streaming_tools
 605    }
 606
 607    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 608        match choice {
 609            LanguageModelToolChoice::Auto
 610            | LanguageModelToolChoice::Any
 611            | LanguageModelToolChoice::None => true,
 612        }
 613    }
 614
 615    fn supports_split_token_display(&self) -> bool {
 616        use cloud_llm_client::LanguageModelProvider::*;
 617        matches!(self.model.provider, OpenAi)
 618    }
 619
 620    fn telemetry_id(&self) -> String {
 621        format!("zed.dev/{}", self.model.id)
 622    }
 623
 624    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 625        match self.model.provider {
 626            cloud_llm_client::LanguageModelProvider::Anthropic
 627            | cloud_llm_client::LanguageModelProvider::OpenAi
 628            | cloud_llm_client::LanguageModelProvider::XAi => {
 629                LanguageModelToolSchemaFormat::JsonSchema
 630            }
 631            cloud_llm_client::LanguageModelProvider::Google => {
 632                LanguageModelToolSchemaFormat::JsonSchemaSubset
 633            }
 634        }
 635    }
 636
 637    fn max_token_count(&self) -> u64 {
 638        self.model.max_token_count as u64
 639    }
 640
 641    fn max_output_tokens(&self) -> Option<u64> {
 642        Some(self.model.max_output_tokens as u64)
 643    }
 644
 645    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 646        match &self.model.provider {
 647            cloud_llm_client::LanguageModelProvider::Anthropic => {
 648                Some(LanguageModelCacheConfiguration {
 649                    min_total_token: 2_048,
 650                    should_speculate: true,
 651                    max_cache_anchors: 4,
 652                })
 653            }
 654            cloud_llm_client::LanguageModelProvider::OpenAi
 655            | cloud_llm_client::LanguageModelProvider::XAi
 656            | cloud_llm_client::LanguageModelProvider::Google => None,
 657        }
 658    }
 659
 660    fn count_tokens(
 661        &self,
 662        request: LanguageModelRequest,
 663        cx: &App,
 664    ) -> BoxFuture<'static, Result<u64>> {
 665        match self.model.provider {
 666            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 667                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 668                .boxed(),
 669            cloud_llm_client::LanguageModelProvider::OpenAi => {
 670                let model = match open_ai::Model::from_id(&self.model.id.0) {
 671                    Ok(model) => model,
 672                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 673                };
 674                count_open_ai_tokens(request, model, cx)
 675            }
 676            cloud_llm_client::LanguageModelProvider::XAi => {
 677                let model = match x_ai::Model::from_id(&self.model.id.0) {
 678                    Ok(model) => model,
 679                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 680                };
 681                count_xai_tokens(request, model, cx)
 682            }
 683            cloud_llm_client::LanguageModelProvider::Google => {
 684                let client = self.client.clone();
 685                let llm_api_token = self.llm_api_token.clone();
 686                let model_id = self.model.id.to_string();
 687                let generate_content_request =
 688                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 689                async move {
 690                    let http_client = &client.http_client();
 691                    let token = llm_api_token.acquire(&client).await?;
 692
 693                    let request_body = CountTokensBody {
 694                        provider: cloud_llm_client::LanguageModelProvider::Google,
 695                        model: model_id,
 696                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 697                            generate_content_request,
 698                        })?,
 699                    };
 700                    let request = http_client::Request::builder()
 701                        .method(Method::POST)
 702                        .uri(
 703                            http_client
 704                                .build_zed_llm_url("/count_tokens", &[])?
 705                                .as_ref(),
 706                        )
 707                        .header("Content-Type", "application/json")
 708                        .header("Authorization", format!("Bearer {token}"))
 709                        .body(serde_json::to_string(&request_body)?.into())?;
 710                    let mut response = http_client.send(request).await?;
 711                    let status = response.status();
 712                    let headers = response.headers().clone();
 713                    let mut response_body = String::new();
 714                    response
 715                        .body_mut()
 716                        .read_to_string(&mut response_body)
 717                        .await?;
 718
 719                    if status.is_success() {
 720                        let response_body: CountTokensResponse =
 721                            serde_json::from_str(&response_body)?;
 722
 723                        Ok(response_body.tokens as u64)
 724                    } else {
 725                        Err(anyhow!(ApiError {
 726                            status,
 727                            body: response_body,
 728                            headers
 729                        }))
 730                    }
 731                }
 732                .boxed()
 733            }
 734        }
 735    }
 736
 737    fn stream_completion(
 738        &self,
 739        request: LanguageModelRequest,
 740        cx: &AsyncApp,
 741    ) -> BoxFuture<
 742        'static,
 743        Result<
 744            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 745            LanguageModelCompletionError,
 746        >,
 747    > {
 748        let thread_id = request.thread_id.clone();
 749        let prompt_id = request.prompt_id.clone();
 750        let intent = request.intent;
 751        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 752        let thinking_allowed = request.thinking_allowed;
 753        let is_thinking_effort_enabled =
 754            cx.update(|cx| cx.has_flag::<CloudThinkingEffortFeatureFlag>());
 755        let enable_thinking = if is_thinking_effort_enabled {
 756            thinking_allowed && self.model.supports_thinking
 757        } else {
 758            thinking_allowed && self.model.id.0.ends_with("-thinking")
 759        };
 760        let provider_name = provider_name(&self.model.provider);
 761        match self.model.provider {
 762            cloud_llm_client::LanguageModelProvider::Anthropic => {
 763                let effort = request
 764                    .thinking_effort
 765                    .as_ref()
 766                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 767
 768                let mut request = into_anthropic(
 769                    request,
 770                    self.model.id.to_string(),
 771                    1.0,
 772                    self.model.max_output_tokens as u64,
 773                    if enable_thinking {
 774                        AnthropicModelMode::Thinking {
 775                            budget_tokens: Some(4_096),
 776                        }
 777                    } else {
 778                        AnthropicModelMode::Default
 779                    },
 780                );
 781
 782                if enable_thinking && effort.is_some() {
 783                    request.thinking = Some(anthropic::Thinking::Adaptive);
 784                    request.output_config = Some(anthropic::OutputConfig { effort });
 785                }
 786
 787                let client = self.client.clone();
 788                let llm_api_token = self.llm_api_token.clone();
 789                let future = self.request_limiter.stream(async move {
 790                    let PerformLlmCompletionResponse {
 791                        response,
 792                        includes_status_messages,
 793                    } = Self::perform_llm_completion(
 794                        client.clone(),
 795                        llm_api_token,
 796                        app_version,
 797                        CompletionBody {
 798                            thread_id,
 799                            prompt_id,
 800                            intent,
 801                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 802                            model: request.model.clone(),
 803                            provider_request: serde_json::to_value(&request)
 804                                .map_err(|e| anyhow!(e))?,
 805                        },
 806                    )
 807                    .await
 808                    .map_err(|err| match err.downcast::<ApiError>() {
 809                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 810                        Err(err) => anyhow!(err),
 811                    })?;
 812
 813                    let mut mapper = AnthropicEventMapper::new();
 814                    Ok(map_cloud_completion_events(
 815                        Box::pin(response_lines(response, includes_status_messages)),
 816                        &provider_name,
 817                        move |event| mapper.map_event(event),
 818                    ))
 819                });
 820                async move { Ok(future.await?.boxed()) }.boxed()
 821            }
 822            cloud_llm_client::LanguageModelProvider::OpenAi => {
 823                let client = self.client.clone();
 824                let llm_api_token = self.llm_api_token.clone();
 825                let effort = request
 826                    .thinking_effort
 827                    .as_ref()
 828                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 829
 830                let mut request = into_open_ai_response(
 831                    request,
 832                    &self.model.id.0,
 833                    self.model.supports_parallel_tool_calls,
 834                    true,
 835                    None,
 836                    None,
 837                );
 838
 839                if enable_thinking && let Some(effort) = effort {
 840                    request.reasoning = Some(open_ai::responses::ReasoningConfig { effort });
 841                }
 842
 843                let future = self.request_limiter.stream(async move {
 844                    let PerformLlmCompletionResponse {
 845                        response,
 846                        includes_status_messages,
 847                    } = Self::perform_llm_completion(
 848                        client.clone(),
 849                        llm_api_token,
 850                        app_version,
 851                        CompletionBody {
 852                            thread_id,
 853                            prompt_id,
 854                            intent,
 855                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 856                            model: request.model.clone(),
 857                            provider_request: serde_json::to_value(&request)
 858                                .map_err(|e| anyhow!(e))?,
 859                        },
 860                    )
 861                    .await?;
 862
 863                    let mut mapper = OpenAiResponseEventMapper::new();
 864                    Ok(map_cloud_completion_events(
 865                        Box::pin(response_lines(response, includes_status_messages)),
 866                        &provider_name,
 867                        move |event| mapper.map_event(event),
 868                    ))
 869                });
 870                async move { Ok(future.await?.boxed()) }.boxed()
 871            }
 872            cloud_llm_client::LanguageModelProvider::XAi => {
 873                let client = self.client.clone();
 874                let request = into_open_ai(
 875                    request,
 876                    &self.model.id.0,
 877                    self.model.supports_parallel_tool_calls,
 878                    false,
 879                    None,
 880                    None,
 881                );
 882                let llm_api_token = self.llm_api_token.clone();
 883                let future = self.request_limiter.stream(async move {
 884                    let PerformLlmCompletionResponse {
 885                        response,
 886                        includes_status_messages,
 887                    } = Self::perform_llm_completion(
 888                        client.clone(),
 889                        llm_api_token,
 890                        app_version,
 891                        CompletionBody {
 892                            thread_id,
 893                            prompt_id,
 894                            intent,
 895                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 896                            model: request.model.clone(),
 897                            provider_request: serde_json::to_value(&request)
 898                                .map_err(|e| anyhow!(e))?,
 899                        },
 900                    )
 901                    .await?;
 902
 903                    let mut mapper = OpenAiEventMapper::new();
 904                    Ok(map_cloud_completion_events(
 905                        Box::pin(response_lines(response, includes_status_messages)),
 906                        &provider_name,
 907                        move |event| mapper.map_event(event),
 908                    ))
 909                });
 910                async move { Ok(future.await?.boxed()) }.boxed()
 911            }
 912            cloud_llm_client::LanguageModelProvider::Google => {
 913                let client = self.client.clone();
 914                let request =
 915                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 916                let llm_api_token = self.llm_api_token.clone();
 917                let future = self.request_limiter.stream(async move {
 918                    let PerformLlmCompletionResponse {
 919                        response,
 920                        includes_status_messages,
 921                    } = Self::perform_llm_completion(
 922                        client.clone(),
 923                        llm_api_token,
 924                        app_version,
 925                        CompletionBody {
 926                            thread_id,
 927                            prompt_id,
 928                            intent,
 929                            provider: cloud_llm_client::LanguageModelProvider::Google,
 930                            model: request.model.model_id.clone(),
 931                            provider_request: serde_json::to_value(&request)
 932                                .map_err(|e| anyhow!(e))?,
 933                        },
 934                    )
 935                    .await?;
 936
 937                    let mut mapper = GoogleEventMapper::new();
 938                    Ok(map_cloud_completion_events(
 939                        Box::pin(response_lines(response, includes_status_messages)),
 940                        &provider_name,
 941                        move |event| mapper.map_event(event),
 942                    ))
 943                });
 944                async move { Ok(future.await?.boxed()) }.boxed()
 945            }
 946        }
 947    }
 948}
 949
 950fn map_cloud_completion_events<T, F>(
 951    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 952    provider: &LanguageModelProviderName,
 953    mut map_callback: F,
 954) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 955where
 956    T: DeserializeOwned + 'static,
 957    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 958        + Send
 959        + 'static,
 960{
 961    let provider = provider.clone();
 962    let mut stream = stream.fuse();
 963
 964    // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
 965    // let mut saw_stream_ended = false;
 966
 967    let mut done = false;
 968    let mut pending = VecDeque::new();
 969
 970    stream::poll_fn(move |cx| {
 971        loop {
 972            if let Some(item) = pending.pop_front() {
 973                return Poll::Ready(Some(item));
 974            }
 975
 976            if done {
 977                return Poll::Ready(None);
 978            }
 979
 980            match stream.poll_next_unpin(cx) {
 981                Poll::Ready(Some(event)) => {
 982                    let items = match event {
 983                        Err(error) => {
 984                            vec![Err(LanguageModelCompletionError::from(error))]
 985                        }
 986                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
 987                            // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
 988                            // let mut saw_stream_ended = false;
 989                            //
 990                            // saw_stream_ended = true;
 991                            vec![]
 992                        }
 993                        Ok(CompletionEvent::Status(status)) => {
 994                            LanguageModelCompletionEvent::from_completion_request_status(
 995                                status,
 996                                provider.clone(),
 997                            )
 998                            .transpose()
 999                            .map(|event| vec![event])
1000                            .unwrap_or_default()
1001                        }
1002                        Ok(CompletionEvent::Event(event)) => map_callback(event),
1003                    };
1004                    pending.extend(items);
1005                }
1006                Poll::Ready(None) => {
1007                    done = true;
1008
1009                    // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
1010                    //
1011                    // if !saw_stream_ended {
1012                    //     return Poll::Ready(Some(Err(
1013                    //         LanguageModelCompletionError::StreamEndedUnexpectedly {
1014                    //             provider: provider.clone(),
1015                    //         },
1016                    //     )));
1017                    // }
1018                }
1019                Poll::Pending => return Poll::Pending,
1020            }
1021        }
1022    })
1023    .boxed()
1024}
1025
1026fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1027    match provider {
1028        cloud_llm_client::LanguageModelProvider::Anthropic => {
1029            language_model::ANTHROPIC_PROVIDER_NAME
1030        }
1031        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1032        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1033        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1034    }
1035}
1036
1037fn response_lines<T: DeserializeOwned>(
1038    response: Response<AsyncBody>,
1039    includes_status_messages: bool,
1040) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1041    futures::stream::try_unfold(
1042        (String::new(), BufReader::new(response.into_body())),
1043        move |(mut line, mut body)| async move {
1044            match body.read_line(&mut line).await {
1045                Ok(0) => Ok(None),
1046                Ok(_) => {
1047                    let event = if includes_status_messages {
1048                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1049                    } else {
1050                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1051                    };
1052
1053                    line.clear();
1054                    Ok(Some((event, (line, body))))
1055                }
1056                Err(e) => Err(e.into()),
1057            }
1058        },
1059    )
1060}
1061
1062#[derive(IntoElement, RegisterComponent)]
1063struct ZedAiConfiguration {
1064    is_connected: bool,
1065    plan: Option<Plan>,
1066    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1067    eligible_for_trial: bool,
1068    account_too_young: bool,
1069    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1070}
1071
1072impl RenderOnce for ZedAiConfiguration {
1073    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1074        let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
1075        let subscription_text = match (self.plan, self.subscription_period) {
1076            (Some(Plan::ZedPro), Some(_)) => {
1077                "You have access to Zed's hosted models through your Pro subscription."
1078            }
1079            (Some(Plan::ZedProTrial), Some(_)) => {
1080                "You have access to Zed's hosted models through your Pro trial."
1081            }
1082            (Some(Plan::ZedFree), Some(_)) => {
1083                if self.eligible_for_trial {
1084                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1085                } else {
1086                    "Subscribe for access to Zed's hosted models."
1087                }
1088            }
1089            _ => {
1090                if self.eligible_for_trial {
1091                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1092                } else {
1093                    "Subscribe for access to Zed's hosted models."
1094                }
1095            }
1096        };
1097
1098        let manage_subscription_buttons = if is_pro {
1099            Button::new("manage_settings", "Manage Subscription")
1100                .full_width()
1101                .style(ButtonStyle::Tinted(TintColor::Accent))
1102                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1103                .into_any_element()
1104        } else if self.plan.is_none() || self.eligible_for_trial {
1105            Button::new("start_trial", "Start 14-day Free Pro Trial")
1106                .full_width()
1107                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1108                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1109                .into_any_element()
1110        } else {
1111            Button::new("upgrade", "Upgrade to Pro")
1112                .full_width()
1113                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1114                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1115                .into_any_element()
1116        };
1117
1118        if !self.is_connected {
1119            return v_flex()
1120                .gap_2()
1121                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1122                .child(
1123                    Button::new("sign_in", "Sign In to use Zed AI")
1124                        .icon_color(Color::Muted)
1125                        .icon(IconName::Github)
1126                        .icon_size(IconSize::Small)
1127                        .icon_position(IconPosition::Start)
1128                        .full_width()
1129                        .on_click({
1130                            let callback = self.sign_in_callback.clone();
1131                            move |_, window, cx| (callback)(window, cx)
1132                        }),
1133                );
1134        }
1135
1136        v_flex().gap_2().w_full().map(|this| {
1137            if self.account_too_young {
1138                this.child(YoungAccountBanner).child(
1139                    Button::new("upgrade", "Upgrade to Pro")
1140                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1141                        .full_width()
1142                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1143                )
1144            } else {
1145                this.text_sm()
1146                    .child(subscription_text)
1147                    .child(manage_subscription_buttons)
1148            }
1149        })
1150    }
1151}
1152
1153struct ConfigurationView {
1154    state: Entity<State>,
1155    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1156}
1157
1158impl ConfigurationView {
1159    fn new(state: Entity<State>) -> Self {
1160        let sign_in_callback = Arc::new({
1161            let state = state.clone();
1162            move |_window: &mut Window, cx: &mut App| {
1163                state.update(cx, |state, cx| {
1164                    state.authenticate(cx).detach_and_log_err(cx);
1165                });
1166            }
1167        });
1168
1169        Self {
1170            state,
1171            sign_in_callback,
1172        }
1173    }
1174}
1175
1176impl Render for ConfigurationView {
1177    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1178        let state = self.state.read(cx);
1179        let user_store = state.user_store.read(cx);
1180
1181        ZedAiConfiguration {
1182            is_connected: !state.is_signed_out(cx),
1183            plan: user_store.plan(),
1184            subscription_period: user_store.subscription_period(),
1185            eligible_for_trial: user_store.trial_started_at().is_none(),
1186            account_too_young: user_store.account_too_young(),
1187            sign_in_callback: self.sign_in_callback.clone(),
1188        }
1189    }
1190}
1191
1192impl Component for ZedAiConfiguration {
1193    fn name() -> &'static str {
1194        "AI Configuration Content"
1195    }
1196
1197    fn sort_name() -> &'static str {
1198        "AI Configuration Content"
1199    }
1200
1201    fn scope() -> ComponentScope {
1202        ComponentScope::Onboarding
1203    }
1204
1205    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1206        fn configuration(
1207            is_connected: bool,
1208            plan: Option<Plan>,
1209            eligible_for_trial: bool,
1210            account_too_young: bool,
1211        ) -> AnyElement {
1212            ZedAiConfiguration {
1213                is_connected,
1214                plan,
1215                subscription_period: plan
1216                    .is_some()
1217                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1218                eligible_for_trial,
1219                account_too_young,
1220                sign_in_callback: Arc::new(|_, _| {}),
1221            }
1222            .into_any_element()
1223        }
1224
1225        Some(
1226            v_flex()
1227                .p_4()
1228                .gap_4()
1229                .children(vec![
1230                    single_example("Not connected", configuration(false, None, false, false)),
1231                    single_example(
1232                        "Accept Terms of Service",
1233                        configuration(true, None, true, false),
1234                    ),
1235                    single_example(
1236                        "No Plan - Not eligible for trial",
1237                        configuration(true, None, false, false),
1238                    ),
1239                    single_example(
1240                        "No Plan - Eligible for trial",
1241                        configuration(true, None, true, false),
1242                    ),
1243                    single_example(
1244                        "Free Plan",
1245                        configuration(true, Some(Plan::ZedFree), true, false),
1246                    ),
1247                    single_example(
1248                        "Zed Pro Trial Plan",
1249                        configuration(true, Some(Plan::ZedProTrial), true, false),
1250                    ),
1251                    single_example(
1252                        "Zed Pro Plan",
1253                        configuration(true, Some(Plan::ZedPro), true, false),
1254                    ),
1255                ])
1256                .into_any_element(),
1257        )
1258    }
1259}
1260
1261#[cfg(test)]
1262mod tests {
1263    use super::*;
1264    use http_client::http::{HeaderMap, StatusCode};
1265    use language_model::LanguageModelCompletionError;
1266
1267    #[test]
1268    fn test_api_error_conversion_with_upstream_http_error() {
1269        // upstream_http_error with 503 status should become ServerOverloaded
1270        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1271
1272        let api_error = ApiError {
1273            status: StatusCode::INTERNAL_SERVER_ERROR,
1274            body: error_body.to_string(),
1275            headers: HeaderMap::new(),
1276        };
1277
1278        let completion_error: LanguageModelCompletionError = api_error.into();
1279
1280        match completion_error {
1281            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1282                assert_eq!(
1283                    message,
1284                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1285                );
1286            }
1287            _ => panic!(
1288                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1289                completion_error
1290            ),
1291        }
1292
1293        // upstream_http_error with 500 status should become ApiInternalServerError
1294        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1295
1296        let api_error = ApiError {
1297            status: StatusCode::INTERNAL_SERVER_ERROR,
1298            body: error_body.to_string(),
1299            headers: HeaderMap::new(),
1300        };
1301
1302        let completion_error: LanguageModelCompletionError = api_error.into();
1303
1304        match completion_error {
1305            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1306                assert_eq!(
1307                    message,
1308                    "Received an error from the OpenAI API: internal server error"
1309                );
1310            }
1311            _ => panic!(
1312                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1313                completion_error
1314            ),
1315        }
1316
1317        // upstream_http_error with 429 status should become RateLimitExceeded
1318        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1319
1320        let api_error = ApiError {
1321            status: StatusCode::INTERNAL_SERVER_ERROR,
1322            body: error_body.to_string(),
1323            headers: HeaderMap::new(),
1324        };
1325
1326        let completion_error: LanguageModelCompletionError = api_error.into();
1327
1328        match completion_error {
1329            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1330                assert_eq!(
1331                    message,
1332                    "Received an error from the Google API: rate limit exceeded"
1333                );
1334            }
1335            _ => panic!(
1336                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1337                completion_error
1338            ),
1339        }
1340
1341        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1342        let error_body = "Regular internal server error";
1343
1344        let api_error = ApiError {
1345            status: StatusCode::INTERNAL_SERVER_ERROR,
1346            body: error_body.to_string(),
1347            headers: HeaderMap::new(),
1348        };
1349
1350        let completion_error: LanguageModelCompletionError = api_error.into();
1351
1352        match completion_error {
1353            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1354                assert_eq!(provider, PROVIDER_NAME);
1355                assert_eq!(message, "Regular internal server error");
1356            }
1357            _ => panic!(
1358                "Expected ApiInternalServerError for regular 500, got: {:?}",
1359                completion_error
1360            ),
1361        }
1362
1363        // upstream_http_429 format should be converted to UpstreamProviderError
1364        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1365
1366        let api_error = ApiError {
1367            status: StatusCode::INTERNAL_SERVER_ERROR,
1368            body: error_body.to_string(),
1369            headers: HeaderMap::new(),
1370        };
1371
1372        let completion_error: LanguageModelCompletionError = api_error.into();
1373
1374        match completion_error {
1375            LanguageModelCompletionError::UpstreamProviderError {
1376                message,
1377                status,
1378                retry_after,
1379            } => {
1380                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1381                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1382                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1383            }
1384            _ => panic!(
1385                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1386                completion_error
1387            ),
1388        }
1389
1390        // Invalid JSON in error body should fall back to regular error handling
1391        let error_body = "Not JSON at all";
1392
1393        let api_error = ApiError {
1394            status: StatusCode::INTERNAL_SERVER_ERROR,
1395            body: error_body.to_string(),
1396            headers: HeaderMap::new(),
1397        };
1398
1399        let completion_error: LanguageModelCompletionError = api_error.into();
1400
1401        match completion_error {
1402            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1403                assert_eq!(provider, PROVIDER_NAME);
1404            }
1405            _ => panic!(
1406                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1407                completion_error
1408            ),
1409        }
1410    }
1411}