cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::{
   3    AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent, ToolResultPart,
   4    Usage,
   5};
   6use anyhow::{Context as _, Result, anyhow};
   7use chrono::{DateTime, Utc};
   8use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   9use cloud_llm_client::{
  10    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME,
  11    CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  12    CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
  13    MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, PlanV1, PlanV2,
  14    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  15    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  16};
  17use futures::{
  18    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
  19};
  20use google_ai::GoogleModelMode;
  21use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  22use http_client::http::{HeaderMap, HeaderValue};
  23use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
  24use language_model::{
  25    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  26    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  27    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  28    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  29    LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
  30    LanguageModelToolUseId, LlmApiToken, MessageContent, ModelRequestLimitReachedError,
  31    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, Role, StopReason,
  32};
  33use release_channel::AppVersion;
  34use schemars::JsonSchema;
  35use semver::Version;
  36use serde::{Deserialize, Serialize, de::DeserializeOwned};
  37use settings::SettingsStore;
  38pub use settings::ZedDotDevAvailableModel as AvailableModel;
  39pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  40use smol::io::{AsyncReadExt, BufReader};
  41use std::pin::Pin;
  42use std::str::FromStr as _;
  43use std::sync::Arc;
  44use std::time::Duration;
  45use thiserror::Error;
  46use ui::{TintColor, prelude::*};
  47use util::{ResultExt as _, maybe};
  48
  49use crate::provider::google::{GoogleEventMapper, into_google};
  50use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  51use crate::provider::x_ai::count_xai_tokens;
  52
  53const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  54const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  55
  56#[derive(Default, Clone, Debug, PartialEq)]
  57pub struct ZedDotDevSettings {
  58    pub available_models: Vec<AvailableModel>,
  59}
  60#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  61#[serde(tag = "type", rename_all = "lowercase")]
  62pub enum ModelMode {
  63    #[default]
  64    Default,
  65    Thinking {
  66        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  67        budget_tokens: Option<u32>,
  68    },
  69}
  70
  71impl From<ModelMode> for AnthropicModelMode {
  72    fn from(value: ModelMode) -> Self {
  73        match value {
  74            ModelMode::Default => AnthropicModelMode::Default,
  75            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  76        }
  77    }
  78}
  79
  80pub struct CloudLanguageModelProvider {
  81    client: Arc<Client>,
  82    state: Entity<State>,
  83    _maintain_client_status: Task<()>,
  84}
  85
  86pub struct State {
  87    client: Arc<Client>,
  88    llm_api_token: LlmApiToken,
  89    user_store: Entity<UserStore>,
  90    status: client::Status,
  91    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  92    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  93    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  94    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  95    _fetch_models_task: Task<()>,
  96    _settings_subscription: Subscription,
  97    _llm_token_subscription: Subscription,
  98}
  99
 100impl State {
 101    fn new(
 102        client: Arc<Client>,
 103        user_store: Entity<UserStore>,
 104        status: client::Status,
 105        cx: &mut Context<Self>,
 106    ) -> Self {
 107        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 108        let mut current_user = user_store.read(cx).watch_current_user();
 109        Self {
 110            client: client.clone(),
 111            llm_api_token: LlmApiToken::default(),
 112            user_store,
 113            status,
 114            models: Vec::new(),
 115            default_model: None,
 116            default_fast_model: None,
 117            recommended_models: Vec::new(),
 118            _fetch_models_task: cx.spawn(async move |this, cx| {
 119                maybe!(async move {
 120                    let (client, llm_api_token) = this
 121                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 122
 123                    while current_user.borrow().is_none() {
 124                        current_user.next().await;
 125                    }
 126
 127                    let response =
 128                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 129                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 130                    anyhow::Ok(())
 131                })
 132                .await
 133                .context("failed to fetch Zed models")
 134                .log_err();
 135            }),
 136            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 137                cx.notify();
 138            }),
 139            _llm_token_subscription: cx.subscribe(
 140                &refresh_llm_token_listener,
 141                move |this, _listener, _event, cx| {
 142                    let client = this.client.clone();
 143                    let llm_api_token = this.llm_api_token.clone();
 144                    cx.spawn(async move |this, cx| {
 145                        llm_api_token.refresh(&client).await?;
 146                        let response = Self::fetch_models(client, llm_api_token).await?;
 147                        this.update(cx, |this, cx| {
 148                            this.update_models(response, cx);
 149                        })
 150                    })
 151                    .detach_and_log_err(cx);
 152                },
 153            ),
 154        }
 155    }
 156
 157    fn is_signed_out(&self, cx: &App) -> bool {
 158        self.user_store.read(cx).current_user().is_none()
 159    }
 160
 161    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 162        let client = self.client.clone();
 163        cx.spawn(async move |state, cx| {
 164            client.sign_in_with_optional_connect(true, cx).await?;
 165            state.update(cx, |_, cx| cx.notify())
 166        })
 167    }
 168    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 169        let mut models = Vec::new();
 170
 171        for model in response.models {
 172            models.push(Arc::new(model.clone()));
 173
 174            // Right now we represent thinking variants of models as separate models on the client,
 175            // so we need to insert variants for any model that supports thinking.
 176            if model.supports_thinking {
 177                models.push(Arc::new(cloud_llm_client::LanguageModel {
 178                    id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 179                    display_name: format!("{} Thinking", model.display_name),
 180                    ..model
 181                }));
 182            }
 183        }
 184
 185        self.default_model = models
 186            .iter()
 187            .find(|model| {
 188                response
 189                    .default_model
 190                    .as_ref()
 191                    .is_some_and(|default_model_id| &model.id == default_model_id)
 192            })
 193            .cloned();
 194        self.default_fast_model = models
 195            .iter()
 196            .find(|model| {
 197                response
 198                    .default_fast_model
 199                    .as_ref()
 200                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 201            })
 202            .cloned();
 203        self.recommended_models = response
 204            .recommended_models
 205            .iter()
 206            .filter_map(|id| models.iter().find(|model| &model.id == id))
 207            .cloned()
 208            .collect();
 209        self.models = models;
 210        cx.notify();
 211    }
 212
 213    async fn fetch_models(
 214        client: Arc<Client>,
 215        llm_api_token: LlmApiToken,
 216    ) -> Result<ListModelsResponse> {
 217        let http_client = &client.http_client();
 218        let token = llm_api_token.acquire(&client).await?;
 219
 220        let request = http_client::Request::builder()
 221            .method(Method::GET)
 222            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 223            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 224            .header("Authorization", format!("Bearer {token}"))
 225            .body(AsyncBody::empty())?;
 226        let mut response = http_client
 227            .send(request)
 228            .await
 229            .context("failed to send list models request")?;
 230
 231        if response.status().is_success() {
 232            let mut body = String::new();
 233            response.body_mut().read_to_string(&mut body).await?;
 234            Ok(serde_json::from_str(&body)?)
 235        } else {
 236            let mut body = String::new();
 237            response.body_mut().read_to_string(&mut body).await?;
 238            anyhow::bail!(
 239                "error listing models.\nStatus: {:?}\nBody: {body}",
 240                response.status(),
 241            );
 242        }
 243    }
 244}
 245
 246impl CloudLanguageModelProvider {
 247    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 248        let mut status_rx = client.status();
 249        let status = *status_rx.borrow();
 250
 251        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 252
 253        let state_ref = state.downgrade();
 254        let maintain_client_status = cx.spawn(async move |cx| {
 255            while let Some(status) = status_rx.next().await {
 256                if let Some(this) = state_ref.upgrade() {
 257                    _ = this.update(cx, |this, cx| {
 258                        if this.status != status {
 259                            this.status = status;
 260                            cx.notify();
 261                        }
 262                    });
 263                } else {
 264                    break;
 265                }
 266            }
 267        });
 268
 269        Self {
 270            client,
 271            state,
 272            _maintain_client_status: maintain_client_status,
 273        }
 274    }
 275
 276    fn create_language_model(
 277        &self,
 278        model: Arc<cloud_llm_client::LanguageModel>,
 279        llm_api_token: LlmApiToken,
 280    ) -> Arc<dyn LanguageModel> {
 281        Arc::new(CloudLanguageModel {
 282            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 283            model,
 284            llm_api_token,
 285            client: self.client.clone(),
 286            request_limiter: RateLimiter::new(4),
 287        })
 288    }
 289}
 290
 291impl LanguageModelProviderState for CloudLanguageModelProvider {
 292    type ObservableEntity = State;
 293
 294    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 295        Some(self.state.clone())
 296    }
 297}
 298
 299impl LanguageModelProvider for CloudLanguageModelProvider {
 300    fn id(&self) -> LanguageModelProviderId {
 301        PROVIDER_ID
 302    }
 303
 304    fn name(&self) -> LanguageModelProviderName {
 305        PROVIDER_NAME
 306    }
 307
 308    fn icon(&self) -> IconName {
 309        IconName::AiZed
 310    }
 311
 312    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 313        let default_model = self.state.read(cx).default_model.clone()?;
 314        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 315        Some(self.create_language_model(default_model, llm_api_token))
 316    }
 317
 318    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 319        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 320        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 321        Some(self.create_language_model(default_fast_model, llm_api_token))
 322    }
 323
 324    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 325        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 326        self.state
 327            .read(cx)
 328            .recommended_models
 329            .iter()
 330            .cloned()
 331            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 332            .collect()
 333    }
 334
 335    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 336        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 337        self.state
 338            .read(cx)
 339            .models
 340            .iter()
 341            .cloned()
 342            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 343            .collect()
 344    }
 345
 346    fn is_authenticated(&self, cx: &App) -> bool {
 347        let state = self.state.read(cx);
 348        !state.is_signed_out(cx)
 349    }
 350
 351    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 352        Task::ready(Ok(()))
 353    }
 354
 355    fn configuration_view(
 356        &self,
 357        _target_agent: language_model::ConfigurationViewTargetAgent,
 358        _: &mut Window,
 359        cx: &mut App,
 360    ) -> AnyView {
 361        cx.new(|_| ConfigurationView::new(self.state.clone()))
 362            .into()
 363    }
 364
 365    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 366        Task::ready(Ok(()))
 367    }
 368}
 369
 370pub struct CloudLanguageModel {
 371    id: LanguageModelId,
 372    model: Arc<cloud_llm_client::LanguageModel>,
 373    llm_api_token: LlmApiToken,
 374    client: Arc<Client>,
 375    request_limiter: RateLimiter,
 376}
 377
 378struct PerformLlmCompletionResponse {
 379    response: Response<AsyncBody>,
 380    usage: Option<ModelRequestUsage>,
 381    tool_use_limit_reached: bool,
 382    includes_status_messages: bool,
 383}
 384
 385impl CloudLanguageModel {
 386    async fn perform_llm_completion(
 387        client: Arc<Client>,
 388        llm_api_token: LlmApiToken,
 389        app_version: Option<Version>,
 390        body: CompletionBody,
 391    ) -> Result<PerformLlmCompletionResponse> {
 392        let http_client = &client.http_client();
 393
 394        let mut token = llm_api_token.acquire(&client).await?;
 395        let mut refreshed_token = false;
 396
 397        loop {
 398            let request = http_client::Request::builder()
 399                .method(Method::POST)
 400                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 401                .when_some(app_version.as_ref(), |builder, app_version| {
 402                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 403                })
 404                .header("Content-Type", "application/json")
 405                .header("Authorization", format!("Bearer {token}"))
 406                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 407                .body(serde_json::to_string(&body)?.into())?;
 408
 409            let mut response = http_client.send(request).await?;
 410            let status = response.status();
 411            if status.is_success() {
 412                let includes_status_messages = response
 413                    .headers()
 414                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 415                    .is_some();
 416
 417                let tool_use_limit_reached = response
 418                    .headers()
 419                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 420                    .is_some();
 421
 422                let usage = if includes_status_messages {
 423                    None
 424                } else {
 425                    ModelRequestUsage::from_headers(response.headers()).ok()
 426                };
 427
 428                return Ok(PerformLlmCompletionResponse {
 429                    response,
 430                    usage,
 431                    includes_status_messages,
 432                    tool_use_limit_reached,
 433                });
 434            }
 435
 436            if !refreshed_token
 437                && response
 438                    .headers()
 439                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 440                    .is_some()
 441            {
 442                token = llm_api_token.refresh(&client).await?;
 443                refreshed_token = true;
 444                continue;
 445            }
 446
 447            if status == StatusCode::FORBIDDEN
 448                && response
 449                    .headers()
 450                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 451                    .is_some()
 452            {
 453                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 454                    .headers()
 455                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 456                    .and_then(|resource| resource.to_str().ok())
 457                    && let Some(plan) = response
 458                        .headers()
 459                        .get(CURRENT_PLAN_HEADER_NAME)
 460                        .and_then(|plan| plan.to_str().ok())
 461                        .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
 462                        .map(Plan::V1)
 463                {
 464                    return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 465                }
 466            } else if status == StatusCode::PAYMENT_REQUIRED {
 467                return Err(anyhow!(PaymentRequiredError));
 468            }
 469
 470            let mut body = String::new();
 471            let headers = response.headers().clone();
 472            response.body_mut().read_to_string(&mut body).await?;
 473            return Err(anyhow!(ApiError {
 474                status,
 475                body,
 476                headers
 477            }));
 478        }
 479    }
 480}
 481
 482#[derive(Debug, Error)]
 483#[error("cloud language model request failed with status {status}: {body}")]
 484struct ApiError {
 485    status: StatusCode,
 486    body: String,
 487    headers: HeaderMap<HeaderValue>,
 488}
 489
 490/// Represents error responses from Zed's cloud API.
 491///
 492/// Example JSON for an upstream HTTP error:
 493/// ```json
 494/// {
 495///   "code": "upstream_http_error",
 496///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 497///   "upstream_status": 503
 498/// }
 499/// ```
 500#[derive(Debug, serde::Deserialize)]
 501struct CloudApiError {
 502    code: String,
 503    message: String,
 504    #[serde(default)]
 505    #[serde(deserialize_with = "deserialize_optional_status_code")]
 506    upstream_status: Option<StatusCode>,
 507    #[serde(default)]
 508    retry_after: Option<f64>,
 509}
 510
 511fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 512where
 513    D: serde::Deserializer<'de>,
 514{
 515    let opt: Option<u16> = Option::deserialize(deserializer)?;
 516    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 517}
 518
 519impl From<ApiError> for LanguageModelCompletionError {
 520    fn from(error: ApiError) -> Self {
 521        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 522            if cloud_error.code.starts_with("upstream_http_") {
 523                let status = if let Some(status) = cloud_error.upstream_status {
 524                    status
 525                } else if cloud_error.code.ends_with("_error") {
 526                    error.status
 527                } else {
 528                    // If there's a status code in the code string (e.g. "upstream_http_429")
 529                    // then use that; otherwise, see if the JSON contains a status code.
 530                    cloud_error
 531                        .code
 532                        .strip_prefix("upstream_http_")
 533                        .and_then(|code_str| code_str.parse::<u16>().ok())
 534                        .and_then(|code| StatusCode::from_u16(code).ok())
 535                        .unwrap_or(error.status)
 536                };
 537
 538                return LanguageModelCompletionError::UpstreamProviderError {
 539                    message: cloud_error.message,
 540                    status,
 541                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 542                };
 543            }
 544
 545            return LanguageModelCompletionError::from_http_status(
 546                PROVIDER_NAME,
 547                error.status,
 548                cloud_error.message,
 549                None,
 550            );
 551        }
 552
 553        let retry_after = None;
 554        LanguageModelCompletionError::from_http_status(
 555            PROVIDER_NAME,
 556            error.status,
 557            error.body,
 558            retry_after,
 559        )
 560    }
 561}
 562
 563impl LanguageModel for CloudLanguageModel {
 564    fn id(&self) -> LanguageModelId {
 565        self.id.clone()
 566    }
 567
 568    fn name(&self) -> LanguageModelName {
 569        LanguageModelName::from(self.model.display_name.clone())
 570    }
 571
 572    fn provider_id(&self) -> LanguageModelProviderId {
 573        PROVIDER_ID
 574    }
 575
 576    fn provider_name(&self) -> LanguageModelProviderName {
 577        PROVIDER_NAME
 578    }
 579
 580    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 581        use cloud_llm_client::LanguageModelProvider::*;
 582        match self.model.provider {
 583            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 584            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 585            Google => language_model::GOOGLE_PROVIDER_ID,
 586            XAi => language_model::X_AI_PROVIDER_ID,
 587        }
 588    }
 589
 590    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 591        use cloud_llm_client::LanguageModelProvider::*;
 592        match self.model.provider {
 593            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 594            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 595            Google => language_model::GOOGLE_PROVIDER_NAME,
 596            XAi => language_model::X_AI_PROVIDER_NAME,
 597        }
 598    }
 599
 600    fn supports_tools(&self) -> bool {
 601        self.model.supports_tools
 602    }
 603
 604    fn supports_images(&self) -> bool {
 605        self.model.supports_images
 606    }
 607
 608    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 609        match choice {
 610            LanguageModelToolChoice::Auto
 611            | LanguageModelToolChoice::Any
 612            | LanguageModelToolChoice::None => true,
 613        }
 614    }
 615
 616    fn supports_burn_mode(&self) -> bool {
 617        self.model.supports_max_mode
 618    }
 619
 620    fn telemetry_id(&self) -> String {
 621        format!("zed.dev/{}", self.model.id)
 622    }
 623
 624    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 625        match self.model.provider {
 626            cloud_llm_client::LanguageModelProvider::Anthropic
 627            | cloud_llm_client::LanguageModelProvider::OpenAi
 628            | cloud_llm_client::LanguageModelProvider::XAi => {
 629                LanguageModelToolSchemaFormat::JsonSchema
 630            }
 631            cloud_llm_client::LanguageModelProvider::Google => {
 632                LanguageModelToolSchemaFormat::JsonSchemaSubset
 633            }
 634        }
 635    }
 636
 637    fn max_token_count(&self) -> u64 {
 638        self.model.max_token_count as u64
 639    }
 640
 641    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 642        self.model
 643            .max_token_count_in_max_mode
 644            .filter(|_| self.model.supports_max_mode)
 645            .map(|max_token_count| max_token_count as u64)
 646    }
 647
 648    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 649        match &self.model.provider {
 650            cloud_llm_client::LanguageModelProvider::Anthropic => {
 651                Some(LanguageModelCacheConfiguration {
 652                    min_total_token: 2_048,
 653                    should_speculate: true,
 654                    max_cache_anchors: 4,
 655                })
 656            }
 657            cloud_llm_client::LanguageModelProvider::OpenAi
 658            | cloud_llm_client::LanguageModelProvider::XAi
 659            | cloud_llm_client::LanguageModelProvider::Google => None,
 660        }
 661    }
 662
 663    fn count_tokens(
 664        &self,
 665        request: LanguageModelRequest,
 666        cx: &App,
 667    ) -> BoxFuture<'static, Result<u64>> {
 668        match self.model.provider {
 669            cloud_llm_client::LanguageModelProvider::Anthropic => {
 670                count_anthropic_tokens(request, cx)
 671            }
 672            cloud_llm_client::LanguageModelProvider::OpenAi => {
 673                let model = match open_ai::Model::from_id(&self.model.id.0) {
 674                    Ok(model) => model,
 675                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 676                };
 677                count_open_ai_tokens(request, model, cx)
 678            }
 679            cloud_llm_client::LanguageModelProvider::XAi => {
 680                let model = match x_ai::Model::from_id(&self.model.id.0) {
 681                    Ok(model) => model,
 682                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 683                };
 684                count_xai_tokens(request, model, cx)
 685            }
 686            cloud_llm_client::LanguageModelProvider::Google => {
 687                let client = self.client.clone();
 688                let llm_api_token = self.llm_api_token.clone();
 689                let model_id = self.model.id.to_string();
 690                let generate_content_request =
 691                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 692                async move {
 693                    let http_client = &client.http_client();
 694                    let token = llm_api_token.acquire(&client).await?;
 695
 696                    let request_body = CountTokensBody {
 697                        provider: cloud_llm_client::LanguageModelProvider::Google,
 698                        model: model_id,
 699                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 700                            generate_content_request,
 701                        })?,
 702                    };
 703                    let request = http_client::Request::builder()
 704                        .method(Method::POST)
 705                        .uri(
 706                            http_client
 707                                .build_zed_llm_url("/count_tokens", &[])?
 708                                .as_ref(),
 709                        )
 710                        .header("Content-Type", "application/json")
 711                        .header("Authorization", format!("Bearer {token}"))
 712                        .body(serde_json::to_string(&request_body)?.into())?;
 713                    let mut response = http_client.send(request).await?;
 714                    let status = response.status();
 715                    let headers = response.headers().clone();
 716                    let mut response_body = String::new();
 717                    response
 718                        .body_mut()
 719                        .read_to_string(&mut response_body)
 720                        .await?;
 721
 722                    if status.is_success() {
 723                        let response_body: CountTokensResponse =
 724                            serde_json::from_str(&response_body)?;
 725
 726                        Ok(response_body.tokens as u64)
 727                    } else {
 728                        Err(anyhow!(ApiError {
 729                            status,
 730                            body: response_body,
 731                            headers
 732                        }))
 733                    }
 734                }
 735                .boxed()
 736            }
 737        }
 738    }
 739
 740    fn stream_completion(
 741        &self,
 742        request: LanguageModelRequest,
 743        cx: &AsyncApp,
 744    ) -> BoxFuture<
 745        'static,
 746        Result<
 747            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 748            LanguageModelCompletionError,
 749        >,
 750    > {
 751        let thread_id = request.thread_id.clone();
 752        let prompt_id = request.prompt_id.clone();
 753        let intent = request.intent;
 754        let mode = request.mode;
 755        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 756        let thinking_allowed = request.thinking_allowed;
 757        let provider_name = provider_name(&self.model.provider);
 758        match self.model.provider {
 759            cloud_llm_client::LanguageModelProvider::Anthropic => {
 760                let request = into_anthropic(
 761                    request,
 762                    self.model.id.to_string(),
 763                    1.0,
 764                    self.model.max_output_tokens as u64,
 765                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 766                        AnthropicModelMode::Thinking {
 767                            budget_tokens: Some(4_096),
 768                        }
 769                    } else {
 770                        AnthropicModelMode::Default
 771                    },
 772                );
 773                let client = self.client.clone();
 774                let llm_api_token = self.llm_api_token.clone();
 775                let future = self.request_limiter.stream(async move {
 776                    let PerformLlmCompletionResponse {
 777                        response,
 778                        usage,
 779                        includes_status_messages,
 780                        tool_use_limit_reached,
 781                    } = Self::perform_llm_completion(
 782                        client.clone(),
 783                        llm_api_token,
 784                        app_version,
 785                        CompletionBody {
 786                            thread_id,
 787                            prompt_id,
 788                            intent,
 789                            mode,
 790                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 791                            model: request.model.clone(),
 792                            provider_request: serde_json::to_value(&request)
 793                                .map_err(|e| anyhow!(e))?,
 794                        },
 795                    )
 796                    .await
 797                    .map_err(|err| match err.downcast::<ApiError>() {
 798                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 799                        Err(err) => anyhow!(err),
 800                    })?;
 801
 802                    let mut mapper = AnthropicEventMapper::new();
 803                    Ok(map_cloud_completion_events(
 804                        Box::pin(
 805                            response_lines(response, includes_status_messages)
 806                                .chain(usage_updated_event(usage))
 807                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}),
 808                        ),
 809                        &provider_name,
 810                        move |event| mapper.map_event(event),
 811                    ))
 812                });
 813                async move { Ok(future.await?.boxed()) }.boxed()
 814            }
 815            cloud_llm_client::LanguageModelProvider::OpenAi => {
 816                let client = self.client.clone();
 817                let request = into_open_ai(
 818                    request,
 819                    &self.model.id.0,
 820                    self.model.supports_parallel_tool_calls,
 821                    true,
 822                    None,
 823                    None,
 824                );
 825                let llm_api_token = self.llm_api_token.clone();
 826                let future = self.request_limiter.stream(async move {
 827                    let PerformLlmCompletionResponse {
 828                        response,
 829                        usage,
 830                        includes_status_messages,
 831                        tool_use_limit_reached,
 832                    } = Self::perform_llm_completion(
 833                        client.clone(),
 834                        llm_api_token,
 835                        app_version,
 836                        CompletionBody {
 837                            thread_id,
 838                            prompt_id,
 839                            intent,
 840                            mode,
 841                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 842                            model: request.model.clone(),
 843                            provider_request: serde_json::to_value(&request)
 844                                .map_err(|e| anyhow!(e))?,
 845                        },
 846                    )
 847                    .await?;
 848
 849                    let mut mapper = OpenAiEventMapper::new();
 850                    Ok(map_cloud_completion_events(
 851                        Box::pin(
 852                            response_lines(response, includes_status_messages)
 853                                .chain(usage_updated_event(usage))
 854                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 855                        ),
 856                        &provider_name,
 857                        move |event| mapper.map_event(event),
 858                    ))
 859                });
 860                async move { Ok(future.await?.boxed()) }.boxed()
 861            }
 862            cloud_llm_client::LanguageModelProvider::XAi => {
 863                let client = self.client.clone();
 864                let request = into_open_ai(
 865                    request,
 866                    &self.model.id.0,
 867                    self.model.supports_parallel_tool_calls,
 868                    false,
 869                    None,
 870                    None,
 871                );
 872                let llm_api_token = self.llm_api_token.clone();
 873                let future = self.request_limiter.stream(async move {
 874                    let PerformLlmCompletionResponse {
 875                        response,
 876                        usage,
 877                        includes_status_messages,
 878                        tool_use_limit_reached,
 879                    } = Self::perform_llm_completion(
 880                        client.clone(),
 881                        llm_api_token,
 882                        app_version,
 883                        CompletionBody {
 884                            thread_id,
 885                            prompt_id,
 886                            intent,
 887                            mode,
 888                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 889                            model: request.model.clone(),
 890                            provider_request: serde_json::to_value(&request)
 891                                .map_err(|e| anyhow!(e))?,
 892                        },
 893                    )
 894                    .await?;
 895
 896                    let mut mapper = OpenAiEventMapper::new();
 897                    Ok(map_cloud_completion_events(
 898                        Box::pin(
 899                            response_lines(response, includes_status_messages)
 900                                .chain(usage_updated_event(usage))
 901                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 902                        ),
 903                        &provider_name,
 904                        move |event| mapper.map_event(event),
 905                    ))
 906                });
 907                async move { Ok(future.await?.boxed()) }.boxed()
 908            }
 909            cloud_llm_client::LanguageModelProvider::Google => {
 910                let client = self.client.clone();
 911                let request =
 912                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 913                let llm_api_token = self.llm_api_token.clone();
 914                let future = self.request_limiter.stream(async move {
 915                    let PerformLlmCompletionResponse {
 916                        response,
 917                        usage,
 918                        includes_status_messages,
 919                        tool_use_limit_reached,
 920                    } = Self::perform_llm_completion(
 921                        client.clone(),
 922                        llm_api_token,
 923                        app_version,
 924                        CompletionBody {
 925                            thread_id,
 926                            prompt_id,
 927                            intent,
 928                            mode,
 929                            provider: cloud_llm_client::LanguageModelProvider::Google,
 930                            model: request.model.model_id.clone(),
 931                            provider_request: serde_json::to_value(&request)
 932                                .map_err(|e| anyhow!(e))?,
 933                        },
 934                    )
 935                    .await?;
 936
 937                    let mut mapper = GoogleEventMapper::new();
 938                    Ok(map_cloud_completion_events(
 939                        Box::pin(
 940                            response_lines(response, includes_status_messages)
 941                                .chain(usage_updated_event(usage))
 942                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 943                        ),
 944                        &provider_name,
 945                        move |event| mapper.map_event(event),
 946                    ))
 947                });
 948                async move { Ok(future.await?.boxed()) }.boxed()
 949            }
 950        }
 951    }
 952}
 953
 954fn map_cloud_completion_events<T, F>(
 955    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 956    provider: &LanguageModelProviderName,
 957    mut map_callback: F,
 958) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 959where
 960    T: DeserializeOwned + 'static,
 961    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 962        + Send
 963        + 'static,
 964{
 965    let provider = provider.clone();
 966    stream
 967        .flat_map(move |event| {
 968            futures::stream::iter(match event {
 969                Err(error) => {
 970                    vec![Err(LanguageModelCompletionError::from(error))]
 971                }
 972                Ok(CompletionEvent::Status(event)) => {
 973                    vec![
 974                        LanguageModelCompletionEvent::from_completion_request_status(
 975                            event,
 976                            provider.clone(),
 977                        ),
 978                    ]
 979                }
 980                Ok(CompletionEvent::Event(event)) => map_callback(event),
 981            })
 982        })
 983        .boxed()
 984}
 985
 986fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
 987    match provider {
 988        cloud_llm_client::LanguageModelProvider::Anthropic => {
 989            language_model::ANTHROPIC_PROVIDER_NAME
 990        }
 991        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 992        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
 993        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
 994    }
 995}
 996
 997fn usage_updated_event<T>(
 998    usage: Option<ModelRequestUsage>,
 999) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1000    futures::stream::iter(usage.map(|usage| {
1001        Ok(CompletionEvent::Status(
1002            CompletionRequestStatus::UsageUpdated {
1003                amount: usage.amount as usize,
1004                limit: usage.limit,
1005            },
1006        ))
1007    }))
1008}
1009
1010fn tool_use_limit_reached_event<T>(
1011    tool_use_limit_reached: bool,
1012) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1013    futures::stream::iter(tool_use_limit_reached.then(|| {
1014        Ok(CompletionEvent::Status(
1015            CompletionRequestStatus::ToolUseLimitReached,
1016        ))
1017    }))
1018}
1019
1020fn response_lines<T: DeserializeOwned>(
1021    response: Response<AsyncBody>,
1022    includes_status_messages: bool,
1023) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1024    futures::stream::try_unfold(
1025        (String::new(), BufReader::new(response.into_body())),
1026        move |(mut line, mut body)| async move {
1027            match body.read_line(&mut line).await {
1028                Ok(0) => Ok(None),
1029                Ok(_) => {
1030                    let event = if includes_status_messages {
1031                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1032                    } else {
1033                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1034                    };
1035
1036                    line.clear();
1037                    Ok(Some((event, (line, body))))
1038                }
1039                Err(e) => Err(e.into()),
1040            }
1041        },
1042    )
1043}
1044
1045#[derive(IntoElement, RegisterComponent)]
1046struct ZedAiConfiguration {
1047    is_connected: bool,
1048    plan: Option<Plan>,
1049    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1050    eligible_for_trial: bool,
1051    account_too_young: bool,
1052    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1053}
1054
1055impl RenderOnce for ZedAiConfiguration {
1056    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1057        let is_pro = self.plan.is_some_and(|plan| {
1058            matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1059        });
1060        let subscription_text = match (self.plan, self.subscription_period) {
1061            (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1062                "You have access to Zed's hosted models through your Pro subscription."
1063            }
1064            (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1065                "You have access to Zed's hosted models through your Pro trial."
1066            }
1067            (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1068                "You have basic access to Zed's hosted models through the Free plan."
1069            }
1070            (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1071                if self.eligible_for_trial {
1072                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1073                } else {
1074                    "Subscribe for access to Zed's hosted models."
1075                }
1076            }
1077            _ => {
1078                if self.eligible_for_trial {
1079                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1080                } else {
1081                    "Subscribe for access to Zed's hosted models."
1082                }
1083            }
1084        };
1085
1086        let manage_subscription_buttons = if is_pro {
1087            Button::new("manage_settings", "Manage Subscription")
1088                .full_width()
1089                .style(ButtonStyle::Tinted(TintColor::Accent))
1090                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1091                .into_any_element()
1092        } else if self.plan.is_none() || self.eligible_for_trial {
1093            Button::new("start_trial", "Start 14-day Free Pro Trial")
1094                .full_width()
1095                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1096                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1097                .into_any_element()
1098        } else {
1099            Button::new("upgrade", "Upgrade to Pro")
1100                .full_width()
1101                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1102                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1103                .into_any_element()
1104        };
1105
1106        if !self.is_connected {
1107            return v_flex()
1108                .gap_2()
1109                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1110                .child(
1111                    Button::new("sign_in", "Sign In to use Zed AI")
1112                        .icon_color(Color::Muted)
1113                        .icon(IconName::Github)
1114                        .icon_size(IconSize::Small)
1115                        .icon_position(IconPosition::Start)
1116                        .full_width()
1117                        .on_click({
1118                            let callback = self.sign_in_callback.clone();
1119                            move |_, window, cx| (callback)(window, cx)
1120                        }),
1121                );
1122        }
1123
1124        v_flex().gap_2().w_full().map(|this| {
1125            if self.account_too_young {
1126                this.child(YoungAccountBanner).child(
1127                    Button::new("upgrade", "Upgrade to Pro")
1128                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1129                        .full_width()
1130                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1131                )
1132            } else {
1133                this.text_sm()
1134                    .child(subscription_text)
1135                    .child(manage_subscription_buttons)
1136            }
1137        })
1138    }
1139}
1140
1141struct ConfigurationView {
1142    state: Entity<State>,
1143    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1144}
1145
1146impl ConfigurationView {
1147    fn new(state: Entity<State>) -> Self {
1148        let sign_in_callback = Arc::new({
1149            let state = state.clone();
1150            move |_window: &mut Window, cx: &mut App| {
1151                state.update(cx, |state, cx| {
1152                    state.authenticate(cx).detach_and_log_err(cx);
1153                });
1154            }
1155        });
1156
1157        Self {
1158            state,
1159            sign_in_callback,
1160        }
1161    }
1162}
1163
1164impl Render for ConfigurationView {
1165    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1166        let state = self.state.read(cx);
1167        let user_store = state.user_store.read(cx);
1168
1169        ZedAiConfiguration {
1170            is_connected: !state.is_signed_out(cx),
1171            plan: user_store.plan(),
1172            subscription_period: user_store.subscription_period(),
1173            eligible_for_trial: user_store.trial_started_at().is_none(),
1174            account_too_young: user_store.account_too_young(),
1175            sign_in_callback: self.sign_in_callback.clone(),
1176        }
1177    }
1178}
1179
1180impl Component for ZedAiConfiguration {
1181    fn name() -> &'static str {
1182        "AI Configuration Content"
1183    }
1184
1185    fn sort_name() -> &'static str {
1186        "AI Configuration Content"
1187    }
1188
1189    fn scope() -> ComponentScope {
1190        ComponentScope::Onboarding
1191    }
1192
1193    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1194        fn configuration(
1195            is_connected: bool,
1196            plan: Option<Plan>,
1197            eligible_for_trial: bool,
1198            account_too_young: bool,
1199        ) -> AnyElement {
1200            ZedAiConfiguration {
1201                is_connected,
1202                plan,
1203                subscription_period: plan
1204                    .is_some()
1205                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1206                eligible_for_trial,
1207                account_too_young,
1208                sign_in_callback: Arc::new(|_, _| {}),
1209            }
1210            .into_any_element()
1211        }
1212
1213        Some(
1214            v_flex()
1215                .p_4()
1216                .gap_4()
1217                .children(vec![
1218                    single_example("Not connected", configuration(false, None, false, false)),
1219                    single_example(
1220                        "Accept Terms of Service",
1221                        configuration(true, None, true, false),
1222                    ),
1223                    single_example(
1224                        "No Plan - Not eligible for trial",
1225                        configuration(true, None, false, false),
1226                    ),
1227                    single_example(
1228                        "No Plan - Eligible for trial",
1229                        configuration(true, None, true, false),
1230                    ),
1231                    single_example(
1232                        "Free Plan",
1233                        configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1234                    ),
1235                    single_example(
1236                        "Zed Pro Trial Plan",
1237                        configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1238                    ),
1239                    single_example(
1240                        "Zed Pro Plan",
1241                        configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1242                    ),
1243                ])
1244                .into_any_element(),
1245        )
1246    }
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251    use super::*;
1252    use http_client::http::{HeaderMap, StatusCode};
1253    use language_model::LanguageModelCompletionError;
1254
1255    #[test]
1256    fn test_api_error_conversion_with_upstream_http_error() {
1257        // upstream_http_error with 503 status should become ServerOverloaded
1258        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1259
1260        let api_error = ApiError {
1261            status: StatusCode::INTERNAL_SERVER_ERROR,
1262            body: error_body.to_string(),
1263            headers: HeaderMap::new(),
1264        };
1265
1266        let completion_error: LanguageModelCompletionError = api_error.into();
1267
1268        match completion_error {
1269            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1270                assert_eq!(
1271                    message,
1272                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1273                );
1274            }
1275            _ => panic!(
1276                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1277                completion_error
1278            ),
1279        }
1280
1281        // upstream_http_error with 500 status should become ApiInternalServerError
1282        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1283
1284        let api_error = ApiError {
1285            status: StatusCode::INTERNAL_SERVER_ERROR,
1286            body: error_body.to_string(),
1287            headers: HeaderMap::new(),
1288        };
1289
1290        let completion_error: LanguageModelCompletionError = api_error.into();
1291
1292        match completion_error {
1293            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1294                assert_eq!(
1295                    message,
1296                    "Received an error from the OpenAI API: internal server error"
1297                );
1298            }
1299            _ => panic!(
1300                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1301                completion_error
1302            ),
1303        }
1304
1305        // upstream_http_error with 429 status should become RateLimitExceeded
1306        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1307
1308        let api_error = ApiError {
1309            status: StatusCode::INTERNAL_SERVER_ERROR,
1310            body: error_body.to_string(),
1311            headers: HeaderMap::new(),
1312        };
1313
1314        let completion_error: LanguageModelCompletionError = api_error.into();
1315
1316        match completion_error {
1317            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1318                assert_eq!(
1319                    message,
1320                    "Received an error from the Google API: rate limit exceeded"
1321                );
1322            }
1323            _ => panic!(
1324                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1325                completion_error
1326            ),
1327        }
1328
1329        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1330        let error_body = "Regular internal server error";
1331
1332        let api_error = ApiError {
1333            status: StatusCode::INTERNAL_SERVER_ERROR,
1334            body: error_body.to_string(),
1335            headers: HeaderMap::new(),
1336        };
1337
1338        let completion_error: LanguageModelCompletionError = api_error.into();
1339
1340        match completion_error {
1341            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1342                assert_eq!(provider, PROVIDER_NAME);
1343                assert_eq!(message, "Regular internal server error");
1344            }
1345            _ => panic!(
1346                "Expected ApiInternalServerError for regular 500, got: {:?}",
1347                completion_error
1348            ),
1349        }
1350
1351        // upstream_http_429 format should be converted to UpstreamProviderError
1352        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1353
1354        let api_error = ApiError {
1355            status: StatusCode::INTERNAL_SERVER_ERROR,
1356            body: error_body.to_string(),
1357            headers: HeaderMap::new(),
1358        };
1359
1360        let completion_error: LanguageModelCompletionError = api_error.into();
1361
1362        match completion_error {
1363            LanguageModelCompletionError::UpstreamProviderError {
1364                message,
1365                status,
1366                retry_after,
1367            } => {
1368                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1369                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1370                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1371            }
1372            _ => panic!(
1373                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1374                completion_error
1375            ),
1376        }
1377
1378        // Invalid JSON in error body should fall back to regular error handling
1379        let error_body = "Not JSON at all";
1380
1381        let api_error = ApiError {
1382            status: StatusCode::INTERNAL_SERVER_ERROR,
1383            body: error_body.to_string(),
1384            headers: HeaderMap::new(),
1385        };
1386
1387        let completion_error: LanguageModelCompletionError = api_error.into();
1388
1389        match completion_error {
1390            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1391                assert_eq!(provider, PROVIDER_NAME);
1392            }
1393            _ => panic!(
1394                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1395                completion_error
1396            ),
1397        }
1398    }
1399}
1400
1401fn count_anthropic_tokens(
1402    request: LanguageModelRequest,
1403    cx: &App,
1404) -> BoxFuture<'static, Result<u64>> {
1405    use gpui::AppContext as _;
1406    cx.background_spawn(async move {
1407        let messages = request.messages;
1408        let mut tokens_from_images = 0;
1409        let mut string_messages = Vec::with_capacity(messages.len());
1410
1411        for message in messages {
1412            let mut string_contents = String::new();
1413
1414            for content in message.content {
1415                match content {
1416                    MessageContent::Text(text) => {
1417                        string_contents.push_str(&text);
1418                    }
1419                    MessageContent::Thinking { .. } => {}
1420                    MessageContent::RedactedThinking(_) => {}
1421                    MessageContent::Image(image) => {
1422                        tokens_from_images += image.estimate_tokens();
1423                    }
1424                    MessageContent::ToolUse(_tool_use) => {}
1425                    MessageContent::ToolResult(tool_result) => match &tool_result.content {
1426                        LanguageModelToolResultContent::Text(text) => {
1427                            string_contents.push_str(text);
1428                        }
1429                        LanguageModelToolResultContent::Image(image) => {
1430                            tokens_from_images += image.estimate_tokens();
1431                        }
1432                    },
1433                }
1434            }
1435
1436            if !string_contents.is_empty() {
1437                string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
1438                    role: match message.role {
1439                        Role::User => "user".into(),
1440                        Role::Assistant => "assistant".into(),
1441                        Role::System => "system".into(),
1442                    },
1443                    content: Some(string_contents),
1444                    name: None,
1445                    function_call: None,
1446                });
1447            }
1448        }
1449
1450        tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
1451            .map(|tokens| (tokens + tokens_from_images) as u64)
1452    })
1453    .boxed()
1454}
1455
1456fn into_anthropic(
1457    request: LanguageModelRequest,
1458    model: String,
1459    default_temperature: f32,
1460    max_output_tokens: u64,
1461    mode: AnthropicModelMode,
1462) -> anthropic::Request {
1463    let mut new_messages: Vec<anthropic::Message> = Vec::new();
1464    let mut system_message = String::new();
1465
1466    for message in request.messages {
1467        if message.contents_empty() {
1468            continue;
1469        }
1470
1471        match message.role {
1472            Role::User | Role::Assistant => {
1473                let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
1474                    .content
1475                    .into_iter()
1476                    .filter_map(|content| match content {
1477                        MessageContent::Text(text) => {
1478                            let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
1479                                text.trim_end().to_string()
1480                            } else {
1481                                text
1482                            };
1483                            if !text.is_empty() {
1484                                Some(anthropic::RequestContent::Text {
1485                                    text,
1486                                    cache_control: None,
1487                                })
1488                            } else {
1489                                None
1490                            }
1491                        }
1492                        MessageContent::Thinking {
1493                            text: thinking,
1494                            signature,
1495                        } => {
1496                            if !thinking.is_empty() {
1497                                Some(anthropic::RequestContent::Thinking {
1498                                    thinking,
1499                                    signature: signature.unwrap_or_default(),
1500                                    cache_control: None,
1501                                })
1502                            } else {
1503                                None
1504                            }
1505                        }
1506                        MessageContent::RedactedThinking(data) => {
1507                            if !data.is_empty() {
1508                                Some(anthropic::RequestContent::RedactedThinking { data })
1509                            } else {
1510                                None
1511                            }
1512                        }
1513                        MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
1514                            source: anthropic::ImageSource {
1515                                source_type: "base64".to_string(),
1516                                media_type: "image/png".to_string(),
1517                                data: image.source.to_string(),
1518                            },
1519                            cache_control: None,
1520                        }),
1521                        MessageContent::ToolUse(tool_use) => {
1522                            Some(anthropic::RequestContent::ToolUse {
1523                                id: tool_use.id.to_string(),
1524                                name: tool_use.name.to_string(),
1525                                input: tool_use.input,
1526                                cache_control: None,
1527                            })
1528                        }
1529                        MessageContent::ToolResult(tool_result) => {
1530                            Some(anthropic::RequestContent::ToolResult {
1531                                tool_use_id: tool_result.tool_use_id.to_string(),
1532                                is_error: tool_result.is_error,
1533                                content: match tool_result.content {
1534                                    LanguageModelToolResultContent::Text(text) => {
1535                                        ToolResultContent::Plain(text.to_string())
1536                                    }
1537                                    LanguageModelToolResultContent::Image(image) => {
1538                                        ToolResultContent::Multipart(vec![ToolResultPart::Image {
1539                                            source: anthropic::ImageSource {
1540                                                source_type: "base64".to_string(),
1541                                                media_type: "image/png".to_string(),
1542                                                data: image.source.to_string(),
1543                                            },
1544                                        }])
1545                                    }
1546                                },
1547                                cache_control: None,
1548                            })
1549                        }
1550                    })
1551                    .collect();
1552                let anthropic_role = match message.role {
1553                    Role::User => anthropic::Role::User,
1554                    Role::Assistant => anthropic::Role::Assistant,
1555                    Role::System => unreachable!("System role should never occur here"),
1556                };
1557                if let Some(last_message) = new_messages.last_mut()
1558                    && last_message.role == anthropic_role
1559                {
1560                    last_message.content.extend(anthropic_message_content);
1561                    continue;
1562                }
1563
1564                if message.cache {
1565                    let cache_control_value = Some(anthropic::CacheControl {
1566                        cache_type: anthropic::CacheControlType::Ephemeral,
1567                    });
1568                    for message_content in anthropic_message_content.iter_mut().rev() {
1569                        match message_content {
1570                            anthropic::RequestContent::RedactedThinking { .. } => {}
1571                            anthropic::RequestContent::Text { cache_control, .. }
1572                            | anthropic::RequestContent::Thinking { cache_control, .. }
1573                            | anthropic::RequestContent::Image { cache_control, .. }
1574                            | anthropic::RequestContent::ToolUse { cache_control, .. }
1575                            | anthropic::RequestContent::ToolResult { cache_control, .. } => {
1576                                *cache_control = cache_control_value;
1577                                break;
1578                            }
1579                        }
1580                    }
1581                }
1582
1583                new_messages.push(anthropic::Message {
1584                    role: anthropic_role,
1585                    content: anthropic_message_content,
1586                });
1587            }
1588            Role::System => {
1589                if !system_message.is_empty() {
1590                    system_message.push_str("\n\n");
1591                }
1592                system_message.push_str(&message.string_contents());
1593            }
1594        }
1595    }
1596
1597    anthropic::Request {
1598        model,
1599        messages: new_messages,
1600        max_tokens: max_output_tokens,
1601        system: if system_message.is_empty() {
1602            None
1603        } else {
1604            Some(anthropic::StringOrContents::String(system_message))
1605        },
1606        thinking: if request.thinking_allowed
1607            && let AnthropicModelMode::Thinking { budget_tokens } = mode
1608        {
1609            Some(anthropic::Thinking::Enabled { budget_tokens })
1610        } else {
1611            None
1612        },
1613        tools: request
1614            .tools
1615            .into_iter()
1616            .map(|tool| anthropic::Tool {
1617                name: tool.name,
1618                description: tool.description,
1619                input_schema: tool.input_schema,
1620            })
1621            .collect(),
1622        tool_choice: request.tool_choice.map(|choice| match choice {
1623            LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
1624            LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
1625            LanguageModelToolChoice::None => anthropic::ToolChoice::None,
1626        }),
1627        metadata: None,
1628        stop_sequences: Vec::new(),
1629        temperature: request.temperature.or(Some(default_temperature)),
1630        top_k: None,
1631        top_p: None,
1632    }
1633}
1634
1635struct AnthropicEventMapper {
1636    tool_uses_by_index: collections::HashMap<usize, RawToolUse>,
1637    usage: Usage,
1638    stop_reason: StopReason,
1639}
1640
1641impl AnthropicEventMapper {
1642    fn new() -> Self {
1643        Self {
1644            tool_uses_by_index: collections::HashMap::default(),
1645            usage: Usage::default(),
1646            stop_reason: StopReason::EndTurn,
1647        }
1648    }
1649
1650    fn map_event(
1651        &mut self,
1652        event: Event,
1653    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
1654        match event {
1655            Event::ContentBlockStart {
1656                index,
1657                content_block,
1658            } => match content_block {
1659                ResponseContent::Text { text } => {
1660                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
1661                }
1662                ResponseContent::Thinking { thinking } => {
1663                    vec![Ok(LanguageModelCompletionEvent::Thinking {
1664                        text: thinking,
1665                        signature: None,
1666                    })]
1667                }
1668                ResponseContent::RedactedThinking { data } => {
1669                    vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
1670                }
1671                ResponseContent::ToolUse { id, name, .. } => {
1672                    self.tool_uses_by_index.insert(
1673                        index,
1674                        RawToolUse {
1675                            id,
1676                            name,
1677                            input_json: String::new(),
1678                        },
1679                    );
1680                    Vec::new()
1681                }
1682            },
1683            Event::ContentBlockDelta { index, delta } => match delta {
1684                ContentDelta::TextDelta { text } => {
1685                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
1686                }
1687                ContentDelta::ThinkingDelta { thinking } => {
1688                    vec![Ok(LanguageModelCompletionEvent::Thinking {
1689                        text: thinking,
1690                        signature: None,
1691                    })]
1692                }
1693                ContentDelta::SignatureDelta { signature } => {
1694                    vec![Ok(LanguageModelCompletionEvent::Thinking {
1695                        text: "".to_string(),
1696                        signature: Some(signature),
1697                    })]
1698                }
1699                ContentDelta::InputJsonDelta { partial_json } => {
1700                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
1701                        tool_use.input_json.push_str(&partial_json);
1702
1703                        let event = serde_json::from_str::<serde_json::Value>(&tool_use.input_json)
1704                            .ok()
1705                            .and_then(|input| {
1706                                let input_json_roundtripped = serde_json::to_string(&input).ok()?;
1707
1708                                if !tool_use.input_json.starts_with(&input_json_roundtripped) {
1709                                    return None;
1710                                }
1711
1712                                Some(LanguageModelCompletionEvent::ToolUse(
1713                                    LanguageModelToolUse {
1714                                        id: LanguageModelToolUseId::from(tool_use.id.clone()),
1715                                        name: tool_use.name.clone().into(),
1716                                        raw_input: tool_use.input_json.clone(),
1717                                        input,
1718                                        is_input_complete: false,
1719                                        thought_signature: None,
1720                                    },
1721                                ))
1722                            });
1723
1724                        if let Some(event) = event {
1725                            vec![Ok(event)]
1726                        } else {
1727                            Vec::new()
1728                        }
1729                    } else {
1730                        Vec::new()
1731                    }
1732                }
1733            },
1734            Event::ContentBlockStop { index } => {
1735                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
1736                    let event_result = match serde_json::from_str(&tool_use.input_json) {
1737                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
1738                            LanguageModelToolUse {
1739                                id: LanguageModelToolUseId::from(tool_use.id),
1740                                name: tool_use.name.into(),
1741                                raw_input: tool_use.input_json,
1742                                input,
1743                                is_input_complete: true,
1744                                thought_signature: None,
1745                            },
1746                        )),
1747                        Err(json_parse_err) => {
1748                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1749                                id: LanguageModelToolUseId::from(tool_use.id),
1750                                tool_name: tool_use.name.into(),
1751                                raw_input: tool_use.input_json.into(),
1752                                json_parse_error: json_parse_err.to_string(),
1753                            })
1754                        }
1755                    };
1756
1757                    vec![event_result]
1758                } else {
1759                    Vec::new()
1760                }
1761            }
1762            Event::MessageStart { message } => {
1763                update_anthropic_usage(&mut self.usage, &message.usage);
1764                vec![
1765                    Ok(LanguageModelCompletionEvent::UsageUpdate(
1766                        convert_anthropic_usage(&self.usage),
1767                    )),
1768                    Ok(LanguageModelCompletionEvent::StartMessage {
1769                        message_id: message.id,
1770                    }),
1771                ]
1772            }
1773            Event::MessageDelta { delta, usage } => {
1774                update_anthropic_usage(&mut self.usage, &usage);
1775                if let Some(stop_reason) = delta.stop_reason.as_deref() {
1776                    self.stop_reason = match stop_reason {
1777                        "end_turn" => StopReason::EndTurn,
1778                        "max_tokens" => StopReason::MaxTokens,
1779                        "tool_use" => StopReason::ToolUse,
1780                        "refusal" => StopReason::Refusal,
1781                        _ => {
1782                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
1783                            StopReason::EndTurn
1784                        }
1785                    };
1786                }
1787                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
1788                    convert_anthropic_usage(&self.usage),
1789                ))]
1790            }
1791            Event::MessageStop => {
1792                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
1793            }
1794            Event::Error { error } => {
1795                vec![Err(error.into())]
1796            }
1797            _ => Vec::new(),
1798        }
1799    }
1800}
1801
1802struct RawToolUse {
1803    id: String,
1804    name: String,
1805    input_json: String,
1806}
1807
1808fn update_anthropic_usage(usage: &mut Usage, new: &Usage) {
1809    if let Some(input_tokens) = new.input_tokens {
1810        usage.input_tokens = Some(input_tokens);
1811    }
1812    if let Some(output_tokens) = new.output_tokens {
1813        usage.output_tokens = Some(output_tokens);
1814    }
1815    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
1816        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
1817    }
1818    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
1819        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
1820    }
1821}
1822
1823fn convert_anthropic_usage(usage: &Usage) -> language_model::TokenUsage {
1824    language_model::TokenUsage {
1825        input_tokens: usage.input_tokens.unwrap_or(0),
1826        output_tokens: usage.output_tokens.unwrap_or(0),
1827        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1828        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1829    }
1830}