cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   6use cloud_llm_client::{
   7    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME,
   8    CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
   9    CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
  10    MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, PlanV1, PlanV2,
  11    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  12    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  13};
  14use feature_flags::{FeatureFlagAppExt as _, OpenAiResponsesApiFeatureFlag};
  15use futures::{
  16    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
  17};
  18use google_ai::GoogleModelMode;
  19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  25    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  26    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  27    LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
  28    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::pin::Pin;
  39use std::str::FromStr as _;
  40use std::sync::Arc;
  41use std::time::Duration;
  42use thiserror::Error;
  43use ui::{TintColor, prelude::*};
  44use util::{ResultExt as _, maybe};
  45
  46use crate::provider::anthropic::{
  47    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  48};
  49use crate::provider::google::{GoogleEventMapper, into_google};
  50use crate::provider::open_ai::{
  51    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  52    into_open_ai_response,
  53};
  54use crate::provider::x_ai::count_xai_tokens;
  55
  56const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  57const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  58
  59#[derive(Default, Clone, Debug, PartialEq)]
  60pub struct ZedDotDevSettings {
  61    pub available_models: Vec<AvailableModel>,
  62}
  63#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  64#[serde(tag = "type", rename_all = "lowercase")]
  65pub enum ModelMode {
  66    #[default]
  67    Default,
  68    Thinking {
  69        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  70        budget_tokens: Option<u32>,
  71    },
  72}
  73
  74impl From<ModelMode> for AnthropicModelMode {
  75    fn from(value: ModelMode) -> Self {
  76        match value {
  77            ModelMode::Default => AnthropicModelMode::Default,
  78            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  79        }
  80    }
  81}
  82
  83pub struct CloudLanguageModelProvider {
  84    client: Arc<Client>,
  85    state: Entity<State>,
  86    _maintain_client_status: Task<()>,
  87}
  88
  89pub struct State {
  90    client: Arc<Client>,
  91    llm_api_token: LlmApiToken,
  92    user_store: Entity<UserStore>,
  93    status: client::Status,
  94    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  95    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  96    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  97    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  98    _fetch_models_task: Task<()>,
  99    _settings_subscription: Subscription,
 100    _llm_token_subscription: Subscription,
 101}
 102
 103impl State {
 104    fn new(
 105        client: Arc<Client>,
 106        user_store: Entity<UserStore>,
 107        status: client::Status,
 108        cx: &mut Context<Self>,
 109    ) -> Self {
 110        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 111        let mut current_user = user_store.read(cx).watch_current_user();
 112        Self {
 113            client: client.clone(),
 114            llm_api_token: LlmApiToken::default(),
 115            user_store,
 116            status,
 117            models: Vec::new(),
 118            default_model: None,
 119            default_fast_model: None,
 120            recommended_models: Vec::new(),
 121            _fetch_models_task: cx.spawn(async move |this, cx| {
 122                maybe!(async move {
 123                    let (client, llm_api_token) = this
 124                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 125
 126                    while current_user.borrow().is_none() {
 127                        current_user.next().await;
 128                    }
 129
 130                    let response =
 131                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 132                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 133                    anyhow::Ok(())
 134                })
 135                .await
 136                .context("failed to fetch Zed models")
 137                .log_err();
 138            }),
 139            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 140                cx.notify();
 141            }),
 142            _llm_token_subscription: cx.subscribe(
 143                &refresh_llm_token_listener,
 144                move |this, _listener, _event, cx| {
 145                    let client = this.client.clone();
 146                    let llm_api_token = this.llm_api_token.clone();
 147                    cx.spawn(async move |this, cx| {
 148                        llm_api_token.refresh(&client).await?;
 149                        let response = Self::fetch_models(client, llm_api_token).await?;
 150                        this.update(cx, |this, cx| {
 151                            this.update_models(response, cx);
 152                        })
 153                    })
 154                    .detach_and_log_err(cx);
 155                },
 156            ),
 157        }
 158    }
 159
 160    fn is_signed_out(&self, cx: &App) -> bool {
 161        self.user_store.read(cx).current_user().is_none()
 162    }
 163
 164    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 165        let client = self.client.clone();
 166        cx.spawn(async move |state, cx| {
 167            client.sign_in_with_optional_connect(true, cx).await?;
 168            state.update(cx, |_, cx| cx.notify())
 169        })
 170    }
 171    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 172        let mut models = Vec::new();
 173
 174        for model in response.models {
 175            models.push(Arc::new(model.clone()));
 176
 177            // Right now we represent thinking variants of models as separate models on the client,
 178            // so we need to insert variants for any model that supports thinking.
 179            if model.supports_thinking {
 180                models.push(Arc::new(cloud_llm_client::LanguageModel {
 181                    id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 182                    display_name: format!("{} Thinking", model.display_name),
 183                    ..model
 184                }));
 185            }
 186        }
 187
 188        self.default_model = models
 189            .iter()
 190            .find(|model| {
 191                response
 192                    .default_model
 193                    .as_ref()
 194                    .is_some_and(|default_model_id| &model.id == default_model_id)
 195            })
 196            .cloned();
 197        self.default_fast_model = models
 198            .iter()
 199            .find(|model| {
 200                response
 201                    .default_fast_model
 202                    .as_ref()
 203                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 204            })
 205            .cloned();
 206        self.recommended_models = response
 207            .recommended_models
 208            .iter()
 209            .filter_map(|id| models.iter().find(|model| &model.id == id))
 210            .cloned()
 211            .collect();
 212        self.models = models;
 213        cx.notify();
 214    }
 215
 216    async fn fetch_models(
 217        client: Arc<Client>,
 218        llm_api_token: LlmApiToken,
 219    ) -> Result<ListModelsResponse> {
 220        let http_client = &client.http_client();
 221        let token = llm_api_token.acquire(&client).await?;
 222
 223        let request = http_client::Request::builder()
 224            .method(Method::GET)
 225            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 226            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 227            .header("Authorization", format!("Bearer {token}"))
 228            .body(AsyncBody::empty())?;
 229        let mut response = http_client
 230            .send(request)
 231            .await
 232            .context("failed to send list models request")?;
 233
 234        if response.status().is_success() {
 235            let mut body = String::new();
 236            response.body_mut().read_to_string(&mut body).await?;
 237            Ok(serde_json::from_str(&body)?)
 238        } else {
 239            let mut body = String::new();
 240            response.body_mut().read_to_string(&mut body).await?;
 241            anyhow::bail!(
 242                "error listing models.\nStatus: {:?}\nBody: {body}",
 243                response.status(),
 244            );
 245        }
 246    }
 247}
 248
 249impl CloudLanguageModelProvider {
 250    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 251        let mut status_rx = client.status();
 252        let status = *status_rx.borrow();
 253
 254        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 255
 256        let state_ref = state.downgrade();
 257        let maintain_client_status = cx.spawn(async move |cx| {
 258            while let Some(status) = status_rx.next().await {
 259                if let Some(this) = state_ref.upgrade() {
 260                    _ = this.update(cx, |this, cx| {
 261                        if this.status != status {
 262                            this.status = status;
 263                            cx.notify();
 264                        }
 265                    });
 266                } else {
 267                    break;
 268                }
 269            }
 270        });
 271
 272        Self {
 273            client,
 274            state,
 275            _maintain_client_status: maintain_client_status,
 276        }
 277    }
 278
 279    fn create_language_model(
 280        &self,
 281        model: Arc<cloud_llm_client::LanguageModel>,
 282        llm_api_token: LlmApiToken,
 283    ) -> Arc<dyn LanguageModel> {
 284        Arc::new(CloudLanguageModel {
 285            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 286            model,
 287            llm_api_token,
 288            client: self.client.clone(),
 289            request_limiter: RateLimiter::new(4),
 290        })
 291    }
 292}
 293
 294impl LanguageModelProviderState for CloudLanguageModelProvider {
 295    type ObservableEntity = State;
 296
 297    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 298        Some(self.state.clone())
 299    }
 300}
 301
 302impl LanguageModelProvider for CloudLanguageModelProvider {
 303    fn id(&self) -> LanguageModelProviderId {
 304        PROVIDER_ID
 305    }
 306
 307    fn name(&self) -> LanguageModelProviderName {
 308        PROVIDER_NAME
 309    }
 310
 311    fn icon(&self) -> IconOrSvg {
 312        IconOrSvg::Icon(IconName::AiZed)
 313    }
 314
 315    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 316        let default_model = self.state.read(cx).default_model.clone()?;
 317        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 318        Some(self.create_language_model(default_model, llm_api_token))
 319    }
 320
 321    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 322        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 323        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 324        Some(self.create_language_model(default_fast_model, llm_api_token))
 325    }
 326
 327    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 328        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 329        self.state
 330            .read(cx)
 331            .recommended_models
 332            .iter()
 333            .cloned()
 334            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 335            .collect()
 336    }
 337
 338    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 339        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 340        self.state
 341            .read(cx)
 342            .models
 343            .iter()
 344            .cloned()
 345            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 346            .collect()
 347    }
 348
 349    fn is_authenticated(&self, cx: &App) -> bool {
 350        let state = self.state.read(cx);
 351        !state.is_signed_out(cx)
 352    }
 353
 354    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 355        Task::ready(Ok(()))
 356    }
 357
 358    fn configuration_view(
 359        &self,
 360        _target_agent: language_model::ConfigurationViewTargetAgent,
 361        _: &mut Window,
 362        cx: &mut App,
 363    ) -> AnyView {
 364        cx.new(|_| ConfigurationView::new(self.state.clone()))
 365            .into()
 366    }
 367
 368    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 369        Task::ready(Ok(()))
 370    }
 371}
 372
 373pub struct CloudLanguageModel {
 374    id: LanguageModelId,
 375    model: Arc<cloud_llm_client::LanguageModel>,
 376    llm_api_token: LlmApiToken,
 377    client: Arc<Client>,
 378    request_limiter: RateLimiter,
 379}
 380
 381struct PerformLlmCompletionResponse {
 382    response: Response<AsyncBody>,
 383    usage: Option<ModelRequestUsage>,
 384    tool_use_limit_reached: bool,
 385    includes_status_messages: bool,
 386}
 387
 388impl CloudLanguageModel {
 389    async fn perform_llm_completion(
 390        client: Arc<Client>,
 391        llm_api_token: LlmApiToken,
 392        app_version: Option<Version>,
 393        body: CompletionBody,
 394    ) -> Result<PerformLlmCompletionResponse> {
 395        let http_client = &client.http_client();
 396
 397        let mut token = llm_api_token.acquire(&client).await?;
 398        let mut refreshed_token = false;
 399
 400        loop {
 401            let request = http_client::Request::builder()
 402                .method(Method::POST)
 403                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 404                .when_some(app_version.as_ref(), |builder, app_version| {
 405                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 406                })
 407                .header("Content-Type", "application/json")
 408                .header("Authorization", format!("Bearer {token}"))
 409                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 410                .body(serde_json::to_string(&body)?.into())?;
 411
 412            let mut response = http_client.send(request).await?;
 413            let status = response.status();
 414            if status.is_success() {
 415                let includes_status_messages = response
 416                    .headers()
 417                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 418                    .is_some();
 419
 420                let tool_use_limit_reached = response
 421                    .headers()
 422                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 423                    .is_some();
 424
 425                let usage = if includes_status_messages {
 426                    None
 427                } else {
 428                    ModelRequestUsage::from_headers(response.headers()).ok()
 429                };
 430
 431                return Ok(PerformLlmCompletionResponse {
 432                    response,
 433                    usage,
 434                    includes_status_messages,
 435                    tool_use_limit_reached,
 436                });
 437            }
 438
 439            if !refreshed_token
 440                && response
 441                    .headers()
 442                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 443                    .is_some()
 444            {
 445                token = llm_api_token.refresh(&client).await?;
 446                refreshed_token = true;
 447                continue;
 448            }
 449
 450            if status == StatusCode::FORBIDDEN
 451                && response
 452                    .headers()
 453                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 454                    .is_some()
 455            {
 456                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 457                    .headers()
 458                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 459                    .and_then(|resource| resource.to_str().ok())
 460                    && let Some(plan) = response
 461                        .headers()
 462                        .get(CURRENT_PLAN_HEADER_NAME)
 463                        .and_then(|plan| plan.to_str().ok())
 464                        .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
 465                        .map(Plan::V1)
 466                {
 467                    return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 468                }
 469            } else if status == StatusCode::PAYMENT_REQUIRED {
 470                return Err(anyhow!(PaymentRequiredError));
 471            }
 472
 473            let mut body = String::new();
 474            let headers = response.headers().clone();
 475            response.body_mut().read_to_string(&mut body).await?;
 476            return Err(anyhow!(ApiError {
 477                status,
 478                body,
 479                headers
 480            }));
 481        }
 482    }
 483}
 484
 485#[derive(Debug, Error)]
 486#[error("cloud language model request failed with status {status}: {body}")]
 487struct ApiError {
 488    status: StatusCode,
 489    body: String,
 490    headers: HeaderMap<HeaderValue>,
 491}
 492
 493/// Represents error responses from Zed's cloud API.
 494///
 495/// Example JSON for an upstream HTTP error:
 496/// ```json
 497/// {
 498///   "code": "upstream_http_error",
 499///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 500///   "upstream_status": 503
 501/// }
 502/// ```
 503#[derive(Debug, serde::Deserialize)]
 504struct CloudApiError {
 505    code: String,
 506    message: String,
 507    #[serde(default)]
 508    #[serde(deserialize_with = "deserialize_optional_status_code")]
 509    upstream_status: Option<StatusCode>,
 510    #[serde(default)]
 511    retry_after: Option<f64>,
 512}
 513
 514fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 515where
 516    D: serde::Deserializer<'de>,
 517{
 518    let opt: Option<u16> = Option::deserialize(deserializer)?;
 519    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 520}
 521
 522impl From<ApiError> for LanguageModelCompletionError {
 523    fn from(error: ApiError) -> Self {
 524        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 525            if cloud_error.code.starts_with("upstream_http_") {
 526                let status = if let Some(status) = cloud_error.upstream_status {
 527                    status
 528                } else if cloud_error.code.ends_with("_error") {
 529                    error.status
 530                } else {
 531                    // If there's a status code in the code string (e.g. "upstream_http_429")
 532                    // then use that; otherwise, see if the JSON contains a status code.
 533                    cloud_error
 534                        .code
 535                        .strip_prefix("upstream_http_")
 536                        .and_then(|code_str| code_str.parse::<u16>().ok())
 537                        .and_then(|code| StatusCode::from_u16(code).ok())
 538                        .unwrap_or(error.status)
 539                };
 540
 541                return LanguageModelCompletionError::UpstreamProviderError {
 542                    message: cloud_error.message,
 543                    status,
 544                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 545                };
 546            }
 547
 548            return LanguageModelCompletionError::from_http_status(
 549                PROVIDER_NAME,
 550                error.status,
 551                cloud_error.message,
 552                None,
 553            );
 554        }
 555
 556        let retry_after = None;
 557        LanguageModelCompletionError::from_http_status(
 558            PROVIDER_NAME,
 559            error.status,
 560            error.body,
 561            retry_after,
 562        )
 563    }
 564}
 565
 566impl LanguageModel for CloudLanguageModel {
 567    fn id(&self) -> LanguageModelId {
 568        self.id.clone()
 569    }
 570
 571    fn name(&self) -> LanguageModelName {
 572        LanguageModelName::from(self.model.display_name.clone())
 573    }
 574
 575    fn provider_id(&self) -> LanguageModelProviderId {
 576        PROVIDER_ID
 577    }
 578
 579    fn provider_name(&self) -> LanguageModelProviderName {
 580        PROVIDER_NAME
 581    }
 582
 583    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 584        use cloud_llm_client::LanguageModelProvider::*;
 585        match self.model.provider {
 586            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 587            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 588            Google => language_model::GOOGLE_PROVIDER_ID,
 589            XAi => language_model::X_AI_PROVIDER_ID,
 590        }
 591    }
 592
 593    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 594        use cloud_llm_client::LanguageModelProvider::*;
 595        match self.model.provider {
 596            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 597            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 598            Google => language_model::GOOGLE_PROVIDER_NAME,
 599            XAi => language_model::X_AI_PROVIDER_NAME,
 600        }
 601    }
 602
 603    fn supports_tools(&self) -> bool {
 604        self.model.supports_tools
 605    }
 606
 607    fn supports_images(&self) -> bool {
 608        self.model.supports_images
 609    }
 610
 611    fn supports_streaming_tools(&self) -> bool {
 612        self.model.supports_streaming_tools
 613    }
 614
 615    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 616        match choice {
 617            LanguageModelToolChoice::Auto
 618            | LanguageModelToolChoice::Any
 619            | LanguageModelToolChoice::None => true,
 620        }
 621    }
 622
 623    fn supports_burn_mode(&self) -> bool {
 624        self.model.supports_max_mode
 625    }
 626
 627    fn supports_split_token_display(&self) -> bool {
 628        use cloud_llm_client::LanguageModelProvider::*;
 629        matches!(self.model.provider, OpenAi)
 630    }
 631
 632    fn telemetry_id(&self) -> String {
 633        format!("zed.dev/{}", self.model.id)
 634    }
 635
 636    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 637        match self.model.provider {
 638            cloud_llm_client::LanguageModelProvider::Anthropic
 639            | cloud_llm_client::LanguageModelProvider::OpenAi
 640            | cloud_llm_client::LanguageModelProvider::XAi => {
 641                LanguageModelToolSchemaFormat::JsonSchema
 642            }
 643            cloud_llm_client::LanguageModelProvider::Google => {
 644                LanguageModelToolSchemaFormat::JsonSchemaSubset
 645            }
 646        }
 647    }
 648
 649    fn max_token_count(&self) -> u64 {
 650        self.model.max_token_count as u64
 651    }
 652
 653    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 654        self.model
 655            .max_token_count_in_max_mode
 656            .filter(|_| self.model.supports_max_mode)
 657            .map(|max_token_count| max_token_count as u64)
 658    }
 659
 660    fn max_output_tokens(&self) -> Option<u64> {
 661        Some(self.model.max_output_tokens as u64)
 662    }
 663
 664    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 665        match &self.model.provider {
 666            cloud_llm_client::LanguageModelProvider::Anthropic => {
 667                Some(LanguageModelCacheConfiguration {
 668                    min_total_token: 2_048,
 669                    should_speculate: true,
 670                    max_cache_anchors: 4,
 671                })
 672            }
 673            cloud_llm_client::LanguageModelProvider::OpenAi
 674            | cloud_llm_client::LanguageModelProvider::XAi
 675            | cloud_llm_client::LanguageModelProvider::Google => None,
 676        }
 677    }
 678
 679    fn count_tokens(
 680        &self,
 681        request: LanguageModelRequest,
 682        cx: &App,
 683    ) -> BoxFuture<'static, Result<u64>> {
 684        match self.model.provider {
 685            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 686                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 687                .boxed(),
 688            cloud_llm_client::LanguageModelProvider::OpenAi => {
 689                let model = match open_ai::Model::from_id(&self.model.id.0) {
 690                    Ok(model) => model,
 691                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 692                };
 693                count_open_ai_tokens(request, model, cx)
 694            }
 695            cloud_llm_client::LanguageModelProvider::XAi => {
 696                let model = match x_ai::Model::from_id(&self.model.id.0) {
 697                    Ok(model) => model,
 698                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 699                };
 700                count_xai_tokens(request, model, cx)
 701            }
 702            cloud_llm_client::LanguageModelProvider::Google => {
 703                let client = self.client.clone();
 704                let llm_api_token = self.llm_api_token.clone();
 705                let model_id = self.model.id.to_string();
 706                let generate_content_request =
 707                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 708                async move {
 709                    let http_client = &client.http_client();
 710                    let token = llm_api_token.acquire(&client).await?;
 711
 712                    let request_body = CountTokensBody {
 713                        provider: cloud_llm_client::LanguageModelProvider::Google,
 714                        model: model_id,
 715                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 716                            generate_content_request,
 717                        })?,
 718                    };
 719                    let request = http_client::Request::builder()
 720                        .method(Method::POST)
 721                        .uri(
 722                            http_client
 723                                .build_zed_llm_url("/count_tokens", &[])?
 724                                .as_ref(),
 725                        )
 726                        .header("Content-Type", "application/json")
 727                        .header("Authorization", format!("Bearer {token}"))
 728                        .body(serde_json::to_string(&request_body)?.into())?;
 729                    let mut response = http_client.send(request).await?;
 730                    let status = response.status();
 731                    let headers = response.headers().clone();
 732                    let mut response_body = String::new();
 733                    response
 734                        .body_mut()
 735                        .read_to_string(&mut response_body)
 736                        .await?;
 737
 738                    if status.is_success() {
 739                        let response_body: CountTokensResponse =
 740                            serde_json::from_str(&response_body)?;
 741
 742                        Ok(response_body.tokens as u64)
 743                    } else {
 744                        Err(anyhow!(ApiError {
 745                            status,
 746                            body: response_body,
 747                            headers
 748                        }))
 749                    }
 750                }
 751                .boxed()
 752            }
 753        }
 754    }
 755
 756    fn stream_completion(
 757        &self,
 758        request: LanguageModelRequest,
 759        cx: &AsyncApp,
 760    ) -> BoxFuture<
 761        'static,
 762        Result<
 763            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 764            LanguageModelCompletionError,
 765        >,
 766    > {
 767        let thread_id = request.thread_id.clone();
 768        let prompt_id = request.prompt_id.clone();
 769        let intent = request.intent;
 770        let mode = request.mode;
 771        let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
 772        let use_responses_api = cx.update(|cx| cx.has_flag::<OpenAiResponsesApiFeatureFlag>());
 773        let thinking_allowed = request.thinking_allowed;
 774        let provider_name = provider_name(&self.model.provider);
 775        match self.model.provider {
 776            cloud_llm_client::LanguageModelProvider::Anthropic => {
 777                let request = into_anthropic(
 778                    request,
 779                    self.model.id.to_string(),
 780                    1.0,
 781                    self.model.max_output_tokens as u64,
 782                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 783                        AnthropicModelMode::Thinking {
 784                            budget_tokens: Some(4_096),
 785                        }
 786                    } else {
 787                        AnthropicModelMode::Default
 788                    },
 789                );
 790                let client = self.client.clone();
 791                let llm_api_token = self.llm_api_token.clone();
 792                let future = self.request_limiter.stream(async move {
 793                    let PerformLlmCompletionResponse {
 794                        response,
 795                        usage,
 796                        includes_status_messages,
 797                        tool_use_limit_reached,
 798                    } = Self::perform_llm_completion(
 799                        client.clone(),
 800                        llm_api_token,
 801                        app_version,
 802                        CompletionBody {
 803                            thread_id,
 804                            prompt_id,
 805                            intent,
 806                            mode,
 807                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 808                            model: request.model.clone(),
 809                            provider_request: serde_json::to_value(&request)
 810                                .map_err(|e| anyhow!(e))?,
 811                        },
 812                    )
 813                    .await
 814                    .map_err(|err| match err.downcast::<ApiError>() {
 815                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 816                        Err(err) => anyhow!(err),
 817                    })?;
 818
 819                    let mut mapper = AnthropicEventMapper::new();
 820                    Ok(map_cloud_completion_events(
 821                        Box::pin(
 822                            response_lines(response, includes_status_messages)
 823                                .chain(usage_updated_event(usage))
 824                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 825                        ),
 826                        &provider_name,
 827                        move |event| mapper.map_event(event),
 828                    ))
 829                });
 830                async move { Ok(future.await?.boxed()) }.boxed()
 831            }
 832            cloud_llm_client::LanguageModelProvider::OpenAi => {
 833                let client = self.client.clone();
 834                let llm_api_token = self.llm_api_token.clone();
 835
 836                if use_responses_api {
 837                    let request = into_open_ai_response(
 838                        request,
 839                        &self.model.id.0,
 840                        self.model.supports_parallel_tool_calls,
 841                        true,
 842                        None,
 843                        None,
 844                    );
 845                    let future = self.request_limiter.stream(async move {
 846                        let PerformLlmCompletionResponse {
 847                            response,
 848                            usage,
 849                            includes_status_messages,
 850                            tool_use_limit_reached,
 851                        } = Self::perform_llm_completion(
 852                            client.clone(),
 853                            llm_api_token,
 854                            app_version,
 855                            CompletionBody {
 856                                thread_id,
 857                                prompt_id,
 858                                intent,
 859                                mode,
 860                                provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 861                                model: request.model.clone(),
 862                                provider_request: serde_json::to_value(&request)
 863                                    .map_err(|e| anyhow!(e))?,
 864                            },
 865                        )
 866                        .await?;
 867
 868                        let mut mapper = OpenAiResponseEventMapper::new();
 869                        Ok(map_cloud_completion_events(
 870                            Box::pin(
 871                                response_lines(response, includes_status_messages)
 872                                    .chain(usage_updated_event(usage))
 873                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 874                            ),
 875                            &provider_name,
 876                            move |event| mapper.map_event(event),
 877                        ))
 878                    });
 879                    async move { Ok(future.await?.boxed()) }.boxed()
 880                } else {
 881                    let request = into_open_ai(
 882                        request,
 883                        &self.model.id.0,
 884                        self.model.supports_parallel_tool_calls,
 885                        true,
 886                        None,
 887                        None,
 888                    );
 889                    let future = self.request_limiter.stream(async move {
 890                        let PerformLlmCompletionResponse {
 891                            response,
 892                            usage,
 893                            includes_status_messages,
 894                            tool_use_limit_reached,
 895                        } = Self::perform_llm_completion(
 896                            client.clone(),
 897                            llm_api_token,
 898                            app_version,
 899                            CompletionBody {
 900                                thread_id,
 901                                prompt_id,
 902                                intent,
 903                                mode,
 904                                provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 905                                model: request.model.clone(),
 906                                provider_request: serde_json::to_value(&request)
 907                                    .map_err(|e| anyhow!(e))?,
 908                            },
 909                        )
 910                        .await?;
 911
 912                        let mut mapper = OpenAiEventMapper::new();
 913                        Ok(map_cloud_completion_events(
 914                            Box::pin(
 915                                response_lines(response, includes_status_messages)
 916                                    .chain(usage_updated_event(usage))
 917                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 918                            ),
 919                            &provider_name,
 920                            move |event| mapper.map_event(event),
 921                        ))
 922                    });
 923                    async move { Ok(future.await?.boxed()) }.boxed()
 924                }
 925            }
 926            cloud_llm_client::LanguageModelProvider::XAi => {
 927                let client = self.client.clone();
 928                let request = into_open_ai(
 929                    request,
 930                    &self.model.id.0,
 931                    self.model.supports_parallel_tool_calls,
 932                    false,
 933                    None,
 934                    None,
 935                );
 936                let llm_api_token = self.llm_api_token.clone();
 937                let future = self.request_limiter.stream(async move {
 938                    let PerformLlmCompletionResponse {
 939                        response,
 940                        usage,
 941                        includes_status_messages,
 942                        tool_use_limit_reached,
 943                    } = Self::perform_llm_completion(
 944                        client.clone(),
 945                        llm_api_token,
 946                        app_version,
 947                        CompletionBody {
 948                            thread_id,
 949                            prompt_id,
 950                            intent,
 951                            mode,
 952                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 953                            model: request.model.clone(),
 954                            provider_request: serde_json::to_value(&request)
 955                                .map_err(|e| anyhow!(e))?,
 956                        },
 957                    )
 958                    .await?;
 959
 960                    let mut mapper = OpenAiEventMapper::new();
 961                    Ok(map_cloud_completion_events(
 962                        Box::pin(
 963                            response_lines(response, includes_status_messages)
 964                                .chain(usage_updated_event(usage))
 965                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 966                        ),
 967                        &provider_name,
 968                        move |event| mapper.map_event(event),
 969                    ))
 970                });
 971                async move { Ok(future.await?.boxed()) }.boxed()
 972            }
 973            cloud_llm_client::LanguageModelProvider::Google => {
 974                let client = self.client.clone();
 975                let request =
 976                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 977                let llm_api_token = self.llm_api_token.clone();
 978                let future = self.request_limiter.stream(async move {
 979                    let PerformLlmCompletionResponse {
 980                        response,
 981                        usage,
 982                        includes_status_messages,
 983                        tool_use_limit_reached,
 984                    } = Self::perform_llm_completion(
 985                        client.clone(),
 986                        llm_api_token,
 987                        app_version,
 988                        CompletionBody {
 989                            thread_id,
 990                            prompt_id,
 991                            intent,
 992                            mode,
 993                            provider: cloud_llm_client::LanguageModelProvider::Google,
 994                            model: request.model.model_id.clone(),
 995                            provider_request: serde_json::to_value(&request)
 996                                .map_err(|e| anyhow!(e))?,
 997                        },
 998                    )
 999                    .await?;
1000
1001                    let mut mapper = GoogleEventMapper::new();
1002                    Ok(map_cloud_completion_events(
1003                        Box::pin(
1004                            response_lines(response, includes_status_messages)
1005                                .chain(usage_updated_event(usage))
1006                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
1007                        ),
1008                        &provider_name,
1009                        move |event| mapper.map_event(event),
1010                    ))
1011                });
1012                async move { Ok(future.await?.boxed()) }.boxed()
1013            }
1014        }
1015    }
1016}
1017
1018fn map_cloud_completion_events<T, F>(
1019    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
1020    provider: &LanguageModelProviderName,
1021    mut map_callback: F,
1022) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1023where
1024    T: DeserializeOwned + 'static,
1025    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1026        + Send
1027        + 'static,
1028{
1029    let provider = provider.clone();
1030    stream
1031        .flat_map(move |event| {
1032            futures::stream::iter(match event {
1033                Err(error) => {
1034                    vec![Err(LanguageModelCompletionError::from(error))]
1035                }
1036                Ok(CompletionEvent::Status(event)) => {
1037                    vec![
1038                        LanguageModelCompletionEvent::from_completion_request_status(
1039                            event,
1040                            provider.clone(),
1041                        ),
1042                    ]
1043                }
1044                Ok(CompletionEvent::Event(event)) => map_callback(event),
1045            })
1046        })
1047        .boxed()
1048}
1049
1050fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1051    match provider {
1052        cloud_llm_client::LanguageModelProvider::Anthropic => {
1053            language_model::ANTHROPIC_PROVIDER_NAME
1054        }
1055        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1056        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1057        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1058    }
1059}
1060
1061fn usage_updated_event<T>(
1062    usage: Option<ModelRequestUsage>,
1063) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1064    futures::stream::iter(usage.map(|usage| {
1065        Ok(CompletionEvent::Status(
1066            CompletionRequestStatus::UsageUpdated {
1067                amount: usage.amount as usize,
1068                limit: usage.limit,
1069            },
1070        ))
1071    }))
1072}
1073
1074fn tool_use_limit_reached_event<T>(
1075    tool_use_limit_reached: bool,
1076) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1077    futures::stream::iter(tool_use_limit_reached.then(|| {
1078        Ok(CompletionEvent::Status(
1079            CompletionRequestStatus::ToolUseLimitReached,
1080        ))
1081    }))
1082}
1083
1084fn response_lines<T: DeserializeOwned>(
1085    response: Response<AsyncBody>,
1086    includes_status_messages: bool,
1087) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1088    futures::stream::try_unfold(
1089        (String::new(), BufReader::new(response.into_body())),
1090        move |(mut line, mut body)| async move {
1091            match body.read_line(&mut line).await {
1092                Ok(0) => Ok(None),
1093                Ok(_) => {
1094                    let event = if includes_status_messages {
1095                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1096                    } else {
1097                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1098                    };
1099
1100                    line.clear();
1101                    Ok(Some((event, (line, body))))
1102                }
1103                Err(e) => Err(e.into()),
1104            }
1105        },
1106    )
1107}
1108
1109#[derive(IntoElement, RegisterComponent)]
1110struct ZedAiConfiguration {
1111    is_connected: bool,
1112    plan: Option<Plan>,
1113    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1114    eligible_for_trial: bool,
1115    account_too_young: bool,
1116    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1117}
1118
1119impl RenderOnce for ZedAiConfiguration {
1120    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1121        let is_pro = self.plan.is_some_and(|plan| {
1122            matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1123        });
1124        let subscription_text = match (self.plan, self.subscription_period) {
1125            (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1126                "You have access to Zed's hosted models through your Pro subscription."
1127            }
1128            (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1129                "You have access to Zed's hosted models through your Pro trial."
1130            }
1131            (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1132                "You have basic access to Zed's hosted models through the Free plan."
1133            }
1134            (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1135                if self.eligible_for_trial {
1136                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1137                } else {
1138                    "Subscribe for access to Zed's hosted models."
1139                }
1140            }
1141            _ => {
1142                if self.eligible_for_trial {
1143                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1144                } else {
1145                    "Subscribe for access to Zed's hosted models."
1146                }
1147            }
1148        };
1149
1150        let manage_subscription_buttons = if is_pro {
1151            Button::new("manage_settings", "Manage Subscription")
1152                .full_width()
1153                .style(ButtonStyle::Tinted(TintColor::Accent))
1154                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1155                .into_any_element()
1156        } else if self.plan.is_none() || self.eligible_for_trial {
1157            Button::new("start_trial", "Start 14-day Free Pro Trial")
1158                .full_width()
1159                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1160                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1161                .into_any_element()
1162        } else {
1163            Button::new("upgrade", "Upgrade to Pro")
1164                .full_width()
1165                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1166                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1167                .into_any_element()
1168        };
1169
1170        if !self.is_connected {
1171            return v_flex()
1172                .gap_2()
1173                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1174                .child(
1175                    Button::new("sign_in", "Sign In to use Zed AI")
1176                        .icon_color(Color::Muted)
1177                        .icon(IconName::Github)
1178                        .icon_size(IconSize::Small)
1179                        .icon_position(IconPosition::Start)
1180                        .full_width()
1181                        .on_click({
1182                            let callback = self.sign_in_callback.clone();
1183                            move |_, window, cx| (callback)(window, cx)
1184                        }),
1185                );
1186        }
1187
1188        v_flex().gap_2().w_full().map(|this| {
1189            if self.account_too_young {
1190                this.child(YoungAccountBanner).child(
1191                    Button::new("upgrade", "Upgrade to Pro")
1192                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1193                        .full_width()
1194                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1195                )
1196            } else {
1197                this.text_sm()
1198                    .child(subscription_text)
1199                    .child(manage_subscription_buttons)
1200            }
1201        })
1202    }
1203}
1204
1205struct ConfigurationView {
1206    state: Entity<State>,
1207    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1208}
1209
1210impl ConfigurationView {
1211    fn new(state: Entity<State>) -> Self {
1212        let sign_in_callback = Arc::new({
1213            let state = state.clone();
1214            move |_window: &mut Window, cx: &mut App| {
1215                state.update(cx, |state, cx| {
1216                    state.authenticate(cx).detach_and_log_err(cx);
1217                });
1218            }
1219        });
1220
1221        Self {
1222            state,
1223            sign_in_callback,
1224        }
1225    }
1226}
1227
1228impl Render for ConfigurationView {
1229    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1230        let state = self.state.read(cx);
1231        let user_store = state.user_store.read(cx);
1232
1233        ZedAiConfiguration {
1234            is_connected: !state.is_signed_out(cx),
1235            plan: user_store.plan(),
1236            subscription_period: user_store.subscription_period(),
1237            eligible_for_trial: user_store.trial_started_at().is_none(),
1238            account_too_young: user_store.account_too_young(),
1239            sign_in_callback: self.sign_in_callback.clone(),
1240        }
1241    }
1242}
1243
1244impl Component for ZedAiConfiguration {
1245    fn name() -> &'static str {
1246        "AI Configuration Content"
1247    }
1248
1249    fn sort_name() -> &'static str {
1250        "AI Configuration Content"
1251    }
1252
1253    fn scope() -> ComponentScope {
1254        ComponentScope::Onboarding
1255    }
1256
1257    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1258        fn configuration(
1259            is_connected: bool,
1260            plan: Option<Plan>,
1261            eligible_for_trial: bool,
1262            account_too_young: bool,
1263        ) -> AnyElement {
1264            ZedAiConfiguration {
1265                is_connected,
1266                plan,
1267                subscription_period: plan
1268                    .is_some()
1269                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1270                eligible_for_trial,
1271                account_too_young,
1272                sign_in_callback: Arc::new(|_, _| {}),
1273            }
1274            .into_any_element()
1275        }
1276
1277        Some(
1278            v_flex()
1279                .p_4()
1280                .gap_4()
1281                .children(vec![
1282                    single_example("Not connected", configuration(false, None, false, false)),
1283                    single_example(
1284                        "Accept Terms of Service",
1285                        configuration(true, None, true, false),
1286                    ),
1287                    single_example(
1288                        "No Plan - Not eligible for trial",
1289                        configuration(true, None, false, false),
1290                    ),
1291                    single_example(
1292                        "No Plan - Eligible for trial",
1293                        configuration(true, None, true, false),
1294                    ),
1295                    single_example(
1296                        "Free Plan",
1297                        configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1298                    ),
1299                    single_example(
1300                        "Zed Pro Trial Plan",
1301                        configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1302                    ),
1303                    single_example(
1304                        "Zed Pro Plan",
1305                        configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1306                    ),
1307                ])
1308                .into_any_element(),
1309        )
1310    }
1311}
1312
1313#[cfg(test)]
1314mod tests {
1315    use super::*;
1316    use http_client::http::{HeaderMap, StatusCode};
1317    use language_model::LanguageModelCompletionError;
1318
1319    #[test]
1320    fn test_api_error_conversion_with_upstream_http_error() {
1321        // upstream_http_error with 503 status should become ServerOverloaded
1322        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1323
1324        let api_error = ApiError {
1325            status: StatusCode::INTERNAL_SERVER_ERROR,
1326            body: error_body.to_string(),
1327            headers: HeaderMap::new(),
1328        };
1329
1330        let completion_error: LanguageModelCompletionError = api_error.into();
1331
1332        match completion_error {
1333            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1334                assert_eq!(
1335                    message,
1336                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1337                );
1338            }
1339            _ => panic!(
1340                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1341                completion_error
1342            ),
1343        }
1344
1345        // upstream_http_error with 500 status should become ApiInternalServerError
1346        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1347
1348        let api_error = ApiError {
1349            status: StatusCode::INTERNAL_SERVER_ERROR,
1350            body: error_body.to_string(),
1351            headers: HeaderMap::new(),
1352        };
1353
1354        let completion_error: LanguageModelCompletionError = api_error.into();
1355
1356        match completion_error {
1357            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1358                assert_eq!(
1359                    message,
1360                    "Received an error from the OpenAI API: internal server error"
1361                );
1362            }
1363            _ => panic!(
1364                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1365                completion_error
1366            ),
1367        }
1368
1369        // upstream_http_error with 429 status should become RateLimitExceeded
1370        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1371
1372        let api_error = ApiError {
1373            status: StatusCode::INTERNAL_SERVER_ERROR,
1374            body: error_body.to_string(),
1375            headers: HeaderMap::new(),
1376        };
1377
1378        let completion_error: LanguageModelCompletionError = api_error.into();
1379
1380        match completion_error {
1381            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1382                assert_eq!(
1383                    message,
1384                    "Received an error from the Google API: rate limit exceeded"
1385                );
1386            }
1387            _ => panic!(
1388                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1389                completion_error
1390            ),
1391        }
1392
1393        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1394        let error_body = "Regular internal server error";
1395
1396        let api_error = ApiError {
1397            status: StatusCode::INTERNAL_SERVER_ERROR,
1398            body: error_body.to_string(),
1399            headers: HeaderMap::new(),
1400        };
1401
1402        let completion_error: LanguageModelCompletionError = api_error.into();
1403
1404        match completion_error {
1405            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1406                assert_eq!(provider, PROVIDER_NAME);
1407                assert_eq!(message, "Regular internal server error");
1408            }
1409            _ => panic!(
1410                "Expected ApiInternalServerError for regular 500, got: {:?}",
1411                completion_error
1412            ),
1413        }
1414
1415        // upstream_http_429 format should be converted to UpstreamProviderError
1416        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1417
1418        let api_error = ApiError {
1419            status: StatusCode::INTERNAL_SERVER_ERROR,
1420            body: error_body.to_string(),
1421            headers: HeaderMap::new(),
1422        };
1423
1424        let completion_error: LanguageModelCompletionError = api_error.into();
1425
1426        match completion_error {
1427            LanguageModelCompletionError::UpstreamProviderError {
1428                message,
1429                status,
1430                retry_after,
1431            } => {
1432                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1433                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1434                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1435            }
1436            _ => panic!(
1437                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1438                completion_error
1439            ),
1440        }
1441
1442        // Invalid JSON in error body should fall back to regular error handling
1443        let error_body = "Not JSON at all";
1444
1445        let api_error = ApiError {
1446            status: StatusCode::INTERNAL_SERVER_ERROR,
1447            body: error_body.to_string(),
1448            headers: HeaderMap::new(),
1449        };
1450
1451        let completion_error: LanguageModelCompletionError = api_error.into();
1452
1453        match completion_error {
1454            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1455                assert_eq!(provider, PROVIDER_NAME);
1456            }
1457            _ => panic!(
1458                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1459                completion_error
1460            ),
1461        }
1462    }
1463}