cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   6use cloud_llm_client::{
   7    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
   8    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
   9    EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
  10    PlanV1, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
  11    SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME,
  12    ZED_VERSION_HEADER_NAME,
  13};
  14use futures::{
  15    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
  16};
  17use google_ai::GoogleModelMode;
  18use gpui::{
  19    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  20};
  21use http_client::http::{HeaderMap, HeaderValue};
  22use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  23use language_model::{
  24    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  25    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  26    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  27    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  28    LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
  29    PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  30};
  31use release_channel::AppVersion;
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use settings::SettingsStore;
  35pub use settings::ZedDotDevAvailableModel as AvailableModel;
  36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
  37use smol::io::{AsyncReadExt, BufReader};
  38use std::pin::Pin;
  39use std::str::FromStr as _;
  40use std::sync::Arc;
  41use std::time::Duration;
  42use thiserror::Error;
  43use ui::{TintColor, prelude::*};
  44use util::{ResultExt as _, maybe};
  45
  46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  47use crate::provider::google::{GoogleEventMapper, into_google};
  48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  49
  50const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  51const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  52
  53#[derive(Default, Clone, Debug, PartialEq)]
  54pub struct ZedDotDevSettings {
  55    pub available_models: Vec<AvailableModel>,
  56}
  57#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  58#[serde(tag = "type", rename_all = "lowercase")]
  59pub enum ModelMode {
  60    #[default]
  61    Default,
  62    Thinking {
  63        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  64        budget_tokens: Option<u32>,
  65    },
  66}
  67
  68impl From<ModelMode> for AnthropicModelMode {
  69    fn from(value: ModelMode) -> Self {
  70        match value {
  71            ModelMode::Default => AnthropicModelMode::Default,
  72            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  73        }
  74    }
  75}
  76
  77pub struct CloudLanguageModelProvider {
  78    client: Arc<Client>,
  79    state: gpui::Entity<State>,
  80    _maintain_client_status: Task<()>,
  81}
  82
  83pub struct State {
  84    client: Arc<Client>,
  85    llm_api_token: LlmApiToken,
  86    user_store: Entity<UserStore>,
  87    status: client::Status,
  88    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  89    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  90    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
  91    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
  92    _fetch_models_task: Task<()>,
  93    _settings_subscription: Subscription,
  94    _llm_token_subscription: Subscription,
  95}
  96
  97impl State {
  98    fn new(
  99        client: Arc<Client>,
 100        user_store: Entity<UserStore>,
 101        status: client::Status,
 102        cx: &mut Context<Self>,
 103    ) -> Self {
 104        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 105        let mut current_user = user_store.read(cx).watch_current_user();
 106        Self {
 107            client: client.clone(),
 108            llm_api_token: LlmApiToken::default(),
 109            user_store,
 110            status,
 111            models: Vec::new(),
 112            default_model: None,
 113            default_fast_model: None,
 114            recommended_models: Vec::new(),
 115            _fetch_models_task: cx.spawn(async move |this, cx| {
 116                maybe!(async move {
 117                    let (client, llm_api_token) = this
 118                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 119
 120                    while current_user.borrow().is_none() {
 121                        current_user.next().await;
 122                    }
 123
 124                    let response =
 125                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
 126                    this.update(cx, |this, cx| this.update_models(response, cx))?;
 127                    anyhow::Ok(())
 128                })
 129                .await
 130                .context("failed to fetch Zed models")
 131                .log_err();
 132            }),
 133            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 134                cx.notify();
 135            }),
 136            _llm_token_subscription: cx.subscribe(
 137                &refresh_llm_token_listener,
 138                move |this, _listener, _event, cx| {
 139                    let client = this.client.clone();
 140                    let llm_api_token = this.llm_api_token.clone();
 141                    cx.spawn(async move |this, cx| {
 142                        llm_api_token.refresh(&client).await?;
 143                        let response = Self::fetch_models(client, llm_api_token).await?;
 144                        this.update(cx, |this, cx| {
 145                            this.update_models(response, cx);
 146                        })
 147                    })
 148                    .detach_and_log_err(cx);
 149                },
 150            ),
 151        }
 152    }
 153
 154    fn is_signed_out(&self, cx: &App) -> bool {
 155        self.user_store.read(cx).current_user().is_none()
 156    }
 157
 158    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 159        let client = self.client.clone();
 160        cx.spawn(async move |state, cx| {
 161            client.sign_in_with_optional_connect(true, cx).await?;
 162            state.update(cx, |_, cx| cx.notify())
 163        })
 164    }
 165    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 166        let mut models = Vec::new();
 167
 168        for model in response.models {
 169            models.push(Arc::new(model.clone()));
 170
 171            // Right now we represent thinking variants of models as separate models on the client,
 172            // so we need to insert variants for any model that supports thinking.
 173            if model.supports_thinking {
 174                models.push(Arc::new(cloud_llm_client::LanguageModel {
 175                    id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 176                    display_name: format!("{} Thinking", model.display_name),
 177                    ..model
 178                }));
 179            }
 180        }
 181
 182        self.default_model = models
 183            .iter()
 184            .find(|model| model.id == response.default_model)
 185            .cloned();
 186        self.default_fast_model = models
 187            .iter()
 188            .find(|model| model.id == response.default_fast_model)
 189            .cloned();
 190        self.recommended_models = response
 191            .recommended_models
 192            .iter()
 193            .filter_map(|id| models.iter().find(|model| &model.id == id))
 194            .cloned()
 195            .collect();
 196        self.models = models;
 197        cx.notify();
 198    }
 199
 200    async fn fetch_models(
 201        client: Arc<Client>,
 202        llm_api_token: LlmApiToken,
 203    ) -> Result<ListModelsResponse> {
 204        let http_client = &client.http_client();
 205        let token = llm_api_token.acquire(&client).await?;
 206
 207        let request = http_client::Request::builder()
 208            .method(Method::GET)
 209            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 210            .header("Authorization", format!("Bearer {token}"))
 211            .body(AsyncBody::empty())?;
 212        let mut response = http_client
 213            .send(request)
 214            .await
 215            .context("failed to send list models request")?;
 216
 217        if response.status().is_success() {
 218            let mut body = String::new();
 219            response.body_mut().read_to_string(&mut body).await?;
 220            Ok(serde_json::from_str(&body)?)
 221        } else {
 222            let mut body = String::new();
 223            response.body_mut().read_to_string(&mut body).await?;
 224            anyhow::bail!(
 225                "error listing models.\nStatus: {:?}\nBody: {body}",
 226                response.status(),
 227            );
 228        }
 229    }
 230}
 231
 232impl CloudLanguageModelProvider {
 233    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 234        let mut status_rx = client.status();
 235        let status = *status_rx.borrow();
 236
 237        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 238
 239        let state_ref = state.downgrade();
 240        let maintain_client_status = cx.spawn(async move |cx| {
 241            while let Some(status) = status_rx.next().await {
 242                if let Some(this) = state_ref.upgrade() {
 243                    _ = this.update(cx, |this, cx| {
 244                        if this.status != status {
 245                            this.status = status;
 246                            cx.notify();
 247                        }
 248                    });
 249                } else {
 250                    break;
 251                }
 252            }
 253        });
 254
 255        Self {
 256            client,
 257            state,
 258            _maintain_client_status: maintain_client_status,
 259        }
 260    }
 261
 262    fn create_language_model(
 263        &self,
 264        model: Arc<cloud_llm_client::LanguageModel>,
 265        llm_api_token: LlmApiToken,
 266    ) -> Arc<dyn LanguageModel> {
 267        Arc::new(CloudLanguageModel {
 268            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 269            model,
 270            llm_api_token,
 271            client: self.client.clone(),
 272            request_limiter: RateLimiter::new(4),
 273        })
 274    }
 275}
 276
 277impl LanguageModelProviderState for CloudLanguageModelProvider {
 278    type ObservableEntity = State;
 279
 280    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 281        Some(self.state.clone())
 282    }
 283}
 284
 285impl LanguageModelProvider for CloudLanguageModelProvider {
 286    fn id(&self) -> LanguageModelProviderId {
 287        PROVIDER_ID
 288    }
 289
 290    fn name(&self) -> LanguageModelProviderName {
 291        PROVIDER_NAME
 292    }
 293
 294    fn icon(&self) -> IconName {
 295        IconName::AiZed
 296    }
 297
 298    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 299        let default_model = self.state.read(cx).default_model.clone()?;
 300        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 301        Some(self.create_language_model(default_model, llm_api_token))
 302    }
 303
 304    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 305        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 306        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 307        Some(self.create_language_model(default_fast_model, llm_api_token))
 308    }
 309
 310    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 311        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 312        self.state
 313            .read(cx)
 314            .recommended_models
 315            .iter()
 316            .cloned()
 317            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 318            .collect()
 319    }
 320
 321    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 322        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 323        self.state
 324            .read(cx)
 325            .models
 326            .iter()
 327            .cloned()
 328            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 329            .collect()
 330    }
 331
 332    fn is_authenticated(&self, cx: &App) -> bool {
 333        let state = self.state.read(cx);
 334        !state.is_signed_out(cx)
 335    }
 336
 337    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 338        Task::ready(Ok(()))
 339    }
 340
 341    fn configuration_view(
 342        &self,
 343        _target_agent: language_model::ConfigurationViewTargetAgent,
 344        _: &mut Window,
 345        cx: &mut App,
 346    ) -> AnyView {
 347        cx.new(|_| ConfigurationView::new(self.state.clone()))
 348            .into()
 349    }
 350
 351    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 352        Task::ready(Ok(()))
 353    }
 354}
 355
 356pub struct CloudLanguageModel {
 357    id: LanguageModelId,
 358    model: Arc<cloud_llm_client::LanguageModel>,
 359    llm_api_token: LlmApiToken,
 360    client: Arc<Client>,
 361    request_limiter: RateLimiter,
 362}
 363
 364struct PerformLlmCompletionResponse {
 365    response: Response<AsyncBody>,
 366    usage: Option<ModelRequestUsage>,
 367    tool_use_limit_reached: bool,
 368    includes_status_messages: bool,
 369}
 370
 371impl CloudLanguageModel {
 372    async fn perform_llm_completion(
 373        client: Arc<Client>,
 374        llm_api_token: LlmApiToken,
 375        app_version: Option<SemanticVersion>,
 376        body: CompletionBody,
 377    ) -> Result<PerformLlmCompletionResponse> {
 378        let http_client = &client.http_client();
 379
 380        let mut token = llm_api_token.acquire(&client).await?;
 381        let mut refreshed_token = false;
 382
 383        loop {
 384            let request_builder = http_client::Request::builder()
 385                .method(Method::POST)
 386                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 387            let request_builder = if let Some(app_version) = app_version {
 388                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 389            } else {
 390                request_builder
 391            };
 392
 393            let request = request_builder
 394                .header("Content-Type", "application/json")
 395                .header("Authorization", format!("Bearer {token}"))
 396                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 397                .body(serde_json::to_string(&body)?.into())?;
 398            let mut response = http_client.send(request).await?;
 399            let status = response.status();
 400            if status.is_success() {
 401                let includes_status_messages = response
 402                    .headers()
 403                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 404                    .is_some();
 405
 406                let tool_use_limit_reached = response
 407                    .headers()
 408                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 409                    .is_some();
 410
 411                let usage = if includes_status_messages {
 412                    None
 413                } else {
 414                    ModelRequestUsage::from_headers(response.headers()).ok()
 415                };
 416
 417                return Ok(PerformLlmCompletionResponse {
 418                    response,
 419                    usage,
 420                    includes_status_messages,
 421                    tool_use_limit_reached,
 422                });
 423            }
 424
 425            if !refreshed_token
 426                && response
 427                    .headers()
 428                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 429                    .is_some()
 430            {
 431                token = llm_api_token.refresh(&client).await?;
 432                refreshed_token = true;
 433                continue;
 434            }
 435
 436            if status == StatusCode::FORBIDDEN
 437                && response
 438                    .headers()
 439                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 440                    .is_some()
 441            {
 442                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 443                    .headers()
 444                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 445                    .and_then(|resource| resource.to_str().ok())
 446                    && let Some(plan) = response
 447                        .headers()
 448                        .get(CURRENT_PLAN_HEADER_NAME)
 449                        .and_then(|plan| plan.to_str().ok())
 450                        .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
 451                        .map(Plan::V1)
 452                {
 453                    return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 454                }
 455            } else if status == StatusCode::PAYMENT_REQUIRED {
 456                return Err(anyhow!(PaymentRequiredError));
 457            }
 458
 459            let mut body = String::new();
 460            let headers = response.headers().clone();
 461            response.body_mut().read_to_string(&mut body).await?;
 462            return Err(anyhow!(ApiError {
 463                status,
 464                body,
 465                headers
 466            }));
 467        }
 468    }
 469}
 470
 471#[derive(Debug, Error)]
 472#[error("cloud language model request failed with status {status}: {body}")]
 473struct ApiError {
 474    status: StatusCode,
 475    body: String,
 476    headers: HeaderMap<HeaderValue>,
 477}
 478
 479/// Represents error responses from Zed's cloud API.
 480///
 481/// Example JSON for an upstream HTTP error:
 482/// ```json
 483/// {
 484///   "code": "upstream_http_error",
 485///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 486///   "upstream_status": 503
 487/// }
 488/// ```
 489#[derive(Debug, serde::Deserialize)]
 490struct CloudApiError {
 491    code: String,
 492    message: String,
 493    #[serde(default)]
 494    #[serde(deserialize_with = "deserialize_optional_status_code")]
 495    upstream_status: Option<StatusCode>,
 496    #[serde(default)]
 497    retry_after: Option<f64>,
 498}
 499
 500fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 501where
 502    D: serde::Deserializer<'de>,
 503{
 504    let opt: Option<u16> = Option::deserialize(deserializer)?;
 505    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 506}
 507
 508impl From<ApiError> for LanguageModelCompletionError {
 509    fn from(error: ApiError) -> Self {
 510        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body)
 511            && cloud_error.code.starts_with("upstream_http_")
 512        {
 513            let status = if let Some(status) = cloud_error.upstream_status {
 514                status
 515            } else if cloud_error.code.ends_with("_error") {
 516                error.status
 517            } else {
 518                // If there's a status code in the code string (e.g. "upstream_http_429")
 519                // then use that; otherwise, see if the JSON contains a status code.
 520                cloud_error
 521                    .code
 522                    .strip_prefix("upstream_http_")
 523                    .and_then(|code_str| code_str.parse::<u16>().ok())
 524                    .and_then(|code| StatusCode::from_u16(code).ok())
 525                    .unwrap_or(error.status)
 526            };
 527
 528            return LanguageModelCompletionError::UpstreamProviderError {
 529                message: cloud_error.message,
 530                status,
 531                retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 532            };
 533        }
 534
 535        let retry_after = None;
 536        LanguageModelCompletionError::from_http_status(
 537            PROVIDER_NAME,
 538            error.status,
 539            error.body,
 540            retry_after,
 541        )
 542    }
 543}
 544
 545impl LanguageModel for CloudLanguageModel {
 546    fn id(&self) -> LanguageModelId {
 547        self.id.clone()
 548    }
 549
 550    fn name(&self) -> LanguageModelName {
 551        LanguageModelName::from(self.model.display_name.clone())
 552    }
 553
 554    fn provider_id(&self) -> LanguageModelProviderId {
 555        PROVIDER_ID
 556    }
 557
 558    fn provider_name(&self) -> LanguageModelProviderName {
 559        PROVIDER_NAME
 560    }
 561
 562    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 563        use cloud_llm_client::LanguageModelProvider::*;
 564        match self.model.provider {
 565            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 566            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 567            Google => language_model::GOOGLE_PROVIDER_ID,
 568        }
 569    }
 570
 571    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 572        use cloud_llm_client::LanguageModelProvider::*;
 573        match self.model.provider {
 574            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 575            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 576            Google => language_model::GOOGLE_PROVIDER_NAME,
 577        }
 578    }
 579
 580    fn supports_tools(&self) -> bool {
 581        self.model.supports_tools
 582    }
 583
 584    fn supports_images(&self) -> bool {
 585        self.model.supports_images
 586    }
 587
 588    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 589        match choice {
 590            LanguageModelToolChoice::Auto
 591            | LanguageModelToolChoice::Any
 592            | LanguageModelToolChoice::None => true,
 593        }
 594    }
 595
 596    fn supports_burn_mode(&self) -> bool {
 597        self.model.supports_max_mode
 598    }
 599
 600    fn telemetry_id(&self) -> String {
 601        format!("zed.dev/{}", self.model.id)
 602    }
 603
 604    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 605        match self.model.provider {
 606            cloud_llm_client::LanguageModelProvider::Anthropic
 607            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 608                LanguageModelToolSchemaFormat::JsonSchema
 609            }
 610            cloud_llm_client::LanguageModelProvider::Google => {
 611                LanguageModelToolSchemaFormat::JsonSchemaSubset
 612            }
 613        }
 614    }
 615
 616    fn max_token_count(&self) -> u64 {
 617        self.model.max_token_count as u64
 618    }
 619
 620    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 621        self.model
 622            .max_token_count_in_max_mode
 623            .filter(|_| self.model.supports_max_mode)
 624            .map(|max_token_count| max_token_count as u64)
 625    }
 626
 627    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 628        match &self.model.provider {
 629            cloud_llm_client::LanguageModelProvider::Anthropic => {
 630                Some(LanguageModelCacheConfiguration {
 631                    min_total_token: 2_048,
 632                    should_speculate: true,
 633                    max_cache_anchors: 4,
 634                })
 635            }
 636            cloud_llm_client::LanguageModelProvider::OpenAi
 637            | cloud_llm_client::LanguageModelProvider::Google => None,
 638        }
 639    }
 640
 641    fn count_tokens(
 642        &self,
 643        request: LanguageModelRequest,
 644        cx: &App,
 645    ) -> BoxFuture<'static, Result<u64>> {
 646        match self.model.provider {
 647            cloud_llm_client::LanguageModelProvider::Anthropic => {
 648                count_anthropic_tokens(request, cx)
 649            }
 650            cloud_llm_client::LanguageModelProvider::OpenAi => {
 651                let model = match open_ai::Model::from_id(&self.model.id.0) {
 652                    Ok(model) => model,
 653                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 654                };
 655                count_open_ai_tokens(request, model, cx)
 656            }
 657            cloud_llm_client::LanguageModelProvider::Google => {
 658                let client = self.client.clone();
 659                let llm_api_token = self.llm_api_token.clone();
 660                let model_id = self.model.id.to_string();
 661                let generate_content_request =
 662                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 663                async move {
 664                    let http_client = &client.http_client();
 665                    let token = llm_api_token.acquire(&client).await?;
 666
 667                    let request_body = CountTokensBody {
 668                        provider: cloud_llm_client::LanguageModelProvider::Google,
 669                        model: model_id,
 670                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 671                            generate_content_request,
 672                        })?,
 673                    };
 674                    let request = http_client::Request::builder()
 675                        .method(Method::POST)
 676                        .uri(
 677                            http_client
 678                                .build_zed_llm_url("/count_tokens", &[])?
 679                                .as_ref(),
 680                        )
 681                        .header("Content-Type", "application/json")
 682                        .header("Authorization", format!("Bearer {token}"))
 683                        .body(serde_json::to_string(&request_body)?.into())?;
 684                    let mut response = http_client.send(request).await?;
 685                    let status = response.status();
 686                    let headers = response.headers().clone();
 687                    let mut response_body = String::new();
 688                    response
 689                        .body_mut()
 690                        .read_to_string(&mut response_body)
 691                        .await?;
 692
 693                    if status.is_success() {
 694                        let response_body: CountTokensResponse =
 695                            serde_json::from_str(&response_body)?;
 696
 697                        Ok(response_body.tokens as u64)
 698                    } else {
 699                        Err(anyhow!(ApiError {
 700                            status,
 701                            body: response_body,
 702                            headers
 703                        }))
 704                    }
 705                }
 706                .boxed()
 707            }
 708        }
 709    }
 710
 711    fn stream_completion(
 712        &self,
 713        request: LanguageModelRequest,
 714        cx: &AsyncApp,
 715    ) -> BoxFuture<
 716        'static,
 717        Result<
 718            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 719            LanguageModelCompletionError,
 720        >,
 721    > {
 722        let thread_id = request.thread_id.clone();
 723        let prompt_id = request.prompt_id.clone();
 724        let intent = request.intent;
 725        let mode = request.mode;
 726        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 727        let thinking_allowed = request.thinking_allowed;
 728        match self.model.provider {
 729            cloud_llm_client::LanguageModelProvider::Anthropic => {
 730                let request = into_anthropic(
 731                    request,
 732                    self.model.id.to_string(),
 733                    1.0,
 734                    self.model.max_output_tokens as u64,
 735                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 736                        AnthropicModelMode::Thinking {
 737                            budget_tokens: Some(4_096),
 738                        }
 739                    } else {
 740                        AnthropicModelMode::Default
 741                    },
 742                );
 743                let client = self.client.clone();
 744                let llm_api_token = self.llm_api_token.clone();
 745                let future = self.request_limiter.stream(async move {
 746                    let PerformLlmCompletionResponse {
 747                        response,
 748                        usage,
 749                        includes_status_messages,
 750                        tool_use_limit_reached,
 751                    } = Self::perform_llm_completion(
 752                        client.clone(),
 753                        llm_api_token,
 754                        app_version,
 755                        CompletionBody {
 756                            thread_id,
 757                            prompt_id,
 758                            intent,
 759                            mode,
 760                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 761                            model: request.model.clone(),
 762                            provider_request: serde_json::to_value(&request)
 763                                .map_err(|e| anyhow!(e))?,
 764                        },
 765                    )
 766                    .await
 767                    .map_err(|err| match err.downcast::<ApiError>() {
 768                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 769                        Err(err) => anyhow!(err),
 770                    })?;
 771
 772                    let mut mapper = AnthropicEventMapper::new();
 773                    Ok(map_cloud_completion_events(
 774                        Box::pin(
 775                            response_lines(response, includes_status_messages)
 776                                .chain(usage_updated_event(usage))
 777                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 778                        ),
 779                        move |event| mapper.map_event(event),
 780                    ))
 781                });
 782                async move { Ok(future.await?.boxed()) }.boxed()
 783            }
 784            cloud_llm_client::LanguageModelProvider::OpenAi => {
 785                let client = self.client.clone();
 786                let model = match open_ai::Model::from_id(&self.model.id.0) {
 787                    Ok(model) => model,
 788                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 789                };
 790                let request = into_open_ai(
 791                    request,
 792                    model.id(),
 793                    model.supports_parallel_tool_calls(),
 794                    model.supports_prompt_cache_key(),
 795                    None,
 796                    None,
 797                );
 798                let llm_api_token = self.llm_api_token.clone();
 799                let future = self.request_limiter.stream(async move {
 800                    let PerformLlmCompletionResponse {
 801                        response,
 802                        usage,
 803                        includes_status_messages,
 804                        tool_use_limit_reached,
 805                    } = Self::perform_llm_completion(
 806                        client.clone(),
 807                        llm_api_token,
 808                        app_version,
 809                        CompletionBody {
 810                            thread_id,
 811                            prompt_id,
 812                            intent,
 813                            mode,
 814                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 815                            model: request.model.clone(),
 816                            provider_request: serde_json::to_value(&request)
 817                                .map_err(|e| anyhow!(e))?,
 818                        },
 819                    )
 820                    .await?;
 821
 822                    let mut mapper = OpenAiEventMapper::new();
 823                    Ok(map_cloud_completion_events(
 824                        Box::pin(
 825                            response_lines(response, includes_status_messages)
 826                                .chain(usage_updated_event(usage))
 827                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 828                        ),
 829                        move |event| mapper.map_event(event),
 830                    ))
 831                });
 832                async move { Ok(future.await?.boxed()) }.boxed()
 833            }
 834            cloud_llm_client::LanguageModelProvider::Google => {
 835                let client = self.client.clone();
 836                let request =
 837                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 838                let llm_api_token = self.llm_api_token.clone();
 839                let future = self.request_limiter.stream(async move {
 840                    let PerformLlmCompletionResponse {
 841                        response,
 842                        usage,
 843                        includes_status_messages,
 844                        tool_use_limit_reached,
 845                    } = Self::perform_llm_completion(
 846                        client.clone(),
 847                        llm_api_token,
 848                        app_version,
 849                        CompletionBody {
 850                            thread_id,
 851                            prompt_id,
 852                            intent,
 853                            mode,
 854                            provider: cloud_llm_client::LanguageModelProvider::Google,
 855                            model: request.model.model_id.clone(),
 856                            provider_request: serde_json::to_value(&request)
 857                                .map_err(|e| anyhow!(e))?,
 858                        },
 859                    )
 860                    .await?;
 861
 862                    let mut mapper = GoogleEventMapper::new();
 863                    Ok(map_cloud_completion_events(
 864                        Box::pin(
 865                            response_lines(response, includes_status_messages)
 866                                .chain(usage_updated_event(usage))
 867                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 868                        ),
 869                        move |event| mapper.map_event(event),
 870                    ))
 871                });
 872                async move { Ok(future.await?.boxed()) }.boxed()
 873            }
 874        }
 875    }
 876}
 877
 878fn map_cloud_completion_events<T, F>(
 879    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 880    mut map_callback: F,
 881) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 882where
 883    T: DeserializeOwned + 'static,
 884    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 885        + Send
 886        + 'static,
 887{
 888    stream
 889        .flat_map(move |event| {
 890            futures::stream::iter(match event {
 891                Err(error) => {
 892                    vec![Err(LanguageModelCompletionError::from(error))]
 893                }
 894                Ok(CompletionEvent::Status(event)) => {
 895                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
 896                }
 897                Ok(CompletionEvent::Event(event)) => map_callback(event),
 898            })
 899        })
 900        .boxed()
 901}
 902
 903fn usage_updated_event<T>(
 904    usage: Option<ModelRequestUsage>,
 905) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 906    futures::stream::iter(usage.map(|usage| {
 907        Ok(CompletionEvent::Status(
 908            CompletionRequestStatus::UsageUpdated {
 909                amount: usage.amount as usize,
 910                limit: usage.limit,
 911            },
 912        ))
 913    }))
 914}
 915
 916fn tool_use_limit_reached_event<T>(
 917    tool_use_limit_reached: bool,
 918) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 919    futures::stream::iter(tool_use_limit_reached.then(|| {
 920        Ok(CompletionEvent::Status(
 921            CompletionRequestStatus::ToolUseLimitReached,
 922        ))
 923    }))
 924}
 925
 926fn response_lines<T: DeserializeOwned>(
 927    response: Response<AsyncBody>,
 928    includes_status_messages: bool,
 929) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 930    futures::stream::try_unfold(
 931        (String::new(), BufReader::new(response.into_body())),
 932        move |(mut line, mut body)| async move {
 933            match body.read_line(&mut line).await {
 934                Ok(0) => Ok(None),
 935                Ok(_) => {
 936                    let event = if includes_status_messages {
 937                        serde_json::from_str::<CompletionEvent<T>>(&line)?
 938                    } else {
 939                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 940                    };
 941
 942                    line.clear();
 943                    Ok(Some((event, (line, body))))
 944                }
 945                Err(e) => Err(e.into()),
 946            }
 947        },
 948    )
 949}
 950
 951#[derive(IntoElement, RegisterComponent)]
 952struct ZedAiConfiguration {
 953    is_connected: bool,
 954    plan: Option<Plan>,
 955    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
 956    eligible_for_trial: bool,
 957    account_too_young: bool,
 958    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
 959}
 960
 961impl RenderOnce for ZedAiConfiguration {
 962    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
 963        let young_account_banner = YoungAccountBanner;
 964
 965        let is_pro = self.plan.is_some_and(|plan| {
 966            matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
 967        });
 968        let subscription_text = match (self.plan, self.subscription_period) {
 969            (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
 970                "You have access to Zed's hosted models through your Pro subscription."
 971            }
 972            (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
 973                "You have access to Zed's hosted models through your Pro trial."
 974            }
 975            (Some(Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree)), Some(_)) => {
 976                "You have basic access to Zed's hosted models through the Free plan."
 977            }
 978            _ => {
 979                if self.eligible_for_trial {
 980                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
 981                } else {
 982                    "Subscribe for access to Zed's hosted models."
 983                }
 984            }
 985        };
 986
 987        let manage_subscription_buttons = if is_pro {
 988            Button::new("manage_settings", "Manage Subscription")
 989                .full_width()
 990                .style(ButtonStyle::Tinted(TintColor::Accent))
 991                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
 992                .into_any_element()
 993        } else if self.plan.is_none() || self.eligible_for_trial {
 994            Button::new("start_trial", "Start 14-day Free Pro Trial")
 995                .full_width()
 996                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
 997                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
 998                .into_any_element()
 999        } else {
1000            Button::new("upgrade", "Upgrade to Pro")
1001                .full_width()
1002                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1003                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1004                .into_any_element()
1005        };
1006
1007        if !self.is_connected {
1008            return v_flex()
1009                .gap_2()
1010                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1011                .child(
1012                    Button::new("sign_in", "Sign In to use Zed AI")
1013                        .icon_color(Color::Muted)
1014                        .icon(IconName::Github)
1015                        .icon_size(IconSize::Small)
1016                        .icon_position(IconPosition::Start)
1017                        .full_width()
1018                        .on_click({
1019                            let callback = self.sign_in_callback.clone();
1020                            move |_, window, cx| (callback)(window, cx)
1021                        }),
1022                );
1023        }
1024
1025        v_flex().gap_2().w_full().map(|this| {
1026            if self.account_too_young {
1027                this.child(young_account_banner).child(
1028                    Button::new("upgrade", "Upgrade to Pro")
1029                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1030                        .full_width()
1031                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1032                )
1033            } else {
1034                this.text_sm()
1035                    .child(subscription_text)
1036                    .child(manage_subscription_buttons)
1037            }
1038        })
1039    }
1040}
1041
1042struct ConfigurationView {
1043    state: Entity<State>,
1044    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1045}
1046
1047impl ConfigurationView {
1048    fn new(state: Entity<State>) -> Self {
1049        let sign_in_callback = Arc::new({
1050            let state = state.clone();
1051            move |_window: &mut Window, cx: &mut App| {
1052                state.update(cx, |state, cx| {
1053                    state.authenticate(cx).detach_and_log_err(cx);
1054                });
1055            }
1056        });
1057
1058        Self {
1059            state,
1060            sign_in_callback,
1061        }
1062    }
1063}
1064
1065impl Render for ConfigurationView {
1066    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1067        let state = self.state.read(cx);
1068        let user_store = state.user_store.read(cx);
1069
1070        ZedAiConfiguration {
1071            is_connected: !state.is_signed_out(cx),
1072            plan: user_store.plan(),
1073            subscription_period: user_store.subscription_period(),
1074            eligible_for_trial: user_store.trial_started_at().is_none(),
1075            account_too_young: user_store.account_too_young(),
1076            sign_in_callback: self.sign_in_callback.clone(),
1077        }
1078    }
1079}
1080
1081impl Component for ZedAiConfiguration {
1082    fn name() -> &'static str {
1083        "AI Configuration Content"
1084    }
1085
1086    fn sort_name() -> &'static str {
1087        "AI Configuration Content"
1088    }
1089
1090    fn scope() -> ComponentScope {
1091        ComponentScope::Onboarding
1092    }
1093
1094    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1095        fn configuration(
1096            is_connected: bool,
1097            plan: Option<Plan>,
1098            eligible_for_trial: bool,
1099            account_too_young: bool,
1100        ) -> AnyElement {
1101            ZedAiConfiguration {
1102                is_connected,
1103                plan,
1104                subscription_period: plan
1105                    .is_some()
1106                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1107                eligible_for_trial,
1108                account_too_young,
1109                sign_in_callback: Arc::new(|_, _| {}),
1110            }
1111            .into_any_element()
1112        }
1113
1114        Some(
1115            v_flex()
1116                .p_4()
1117                .gap_4()
1118                .children(vec![
1119                    single_example("Not connected", configuration(false, None, false, false)),
1120                    single_example(
1121                        "Accept Terms of Service",
1122                        configuration(true, None, true, false),
1123                    ),
1124                    single_example(
1125                        "No Plan - Not eligible for trial",
1126                        configuration(true, None, false, false),
1127                    ),
1128                    single_example(
1129                        "No Plan - Eligible for trial",
1130                        configuration(true, None, true, false),
1131                    ),
1132                    single_example(
1133                        "Free Plan",
1134                        configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1135                    ),
1136                    single_example(
1137                        "Zed Pro Trial Plan",
1138                        configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1139                    ),
1140                    single_example(
1141                        "Zed Pro Plan",
1142                        configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1143                    ),
1144                ])
1145                .into_any_element(),
1146        )
1147    }
1148}
1149
1150#[cfg(test)]
1151mod tests {
1152    use super::*;
1153    use http_client::http::{HeaderMap, StatusCode};
1154    use language_model::LanguageModelCompletionError;
1155
1156    #[test]
1157    fn test_api_error_conversion_with_upstream_http_error() {
1158        // upstream_http_error with 503 status should become ServerOverloaded
1159        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1160
1161        let api_error = ApiError {
1162            status: StatusCode::INTERNAL_SERVER_ERROR,
1163            body: error_body.to_string(),
1164            headers: HeaderMap::new(),
1165        };
1166
1167        let completion_error: LanguageModelCompletionError = api_error.into();
1168
1169        match completion_error {
1170            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1171                assert_eq!(
1172                    message,
1173                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1174                );
1175            }
1176            _ => panic!(
1177                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1178                completion_error
1179            ),
1180        }
1181
1182        // upstream_http_error with 500 status should become ApiInternalServerError
1183        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1184
1185        let api_error = ApiError {
1186            status: StatusCode::INTERNAL_SERVER_ERROR,
1187            body: error_body.to_string(),
1188            headers: HeaderMap::new(),
1189        };
1190
1191        let completion_error: LanguageModelCompletionError = api_error.into();
1192
1193        match completion_error {
1194            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1195                assert_eq!(
1196                    message,
1197                    "Received an error from the OpenAI API: internal server error"
1198                );
1199            }
1200            _ => panic!(
1201                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1202                completion_error
1203            ),
1204        }
1205
1206        // upstream_http_error with 429 status should become RateLimitExceeded
1207        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1208
1209        let api_error = ApiError {
1210            status: StatusCode::INTERNAL_SERVER_ERROR,
1211            body: error_body.to_string(),
1212            headers: HeaderMap::new(),
1213        };
1214
1215        let completion_error: LanguageModelCompletionError = api_error.into();
1216
1217        match completion_error {
1218            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1219                assert_eq!(
1220                    message,
1221                    "Received an error from the Google API: rate limit exceeded"
1222                );
1223            }
1224            _ => panic!(
1225                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1226                completion_error
1227            ),
1228        }
1229
1230        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1231        let error_body = "Regular internal server error";
1232
1233        let api_error = ApiError {
1234            status: StatusCode::INTERNAL_SERVER_ERROR,
1235            body: error_body.to_string(),
1236            headers: HeaderMap::new(),
1237        };
1238
1239        let completion_error: LanguageModelCompletionError = api_error.into();
1240
1241        match completion_error {
1242            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1243                assert_eq!(provider, PROVIDER_NAME);
1244                assert_eq!(message, "Regular internal server error");
1245            }
1246            _ => panic!(
1247                "Expected ApiInternalServerError for regular 500, got: {:?}",
1248                completion_error
1249            ),
1250        }
1251
1252        // upstream_http_429 format should be converted to UpstreamProviderError
1253        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1254
1255        let api_error = ApiError {
1256            status: StatusCode::INTERNAL_SERVER_ERROR,
1257            body: error_body.to_string(),
1258            headers: HeaderMap::new(),
1259        };
1260
1261        let completion_error: LanguageModelCompletionError = api_error.into();
1262
1263        match completion_error {
1264            LanguageModelCompletionError::UpstreamProviderError {
1265                message,
1266                status,
1267                retry_after,
1268            } => {
1269                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1270                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1271                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1272            }
1273            _ => panic!(
1274                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1275                completion_error
1276            ),
1277        }
1278
1279        // Invalid JSON in error body should fall back to regular error handling
1280        let error_body = "Not JSON at all";
1281
1282        let api_error = ApiError {
1283            status: StatusCode::INTERNAL_SERVER_ERROR,
1284            body: error_body.to_string(),
1285            headers: HeaderMap::new(),
1286        };
1287
1288        let completion_error: LanguageModelCompletionError = api_error.into();
1289
1290        match completion_error {
1291            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1292                assert_eq!(provider, PROVIDER_NAME);
1293            }
1294            _ => panic!(
1295                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1296                completion_error
1297            ),
1298        }
1299    }
1300}