cloud.rs

   1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
   2use anyhow::{Result, anyhow};
   3use client::{Client, UserStore, zed_urls};
   4use collections::BTreeMap;
   5use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
   6use futures::{
   7    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   8};
   9use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  10use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  11use language_model::{
  12    AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
  13    LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
  14    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  15    LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
  16    ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
  17};
  18use language_model::{
  19    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
  20    MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
  21};
  22use proto::Plan;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize, de::DeserializeOwned};
  25use settings::{Settings, SettingsStore};
  26use smol::Timer;
  27use smol::io::{AsyncReadExt, BufReader};
  28use std::pin::Pin;
  29use std::str::FromStr as _;
  30use std::{
  31    sync::{Arc, LazyLock},
  32    time::Duration,
  33};
  34use strum::IntoEnumIterator;
  35use thiserror::Error;
  36use ui::{TintColor, prelude::*};
  37use zed_llm_client::{
  38    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
  39    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  40    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
  41    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  42    TOOL_USE_LIMIT_REACHED_HEADER_NAME,
  43};
  44
  45use crate::AllLanguageModelSettings;
  46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  47use crate::provider::google::{GoogleEventMapper, into_google};
  48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  49
  50pub const PROVIDER_NAME: &str = "Zed";
  51
  52const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
  53    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
  54
  55fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
  56    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
  57        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
  58            .map(|json| serde_json::from_str(json).unwrap())
  59            .unwrap_or_default()
  60    });
  61    ADDITIONAL_MODELS.as_slice()
  62}
  63
  64#[derive(Default, Clone, Debug, PartialEq)]
  65pub struct ZedDotDevSettings {
  66    pub available_models: Vec<AvailableModel>,
  67}
  68
  69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  70#[serde(rename_all = "lowercase")]
  71pub enum AvailableProvider {
  72    Anthropic,
  73    OpenAi,
  74    Google,
  75}
  76
  77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  78pub struct AvailableModel {
  79    /// The provider of the language model.
  80    pub provider: AvailableProvider,
  81    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  82    pub name: String,
  83    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  84    pub display_name: Option<String>,
  85    /// The size of the context window, indicating the maximum number of tokens the model can process.
  86    pub max_tokens: usize,
  87    /// The maximum number of output tokens allowed by the model.
  88    pub max_output_tokens: Option<u32>,
  89    /// The maximum number of completion tokens allowed by the model (o1-* only)
  90    pub max_completion_tokens: Option<u32>,
  91    /// Override this model with a different Anthropic model for tool calls.
  92    pub tool_override: Option<String>,
  93    /// Indicates whether this custom model supports caching.
  94    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  95    /// The default temperature to use for this model.
  96    pub default_temperature: Option<f32>,
  97    /// Any extra beta headers to provide when using the model.
  98    #[serde(default)]
  99    pub extra_beta_headers: Vec<String>,
 100    /// The model's mode (e.g. thinking)
 101    pub mode: Option<ModelMode>,
 102}
 103
 104#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 105#[serde(tag = "type", rename_all = "lowercase")]
 106pub enum ModelMode {
 107    #[default]
 108    Default,
 109    Thinking {
 110        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 111        budget_tokens: Option<u32>,
 112    },
 113}
 114
 115impl From<ModelMode> for AnthropicModelMode {
 116    fn from(value: ModelMode) -> Self {
 117        match value {
 118            ModelMode::Default => AnthropicModelMode::Default,
 119            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 120        }
 121    }
 122}
 123
 124pub struct CloudLanguageModelProvider {
 125    client: Arc<Client>,
 126    state: gpui::Entity<State>,
 127    _maintain_client_status: Task<()>,
 128}
 129
 130pub struct State {
 131    client: Arc<Client>,
 132    llm_api_token: LlmApiToken,
 133    user_store: Entity<UserStore>,
 134    status: client::Status,
 135    accept_terms: Option<Task<Result<()>>>,
 136    _settings_subscription: Subscription,
 137    _llm_token_subscription: Subscription,
 138}
 139
 140impl State {
 141    fn new(
 142        client: Arc<Client>,
 143        user_store: Entity<UserStore>,
 144        status: client::Status,
 145        cx: &mut Context<Self>,
 146    ) -> Self {
 147        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 148
 149        Self {
 150            client: client.clone(),
 151            llm_api_token: LlmApiToken::default(),
 152            user_store,
 153            status,
 154            accept_terms: None,
 155            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 156                cx.notify();
 157            }),
 158            _llm_token_subscription: cx.subscribe(
 159                &refresh_llm_token_listener,
 160                |this, _listener, _event, cx| {
 161                    let client = this.client.clone();
 162                    let llm_api_token = this.llm_api_token.clone();
 163                    cx.spawn(async move |_this, _cx| {
 164                        llm_api_token.refresh(&client).await?;
 165                        anyhow::Ok(())
 166                    })
 167                    .detach_and_log_err(cx);
 168                },
 169            ),
 170        }
 171    }
 172
 173    fn is_signed_out(&self) -> bool {
 174        self.status.is_signed_out()
 175    }
 176
 177    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 178        let client = self.client.clone();
 179        cx.spawn(async move |this, cx| {
 180            client.authenticate_and_connect(true, &cx).await?;
 181            this.update(cx, |_, cx| cx.notify())
 182        })
 183    }
 184
 185    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 186        self.user_store
 187            .read(cx)
 188            .current_user_has_accepted_terms()
 189            .unwrap_or(false)
 190    }
 191
 192    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 193        let user_store = self.user_store.clone();
 194        self.accept_terms = Some(cx.spawn(async move |this, cx| {
 195            let _ = user_store
 196                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 197                .await;
 198            this.update(cx, |this, cx| {
 199                this.accept_terms = None;
 200                cx.notify()
 201            })
 202        }));
 203    }
 204}
 205
 206impl CloudLanguageModelProvider {
 207    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 208        let mut status_rx = client.status();
 209        let status = *status_rx.borrow();
 210
 211        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 212
 213        let state_ref = state.downgrade();
 214        let maintain_client_status = cx.spawn(async move |cx| {
 215            while let Some(status) = status_rx.next().await {
 216                if let Some(this) = state_ref.upgrade() {
 217                    _ = this.update(cx, |this, cx| {
 218                        if this.status != status {
 219                            this.status = status;
 220                            cx.notify();
 221                        }
 222                    });
 223                } else {
 224                    break;
 225                }
 226            }
 227        });
 228
 229        Self {
 230            client,
 231            state: state.clone(),
 232            _maintain_client_status: maintain_client_status,
 233        }
 234    }
 235
 236    fn create_language_model(
 237        &self,
 238        model: CloudModel,
 239        llm_api_token: LlmApiToken,
 240    ) -> Arc<dyn LanguageModel> {
 241        Arc::new(CloudLanguageModel {
 242            id: LanguageModelId::from(model.id().to_string()),
 243            model,
 244            llm_api_token: llm_api_token.clone(),
 245            client: self.client.clone(),
 246            request_limiter: RateLimiter::new(4),
 247        })
 248    }
 249}
 250
 251impl LanguageModelProviderState for CloudLanguageModelProvider {
 252    type ObservableEntity = State;
 253
 254    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 255        Some(self.state.clone())
 256    }
 257}
 258
 259impl LanguageModelProvider for CloudLanguageModelProvider {
 260    fn id(&self) -> LanguageModelProviderId {
 261        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 262    }
 263
 264    fn name(&self) -> LanguageModelProviderName {
 265        LanguageModelProviderName(PROVIDER_NAME.into())
 266    }
 267
 268    fn icon(&self) -> IconName {
 269        IconName::AiZed
 270    }
 271
 272    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 273        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 274        let model = CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet);
 275        Some(self.create_language_model(model, llm_api_token))
 276    }
 277
 278    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 279        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 280        let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet);
 281        Some(self.create_language_model(model, llm_api_token))
 282    }
 283
 284    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 285        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 286        [
 287            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
 288            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
 289        ]
 290        .into_iter()
 291        .map(|model| self.create_language_model(model, llm_api_token.clone()))
 292        .collect()
 293    }
 294
 295    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 296        let mut models = BTreeMap::default();
 297
 298        if cx.is_staff() {
 299            for model in anthropic::Model::iter() {
 300                if !matches!(model, anthropic::Model::Custom { .. }) {
 301                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
 302                }
 303            }
 304            for model in open_ai::Model::iter() {
 305                if !matches!(model, open_ai::Model::Custom { .. }) {
 306                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
 307                }
 308            }
 309            for model in google_ai::Model::iter() {
 310                if !matches!(model, google_ai::Model::Custom { .. }) {
 311                    models.insert(model.id().to_string(), CloudModel::Google(model));
 312                }
 313            }
 314        } else {
 315            models.insert(
 316                anthropic::Model::Claude3_5Sonnet.id().to_string(),
 317                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
 318            );
 319            models.insert(
 320                anthropic::Model::Claude3_7Sonnet.id().to_string(),
 321                CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
 322            );
 323            models.insert(
 324                anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
 325                CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
 326            );
 327        }
 328
 329        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
 330            zed_cloud_provider_additional_models()
 331        } else {
 332            &[]
 333        };
 334
 335        // Override with available models from settings
 336        for model in AllLanguageModelSettings::get_global(cx)
 337            .zed_dot_dev
 338            .available_models
 339            .iter()
 340            .chain(llm_closed_beta_models)
 341            .cloned()
 342        {
 343            let model = match model.provider {
 344                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
 345                    name: model.name.clone(),
 346                    display_name: model.display_name.clone(),
 347                    max_tokens: model.max_tokens,
 348                    tool_override: model.tool_override.clone(),
 349                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 350                        anthropic::AnthropicModelCacheConfiguration {
 351                            max_cache_anchors: config.max_cache_anchors,
 352                            should_speculate: config.should_speculate,
 353                            min_total_token: config.min_total_token,
 354                        }
 355                    }),
 356                    default_temperature: model.default_temperature,
 357                    max_output_tokens: model.max_output_tokens,
 358                    extra_beta_headers: model.extra_beta_headers.clone(),
 359                    mode: model.mode.unwrap_or_default().into(),
 360                }),
 361                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
 362                    name: model.name.clone(),
 363                    display_name: model.display_name.clone(),
 364                    max_tokens: model.max_tokens,
 365                    max_output_tokens: model.max_output_tokens,
 366                    max_completion_tokens: model.max_completion_tokens,
 367                }),
 368                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
 369                    name: model.name.clone(),
 370                    display_name: model.display_name.clone(),
 371                    max_tokens: model.max_tokens,
 372                }),
 373            };
 374            models.insert(model.id().to_string(), model.clone());
 375        }
 376
 377        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 378        models
 379            .into_values()
 380            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 381            .collect()
 382    }
 383
 384    fn is_authenticated(&self, cx: &App) -> bool {
 385        !self.state.read(cx).is_signed_out()
 386    }
 387
 388    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 389        Task::ready(Ok(()))
 390    }
 391
 392    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 393        cx.new(|_| ConfigurationView {
 394            state: self.state.clone(),
 395        })
 396        .into()
 397    }
 398
 399    fn must_accept_terms(&self, cx: &App) -> bool {
 400        !self.state.read(cx).has_accepted_terms_of_service(cx)
 401    }
 402
 403    fn render_accept_terms(
 404        &self,
 405        view: LanguageModelProviderTosView,
 406        cx: &mut App,
 407    ) -> Option<AnyElement> {
 408        render_accept_terms(self.state.clone(), view, cx)
 409    }
 410
 411    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 412        Task::ready(Ok(()))
 413    }
 414}
 415
 416fn render_accept_terms(
 417    state: Entity<State>,
 418    view_kind: LanguageModelProviderTosView,
 419    cx: &mut App,
 420) -> Option<AnyElement> {
 421    if state.read(cx).has_accepted_terms_of_service(cx) {
 422        return None;
 423    }
 424
 425    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
 426
 427    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 428    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
 429
 430    let terms_button = Button::new("terms_of_service", "Terms of Service")
 431        .style(ButtonStyle::Subtle)
 432        .icon(IconName::ArrowUpRight)
 433        .icon_color(Color::Muted)
 434        .icon_size(IconSize::XSmall)
 435        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 436        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 437
 438    let button_container = h_flex().child(
 439        Button::new("accept_terms", "I accept the Terms of Service")
 440            .when(!thread_empty_state, |this| {
 441                this.full_width()
 442                    .style(ButtonStyle::Tinted(TintColor::Accent))
 443                    .icon(IconName::Check)
 444                    .icon_position(IconPosition::Start)
 445                    .icon_size(IconSize::Small)
 446            })
 447            .when(thread_empty_state, |this| {
 448                this.style(ButtonStyle::Tinted(TintColor::Warning))
 449                    .label_size(LabelSize::Small)
 450            })
 451            .disabled(accept_terms_disabled)
 452            .on_click({
 453                let state = state.downgrade();
 454                move |_, _window, cx| {
 455                    state
 456                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
 457                        .ok();
 458                }
 459            }),
 460    );
 461
 462    let form = if thread_empty_state {
 463        h_flex()
 464            .w_full()
 465            .flex_wrap()
 466            .justify_between()
 467            .child(
 468                h_flex()
 469                    .child(
 470                        Label::new("To start using Zed AI, please read and accept the")
 471                            .size(LabelSize::Small),
 472                    )
 473                    .child(terms_button),
 474            )
 475            .child(button_container)
 476    } else {
 477        v_flex()
 478            .w_full()
 479            .gap_2()
 480            .child(
 481                h_flex()
 482                    .flex_wrap()
 483                    .when(thread_fresh_start, |this| this.justify_center())
 484                    .child(Label::new(
 485                        "To start using Zed AI, please read and accept the",
 486                    ))
 487                    .child(terms_button),
 488            )
 489            .child({
 490                match view_kind {
 491                    LanguageModelProviderTosView::PromptEditorPopup => {
 492                        button_container.w_full().justify_end()
 493                    }
 494                    LanguageModelProviderTosView::Configuration => {
 495                        button_container.w_full().justify_start()
 496                    }
 497                    LanguageModelProviderTosView::ThreadFreshStart => {
 498                        button_container.w_full().justify_center()
 499                    }
 500                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
 501                }
 502            })
 503    };
 504
 505    Some(form.into_any())
 506}
 507
 508pub struct CloudLanguageModel {
 509    id: LanguageModelId,
 510    model: CloudModel,
 511    llm_api_token: LlmApiToken,
 512    client: Arc<Client>,
 513    request_limiter: RateLimiter,
 514}
 515
 516struct PerformLlmCompletionResponse {
 517    response: Response<AsyncBody>,
 518    usage: Option<RequestUsage>,
 519    tool_use_limit_reached: bool,
 520    includes_status_messages: bool,
 521}
 522
 523impl CloudLanguageModel {
 524    const MAX_RETRIES: usize = 3;
 525
 526    async fn perform_llm_completion(
 527        client: Arc<Client>,
 528        llm_api_token: LlmApiToken,
 529        body: CompletionBody,
 530    ) -> Result<PerformLlmCompletionResponse> {
 531        let http_client = &client.http_client();
 532
 533        let mut token = llm_api_token.acquire(&client).await?;
 534        let mut retries_remaining = Self::MAX_RETRIES;
 535        let mut retry_delay = Duration::from_secs(1);
 536
 537        loop {
 538            let request_builder = http_client::Request::builder().method(Method::POST);
 539            let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
 540            {
 541                request_builder.uri(completions_url)
 542            } else {
 543                request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 544            };
 545            let request = request_builder
 546                .header("Content-Type", "application/json")
 547                .header("Authorization", format!("Bearer {token}"))
 548                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 549                .body(serde_json::to_string(&body)?.into())?;
 550            let mut response = http_client.send(request).await?;
 551            let status = response.status();
 552            if status.is_success() {
 553                let includes_status_messages = response
 554                    .headers()
 555                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 556                    .is_some();
 557
 558                let tool_use_limit_reached = response
 559                    .headers()
 560                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 561                    .is_some();
 562
 563                let usage = if includes_status_messages {
 564                    None
 565                } else {
 566                    RequestUsage::from_headers(response.headers()).ok()
 567                };
 568
 569                return Ok(PerformLlmCompletionResponse {
 570                    response,
 571                    usage,
 572                    includes_status_messages,
 573                    tool_use_limit_reached,
 574                });
 575            } else if response
 576                .headers()
 577                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 578                .is_some()
 579            {
 580                retries_remaining -= 1;
 581                token = llm_api_token.refresh(&client).await?;
 582            } else if status == StatusCode::FORBIDDEN
 583                && response
 584                    .headers()
 585                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
 586                    .is_some()
 587            {
 588                return Err(anyhow!(MaxMonthlySpendReachedError));
 589            } else if status == StatusCode::FORBIDDEN
 590                && response
 591                    .headers()
 592                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 593                    .is_some()
 594            {
 595                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 596                    .headers()
 597                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 598                    .and_then(|resource| resource.to_str().ok())
 599                {
 600                    if let Some(plan) = response
 601                        .headers()
 602                        .get(CURRENT_PLAN_HEADER_NAME)
 603                        .and_then(|plan| plan.to_str().ok())
 604                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 605                    {
 606                        let plan = match plan {
 607                            zed_llm_client::Plan::Free => Plan::Free,
 608                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 609                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 610                        };
 611                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 612                    }
 613                }
 614
 615                return Err(anyhow!("Forbidden"));
 616            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
 617                // If we encounter an error in the 500 range, retry after a delay.
 618                // We've seen at least these in the wild from API providers:
 619                // * 500 Internal Server Error
 620                // * 502 Bad Gateway
 621                // * 529 Service Overloaded
 622
 623                if retries_remaining == 0 {
 624                    let mut body = String::new();
 625                    response.body_mut().read_to_string(&mut body).await?;
 626                    return Err(anyhow!(
 627                        "cloud language model completion failed after {} retries with status {status}: {body}",
 628                        Self::MAX_RETRIES
 629                    ));
 630                }
 631
 632                Timer::after(retry_delay).await;
 633
 634                retries_remaining -= 1;
 635                retry_delay *= 2; // If it fails again, wait longer.
 636            } else if status == StatusCode::PAYMENT_REQUIRED {
 637                return Err(anyhow!(PaymentRequiredError));
 638            } else {
 639                let mut body = String::new();
 640                response.body_mut().read_to_string(&mut body).await?;
 641                return Err(anyhow!(ApiError { status, body }));
 642            }
 643        }
 644    }
 645}
 646
 647#[derive(Debug, Error)]
 648#[error("cloud language model request failed with status {status}: {body}")]
 649struct ApiError {
 650    status: StatusCode,
 651    body: String,
 652}
 653
 654impl LanguageModel for CloudLanguageModel {
 655    fn id(&self) -> LanguageModelId {
 656        self.id.clone()
 657    }
 658
 659    fn name(&self) -> LanguageModelName {
 660        LanguageModelName::from(self.model.display_name().to_string())
 661    }
 662
 663    fn provider_id(&self) -> LanguageModelProviderId {
 664        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 665    }
 666
 667    fn provider_name(&self) -> LanguageModelProviderName {
 668        LanguageModelProviderName(PROVIDER_NAME.into())
 669    }
 670
 671    fn supports_tools(&self) -> bool {
 672        match self.model {
 673            CloudModel::Anthropic(_) => true,
 674            CloudModel::Google(_) => true,
 675            CloudModel::OpenAi(_) => true,
 676        }
 677    }
 678
 679    fn telemetry_id(&self) -> String {
 680        format!("zed.dev/{}", self.model.id())
 681    }
 682
 683    fn availability(&self) -> LanguageModelAvailability {
 684        self.model.availability()
 685    }
 686
 687    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 688        self.model.tool_input_format()
 689    }
 690
 691    fn max_token_count(&self) -> usize {
 692        self.model.max_token_count()
 693    }
 694
 695    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 696        match &self.model {
 697            CloudModel::Anthropic(model) => {
 698                model
 699                    .cache_configuration()
 700                    .map(|cache| LanguageModelCacheConfiguration {
 701                        max_cache_anchors: cache.max_cache_anchors,
 702                        should_speculate: cache.should_speculate,
 703                        min_total_token: cache.min_total_token,
 704                    })
 705            }
 706            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
 707        }
 708    }
 709
 710    fn count_tokens(
 711        &self,
 712        request: LanguageModelRequest,
 713        cx: &App,
 714    ) -> BoxFuture<'static, Result<usize>> {
 715        match self.model.clone() {
 716            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
 717            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
 718            CloudModel::Google(model) => {
 719                let client = self.client.clone();
 720                let llm_api_token = self.llm_api_token.clone();
 721                let model_id = model.id().to_string();
 722                let generate_content_request = into_google(request, model_id.clone());
 723                async move {
 724                    let http_client = &client.http_client();
 725                    let token = llm_api_token.acquire(&client).await?;
 726
 727                    let request_builder = http_client::Request::builder().method(Method::POST);
 728                    let request_builder =
 729                        if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
 730                            request_builder.uri(completions_url)
 731                        } else {
 732                            request_builder.uri(
 733                                http_client
 734                                    .build_zed_llm_url("/count_tokens", &[])?
 735                                    .as_ref(),
 736                            )
 737                        };
 738                    let request_body = CountTokensBody {
 739                        provider: zed_llm_client::LanguageModelProvider::Google,
 740                        model: model_id,
 741                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 742                            generate_content_request,
 743                        })?,
 744                    };
 745                    let request = request_builder
 746                        .header("Content-Type", "application/json")
 747                        .header("Authorization", format!("Bearer {token}"))
 748                        .body(serde_json::to_string(&request_body)?.into())?;
 749                    let mut response = http_client.send(request).await?;
 750                    let status = response.status();
 751                    let mut response_body = String::new();
 752                    response
 753                        .body_mut()
 754                        .read_to_string(&mut response_body)
 755                        .await?;
 756
 757                    if status.is_success() {
 758                        let response_body: CountTokensResponse =
 759                            serde_json::from_str(&response_body)?;
 760
 761                        Ok(response_body.tokens)
 762                    } else {
 763                        Err(anyhow!(ApiError {
 764                            status,
 765                            body: response_body
 766                        }))
 767                    }
 768                }
 769                .boxed()
 770            }
 771        }
 772    }
 773
 774    fn stream_completion(
 775        &self,
 776        request: LanguageModelRequest,
 777        _cx: &AsyncApp,
 778    ) -> BoxFuture<
 779        'static,
 780        Result<
 781            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 782        >,
 783    > {
 784        let thread_id = request.thread_id.clone();
 785        let prompt_id = request.prompt_id.clone();
 786        let mode = request.mode;
 787        match &self.model {
 788            CloudModel::Anthropic(model) => {
 789                let request = into_anthropic(
 790                    request,
 791                    model.request_id().into(),
 792                    model.default_temperature(),
 793                    model.max_output_tokens(),
 794                    model.mode(),
 795                );
 796                let client = self.client.clone();
 797                let llm_api_token = self.llm_api_token.clone();
 798                let future = self.request_limiter.stream(async move {
 799                    let PerformLlmCompletionResponse {
 800                        response,
 801                        usage,
 802                        includes_status_messages,
 803                        tool_use_limit_reached,
 804                    } = Self::perform_llm_completion(
 805                        client.clone(),
 806                        llm_api_token,
 807                        CompletionBody {
 808                            thread_id,
 809                            prompt_id,
 810                            mode,
 811                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 812                            model: request.model.clone(),
 813                            provider_request: serde_json::to_value(&request)?,
 814                        },
 815                    )
 816                    .await
 817                    .map_err(|err| match err.downcast::<ApiError>() {
 818                        Ok(api_err) => {
 819                            if api_err.status == StatusCode::BAD_REQUEST {
 820                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
 821                                    return anyhow!(
 822                                        LanguageModelKnownError::ContextWindowLimitExceeded {
 823                                            tokens
 824                                        }
 825                                    );
 826                                }
 827                            }
 828                            anyhow!(api_err)
 829                        }
 830                        Err(err) => anyhow!(err),
 831                    })?;
 832
 833                    let mut mapper = AnthropicEventMapper::new();
 834                    Ok(map_cloud_completion_events(
 835                        Box::pin(
 836                            response_lines(response, includes_status_messages)
 837                                .chain(usage_updated_event(usage))
 838                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 839                        ),
 840                        move |event| mapper.map_event(event),
 841                    ))
 842                });
 843                async move { Ok(future.await?.boxed()) }.boxed()
 844            }
 845            CloudModel::OpenAi(model) => {
 846                let client = self.client.clone();
 847                let request = into_open_ai(request, model, model.max_output_tokens());
 848                let llm_api_token = self.llm_api_token.clone();
 849                let future = self.request_limiter.stream(async move {
 850                    let PerformLlmCompletionResponse {
 851                        response,
 852                        usage,
 853                        includes_status_messages,
 854                        tool_use_limit_reached,
 855                    } = Self::perform_llm_completion(
 856                        client.clone(),
 857                        llm_api_token,
 858                        CompletionBody {
 859                            thread_id,
 860                            prompt_id,
 861                            mode,
 862                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 863                            model: request.model.clone(),
 864                            provider_request: serde_json::to_value(&request)?,
 865                        },
 866                    )
 867                    .await?;
 868
 869                    let mut mapper = OpenAiEventMapper::new();
 870                    Ok(map_cloud_completion_events(
 871                        Box::pin(
 872                            response_lines(response, includes_status_messages)
 873                                .chain(usage_updated_event(usage))
 874                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 875                        ),
 876                        move |event| mapper.map_event(event),
 877                    ))
 878                });
 879                async move { Ok(future.await?.boxed()) }.boxed()
 880            }
 881            CloudModel::Google(model) => {
 882                let client = self.client.clone();
 883                let request = into_google(request, model.id().into());
 884                let llm_api_token = self.llm_api_token.clone();
 885                let future = self.request_limiter.stream(async move {
 886                    let PerformLlmCompletionResponse {
 887                        response,
 888                        usage,
 889                        includes_status_messages,
 890                        tool_use_limit_reached,
 891                    } = Self::perform_llm_completion(
 892                        client.clone(),
 893                        llm_api_token,
 894                        CompletionBody {
 895                            thread_id,
 896                            prompt_id,
 897                            mode,
 898                            provider: zed_llm_client::LanguageModelProvider::Google,
 899                            model: request.model.model_id.clone(),
 900                            provider_request: serde_json::to_value(&request)?,
 901                        },
 902                    )
 903                    .await?;
 904
 905                    let mut mapper = GoogleEventMapper::new();
 906                    Ok(map_cloud_completion_events(
 907                        Box::pin(
 908                            response_lines(response, includes_status_messages)
 909                                .chain(usage_updated_event(usage))
 910                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 911                        ),
 912                        move |event| mapper.map_event(event),
 913                    ))
 914                });
 915                async move { Ok(future.await?.boxed()) }.boxed()
 916            }
 917        }
 918    }
 919}
 920
 921#[derive(Serialize, Deserialize)]
 922#[serde(rename_all = "snake_case")]
 923pub enum CloudCompletionEvent<T> {
 924    Status(CompletionRequestStatus),
 925    Event(T),
 926}
 927
 928fn map_cloud_completion_events<T, F>(
 929    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 930    mut map_callback: F,
 931) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 932where
 933    T: DeserializeOwned + 'static,
 934    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 935        + Send
 936        + 'static,
 937{
 938    stream
 939        .flat_map(move |event| {
 940            futures::stream::iter(match event {
 941                Err(error) => {
 942                    vec![Err(LanguageModelCompletionError::Other(error))]
 943                }
 944                Ok(CloudCompletionEvent::Status(event)) => {
 945                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
 946                }
 947                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
 948            })
 949        })
 950        .boxed()
 951}
 952
 953fn usage_updated_event<T>(
 954    usage: Option<RequestUsage>,
 955) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 956    futures::stream::iter(usage.map(|usage| {
 957        Ok(CloudCompletionEvent::Status(
 958            CompletionRequestStatus::UsageUpdated {
 959                amount: usage.amount as usize,
 960                limit: usage.limit,
 961            },
 962        ))
 963    }))
 964}
 965
 966fn tool_use_limit_reached_event<T>(
 967    tool_use_limit_reached: bool,
 968) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 969    futures::stream::iter(tool_use_limit_reached.then(|| {
 970        Ok(CloudCompletionEvent::Status(
 971            CompletionRequestStatus::ToolUseLimitReached,
 972        ))
 973    }))
 974}
 975
 976fn response_lines<T: DeserializeOwned>(
 977    response: Response<AsyncBody>,
 978    includes_status_messages: bool,
 979) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 980    futures::stream::try_unfold(
 981        (String::new(), BufReader::new(response.into_body())),
 982        move |(mut line, mut body)| async move {
 983            match body.read_line(&mut line).await {
 984                Ok(0) => Ok(None),
 985                Ok(_) => {
 986                    let event = if includes_status_messages {
 987                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
 988                    } else {
 989                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 990                    };
 991
 992                    line.clear();
 993                    Ok(Some((event, (line, body))))
 994                }
 995                Err(e) => Err(e.into()),
 996            }
 997        },
 998    )
 999}
1000
1001struct ConfigurationView {
1002    state: gpui::Entity<State>,
1003}
1004
1005impl ConfigurationView {
1006    fn authenticate(&mut self, cx: &mut Context<Self>) {
1007        self.state.update(cx, |state, cx| {
1008            state.authenticate(cx).detach_and_log_err(cx);
1009        });
1010        cx.notify();
1011    }
1012}
1013
1014impl Render for ConfigurationView {
1015    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1016        const ZED_AI_URL: &str = "https://zed.dev/ai";
1017
1018        let is_connected = !self.state.read(cx).is_signed_out();
1019        let plan = self.state.read(cx).user_store.read(cx).current_plan();
1020        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1021
1022        let is_pro = plan == Some(proto::Plan::ZedPro);
1023        let subscription_text = Label::new(if is_pro {
1024            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
1025        } else {
1026            "You have basic access to models from Anthropic through the Zed AI Free plan."
1027        });
1028        let manage_subscription_button = if is_pro {
1029            Some(
1030                h_flex().child(
1031                    Button::new("manage_settings", "Manage Subscription")
1032                        .style(ButtonStyle::Tinted(TintColor::Accent))
1033                        .on_click(
1034                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1035                        ),
1036                ),
1037            )
1038        } else if cx.has_flag::<ZedProFeatureFlag>() {
1039            Some(
1040                h_flex()
1041                    .gap_2()
1042                    .child(
1043                        Button::new("learn_more", "Learn more")
1044                            .style(ButtonStyle::Subtle)
1045                            .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
1046                    )
1047                    .child(
1048                        Button::new("upgrade", "Upgrade")
1049                            .style(ButtonStyle::Subtle)
1050                            .color(Color::Accent)
1051                            .on_click(
1052                                cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1053                            ),
1054                    ),
1055            )
1056        } else {
1057            None
1058        };
1059
1060        if is_connected {
1061            v_flex()
1062                .gap_3()
1063                .w_full()
1064                .children(render_accept_terms(
1065                    self.state.clone(),
1066                    LanguageModelProviderTosView::Configuration,
1067                    cx,
1068                ))
1069                .when(has_accepted_terms, |this| {
1070                    this.child(subscription_text)
1071                        .children(manage_subscription_button)
1072                })
1073        } else {
1074            v_flex()
1075                .gap_2()
1076                .child(Label::new("Use Zed AI to access hosted language models."))
1077                .child(
1078                    Button::new("sign_in", "Sign In")
1079                        .icon_color(Color::Muted)
1080                        .icon(IconName::Github)
1081                        .icon_position(IconPosition::Start)
1082                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1083                )
1084        }
1085    }
1086}