cloud.rs

   1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
   2use anyhow::{Result, anyhow};
   3use client::{Client, UserStore, zed_urls};
   4use collections::BTreeMap;
   5use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
   6use futures::{
   7    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   8};
   9use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  10use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  11use language_model::{
  12    AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
  13    LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
  14    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  15    LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
  16    ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
  17};
  18use language_model::{
  19    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
  20    MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
  21};
  22use proto::Plan;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize, de::DeserializeOwned};
  25use settings::{Settings, SettingsStore};
  26use smol::Timer;
  27use smol::io::{AsyncReadExt, BufReader};
  28use std::pin::Pin;
  29use std::str::FromStr as _;
  30use std::{
  31    sync::{Arc, LazyLock},
  32    time::Duration,
  33};
  34use strum::IntoEnumIterator;
  35use thiserror::Error;
  36use ui::{TintColor, prelude::*};
  37use zed_llm_client::{
  38    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
  39    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  40    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
  41    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  42    TOOL_USE_LIMIT_REACHED_HEADER_NAME,
  43};
  44
  45use crate::AllLanguageModelSettings;
  46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  47use crate::provider::google::{GoogleEventMapper, into_google};
  48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  49
  50pub const PROVIDER_NAME: &str = "Zed";
  51
  52const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
  53    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
  54
  55fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
  56    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
  57        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
  58            .map(|json| serde_json::from_str(json).unwrap())
  59            .unwrap_or_default()
  60    });
  61    ADDITIONAL_MODELS.as_slice()
  62}
  63
  64#[derive(Default, Clone, Debug, PartialEq)]
  65pub struct ZedDotDevSettings {
  66    pub available_models: Vec<AvailableModel>,
  67}
  68
  69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  70#[serde(rename_all = "lowercase")]
  71pub enum AvailableProvider {
  72    Anthropic,
  73    OpenAi,
  74    Google,
  75}
  76
  77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  78pub struct AvailableModel {
  79    /// The provider of the language model.
  80    pub provider: AvailableProvider,
  81    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  82    pub name: String,
  83    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  84    pub display_name: Option<String>,
  85    /// The size of the context window, indicating the maximum number of tokens the model can process.
  86    pub max_tokens: usize,
  87    /// The maximum number of output tokens allowed by the model.
  88    pub max_output_tokens: Option<u32>,
  89    /// The maximum number of completion tokens allowed by the model (o1-* only)
  90    pub max_completion_tokens: Option<u32>,
  91    /// Override this model with a different Anthropic model for tool calls.
  92    pub tool_override: Option<String>,
  93    /// Indicates whether this custom model supports caching.
  94    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  95    /// The default temperature to use for this model.
  96    pub default_temperature: Option<f32>,
  97    /// Any extra beta headers to provide when using the model.
  98    #[serde(default)]
  99    pub extra_beta_headers: Vec<String>,
 100    /// The model's mode (e.g. thinking)
 101    pub mode: Option<ModelMode>,
 102}
 103
 104#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 105#[serde(tag = "type", rename_all = "lowercase")]
 106pub enum ModelMode {
 107    #[default]
 108    Default,
 109    Thinking {
 110        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 111        budget_tokens: Option<u32>,
 112    },
 113}
 114
 115impl From<ModelMode> for AnthropicModelMode {
 116    fn from(value: ModelMode) -> Self {
 117        match value {
 118            ModelMode::Default => AnthropicModelMode::Default,
 119            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 120        }
 121    }
 122}
 123
 124pub struct CloudLanguageModelProvider {
 125    client: Arc<Client>,
 126    state: gpui::Entity<State>,
 127    _maintain_client_status: Task<()>,
 128}
 129
 130pub struct State {
 131    client: Arc<Client>,
 132    llm_api_token: LlmApiToken,
 133    user_store: Entity<UserStore>,
 134    status: client::Status,
 135    accept_terms: Option<Task<Result<()>>>,
 136    _settings_subscription: Subscription,
 137    _llm_token_subscription: Subscription,
 138}
 139
 140impl State {
 141    fn new(
 142        client: Arc<Client>,
 143        user_store: Entity<UserStore>,
 144        status: client::Status,
 145        cx: &mut Context<Self>,
 146    ) -> Self {
 147        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 148
 149        Self {
 150            client: client.clone(),
 151            llm_api_token: LlmApiToken::default(),
 152            user_store,
 153            status,
 154            accept_terms: None,
 155            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 156                cx.notify();
 157            }),
 158            _llm_token_subscription: cx.subscribe(
 159                &refresh_llm_token_listener,
 160                |this, _listener, _event, cx| {
 161                    let client = this.client.clone();
 162                    let llm_api_token = this.llm_api_token.clone();
 163                    cx.spawn(async move |_this, _cx| {
 164                        llm_api_token.refresh(&client).await?;
 165                        anyhow::Ok(())
 166                    })
 167                    .detach_and_log_err(cx);
 168                },
 169            ),
 170        }
 171    }
 172
 173    fn is_signed_out(&self) -> bool {
 174        self.status.is_signed_out()
 175    }
 176
 177    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 178        let client = self.client.clone();
 179        cx.spawn(async move |this, cx| {
 180            client.authenticate_and_connect(true, &cx).await?;
 181            this.update(cx, |_, cx| cx.notify())
 182        })
 183    }
 184
 185    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 186        self.user_store
 187            .read(cx)
 188            .current_user_has_accepted_terms()
 189            .unwrap_or(false)
 190    }
 191
 192    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 193        let user_store = self.user_store.clone();
 194        self.accept_terms = Some(cx.spawn(async move |this, cx| {
 195            let _ = user_store
 196                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 197                .await;
 198            this.update(cx, |this, cx| {
 199                this.accept_terms = None;
 200                cx.notify()
 201            })
 202        }));
 203    }
 204}
 205
 206impl CloudLanguageModelProvider {
 207    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 208        let mut status_rx = client.status();
 209        let status = *status_rx.borrow();
 210
 211        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 212
 213        let state_ref = state.downgrade();
 214        let maintain_client_status = cx.spawn(async move |cx| {
 215            while let Some(status) = status_rx.next().await {
 216                if let Some(this) = state_ref.upgrade() {
 217                    _ = this.update(cx, |this, cx| {
 218                        if this.status != status {
 219                            this.status = status;
 220                            cx.notify();
 221                        }
 222                    });
 223                } else {
 224                    break;
 225                }
 226            }
 227        });
 228
 229        Self {
 230            client,
 231            state: state.clone(),
 232            _maintain_client_status: maintain_client_status,
 233        }
 234    }
 235
 236    fn create_language_model(
 237        &self,
 238        model: CloudModel,
 239        llm_api_token: LlmApiToken,
 240    ) -> Arc<dyn LanguageModel> {
 241        Arc::new(CloudLanguageModel {
 242            id: LanguageModelId::from(model.id().to_string()),
 243            model,
 244            llm_api_token: llm_api_token.clone(),
 245            client: self.client.clone(),
 246            request_limiter: RateLimiter::new(4),
 247        })
 248    }
 249}
 250
 251impl LanguageModelProviderState for CloudLanguageModelProvider {
 252    type ObservableEntity = State;
 253
 254    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 255        Some(self.state.clone())
 256    }
 257}
 258
 259impl LanguageModelProvider for CloudLanguageModelProvider {
 260    fn id(&self) -> LanguageModelProviderId {
 261        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 262    }
 263
 264    fn name(&self) -> LanguageModelProviderName {
 265        LanguageModelProviderName(PROVIDER_NAME.into())
 266    }
 267
 268    fn icon(&self) -> IconName {
 269        IconName::AiZed
 270    }
 271
 272    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 273        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 274        let model = CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet);
 275        Some(self.create_language_model(model, llm_api_token))
 276    }
 277
 278    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 279        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 280        let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet);
 281        Some(self.create_language_model(model, llm_api_token))
 282    }
 283
 284    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 285        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 286        [
 287            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
 288            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
 289        ]
 290        .into_iter()
 291        .map(|model| self.create_language_model(model, llm_api_token.clone()))
 292        .collect()
 293    }
 294
 295    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 296        let mut models = BTreeMap::default();
 297
 298        if cx.is_staff() {
 299            for model in anthropic::Model::iter() {
 300                if !matches!(model, anthropic::Model::Custom { .. }) {
 301                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
 302                }
 303            }
 304            for model in open_ai::Model::iter() {
 305                if !matches!(model, open_ai::Model::Custom { .. }) {
 306                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
 307                }
 308            }
 309            for model in google_ai::Model::iter() {
 310                if !matches!(model, google_ai::Model::Custom { .. }) {
 311                    models.insert(model.id().to_string(), CloudModel::Google(model));
 312                }
 313            }
 314        } else {
 315            models.insert(
 316                anthropic::Model::Claude3_5Sonnet.id().to_string(),
 317                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
 318            );
 319            models.insert(
 320                anthropic::Model::Claude3_7Sonnet.id().to_string(),
 321                CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
 322            );
 323            models.insert(
 324                anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
 325                CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
 326            );
 327        }
 328
 329        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
 330            zed_cloud_provider_additional_models()
 331        } else {
 332            &[]
 333        };
 334
 335        // Override with available models from settings
 336        for model in AllLanguageModelSettings::get_global(cx)
 337            .zed_dot_dev
 338            .available_models
 339            .iter()
 340            .chain(llm_closed_beta_models)
 341            .cloned()
 342        {
 343            let model = match model.provider {
 344                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
 345                    name: model.name.clone(),
 346                    display_name: model.display_name.clone(),
 347                    max_tokens: model.max_tokens,
 348                    tool_override: model.tool_override.clone(),
 349                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 350                        anthropic::AnthropicModelCacheConfiguration {
 351                            max_cache_anchors: config.max_cache_anchors,
 352                            should_speculate: config.should_speculate,
 353                            min_total_token: config.min_total_token,
 354                        }
 355                    }),
 356                    default_temperature: model.default_temperature,
 357                    max_output_tokens: model.max_output_tokens,
 358                    extra_beta_headers: model.extra_beta_headers.clone(),
 359                    mode: model.mode.unwrap_or_default().into(),
 360                }),
 361                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
 362                    name: model.name.clone(),
 363                    display_name: model.display_name.clone(),
 364                    max_tokens: model.max_tokens,
 365                    max_output_tokens: model.max_output_tokens,
 366                    max_completion_tokens: model.max_completion_tokens,
 367                }),
 368                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
 369                    name: model.name.clone(),
 370                    display_name: model.display_name.clone(),
 371                    max_tokens: model.max_tokens,
 372                }),
 373            };
 374            models.insert(model.id().to_string(), model.clone());
 375        }
 376
 377        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 378        models
 379            .into_values()
 380            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 381            .collect()
 382    }
 383
 384    fn is_authenticated(&self, cx: &App) -> bool {
 385        !self.state.read(cx).is_signed_out()
 386    }
 387
 388    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 389        Task::ready(Ok(()))
 390    }
 391
 392    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 393        cx.new(|_| ConfigurationView {
 394            state: self.state.clone(),
 395        })
 396        .into()
 397    }
 398
 399    fn must_accept_terms(&self, cx: &App) -> bool {
 400        !self.state.read(cx).has_accepted_terms_of_service(cx)
 401    }
 402
 403    fn render_accept_terms(
 404        &self,
 405        view: LanguageModelProviderTosView,
 406        cx: &mut App,
 407    ) -> Option<AnyElement> {
 408        render_accept_terms(self.state.clone(), view, cx)
 409    }
 410
 411    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 412        Task::ready(Ok(()))
 413    }
 414}
 415
 416fn render_accept_terms(
 417    state: Entity<State>,
 418    view_kind: LanguageModelProviderTosView,
 419    cx: &mut App,
 420) -> Option<AnyElement> {
 421    if state.read(cx).has_accepted_terms_of_service(cx) {
 422        return None;
 423    }
 424
 425    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
 426
 427    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 428    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
 429
 430    let terms_button = Button::new("terms_of_service", "Terms of Service")
 431        .style(ButtonStyle::Subtle)
 432        .icon(IconName::ArrowUpRight)
 433        .icon_color(Color::Muted)
 434        .icon_size(IconSize::XSmall)
 435        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 436        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 437
 438    let button_container = h_flex().child(
 439        Button::new("accept_terms", "I accept the Terms of Service")
 440            .when(!thread_empty_state, |this| {
 441                this.full_width()
 442                    .style(ButtonStyle::Tinted(TintColor::Accent))
 443                    .icon(IconName::Check)
 444                    .icon_position(IconPosition::Start)
 445                    .icon_size(IconSize::Small)
 446            })
 447            .when(thread_empty_state, |this| {
 448                this.style(ButtonStyle::Tinted(TintColor::Warning))
 449                    .label_size(LabelSize::Small)
 450            })
 451            .disabled(accept_terms_disabled)
 452            .on_click({
 453                let state = state.downgrade();
 454                move |_, _window, cx| {
 455                    state
 456                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
 457                        .ok();
 458                }
 459            }),
 460    );
 461
 462    let form = if thread_empty_state {
 463        h_flex()
 464            .w_full()
 465            .flex_wrap()
 466            .justify_between()
 467            .child(
 468                h_flex()
 469                    .child(
 470                        Label::new("To start using Zed AI, please read and accept the")
 471                            .size(LabelSize::Small),
 472                    )
 473                    .child(terms_button),
 474            )
 475            .child(button_container)
 476    } else {
 477        v_flex()
 478            .w_full()
 479            .gap_2()
 480            .child(
 481                h_flex()
 482                    .flex_wrap()
 483                    .when(thread_fresh_start, |this| this.justify_center())
 484                    .child(Label::new(
 485                        "To start using Zed AI, please read and accept the",
 486                    ))
 487                    .child(terms_button),
 488            )
 489            .child({
 490                match view_kind {
 491                    LanguageModelProviderTosView::PromptEditorPopup => {
 492                        button_container.w_full().justify_end()
 493                    }
 494                    LanguageModelProviderTosView::Configuration => {
 495                        button_container.w_full().justify_start()
 496                    }
 497                    LanguageModelProviderTosView::ThreadFreshStart => {
 498                        button_container.w_full().justify_center()
 499                    }
 500                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
 501                }
 502            })
 503    };
 504
 505    Some(form.into_any())
 506}
 507
 508pub struct CloudLanguageModel {
 509    id: LanguageModelId,
 510    model: CloudModel,
 511    llm_api_token: LlmApiToken,
 512    client: Arc<Client>,
 513    request_limiter: RateLimiter,
 514}
 515
 516struct PerformLlmCompletionResponse {
 517    response: Response<AsyncBody>,
 518    usage: Option<RequestUsage>,
 519    tool_use_limit_reached: bool,
 520    includes_status_messages: bool,
 521}
 522
 523impl CloudLanguageModel {
 524    const MAX_RETRIES: usize = 3;
 525
 526    async fn perform_llm_completion(
 527        client: Arc<Client>,
 528        llm_api_token: LlmApiToken,
 529        body: CompletionBody,
 530    ) -> Result<PerformLlmCompletionResponse> {
 531        let http_client = &client.http_client();
 532
 533        let mut token = llm_api_token.acquire(&client).await?;
 534        let mut retries_remaining = Self::MAX_RETRIES;
 535        let mut retry_delay = Duration::from_secs(1);
 536
 537        loop {
 538            let request_builder = http_client::Request::builder().method(Method::POST);
 539            let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
 540            {
 541                request_builder.uri(completions_url)
 542            } else {
 543                request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 544            };
 545            let request = request_builder
 546                .header("Content-Type", "application/json")
 547                .header("Authorization", format!("Bearer {token}"))
 548                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 549                .body(serde_json::to_string(&body)?.into())?;
 550            let mut response = http_client.send(request).await?;
 551            let status = response.status();
 552            if status.is_success() {
 553                let includes_status_messages = response
 554                    .headers()
 555                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 556                    .is_some();
 557
 558                let tool_use_limit_reached = response
 559                    .headers()
 560                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 561                    .is_some();
 562
 563                let usage = if includes_status_messages {
 564                    None
 565                } else {
 566                    RequestUsage::from_headers(response.headers()).ok()
 567                };
 568
 569                return Ok(PerformLlmCompletionResponse {
 570                    response,
 571                    usage,
 572                    includes_status_messages,
 573                    tool_use_limit_reached,
 574                });
 575            } else if response
 576                .headers()
 577                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 578                .is_some()
 579            {
 580                retries_remaining -= 1;
 581                token = llm_api_token.refresh(&client).await?;
 582            } else if status == StatusCode::FORBIDDEN
 583                && response
 584                    .headers()
 585                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
 586                    .is_some()
 587            {
 588                return Err(anyhow!(MaxMonthlySpendReachedError));
 589            } else if status == StatusCode::FORBIDDEN
 590                && response
 591                    .headers()
 592                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 593                    .is_some()
 594            {
 595                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 596                    .headers()
 597                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 598                    .and_then(|resource| resource.to_str().ok())
 599                {
 600                    if let Some(plan) = response
 601                        .headers()
 602                        .get(CURRENT_PLAN_HEADER_NAME)
 603                        .and_then(|plan| plan.to_str().ok())
 604                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 605                    {
 606                        let plan = match plan {
 607                            zed_llm_client::Plan::Free => Plan::Free,
 608                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 609                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 610                        };
 611                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 612                    }
 613                }
 614
 615                return Err(anyhow!("Forbidden"));
 616            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
 617                // If we encounter an error in the 500 range, retry after a delay.
 618                // We've seen at least these in the wild from API providers:
 619                // * 500 Internal Server Error
 620                // * 502 Bad Gateway
 621                // * 529 Service Overloaded
 622
 623                if retries_remaining == 0 {
 624                    let mut body = String::new();
 625                    response.body_mut().read_to_string(&mut body).await?;
 626                    return Err(anyhow!(
 627                        "cloud language model completion failed after {} retries with status {status}: {body}",
 628                        Self::MAX_RETRIES
 629                    ));
 630                }
 631
 632                Timer::after(retry_delay).await;
 633
 634                retries_remaining -= 1;
 635                retry_delay *= 2; // If it fails again, wait longer.
 636            } else if status == StatusCode::PAYMENT_REQUIRED {
 637                return Err(anyhow!(PaymentRequiredError));
 638            } else {
 639                let mut body = String::new();
 640                response.body_mut().read_to_string(&mut body).await?;
 641                return Err(anyhow!(ApiError { status, body }));
 642            }
 643        }
 644    }
 645}
 646
 647#[derive(Debug, Error)]
 648#[error("cloud language model request failed with status {status}: {body}")]
 649struct ApiError {
 650    status: StatusCode,
 651    body: String,
 652}
 653
 654impl LanguageModel for CloudLanguageModel {
 655    fn id(&self) -> LanguageModelId {
 656        self.id.clone()
 657    }
 658
 659    fn name(&self) -> LanguageModelName {
 660        LanguageModelName::from(self.model.display_name().to_string())
 661    }
 662
 663    fn provider_id(&self) -> LanguageModelProviderId {
 664        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 665    }
 666
 667    fn provider_name(&self) -> LanguageModelProviderName {
 668        LanguageModelProviderName(PROVIDER_NAME.into())
 669    }
 670
 671    fn supports_tools(&self) -> bool {
 672        match self.model {
 673            CloudModel::Anthropic(_) => true,
 674            CloudModel::Google(_) => true,
 675            CloudModel::OpenAi(_) => true,
 676        }
 677    }
 678
 679    fn telemetry_id(&self) -> String {
 680        format!("zed.dev/{}", self.model.id())
 681    }
 682
 683    fn availability(&self) -> LanguageModelAvailability {
 684        self.model.availability()
 685    }
 686
 687    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 688        self.model.tool_input_format()
 689    }
 690
 691    fn max_token_count(&self) -> usize {
 692        self.model.max_token_count()
 693    }
 694
 695    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 696        match &self.model {
 697            CloudModel::Anthropic(model) => {
 698                model
 699                    .cache_configuration()
 700                    .map(|cache| LanguageModelCacheConfiguration {
 701                        max_cache_anchors: cache.max_cache_anchors,
 702                        should_speculate: cache.should_speculate,
 703                        min_total_token: cache.min_total_token,
 704                    })
 705            }
 706            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
 707        }
 708    }
 709
 710    fn count_tokens(
 711        &self,
 712        request: LanguageModelRequest,
 713        cx: &App,
 714    ) -> BoxFuture<'static, Result<usize>> {
 715        match self.model.clone() {
 716            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
 717            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
 718            CloudModel::Google(model) => {
 719                let client = self.client.clone();
 720                let llm_api_token = self.llm_api_token.clone();
 721                let request = into_google(request, model.id().into());
 722                async move {
 723                    let http_client = &client.http_client();
 724                    let token = llm_api_token.acquire(&client).await?;
 725
 726                    let request_builder = http_client::Request::builder().method(Method::POST);
 727                    let request_builder =
 728                        if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
 729                            request_builder.uri(completions_url)
 730                        } else {
 731                            request_builder.uri(
 732                                http_client
 733                                    .build_zed_llm_url("/count_tokens", &[])?
 734                                    .as_ref(),
 735                            )
 736                        };
 737                    let request_body = CountTokensBody {
 738                        provider: zed_llm_client::LanguageModelProvider::Google,
 739                        model: model.id().into(),
 740                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 741                            contents: request.contents,
 742                        })?,
 743                    };
 744                    let request = request_builder
 745                        .header("Content-Type", "application/json")
 746                        .header("Authorization", format!("Bearer {token}"))
 747                        .body(serde_json::to_string(&request_body)?.into())?;
 748                    let mut response = http_client.send(request).await?;
 749                    let status = response.status();
 750                    let mut response_body = String::new();
 751                    response
 752                        .body_mut()
 753                        .read_to_string(&mut response_body)
 754                        .await?;
 755
 756                    if status.is_success() {
 757                        let response_body: CountTokensResponse =
 758                            serde_json::from_str(&response_body)?;
 759
 760                        Ok(response_body.tokens)
 761                    } else {
 762                        Err(anyhow!(ApiError {
 763                            status,
 764                            body: response_body
 765                        }))
 766                    }
 767                }
 768                .boxed()
 769            }
 770        }
 771    }
 772
 773    fn stream_completion(
 774        &self,
 775        request: LanguageModelRequest,
 776        _cx: &AsyncApp,
 777    ) -> BoxFuture<
 778        'static,
 779        Result<
 780            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 781        >,
 782    > {
 783        let thread_id = request.thread_id.clone();
 784        let prompt_id = request.prompt_id.clone();
 785        let mode = request.mode;
 786        match &self.model {
 787            CloudModel::Anthropic(model) => {
 788                let request = into_anthropic(
 789                    request,
 790                    model.request_id().into(),
 791                    model.default_temperature(),
 792                    model.max_output_tokens(),
 793                    model.mode(),
 794                );
 795                let client = self.client.clone();
 796                let llm_api_token = self.llm_api_token.clone();
 797                let future = self.request_limiter.stream(async move {
 798                    let PerformLlmCompletionResponse {
 799                        response,
 800                        usage,
 801                        includes_status_messages,
 802                        tool_use_limit_reached,
 803                    } = Self::perform_llm_completion(
 804                        client.clone(),
 805                        llm_api_token,
 806                        CompletionBody {
 807                            thread_id,
 808                            prompt_id,
 809                            mode,
 810                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 811                            model: request.model.clone(),
 812                            provider_request: serde_json::to_value(&request)?,
 813                        },
 814                    )
 815                    .await
 816                    .map_err(|err| match err.downcast::<ApiError>() {
 817                        Ok(api_err) => {
 818                            if api_err.status == StatusCode::BAD_REQUEST {
 819                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
 820                                    return anyhow!(
 821                                        LanguageModelKnownError::ContextWindowLimitExceeded {
 822                                            tokens
 823                                        }
 824                                    );
 825                                }
 826                            }
 827                            anyhow!(api_err)
 828                        }
 829                        Err(err) => anyhow!(err),
 830                    })?;
 831
 832                    let mut mapper = AnthropicEventMapper::new();
 833                    Ok(map_cloud_completion_events(
 834                        Box::pin(
 835                            response_lines(response, includes_status_messages)
 836                                .chain(usage_updated_event(usage))
 837                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 838                        ),
 839                        move |event| mapper.map_event(event),
 840                    ))
 841                });
 842                async move { Ok(future.await?.boxed()) }.boxed()
 843            }
 844            CloudModel::OpenAi(model) => {
 845                let client = self.client.clone();
 846                let request = into_open_ai(request, model, model.max_output_tokens());
 847                let llm_api_token = self.llm_api_token.clone();
 848                let future = self.request_limiter.stream(async move {
 849                    let PerformLlmCompletionResponse {
 850                        response,
 851                        usage,
 852                        includes_status_messages,
 853                        tool_use_limit_reached,
 854                    } = Self::perform_llm_completion(
 855                        client.clone(),
 856                        llm_api_token,
 857                        CompletionBody {
 858                            thread_id,
 859                            prompt_id,
 860                            mode,
 861                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 862                            model: request.model.clone(),
 863                            provider_request: serde_json::to_value(&request)?,
 864                        },
 865                    )
 866                    .await?;
 867
 868                    let mut mapper = OpenAiEventMapper::new();
 869                    Ok(map_cloud_completion_events(
 870                        Box::pin(
 871                            response_lines(response, includes_status_messages)
 872                                .chain(usage_updated_event(usage))
 873                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 874                        ),
 875                        move |event| mapper.map_event(event),
 876                    ))
 877                });
 878                async move { Ok(future.await?.boxed()) }.boxed()
 879            }
 880            CloudModel::Google(model) => {
 881                let client = self.client.clone();
 882                let request = into_google(request, model.id().into());
 883                let llm_api_token = self.llm_api_token.clone();
 884                let future = self.request_limiter.stream(async move {
 885                    let PerformLlmCompletionResponse {
 886                        response,
 887                        usage,
 888                        includes_status_messages,
 889                        tool_use_limit_reached,
 890                    } = Self::perform_llm_completion(
 891                        client.clone(),
 892                        llm_api_token,
 893                        CompletionBody {
 894                            thread_id,
 895                            prompt_id,
 896                            mode,
 897                            provider: zed_llm_client::LanguageModelProvider::Google,
 898                            model: request.model.clone(),
 899                            provider_request: serde_json::to_value(&request)?,
 900                        },
 901                    )
 902                    .await?;
 903
 904                    let mut mapper = GoogleEventMapper::new();
 905                    Ok(map_cloud_completion_events(
 906                        Box::pin(
 907                            response_lines(response, includes_status_messages)
 908                                .chain(usage_updated_event(usage))
 909                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 910                        ),
 911                        move |event| mapper.map_event(event),
 912                    ))
 913                });
 914                async move { Ok(future.await?.boxed()) }.boxed()
 915            }
 916        }
 917    }
 918}
 919
 920#[derive(Serialize, Deserialize)]
 921#[serde(rename_all = "snake_case")]
 922pub enum CloudCompletionEvent<T> {
 923    Status(CompletionRequestStatus),
 924    Event(T),
 925}
 926
 927fn map_cloud_completion_events<T, F>(
 928    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 929    mut map_callback: F,
 930) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 931where
 932    T: DeserializeOwned + 'static,
 933    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 934        + Send
 935        + 'static,
 936{
 937    stream
 938        .flat_map(move |event| {
 939            futures::stream::iter(match event {
 940                Err(error) => {
 941                    vec![Err(LanguageModelCompletionError::Other(error))]
 942                }
 943                Ok(CloudCompletionEvent::Status(event)) => {
 944                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
 945                }
 946                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
 947            })
 948        })
 949        .boxed()
 950}
 951
 952fn usage_updated_event<T>(
 953    usage: Option<RequestUsage>,
 954) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 955    futures::stream::iter(usage.map(|usage| {
 956        Ok(CloudCompletionEvent::Status(
 957            CompletionRequestStatus::UsageUpdated {
 958                amount: usage.amount as usize,
 959                limit: usage.limit,
 960            },
 961        ))
 962    }))
 963}
 964
 965fn tool_use_limit_reached_event<T>(
 966    tool_use_limit_reached: bool,
 967) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 968    futures::stream::iter(tool_use_limit_reached.then(|| {
 969        Ok(CloudCompletionEvent::Status(
 970            CompletionRequestStatus::ToolUseLimitReached,
 971        ))
 972    }))
 973}
 974
 975fn response_lines<T: DeserializeOwned>(
 976    response: Response<AsyncBody>,
 977    includes_status_messages: bool,
 978) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 979    futures::stream::try_unfold(
 980        (String::new(), BufReader::new(response.into_body())),
 981        move |(mut line, mut body)| async move {
 982            match body.read_line(&mut line).await {
 983                Ok(0) => Ok(None),
 984                Ok(_) => {
 985                    let event = if includes_status_messages {
 986                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
 987                    } else {
 988                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 989                    };
 990
 991                    line.clear();
 992                    Ok(Some((event, (line, body))))
 993                }
 994                Err(e) => Err(e.into()),
 995            }
 996        },
 997    )
 998}
 999
1000struct ConfigurationView {
1001    state: gpui::Entity<State>,
1002}
1003
1004impl ConfigurationView {
1005    fn authenticate(&mut self, cx: &mut Context<Self>) {
1006        self.state.update(cx, |state, cx| {
1007            state.authenticate(cx).detach_and_log_err(cx);
1008        });
1009        cx.notify();
1010    }
1011}
1012
1013impl Render for ConfigurationView {
1014    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1015        const ZED_AI_URL: &str = "https://zed.dev/ai";
1016
1017        let is_connected = !self.state.read(cx).is_signed_out();
1018        let plan = self.state.read(cx).user_store.read(cx).current_plan();
1019        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1020
1021        let is_pro = plan == Some(proto::Plan::ZedPro);
1022        let subscription_text = Label::new(if is_pro {
1023            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
1024        } else {
1025            "You have basic access to models from Anthropic through the Zed AI Free plan."
1026        });
1027        let manage_subscription_button = if is_pro {
1028            Some(
1029                h_flex().child(
1030                    Button::new("manage_settings", "Manage Subscription")
1031                        .style(ButtonStyle::Tinted(TintColor::Accent))
1032                        .on_click(
1033                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1034                        ),
1035                ),
1036            )
1037        } else if cx.has_flag::<ZedProFeatureFlag>() {
1038            Some(
1039                h_flex()
1040                    .gap_2()
1041                    .child(
1042                        Button::new("learn_more", "Learn more")
1043                            .style(ButtonStyle::Subtle)
1044                            .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
1045                    )
1046                    .child(
1047                        Button::new("upgrade", "Upgrade")
1048                            .style(ButtonStyle::Subtle)
1049                            .color(Color::Accent)
1050                            .on_click(
1051                                cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1052                            ),
1053                    ),
1054            )
1055        } else {
1056            None
1057        };
1058
1059        if is_connected {
1060            v_flex()
1061                .gap_3()
1062                .w_full()
1063                .children(render_accept_terms(
1064                    self.state.clone(),
1065                    LanguageModelProviderTosView::Configuration,
1066                    cx,
1067                ))
1068                .when(has_accepted_terms, |this| {
1069                    this.child(subscription_text)
1070                        .children(manage_subscription_button)
1071                })
1072        } else {
1073            v_flex()
1074                .gap_2()
1075                .child(Label::new("Use Zed AI to access hosted language models."))
1076                .child(
1077                    Button::new("sign_in", "Sign In")
1078                        .icon_color(Color::Muted)
1079                        .icon(IconName::Github)
1080                        .icon_position(IconPosition::Start)
1081                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1082                )
1083        }
1084    }
1085}