cloud.rs

   1use ai_onboarding::YoungAccountBanner;
   2use anthropic::AnthropicModelMode;
   3use anyhow::{Context as _, Result, anyhow};
   4use chrono::{DateTime, Utc};
   5use client::{Client, CloudUserStore, ModelRequestUsage, UserStore, zed_urls};
   6use cloud_llm_client::{
   7    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
   8    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
   9    EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
  10    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  11    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  12};
  13use futures::{
  14    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
  15};
  16use google_ai::GoogleModelMode;
  17use gpui::{
  18    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  19};
  20use http_client::http::{HeaderMap, HeaderValue};
  21use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  22use language_model::{
  23    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  24    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  25    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  26    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
  27    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
  28    ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  29};
  30use release_channel::AppVersion;
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize, de::DeserializeOwned};
  33use settings::SettingsStore;
  34use smol::io::{AsyncReadExt, BufReader};
  35use std::pin::Pin;
  36use std::str::FromStr as _;
  37use std::sync::Arc;
  38use std::time::Duration;
  39use thiserror::Error;
  40use ui::{TintColor, prelude::*};
  41use util::{ResultExt as _, maybe};
  42
  43use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  44use crate::provider::google::{GoogleEventMapper, into_google};
  45use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  46
  47const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  48const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  49
  50#[derive(Default, Clone, Debug, PartialEq)]
  51pub struct ZedDotDevSettings {
  52    pub available_models: Vec<AvailableModel>,
  53}
  54
  55#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  56#[serde(rename_all = "lowercase")]
  57pub enum AvailableProvider {
  58    Anthropic,
  59    OpenAi,
  60    Google,
  61}
  62
  63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  64pub struct AvailableModel {
  65    /// The provider of the language model.
  66    pub provider: AvailableProvider,
  67    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  68    pub name: String,
  69    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  70    pub display_name: Option<String>,
  71    /// The size of the context window, indicating the maximum number of tokens the model can process.
  72    pub max_tokens: usize,
  73    /// The maximum number of output tokens allowed by the model.
  74    pub max_output_tokens: Option<u64>,
  75    /// The maximum number of completion tokens allowed by the model (o1-* only)
  76    pub max_completion_tokens: Option<u64>,
  77    /// Override this model with a different Anthropic model for tool calls.
  78    pub tool_override: Option<String>,
  79    /// Indicates whether this custom model supports caching.
  80    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  81    /// The default temperature to use for this model.
  82    pub default_temperature: Option<f32>,
  83    /// Any extra beta headers to provide when using the model.
  84    #[serde(default)]
  85    pub extra_beta_headers: Vec<String>,
  86    /// The model's mode (e.g. thinking)
  87    pub mode: Option<ModelMode>,
  88}
  89
  90#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  91#[serde(tag = "type", rename_all = "lowercase")]
  92pub enum ModelMode {
  93    #[default]
  94    Default,
  95    Thinking {
  96        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  97        budget_tokens: Option<u32>,
  98    },
  99}
 100
 101impl From<ModelMode> for AnthropicModelMode {
 102    fn from(value: ModelMode) -> Self {
 103        match value {
 104            ModelMode::Default => AnthropicModelMode::Default,
 105            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 106        }
 107    }
 108}
 109
 110pub struct CloudLanguageModelProvider {
 111    client: Arc<Client>,
 112    state: gpui::Entity<State>,
 113    _maintain_client_status: Task<()>,
 114}
 115
 116pub struct State {
 117    client: Arc<Client>,
 118    llm_api_token: LlmApiToken,
 119    user_store: Entity<UserStore>,
 120    cloud_user_store: Entity<CloudUserStore>,
 121    status: client::Status,
 122    accept_terms_of_service_task: Option<Task<Result<()>>>,
 123    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 124    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 125    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 126    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 127    _fetch_models_task: Task<()>,
 128    _settings_subscription: Subscription,
 129    _llm_token_subscription: Subscription,
 130}
 131
 132impl State {
 133    fn new(
 134        client: Arc<Client>,
 135        user_store: Entity<UserStore>,
 136        cloud_user_store: Entity<CloudUserStore>,
 137        status: client::Status,
 138        cx: &mut Context<Self>,
 139    ) -> Self {
 140        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 141
 142        Self {
 143            client: client.clone(),
 144            llm_api_token: LlmApiToken::default(),
 145            user_store,
 146            cloud_user_store,
 147            status,
 148            accept_terms_of_service_task: None,
 149            models: Vec::new(),
 150            default_model: None,
 151            default_fast_model: None,
 152            recommended_models: Vec::new(),
 153            _fetch_models_task: cx.spawn(async move |this, cx| {
 154                maybe!(async move {
 155                    let (client, cloud_user_store, llm_api_token) =
 156                        this.read_with(cx, |this, _cx| {
 157                            (
 158                                client.clone(),
 159                                this.cloud_user_store.clone(),
 160                                this.llm_api_token.clone(),
 161                            )
 162                        })?;
 163
 164                    loop {
 165                        let is_authenticated =
 166                            cloud_user_store.read_with(cx, |this, _cx| this.is_authenticated())?;
 167                        if is_authenticated {
 168                            break;
 169                        }
 170
 171                        cx.background_executor()
 172                            .timer(Duration::from_millis(100))
 173                            .await;
 174                    }
 175
 176                    let response = Self::fetch_models(client, llm_api_token).await?;
 177                    this.update(cx, |this, cx| {
 178                        this.update_models(response, cx);
 179                    })
 180                })
 181                .await
 182                .context("failed to fetch Zed models")
 183                .log_err();
 184            }),
 185            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 186                cx.notify();
 187            }),
 188            _llm_token_subscription: cx.subscribe(
 189                &refresh_llm_token_listener,
 190                move |this, _listener, _event, cx| {
 191                    let client = this.client.clone();
 192                    let llm_api_token = this.llm_api_token.clone();
 193                    cx.spawn(async move |this, cx| {
 194                        llm_api_token.refresh(&client).await?;
 195                        let response = Self::fetch_models(client, llm_api_token).await?;
 196                        this.update(cx, |this, cx| {
 197                            this.update_models(response, cx);
 198                        })
 199                    })
 200                    .detach_and_log_err(cx);
 201                },
 202            ),
 203        }
 204    }
 205
 206    fn is_signed_out(&self, cx: &App) -> bool {
 207        !self.cloud_user_store.read(cx).is_authenticated()
 208    }
 209
 210    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 211        let client = self.client.clone();
 212        cx.spawn(async move |state, cx| {
 213            client
 214                .authenticate_and_connect(true, &cx)
 215                .await
 216                .into_response()?;
 217            state.update(cx, |_, cx| cx.notify())
 218        })
 219    }
 220
 221    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 222        self.cloud_user_store.read(cx).has_accepted_tos()
 223    }
 224
 225    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 226        let user_store = self.user_store.clone();
 227        self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
 228            let _ = user_store
 229                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 230                .await;
 231            this.update(cx, |this, cx| {
 232                this.accept_terms_of_service_task = None;
 233                cx.notify()
 234            })
 235        }));
 236    }
 237
 238    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 239        let mut models = Vec::new();
 240
 241        for model in response.models {
 242            models.push(Arc::new(model.clone()));
 243
 244            // Right now we represent thinking variants of models as separate models on the client,
 245            // so we need to insert variants for any model that supports thinking.
 246            if model.supports_thinking {
 247                models.push(Arc::new(cloud_llm_client::LanguageModel {
 248                    id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 249                    display_name: format!("{} Thinking", model.display_name),
 250                    ..model
 251                }));
 252            }
 253        }
 254
 255        self.default_model = models
 256            .iter()
 257            .find(|model| model.id == response.default_model)
 258            .cloned();
 259        self.default_fast_model = models
 260            .iter()
 261            .find(|model| model.id == response.default_fast_model)
 262            .cloned();
 263        self.recommended_models = response
 264            .recommended_models
 265            .iter()
 266            .filter_map(|id| models.iter().find(|model| &model.id == id))
 267            .cloned()
 268            .collect();
 269        self.models = models;
 270        cx.notify();
 271    }
 272
 273    async fn fetch_models(
 274        client: Arc<Client>,
 275        llm_api_token: LlmApiToken,
 276    ) -> Result<ListModelsResponse> {
 277        let http_client = &client.http_client();
 278        let token = llm_api_token.acquire(&client).await?;
 279
 280        let request = http_client::Request::builder()
 281            .method(Method::GET)
 282            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 283            .header("Authorization", format!("Bearer {token}"))
 284            .body(AsyncBody::empty())?;
 285        let mut response = http_client
 286            .send(request)
 287            .await
 288            .context("failed to send list models request")?;
 289
 290        if response.status().is_success() {
 291            let mut body = String::new();
 292            response.body_mut().read_to_string(&mut body).await?;
 293            return Ok(serde_json::from_str(&body)?);
 294        } else {
 295            let mut body = String::new();
 296            response.body_mut().read_to_string(&mut body).await?;
 297            anyhow::bail!(
 298                "error listing models.\nStatus: {:?}\nBody: {body}",
 299                response.status(),
 300            );
 301        }
 302    }
 303}
 304
 305impl CloudLanguageModelProvider {
 306    pub fn new(
 307        user_store: Entity<UserStore>,
 308        cloud_user_store: Entity<CloudUserStore>,
 309        client: Arc<Client>,
 310        cx: &mut App,
 311    ) -> Self {
 312        let mut status_rx = client.status();
 313        let status = *status_rx.borrow();
 314
 315        let state = cx.new(|cx| {
 316            State::new(
 317                client.clone(),
 318                user_store.clone(),
 319                cloud_user_store.clone(),
 320                status,
 321                cx,
 322            )
 323        });
 324
 325        let state_ref = state.downgrade();
 326        let maintain_client_status = cx.spawn(async move |cx| {
 327            while let Some(status) = status_rx.next().await {
 328                if let Some(this) = state_ref.upgrade() {
 329                    _ = this.update(cx, |this, cx| {
 330                        if this.status != status {
 331                            this.status = status;
 332                            cx.notify();
 333                        }
 334                    });
 335                } else {
 336                    break;
 337                }
 338            }
 339        });
 340
 341        Self {
 342            client,
 343            state: state.clone(),
 344            _maintain_client_status: maintain_client_status,
 345        }
 346    }
 347
 348    fn create_language_model(
 349        &self,
 350        model: Arc<cloud_llm_client::LanguageModel>,
 351        llm_api_token: LlmApiToken,
 352    ) -> Arc<dyn LanguageModel> {
 353        Arc::new(CloudLanguageModel {
 354            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 355            model,
 356            llm_api_token: llm_api_token.clone(),
 357            client: self.client.clone(),
 358            request_limiter: RateLimiter::new(4),
 359        })
 360    }
 361}
 362
 363impl LanguageModelProviderState for CloudLanguageModelProvider {
 364    type ObservableEntity = State;
 365
 366    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 367        Some(self.state.clone())
 368    }
 369}
 370
 371impl LanguageModelProvider for CloudLanguageModelProvider {
 372    fn id(&self) -> LanguageModelProviderId {
 373        PROVIDER_ID
 374    }
 375
 376    fn name(&self) -> LanguageModelProviderName {
 377        PROVIDER_NAME
 378    }
 379
 380    fn icon(&self) -> IconName {
 381        IconName::AiZed
 382    }
 383
 384    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 385        let default_model = self.state.read(cx).default_model.clone()?;
 386        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 387        Some(self.create_language_model(default_model, llm_api_token))
 388    }
 389
 390    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 391        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 392        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 393        Some(self.create_language_model(default_fast_model, llm_api_token))
 394    }
 395
 396    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 397        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 398        self.state
 399            .read(cx)
 400            .recommended_models
 401            .iter()
 402            .cloned()
 403            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 404            .collect()
 405    }
 406
 407    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 408        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 409        self.state
 410            .read(cx)
 411            .models
 412            .iter()
 413            .cloned()
 414            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 415            .collect()
 416    }
 417
 418    fn is_authenticated(&self, cx: &App) -> bool {
 419        let state = self.state.read(cx);
 420        !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx)
 421    }
 422
 423    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 424        Task::ready(Ok(()))
 425    }
 426
 427    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 428        cx.new(|_| ConfigurationView::new(self.state.clone()))
 429            .into()
 430    }
 431
 432    fn must_accept_terms(&self, cx: &App) -> bool {
 433        !self.state.read(cx).has_accepted_terms_of_service(cx)
 434    }
 435
 436    fn render_accept_terms(
 437        &self,
 438        view: LanguageModelProviderTosView,
 439        cx: &mut App,
 440    ) -> Option<AnyElement> {
 441        let state = self.state.read(cx);
 442        if state.has_accepted_terms_of_service(cx) {
 443            return None;
 444        }
 445        Some(
 446            render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
 447                let state = self.state.clone();
 448                move |_window, cx| {
 449                    state.update(cx, |state, cx| state.accept_terms_of_service(cx));
 450                }
 451            })
 452            .into_any_element(),
 453        )
 454    }
 455
 456    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 457        Task::ready(Ok(()))
 458    }
 459}
 460
 461fn render_accept_terms(
 462    view_kind: LanguageModelProviderTosView,
 463    accept_terms_of_service_in_progress: bool,
 464    accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
 465) -> impl IntoElement {
 466    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 467    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
 468
 469    let terms_button = Button::new("terms_of_service", "Terms of Service")
 470        .style(ButtonStyle::Subtle)
 471        .icon(IconName::ArrowUpRight)
 472        .icon_color(Color::Muted)
 473        .icon_size(IconSize::XSmall)
 474        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 475        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 476
 477    let button_container = h_flex().child(
 478        Button::new("accept_terms", "I accept the Terms of Service")
 479            .when(!thread_empty_state, |this| {
 480                this.full_width()
 481                    .style(ButtonStyle::Tinted(TintColor::Accent))
 482                    .icon(IconName::Check)
 483                    .icon_position(IconPosition::Start)
 484                    .icon_size(IconSize::Small)
 485            })
 486            .when(thread_empty_state, |this| {
 487                this.style(ButtonStyle::Tinted(TintColor::Warning))
 488                    .label_size(LabelSize::Small)
 489            })
 490            .disabled(accept_terms_of_service_in_progress)
 491            .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
 492    );
 493
 494    if thread_empty_state {
 495        h_flex()
 496            .w_full()
 497            .flex_wrap()
 498            .justify_between()
 499            .child(
 500                h_flex()
 501                    .child(
 502                        Label::new("To start using Zed AI, please read and accept the")
 503                            .size(LabelSize::Small),
 504                    )
 505                    .child(terms_button),
 506            )
 507            .child(button_container)
 508    } else {
 509        v_flex()
 510            .w_full()
 511            .gap_2()
 512            .child(
 513                h_flex()
 514                    .flex_wrap()
 515                    .when(thread_fresh_start, |this| this.justify_center())
 516                    .child(Label::new(
 517                        "To start using Zed AI, please read and accept the",
 518                    ))
 519                    .child(terms_button),
 520            )
 521            .child({
 522                match view_kind {
 523                    LanguageModelProviderTosView::TextThreadPopup => {
 524                        button_container.w_full().justify_end()
 525                    }
 526                    LanguageModelProviderTosView::Configuration => {
 527                        button_container.w_full().justify_start()
 528                    }
 529                    LanguageModelProviderTosView::ThreadFreshStart => {
 530                        button_container.w_full().justify_center()
 531                    }
 532                    LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
 533                }
 534            })
 535    }
 536}
 537
 538pub struct CloudLanguageModel {
 539    id: LanguageModelId,
 540    model: Arc<cloud_llm_client::LanguageModel>,
 541    llm_api_token: LlmApiToken,
 542    client: Arc<Client>,
 543    request_limiter: RateLimiter,
 544}
 545
 546struct PerformLlmCompletionResponse {
 547    response: Response<AsyncBody>,
 548    usage: Option<ModelRequestUsage>,
 549    tool_use_limit_reached: bool,
 550    includes_status_messages: bool,
 551}
 552
 553impl CloudLanguageModel {
 554    async fn perform_llm_completion(
 555        client: Arc<Client>,
 556        llm_api_token: LlmApiToken,
 557        app_version: Option<SemanticVersion>,
 558        body: CompletionBody,
 559    ) -> Result<PerformLlmCompletionResponse> {
 560        let http_client = &client.http_client();
 561
 562        let mut token = llm_api_token.acquire(&client).await?;
 563        let mut refreshed_token = false;
 564
 565        loop {
 566            let request_builder = http_client::Request::builder()
 567                .method(Method::POST)
 568                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 569            let request_builder = if let Some(app_version) = app_version {
 570                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 571            } else {
 572                request_builder
 573            };
 574
 575            let request = request_builder
 576                .header("Content-Type", "application/json")
 577                .header("Authorization", format!("Bearer {token}"))
 578                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 579                .body(serde_json::to_string(&body)?.into())?;
 580            let mut response = http_client.send(request).await?;
 581            let status = response.status();
 582            if status.is_success() {
 583                let includes_status_messages = response
 584                    .headers()
 585                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 586                    .is_some();
 587
 588                let tool_use_limit_reached = response
 589                    .headers()
 590                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 591                    .is_some();
 592
 593                let usage = if includes_status_messages {
 594                    None
 595                } else {
 596                    ModelRequestUsage::from_headers(response.headers()).ok()
 597                };
 598
 599                return Ok(PerformLlmCompletionResponse {
 600                    response,
 601                    usage,
 602                    includes_status_messages,
 603                    tool_use_limit_reached,
 604                });
 605            }
 606
 607            if !refreshed_token
 608                && response
 609                    .headers()
 610                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 611                    .is_some()
 612            {
 613                token = llm_api_token.refresh(&client).await?;
 614                refreshed_token = true;
 615                continue;
 616            }
 617
 618            if status == StatusCode::FORBIDDEN
 619                && response
 620                    .headers()
 621                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 622                    .is_some()
 623            {
 624                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 625                    .headers()
 626                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 627                    .and_then(|resource| resource.to_str().ok())
 628                {
 629                    if let Some(plan) = response
 630                        .headers()
 631                        .get(CURRENT_PLAN_HEADER_NAME)
 632                        .and_then(|plan| plan.to_str().ok())
 633                        .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok())
 634                    {
 635                        let plan = match plan {
 636                            cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
 637                            cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
 638                            cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
 639                        };
 640                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 641                    }
 642                }
 643            } else if status == StatusCode::PAYMENT_REQUIRED {
 644                return Err(anyhow!(PaymentRequiredError));
 645            }
 646
 647            let mut body = String::new();
 648            let headers = response.headers().clone();
 649            response.body_mut().read_to_string(&mut body).await?;
 650            return Err(anyhow!(ApiError {
 651                status,
 652                body,
 653                headers
 654            }));
 655        }
 656    }
 657}
 658
 659#[derive(Debug, Error)]
 660#[error("cloud language model request failed with status {status}: {body}")]
 661struct ApiError {
 662    status: StatusCode,
 663    body: String,
 664    headers: HeaderMap<HeaderValue>,
 665}
 666
 667/// Represents error responses from Zed's cloud API.
 668///
 669/// Example JSON for an upstream HTTP error:
 670/// ```json
 671/// {
 672///   "code": "upstream_http_error",
 673///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 674///   "upstream_status": 503
 675/// }
 676/// ```
 677#[derive(Debug, serde::Deserialize)]
 678struct CloudApiError {
 679    code: String,
 680    message: String,
 681    #[serde(default)]
 682    #[serde(deserialize_with = "deserialize_optional_status_code")]
 683    upstream_status: Option<StatusCode>,
 684    #[serde(default)]
 685    retry_after: Option<f64>,
 686}
 687
 688fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 689where
 690    D: serde::Deserializer<'de>,
 691{
 692    let opt: Option<u16> = Option::deserialize(deserializer)?;
 693    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 694}
 695
 696impl From<ApiError> for LanguageModelCompletionError {
 697    fn from(error: ApiError) -> Self {
 698        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 699            if cloud_error.code.starts_with("upstream_http_") {
 700                let status = if let Some(status) = cloud_error.upstream_status {
 701                    status
 702                } else if cloud_error.code.ends_with("_error") {
 703                    error.status
 704                } else {
 705                    // If there's a status code in the code string (e.g. "upstream_http_429")
 706                    // then use that; otherwise, see if the JSON contains a status code.
 707                    cloud_error
 708                        .code
 709                        .strip_prefix("upstream_http_")
 710                        .and_then(|code_str| code_str.parse::<u16>().ok())
 711                        .and_then(|code| StatusCode::from_u16(code).ok())
 712                        .unwrap_or(error.status)
 713                };
 714
 715                return LanguageModelCompletionError::UpstreamProviderError {
 716                    message: cloud_error.message,
 717                    status,
 718                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 719                };
 720            }
 721        }
 722
 723        let retry_after = None;
 724        LanguageModelCompletionError::from_http_status(
 725            PROVIDER_NAME,
 726            error.status,
 727            error.body,
 728            retry_after,
 729        )
 730    }
 731}
 732
 733impl LanguageModel for CloudLanguageModel {
 734    fn id(&self) -> LanguageModelId {
 735        self.id.clone()
 736    }
 737
 738    fn name(&self) -> LanguageModelName {
 739        LanguageModelName::from(self.model.display_name.clone())
 740    }
 741
 742    fn provider_id(&self) -> LanguageModelProviderId {
 743        PROVIDER_ID
 744    }
 745
 746    fn provider_name(&self) -> LanguageModelProviderName {
 747        PROVIDER_NAME
 748    }
 749
 750    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 751        use cloud_llm_client::LanguageModelProvider::*;
 752        match self.model.provider {
 753            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 754            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 755            Google => language_model::GOOGLE_PROVIDER_ID,
 756        }
 757    }
 758
 759    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 760        use cloud_llm_client::LanguageModelProvider::*;
 761        match self.model.provider {
 762            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 763            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 764            Google => language_model::GOOGLE_PROVIDER_NAME,
 765        }
 766    }
 767
 768    fn supports_tools(&self) -> bool {
 769        self.model.supports_tools
 770    }
 771
 772    fn supports_images(&self) -> bool {
 773        self.model.supports_images
 774    }
 775
 776    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 777        match choice {
 778            LanguageModelToolChoice::Auto
 779            | LanguageModelToolChoice::Any
 780            | LanguageModelToolChoice::None => true,
 781        }
 782    }
 783
 784    fn supports_burn_mode(&self) -> bool {
 785        self.model.supports_max_mode
 786    }
 787
 788    fn telemetry_id(&self) -> String {
 789        format!("zed.dev/{}", self.model.id)
 790    }
 791
 792    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 793        match self.model.provider {
 794            cloud_llm_client::LanguageModelProvider::Anthropic
 795            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 796                LanguageModelToolSchemaFormat::JsonSchema
 797            }
 798            cloud_llm_client::LanguageModelProvider::Google => {
 799                LanguageModelToolSchemaFormat::JsonSchemaSubset
 800            }
 801        }
 802    }
 803
 804    fn max_token_count(&self) -> u64 {
 805        self.model.max_token_count as u64
 806    }
 807
 808    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 809        self.model
 810            .max_token_count_in_max_mode
 811            .filter(|_| self.model.supports_max_mode)
 812            .map(|max_token_count| max_token_count as u64)
 813    }
 814
 815    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 816        match &self.model.provider {
 817            cloud_llm_client::LanguageModelProvider::Anthropic => {
 818                Some(LanguageModelCacheConfiguration {
 819                    min_total_token: 2_048,
 820                    should_speculate: true,
 821                    max_cache_anchors: 4,
 822                })
 823            }
 824            cloud_llm_client::LanguageModelProvider::OpenAi
 825            | cloud_llm_client::LanguageModelProvider::Google => None,
 826        }
 827    }
 828
 829    fn count_tokens(
 830        &self,
 831        request: LanguageModelRequest,
 832        cx: &App,
 833    ) -> BoxFuture<'static, Result<u64>> {
 834        match self.model.provider {
 835            cloud_llm_client::LanguageModelProvider::Anthropic => {
 836                count_anthropic_tokens(request, cx)
 837            }
 838            cloud_llm_client::LanguageModelProvider::OpenAi => {
 839                let model = match open_ai::Model::from_id(&self.model.id.0) {
 840                    Ok(model) => model,
 841                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 842                };
 843                count_open_ai_tokens(request, model, cx)
 844            }
 845            cloud_llm_client::LanguageModelProvider::Google => {
 846                let client = self.client.clone();
 847                let llm_api_token = self.llm_api_token.clone();
 848                let model_id = self.model.id.to_string();
 849                let generate_content_request =
 850                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 851                async move {
 852                    let http_client = &client.http_client();
 853                    let token = llm_api_token.acquire(&client).await?;
 854
 855                    let request_body = CountTokensBody {
 856                        provider: cloud_llm_client::LanguageModelProvider::Google,
 857                        model: model_id,
 858                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 859                            generate_content_request,
 860                        })?,
 861                    };
 862                    let request = http_client::Request::builder()
 863                        .method(Method::POST)
 864                        .uri(
 865                            http_client
 866                                .build_zed_llm_url("/count_tokens", &[])?
 867                                .as_ref(),
 868                        )
 869                        .header("Content-Type", "application/json")
 870                        .header("Authorization", format!("Bearer {token}"))
 871                        .body(serde_json::to_string(&request_body)?.into())?;
 872                    let mut response = http_client.send(request).await?;
 873                    let status = response.status();
 874                    let headers = response.headers().clone();
 875                    let mut response_body = String::new();
 876                    response
 877                        .body_mut()
 878                        .read_to_string(&mut response_body)
 879                        .await?;
 880
 881                    if status.is_success() {
 882                        let response_body: CountTokensResponse =
 883                            serde_json::from_str(&response_body)?;
 884
 885                        Ok(response_body.tokens as u64)
 886                    } else {
 887                        Err(anyhow!(ApiError {
 888                            status,
 889                            body: response_body,
 890                            headers
 891                        }))
 892                    }
 893                }
 894                .boxed()
 895            }
 896        }
 897    }
 898
 899    fn stream_completion(
 900        &self,
 901        request: LanguageModelRequest,
 902        cx: &AsyncApp,
 903    ) -> BoxFuture<
 904        'static,
 905        Result<
 906            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 907            LanguageModelCompletionError,
 908        >,
 909    > {
 910        let thread_id = request.thread_id.clone();
 911        let prompt_id = request.prompt_id.clone();
 912        let intent = request.intent;
 913        let mode = request.mode;
 914        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 915        let thinking_allowed = request.thinking_allowed;
 916        match self.model.provider {
 917            cloud_llm_client::LanguageModelProvider::Anthropic => {
 918                let request = into_anthropic(
 919                    request,
 920                    self.model.id.to_string(),
 921                    1.0,
 922                    self.model.max_output_tokens as u64,
 923                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 924                        AnthropicModelMode::Thinking {
 925                            budget_tokens: Some(4_096),
 926                        }
 927                    } else {
 928                        AnthropicModelMode::Default
 929                    },
 930                );
 931                let client = self.client.clone();
 932                let llm_api_token = self.llm_api_token.clone();
 933                let future = self.request_limiter.stream(async move {
 934                    let PerformLlmCompletionResponse {
 935                        response,
 936                        usage,
 937                        includes_status_messages,
 938                        tool_use_limit_reached,
 939                    } = Self::perform_llm_completion(
 940                        client.clone(),
 941                        llm_api_token,
 942                        app_version,
 943                        CompletionBody {
 944                            thread_id,
 945                            prompt_id,
 946                            intent,
 947                            mode,
 948                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 949                            model: request.model.clone(),
 950                            provider_request: serde_json::to_value(&request)
 951                                .map_err(|e| anyhow!(e))?,
 952                        },
 953                    )
 954                    .await
 955                    .map_err(|err| match err.downcast::<ApiError>() {
 956                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 957                        Err(err) => anyhow!(err),
 958                    })?;
 959
 960                    let mut mapper = AnthropicEventMapper::new();
 961                    Ok(map_cloud_completion_events(
 962                        Box::pin(
 963                            response_lines(response, includes_status_messages)
 964                                .chain(usage_updated_event(usage))
 965                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 966                        ),
 967                        move |event| mapper.map_event(event),
 968                    ))
 969                });
 970                async move { Ok(future.await?.boxed()) }.boxed()
 971            }
 972            cloud_llm_client::LanguageModelProvider::OpenAi => {
 973                let client = self.client.clone();
 974                let model = match open_ai::Model::from_id(&self.model.id.0) {
 975                    Ok(model) => model,
 976                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 977                };
 978                let request = into_open_ai(
 979                    request,
 980                    model.id(),
 981                    model.supports_parallel_tool_calls(),
 982                    None,
 983                );
 984                let llm_api_token = self.llm_api_token.clone();
 985                let future = self.request_limiter.stream(async move {
 986                    let PerformLlmCompletionResponse {
 987                        response,
 988                        usage,
 989                        includes_status_messages,
 990                        tool_use_limit_reached,
 991                    } = Self::perform_llm_completion(
 992                        client.clone(),
 993                        llm_api_token,
 994                        app_version,
 995                        CompletionBody {
 996                            thread_id,
 997                            prompt_id,
 998                            intent,
 999                            mode,
1000                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
1001                            model: request.model.clone(),
1002                            provider_request: serde_json::to_value(&request)
1003                                .map_err(|e| anyhow!(e))?,
1004                        },
1005                    )
1006                    .await?;
1007
1008                    let mut mapper = OpenAiEventMapper::new();
1009                    Ok(map_cloud_completion_events(
1010                        Box::pin(
1011                            response_lines(response, includes_status_messages)
1012                                .chain(usage_updated_event(usage))
1013                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
1014                        ),
1015                        move |event| mapper.map_event(event),
1016                    ))
1017                });
1018                async move { Ok(future.await?.boxed()) }.boxed()
1019            }
1020            cloud_llm_client::LanguageModelProvider::Google => {
1021                let client = self.client.clone();
1022                let request =
1023                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
1024                let llm_api_token = self.llm_api_token.clone();
1025                let future = self.request_limiter.stream(async move {
1026                    let PerformLlmCompletionResponse {
1027                        response,
1028                        usage,
1029                        includes_status_messages,
1030                        tool_use_limit_reached,
1031                    } = Self::perform_llm_completion(
1032                        client.clone(),
1033                        llm_api_token,
1034                        app_version,
1035                        CompletionBody {
1036                            thread_id,
1037                            prompt_id,
1038                            intent,
1039                            mode,
1040                            provider: cloud_llm_client::LanguageModelProvider::Google,
1041                            model: request.model.model_id.clone(),
1042                            provider_request: serde_json::to_value(&request)
1043                                .map_err(|e| anyhow!(e))?,
1044                        },
1045                    )
1046                    .await?;
1047
1048                    let mut mapper = GoogleEventMapper::new();
1049                    Ok(map_cloud_completion_events(
1050                        Box::pin(
1051                            response_lines(response, includes_status_messages)
1052                                .chain(usage_updated_event(usage))
1053                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
1054                        ),
1055                        move |event| mapper.map_event(event),
1056                    ))
1057                });
1058                async move { Ok(future.await?.boxed()) }.boxed()
1059            }
1060        }
1061    }
1062}
1063
1064fn map_cloud_completion_events<T, F>(
1065    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
1066    mut map_callback: F,
1067) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1068where
1069    T: DeserializeOwned + 'static,
1070    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1071        + Send
1072        + 'static,
1073{
1074    stream
1075        .flat_map(move |event| {
1076            futures::stream::iter(match event {
1077                Err(error) => {
1078                    vec![Err(LanguageModelCompletionError::from(error))]
1079                }
1080                Ok(CompletionEvent::Status(event)) => {
1081                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1082                }
1083                Ok(CompletionEvent::Event(event)) => map_callback(event),
1084            })
1085        })
1086        .boxed()
1087}
1088
1089fn usage_updated_event<T>(
1090    usage: Option<ModelRequestUsage>,
1091) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1092    futures::stream::iter(usage.map(|usage| {
1093        Ok(CompletionEvent::Status(
1094            CompletionRequestStatus::UsageUpdated {
1095                amount: usage.amount as usize,
1096                limit: usage.limit,
1097            },
1098        ))
1099    }))
1100}
1101
1102fn tool_use_limit_reached_event<T>(
1103    tool_use_limit_reached: bool,
1104) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1105    futures::stream::iter(tool_use_limit_reached.then(|| {
1106        Ok(CompletionEvent::Status(
1107            CompletionRequestStatus::ToolUseLimitReached,
1108        ))
1109    }))
1110}
1111
1112fn response_lines<T: DeserializeOwned>(
1113    response: Response<AsyncBody>,
1114    includes_status_messages: bool,
1115) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1116    futures::stream::try_unfold(
1117        (String::new(), BufReader::new(response.into_body())),
1118        move |(mut line, mut body)| async move {
1119            match body.read_line(&mut line).await {
1120                Ok(0) => Ok(None),
1121                Ok(_) => {
1122                    let event = if includes_status_messages {
1123                        serde_json::from_str::<CompletionEvent<T>>(&line)?
1124                    } else {
1125                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1126                    };
1127
1128                    line.clear();
1129                    Ok(Some((event, (line, body))))
1130                }
1131                Err(e) => Err(e.into()),
1132            }
1133        },
1134    )
1135}
1136
1137#[derive(IntoElement, RegisterComponent)]
1138struct ZedAiConfiguration {
1139    is_connected: bool,
1140    plan: Option<Plan>,
1141    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1142    eligible_for_trial: bool,
1143    has_accepted_terms_of_service: bool,
1144    account_too_young: bool,
1145    accept_terms_of_service_in_progress: bool,
1146    accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1147    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1148}
1149
1150impl RenderOnce for ZedAiConfiguration {
1151    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1152        let young_account_banner = YoungAccountBanner;
1153
1154        let is_pro = self.plan == Some(Plan::ZedPro);
1155        let subscription_text = match (self.plan, self.subscription_period) {
1156            (Some(Plan::ZedPro), Some(_)) => {
1157                "You have access to Zed's hosted models through your Pro subscription."
1158            }
1159            (Some(Plan::ZedProTrial), Some(_)) => {
1160                "You have access to Zed's hosted models through your Pro trial."
1161            }
1162            (Some(Plan::ZedFree), Some(_)) => {
1163                "You have basic access to Zed's hosted models through the Free plan."
1164            }
1165            _ => {
1166                if self.eligible_for_trial {
1167                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1168                } else {
1169                    "Subscribe for access to Zed's hosted models."
1170                }
1171            }
1172        };
1173
1174        let manage_subscription_buttons = if is_pro {
1175            Button::new("manage_settings", "Manage Subscription")
1176                .full_width()
1177                .style(ButtonStyle::Tinted(TintColor::Accent))
1178                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1179                .into_any_element()
1180        } else if self.plan.is_none() || self.eligible_for_trial {
1181            Button::new("start_trial", "Start 14-day Free Pro Trial")
1182                .full_width()
1183                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1184                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1185                .into_any_element()
1186        } else {
1187            Button::new("upgrade", "Upgrade to Pro")
1188                .full_width()
1189                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1190                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1191                .into_any_element()
1192        };
1193
1194        if !self.is_connected {
1195            return v_flex()
1196                .gap_2()
1197                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1198                .child(
1199                    Button::new("sign_in", "Sign In to use Zed AI")
1200                        .icon_color(Color::Muted)
1201                        .icon(IconName::Github)
1202                        .icon_size(IconSize::Small)
1203                        .icon_position(IconPosition::Start)
1204                        .full_width()
1205                        .on_click({
1206                            let callback = self.sign_in_callback.clone();
1207                            move |_, window, cx| (callback)(window, cx)
1208                        }),
1209                );
1210        }
1211
1212        v_flex()
1213            .gap_2()
1214            .w_full()
1215            .when(!self.has_accepted_terms_of_service, |this| {
1216                this.child(render_accept_terms(
1217                    LanguageModelProviderTosView::Configuration,
1218                    self.accept_terms_of_service_in_progress,
1219                    {
1220                        let callback = self.accept_terms_of_service_callback.clone();
1221                        move |window, cx| (callback)(window, cx)
1222                    },
1223                ))
1224            })
1225            .map(|this| {
1226                if self.has_accepted_terms_of_service && self.account_too_young {
1227                    this.child(young_account_banner).child(
1228                        Button::new("upgrade", "Upgrade to Pro")
1229                            .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1230                            .full_width()
1231                            .on_click(|_, _, cx| {
1232                                cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
1233                            }),
1234                    )
1235                } else if self.has_accepted_terms_of_service {
1236                    this.text_sm()
1237                        .child(subscription_text)
1238                        .child(manage_subscription_buttons)
1239                } else {
1240                    this
1241                }
1242            })
1243            .when(self.has_accepted_terms_of_service, |this| this)
1244    }
1245}
1246
1247struct ConfigurationView {
1248    state: Entity<State>,
1249    accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1250    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1251}
1252
1253impl ConfigurationView {
1254    fn new(state: Entity<State>) -> Self {
1255        let accept_terms_of_service_callback = Arc::new({
1256            let state = state.clone();
1257            move |_window: &mut Window, cx: &mut App| {
1258                state.update(cx, |state, cx| {
1259                    state.accept_terms_of_service(cx);
1260                });
1261            }
1262        });
1263
1264        let sign_in_callback = Arc::new({
1265            let state = state.clone();
1266            move |_window: &mut Window, cx: &mut App| {
1267                state.update(cx, |state, cx| {
1268                    state.authenticate(cx).detach_and_log_err(cx);
1269                });
1270            }
1271        });
1272
1273        Self {
1274            state,
1275            accept_terms_of_service_callback,
1276            sign_in_callback,
1277        }
1278    }
1279}
1280
1281impl Render for ConfigurationView {
1282    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1283        let state = self.state.read(cx);
1284        let cloud_user_store = state.cloud_user_store.read(cx);
1285
1286        ZedAiConfiguration {
1287            is_connected: !state.is_signed_out(cx),
1288            plan: cloud_user_store.plan(),
1289            subscription_period: cloud_user_store.subscription_period(),
1290            eligible_for_trial: cloud_user_store.trial_started_at().is_none(),
1291            has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1292            account_too_young: cloud_user_store.account_too_young(),
1293            accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1294            accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1295            sign_in_callback: self.sign_in_callback.clone(),
1296        }
1297    }
1298}
1299
1300impl Component for ZedAiConfiguration {
1301    fn scope() -> ComponentScope {
1302        ComponentScope::Agent
1303    }
1304
1305    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1306        fn configuration(
1307            is_connected: bool,
1308            plan: Option<Plan>,
1309            eligible_for_trial: bool,
1310            account_too_young: bool,
1311            has_accepted_terms_of_service: bool,
1312        ) -> AnyElement {
1313            ZedAiConfiguration {
1314                is_connected,
1315                plan,
1316                subscription_period: plan
1317                    .is_some()
1318                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1319                eligible_for_trial,
1320                has_accepted_terms_of_service,
1321                account_too_young,
1322                accept_terms_of_service_in_progress: false,
1323                accept_terms_of_service_callback: Arc::new(|_, _| {}),
1324                sign_in_callback: Arc::new(|_, _| {}),
1325            }
1326            .into_any_element()
1327        }
1328
1329        Some(
1330            v_flex()
1331                .p_4()
1332                .gap_4()
1333                .children(vec![
1334                    single_example(
1335                        "Not connected",
1336                        configuration(false, None, false, false, true),
1337                    ),
1338                    single_example(
1339                        "Accept Terms of Service",
1340                        configuration(true, None, true, false, false),
1341                    ),
1342                    single_example(
1343                        "No Plan - Not eligible for trial",
1344                        configuration(true, None, false, false, true),
1345                    ),
1346                    single_example(
1347                        "No Plan - Eligible for trial",
1348                        configuration(true, None, true, false, true),
1349                    ),
1350                    single_example(
1351                        "Free Plan",
1352                        configuration(true, Some(Plan::ZedFree), true, false, true),
1353                    ),
1354                    single_example(
1355                        "Zed Pro Trial Plan",
1356                        configuration(true, Some(Plan::ZedProTrial), true, false, true),
1357                    ),
1358                    single_example(
1359                        "Zed Pro Plan",
1360                        configuration(true, Some(Plan::ZedPro), true, false, true),
1361                    ),
1362                ])
1363                .into_any_element(),
1364        )
1365    }
1366}
1367
1368#[cfg(test)]
1369mod tests {
1370    use super::*;
1371    use http_client::http::{HeaderMap, StatusCode};
1372    use language_model::LanguageModelCompletionError;
1373
1374    #[test]
1375    fn test_api_error_conversion_with_upstream_http_error() {
1376        // upstream_http_error with 503 status should become ServerOverloaded
1377        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1378
1379        let api_error = ApiError {
1380            status: StatusCode::INTERNAL_SERVER_ERROR,
1381            body: error_body.to_string(),
1382            headers: HeaderMap::new(),
1383        };
1384
1385        let completion_error: LanguageModelCompletionError = api_error.into();
1386
1387        match completion_error {
1388            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1389                assert_eq!(
1390                    message,
1391                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1392                );
1393            }
1394            _ => panic!(
1395                "Expected UpstreamProviderError for upstream 503, got: {:?}",
1396                completion_error
1397            ),
1398        }
1399
1400        // upstream_http_error with 500 status should become ApiInternalServerError
1401        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1402
1403        let api_error = ApiError {
1404            status: StatusCode::INTERNAL_SERVER_ERROR,
1405            body: error_body.to_string(),
1406            headers: HeaderMap::new(),
1407        };
1408
1409        let completion_error: LanguageModelCompletionError = api_error.into();
1410
1411        match completion_error {
1412            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1413                assert_eq!(
1414                    message,
1415                    "Received an error from the OpenAI API: internal server error"
1416                );
1417            }
1418            _ => panic!(
1419                "Expected UpstreamProviderError for upstream 500, got: {:?}",
1420                completion_error
1421            ),
1422        }
1423
1424        // upstream_http_error with 429 status should become RateLimitExceeded
1425        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1426
1427        let api_error = ApiError {
1428            status: StatusCode::INTERNAL_SERVER_ERROR,
1429            body: error_body.to_string(),
1430            headers: HeaderMap::new(),
1431        };
1432
1433        let completion_error: LanguageModelCompletionError = api_error.into();
1434
1435        match completion_error {
1436            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1437                assert_eq!(
1438                    message,
1439                    "Received an error from the Google API: rate limit exceeded"
1440                );
1441            }
1442            _ => panic!(
1443                "Expected UpstreamProviderError for upstream 429, got: {:?}",
1444                completion_error
1445            ),
1446        }
1447
1448        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1449        let error_body = "Regular internal server error";
1450
1451        let api_error = ApiError {
1452            status: StatusCode::INTERNAL_SERVER_ERROR,
1453            body: error_body.to_string(),
1454            headers: HeaderMap::new(),
1455        };
1456
1457        let completion_error: LanguageModelCompletionError = api_error.into();
1458
1459        match completion_error {
1460            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1461                assert_eq!(provider, PROVIDER_NAME);
1462                assert_eq!(message, "Regular internal server error");
1463            }
1464            _ => panic!(
1465                "Expected ApiInternalServerError for regular 500, got: {:?}",
1466                completion_error
1467            ),
1468        }
1469
1470        // upstream_http_429 format should be converted to UpstreamProviderError
1471        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1472
1473        let api_error = ApiError {
1474            status: StatusCode::INTERNAL_SERVER_ERROR,
1475            body: error_body.to_string(),
1476            headers: HeaderMap::new(),
1477        };
1478
1479        let completion_error: LanguageModelCompletionError = api_error.into();
1480
1481        match completion_error {
1482            LanguageModelCompletionError::UpstreamProviderError {
1483                message,
1484                status,
1485                retry_after,
1486            } => {
1487                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1488                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1489                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1490            }
1491            _ => panic!(
1492                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1493                completion_error
1494            ),
1495        }
1496
1497        // Invalid JSON in error body should fall back to regular error handling
1498        let error_body = "Not JSON at all";
1499
1500        let api_error = ApiError {
1501            status: StatusCode::INTERNAL_SERVER_ERROR,
1502            body: error_body.to_string(),
1503            headers: HeaderMap::new(),
1504        };
1505
1506        let completion_error: LanguageModelCompletionError = api_error.into();
1507
1508        match completion_error {
1509            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1510                assert_eq!(provider, PROVIDER_NAME);
1511            }
1512            _ => panic!(
1513                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1514                completion_error
1515            ),
1516        }
1517    }
1518}