cloud.rs

   1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
   2use anyhow::{Result, anyhow};
   3use client::{Client, UserStore, zed_urls};
   4use collections::BTreeMap;
   5use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
   6use futures::{
   7    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   8};
   9use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
  10use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  11use language_model::{
  12    AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
  13    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
  14    LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  15    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
  16    LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
  17    ZED_CLOUD_PROVIDER_ID,
  18};
  19use language_model::{
  20    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
  21    MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
  22};
  23use proto::Plan;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize, de::DeserializeOwned};
  26use settings::{Settings, SettingsStore};
  27use smol::Timer;
  28use smol::io::{AsyncReadExt, BufReader};
  29use std::pin::Pin;
  30use std::str::FromStr as _;
  31use std::{
  32    sync::{Arc, LazyLock},
  33    time::Duration,
  34};
  35use strum::IntoEnumIterator;
  36use thiserror::Error;
  37use ui::{TintColor, prelude::*};
  38use zed_llm_client::{
  39    CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
  40    EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
  41    MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  42    TOOL_USE_LIMIT_REACHED_HEADER_NAME,
  43};
  44
  45use crate::AllLanguageModelSettings;
  46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  47use crate::provider::google::{GoogleEventMapper, into_google};
  48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  49
  50pub const PROVIDER_NAME: &str = "Zed";
  51
  52const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
  53    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
  54
  55fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
  56    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
  57        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
  58            .map(|json| serde_json::from_str(json).unwrap())
  59            .unwrap_or_default()
  60    });
  61    ADDITIONAL_MODELS.as_slice()
  62}
  63
  64#[derive(Default, Clone, Debug, PartialEq)]
  65pub struct ZedDotDevSettings {
  66    pub available_models: Vec<AvailableModel>,
  67}
  68
  69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  70#[serde(rename_all = "lowercase")]
  71pub enum AvailableProvider {
  72    Anthropic,
  73    OpenAi,
  74    Google,
  75}
  76
  77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  78pub struct AvailableModel {
  79    /// The provider of the language model.
  80    pub provider: AvailableProvider,
  81    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  82    pub name: String,
  83    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  84    pub display_name: Option<String>,
  85    /// The size of the context window, indicating the maximum number of tokens the model can process.
  86    pub max_tokens: usize,
  87    /// The maximum number of output tokens allowed by the model.
  88    pub max_output_tokens: Option<u32>,
  89    /// The maximum number of completion tokens allowed by the model (o1-* only)
  90    pub max_completion_tokens: Option<u32>,
  91    /// Override this model with a different Anthropic model for tool calls.
  92    pub tool_override: Option<String>,
  93    /// Indicates whether this custom model supports caching.
  94    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  95    /// The default temperature to use for this model.
  96    pub default_temperature: Option<f32>,
  97    /// Any extra beta headers to provide when using the model.
  98    #[serde(default)]
  99    pub extra_beta_headers: Vec<String>,
 100    /// The model's mode (e.g. thinking)
 101    pub mode: Option<ModelMode>,
 102}
 103
 104#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 105#[serde(tag = "type", rename_all = "lowercase")]
 106pub enum ModelMode {
 107    #[default]
 108    Default,
 109    Thinking {
 110        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 111        budget_tokens: Option<u32>,
 112    },
 113}
 114
 115impl From<ModelMode> for AnthropicModelMode {
 116    fn from(value: ModelMode) -> Self {
 117        match value {
 118            ModelMode::Default => AnthropicModelMode::Default,
 119            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 120        }
 121    }
 122}
 123
 124pub struct CloudLanguageModelProvider {
 125    client: Arc<Client>,
 126    state: gpui::Entity<State>,
 127    _maintain_client_status: Task<()>,
 128}
 129
 130pub struct State {
 131    client: Arc<Client>,
 132    llm_api_token: LlmApiToken,
 133    user_store: Entity<UserStore>,
 134    status: client::Status,
 135    accept_terms: Option<Task<Result<()>>>,
 136    _settings_subscription: Subscription,
 137    _llm_token_subscription: Subscription,
 138}
 139
 140impl State {
 141    fn new(
 142        client: Arc<Client>,
 143        user_store: Entity<UserStore>,
 144        status: client::Status,
 145        cx: &mut Context<Self>,
 146    ) -> Self {
 147        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 148
 149        Self {
 150            client: client.clone(),
 151            llm_api_token: LlmApiToken::default(),
 152            user_store,
 153            status,
 154            accept_terms: None,
 155            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 156                cx.notify();
 157            }),
 158            _llm_token_subscription: cx.subscribe(
 159                &refresh_llm_token_listener,
 160                |this, _listener, _event, cx| {
 161                    let client = this.client.clone();
 162                    let llm_api_token = this.llm_api_token.clone();
 163                    cx.spawn(async move |_this, _cx| {
 164                        llm_api_token.refresh(&client).await?;
 165                        anyhow::Ok(())
 166                    })
 167                    .detach_and_log_err(cx);
 168                },
 169            ),
 170        }
 171    }
 172
 173    fn is_signed_out(&self) -> bool {
 174        self.status.is_signed_out()
 175    }
 176
 177    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 178        let client = self.client.clone();
 179        cx.spawn(async move |this, cx| {
 180            client.authenticate_and_connect(true, &cx).await?;
 181            this.update(cx, |_, cx| cx.notify())
 182        })
 183    }
 184
 185    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 186        self.user_store
 187            .read(cx)
 188            .current_user_has_accepted_terms()
 189            .unwrap_or(false)
 190    }
 191
 192    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 193        let user_store = self.user_store.clone();
 194        self.accept_terms = Some(cx.spawn(async move |this, cx| {
 195            let _ = user_store
 196                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 197                .await;
 198            this.update(cx, |this, cx| {
 199                this.accept_terms = None;
 200                cx.notify()
 201            })
 202        }));
 203    }
 204}
 205
 206impl CloudLanguageModelProvider {
 207    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 208        let mut status_rx = client.status();
 209        let status = *status_rx.borrow();
 210
 211        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 212
 213        let state_ref = state.downgrade();
 214        let maintain_client_status = cx.spawn(async move |cx| {
 215            while let Some(status) = status_rx.next().await {
 216                if let Some(this) = state_ref.upgrade() {
 217                    _ = this.update(cx, |this, cx| {
 218                        if this.status != status {
 219                            this.status = status;
 220                            cx.notify();
 221                        }
 222                    });
 223                } else {
 224                    break;
 225                }
 226            }
 227        });
 228
 229        Self {
 230            client,
 231            state: state.clone(),
 232            _maintain_client_status: maintain_client_status,
 233        }
 234    }
 235
 236    fn create_language_model(
 237        &self,
 238        model: CloudModel,
 239        llm_api_token: LlmApiToken,
 240    ) -> Arc<dyn LanguageModel> {
 241        Arc::new(CloudLanguageModel {
 242            id: LanguageModelId::from(model.id().to_string()),
 243            model,
 244            llm_api_token: llm_api_token.clone(),
 245            client: self.client.clone(),
 246            request_limiter: RateLimiter::new(4),
 247        })
 248    }
 249}
 250
 251impl LanguageModelProviderState for CloudLanguageModelProvider {
 252    type ObservableEntity = State;
 253
 254    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 255        Some(self.state.clone())
 256    }
 257}
 258
 259impl LanguageModelProvider for CloudLanguageModelProvider {
 260    fn id(&self) -> LanguageModelProviderId {
 261        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 262    }
 263
 264    fn name(&self) -> LanguageModelProviderName {
 265        LanguageModelProviderName(PROVIDER_NAME.into())
 266    }
 267
 268    fn icon(&self) -> IconName {
 269        IconName::AiZed
 270    }
 271
 272    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 273        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 274        let model = CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet);
 275        Some(self.create_language_model(model, llm_api_token))
 276    }
 277
 278    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 279        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 280        let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet);
 281        Some(self.create_language_model(model, llm_api_token))
 282    }
 283
 284    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 285        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 286        [
 287            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
 288            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
 289        ]
 290        .into_iter()
 291        .map(|model| self.create_language_model(model, llm_api_token.clone()))
 292        .collect()
 293    }
 294
 295    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 296        let mut models = BTreeMap::default();
 297
 298        if cx.is_staff() {
 299            for model in anthropic::Model::iter() {
 300                if !matches!(model, anthropic::Model::Custom { .. }) {
 301                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
 302                }
 303            }
 304            for model in open_ai::Model::iter() {
 305                if !matches!(model, open_ai::Model::Custom { .. }) {
 306                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
 307                }
 308            }
 309            for model in google_ai::Model::iter() {
 310                if !matches!(model, google_ai::Model::Custom { .. }) {
 311                    models.insert(model.id().to_string(), CloudModel::Google(model));
 312                }
 313            }
 314        } else {
 315            models.insert(
 316                anthropic::Model::Claude3_5Sonnet.id().to_string(),
 317                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
 318            );
 319            models.insert(
 320                anthropic::Model::Claude3_7Sonnet.id().to_string(),
 321                CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
 322            );
 323            models.insert(
 324                anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
 325                CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
 326            );
 327        }
 328
 329        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
 330            zed_cloud_provider_additional_models()
 331        } else {
 332            &[]
 333        };
 334
 335        // Override with available models from settings
 336        for model in AllLanguageModelSettings::get_global(cx)
 337            .zed_dot_dev
 338            .available_models
 339            .iter()
 340            .chain(llm_closed_beta_models)
 341            .cloned()
 342        {
 343            let model = match model.provider {
 344                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
 345                    name: model.name.clone(),
 346                    display_name: model.display_name.clone(),
 347                    max_tokens: model.max_tokens,
 348                    tool_override: model.tool_override.clone(),
 349                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 350                        anthropic::AnthropicModelCacheConfiguration {
 351                            max_cache_anchors: config.max_cache_anchors,
 352                            should_speculate: config.should_speculate,
 353                            min_total_token: config.min_total_token,
 354                        }
 355                    }),
 356                    default_temperature: model.default_temperature,
 357                    max_output_tokens: model.max_output_tokens,
 358                    extra_beta_headers: model.extra_beta_headers.clone(),
 359                    mode: model.mode.unwrap_or_default().into(),
 360                }),
 361                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
 362                    name: model.name.clone(),
 363                    display_name: model.display_name.clone(),
 364                    max_tokens: model.max_tokens,
 365                    max_output_tokens: model.max_output_tokens,
 366                    max_completion_tokens: model.max_completion_tokens,
 367                }),
 368                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
 369                    name: model.name.clone(),
 370                    display_name: model.display_name.clone(),
 371                    max_tokens: model.max_tokens,
 372                }),
 373            };
 374            models.insert(model.id().to_string(), model.clone());
 375        }
 376
 377        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 378        models
 379            .into_values()
 380            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 381            .collect()
 382    }
 383
 384    fn is_authenticated(&self, cx: &App) -> bool {
 385        !self.state.read(cx).is_signed_out()
 386    }
 387
 388    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 389        Task::ready(Ok(()))
 390    }
 391
 392    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 393        cx.new(|_| ConfigurationView {
 394            state: self.state.clone(),
 395        })
 396        .into()
 397    }
 398
 399    fn must_accept_terms(&self, cx: &App) -> bool {
 400        !self.state.read(cx).has_accepted_terms_of_service(cx)
 401    }
 402
 403    fn render_accept_terms(
 404        &self,
 405        view: LanguageModelProviderTosView,
 406        cx: &mut App,
 407    ) -> Option<AnyElement> {
 408        render_accept_terms(self.state.clone(), view, cx)
 409    }
 410
 411    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 412        Task::ready(Ok(()))
 413    }
 414}
 415
 416fn render_accept_terms(
 417    state: Entity<State>,
 418    view_kind: LanguageModelProviderTosView,
 419    cx: &mut App,
 420) -> Option<AnyElement> {
 421    if state.read(cx).has_accepted_terms_of_service(cx) {
 422        return None;
 423    }
 424
 425    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
 426
 427    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 428    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
 429
 430    let terms_button = Button::new("terms_of_service", "Terms of Service")
 431        .style(ButtonStyle::Subtle)
 432        .icon(IconName::ArrowUpRight)
 433        .icon_color(Color::Muted)
 434        .icon_size(IconSize::XSmall)
 435        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 436        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 437
 438    let button_container = h_flex().child(
 439        Button::new("accept_terms", "I accept the Terms of Service")
 440            .when(!thread_empty_state, |this| {
 441                this.full_width()
 442                    .style(ButtonStyle::Tinted(TintColor::Accent))
 443                    .icon(IconName::Check)
 444                    .icon_position(IconPosition::Start)
 445                    .icon_size(IconSize::Small)
 446            })
 447            .when(thread_empty_state, |this| {
 448                this.style(ButtonStyle::Tinted(TintColor::Warning))
 449                    .label_size(LabelSize::Small)
 450            })
 451            .disabled(accept_terms_disabled)
 452            .on_click({
 453                let state = state.downgrade();
 454                move |_, _window, cx| {
 455                    state
 456                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
 457                        .ok();
 458                }
 459            }),
 460    );
 461
 462    let form = if thread_empty_state {
 463        h_flex()
 464            .w_full()
 465            .flex_wrap()
 466            .justify_between()
 467            .child(
 468                h_flex()
 469                    .child(
 470                        Label::new("To start using Zed AI, please read and accept the")
 471                            .size(LabelSize::Small),
 472                    )
 473                    .child(terms_button),
 474            )
 475            .child(button_container)
 476    } else {
 477        v_flex()
 478            .w_full()
 479            .gap_2()
 480            .child(
 481                h_flex()
 482                    .flex_wrap()
 483                    .when(thread_fresh_start, |this| this.justify_center())
 484                    .child(Label::new(
 485                        "To start using Zed AI, please read and accept the",
 486                    ))
 487                    .child(terms_button),
 488            )
 489            .child({
 490                match view_kind {
 491                    LanguageModelProviderTosView::PromptEditorPopup => {
 492                        button_container.w_full().justify_end()
 493                    }
 494                    LanguageModelProviderTosView::Configuration => {
 495                        button_container.w_full().justify_start()
 496                    }
 497                    LanguageModelProviderTosView::ThreadFreshStart => {
 498                        button_container.w_full().justify_center()
 499                    }
 500                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
 501                }
 502            })
 503    };
 504
 505    Some(form.into_any())
 506}
 507
 508pub struct CloudLanguageModel {
 509    id: LanguageModelId,
 510    model: CloudModel,
 511    llm_api_token: LlmApiToken,
 512    client: Arc<Client>,
 513    request_limiter: RateLimiter,
 514}
 515
 516struct PerformLlmCompletionResponse {
 517    response: Response<AsyncBody>,
 518    usage: Option<RequestUsage>,
 519    tool_use_limit_reached: bool,
 520    includes_queue_events: bool,
 521}
 522
 523impl CloudLanguageModel {
 524    const MAX_RETRIES: usize = 3;
 525
 526    async fn perform_llm_completion(
 527        client: Arc<Client>,
 528        llm_api_token: LlmApiToken,
 529        body: CompletionBody,
 530    ) -> Result<PerformLlmCompletionResponse> {
 531        let http_client = &client.http_client();
 532
 533        let mut token = llm_api_token.acquire(&client).await?;
 534        let mut retries_remaining = Self::MAX_RETRIES;
 535        let mut retry_delay = Duration::from_secs(1);
 536
 537        loop {
 538            let request_builder = http_client::Request::builder().method(Method::POST);
 539            let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
 540            {
 541                request_builder.uri(completions_url)
 542            } else {
 543                request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 544            };
 545            let request = request_builder
 546                .header("Content-Type", "application/json")
 547                .header("Authorization", format!("Bearer {token}"))
 548                .header("x-zed-client-supports-queueing", "true")
 549                .body(serde_json::to_string(&body)?.into())?;
 550            let mut response = http_client.send(request).await?;
 551            let status = response.status();
 552            if status.is_success() {
 553                let includes_queue_events = response
 554                    .headers()
 555                    .get("x-zed-server-supports-queueing")
 556                    .is_some();
 557                let tool_use_limit_reached = response
 558                    .headers()
 559                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 560                    .is_some();
 561                let usage = RequestUsage::from_headers(response.headers()).ok();
 562
 563                return Ok(PerformLlmCompletionResponse {
 564                    response,
 565                    usage,
 566                    includes_queue_events,
 567                    tool_use_limit_reached,
 568                });
 569            } else if response
 570                .headers()
 571                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 572                .is_some()
 573            {
 574                retries_remaining -= 1;
 575                token = llm_api_token.refresh(&client).await?;
 576            } else if status == StatusCode::FORBIDDEN
 577                && response
 578                    .headers()
 579                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
 580                    .is_some()
 581            {
 582                return Err(anyhow!(MaxMonthlySpendReachedError));
 583            } else if status == StatusCode::FORBIDDEN
 584                && response
 585                    .headers()
 586                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 587                    .is_some()
 588            {
 589                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 590                    .headers()
 591                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 592                    .and_then(|resource| resource.to_str().ok())
 593                {
 594                    if let Some(plan) = response
 595                        .headers()
 596                        .get(CURRENT_PLAN_HEADER_NAME)
 597                        .and_then(|plan| plan.to_str().ok())
 598                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 599                    {
 600                        let plan = match plan {
 601                            zed_llm_client::Plan::Free => Plan::Free,
 602                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 603                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 604                        };
 605                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 606                    }
 607                }
 608
 609                return Err(anyhow!("Forbidden"));
 610            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
 611                // If we encounter an error in the 500 range, retry after a delay.
 612                // We've seen at least these in the wild from API providers:
 613                // * 500 Internal Server Error
 614                // * 502 Bad Gateway
 615                // * 529 Service Overloaded
 616
 617                if retries_remaining == 0 {
 618                    let mut body = String::new();
 619                    response.body_mut().read_to_string(&mut body).await?;
 620                    return Err(anyhow!(
 621                        "cloud language model completion failed after {} retries with status {status}: {body}",
 622                        Self::MAX_RETRIES
 623                    ));
 624                }
 625
 626                Timer::after(retry_delay).await;
 627
 628                retries_remaining -= 1;
 629                retry_delay *= 2; // If it fails again, wait longer.
 630            } else if status == StatusCode::PAYMENT_REQUIRED {
 631                return Err(anyhow!(PaymentRequiredError));
 632            } else {
 633                let mut body = String::new();
 634                response.body_mut().read_to_string(&mut body).await?;
 635                return Err(anyhow!(ApiError { status, body }));
 636            }
 637        }
 638    }
 639}
 640
 641#[derive(Debug, Error)]
 642#[error("cloud language model completion failed with status {status}: {body}")]
 643struct ApiError {
 644    status: StatusCode,
 645    body: String,
 646}
 647
 648impl LanguageModel for CloudLanguageModel {
 649    fn id(&self) -> LanguageModelId {
 650        self.id.clone()
 651    }
 652
 653    fn name(&self) -> LanguageModelName {
 654        LanguageModelName::from(self.model.display_name().to_string())
 655    }
 656
 657    fn provider_id(&self) -> LanguageModelProviderId {
 658        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 659    }
 660
 661    fn provider_name(&self) -> LanguageModelProviderName {
 662        LanguageModelProviderName(PROVIDER_NAME.into())
 663    }
 664
 665    fn supports_tools(&self) -> bool {
 666        match self.model {
 667            CloudModel::Anthropic(_) => true,
 668            CloudModel::Google(_) => true,
 669            CloudModel::OpenAi(_) => true,
 670        }
 671    }
 672
 673    fn telemetry_id(&self) -> String {
 674        format!("zed.dev/{}", self.model.id())
 675    }
 676
 677    fn availability(&self) -> LanguageModelAvailability {
 678        self.model.availability()
 679    }
 680
 681    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 682        self.model.tool_input_format()
 683    }
 684
 685    fn max_token_count(&self) -> usize {
 686        self.model.max_token_count()
 687    }
 688
 689    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 690        match &self.model {
 691            CloudModel::Anthropic(model) => {
 692                model
 693                    .cache_configuration()
 694                    .map(|cache| LanguageModelCacheConfiguration {
 695                        max_cache_anchors: cache.max_cache_anchors,
 696                        should_speculate: cache.should_speculate,
 697                        min_total_token: cache.min_total_token,
 698                    })
 699            }
 700            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
 701        }
 702    }
 703
 704    fn count_tokens(
 705        &self,
 706        request: LanguageModelRequest,
 707        cx: &App,
 708    ) -> BoxFuture<'static, Result<usize>> {
 709        match self.model.clone() {
 710            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
 711            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
 712            CloudModel::Google(model) => {
 713                let client = self.client.clone();
 714                let llm_api_token = self.llm_api_token.clone();
 715                let request = into_google(request, model.id().into());
 716                async move {
 717                    let http_client = &client.http_client();
 718                    let token = llm_api_token.acquire(&client).await?;
 719
 720                    let request_builder = http_client::Request::builder().method(Method::POST);
 721                    let request_builder =
 722                        if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
 723                            request_builder.uri(completions_url)
 724                        } else {
 725                            request_builder.uri(
 726                                http_client
 727                                    .build_zed_llm_url("/count_tokens", &[])?
 728                                    .as_ref(),
 729                            )
 730                        };
 731                    let request_body = CountTokensBody {
 732                        provider: zed_llm_client::LanguageModelProvider::Google,
 733                        model: model.id().into(),
 734                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 735                            contents: request.contents,
 736                        })?,
 737                    };
 738                    let request = request_builder
 739                        .header("Content-Type", "application/json")
 740                        .header("Authorization", format!("Bearer {token}"))
 741                        .body(serde_json::to_string(&request_body)?.into())?;
 742                    let mut response = http_client.send(request).await?;
 743                    let status = response.status();
 744                    let mut response_body = String::new();
 745                    response
 746                        .body_mut()
 747                        .read_to_string(&mut response_body)
 748                        .await?;
 749
 750                    if status.is_success() {
 751                        let response_body: CountTokensResponse =
 752                            serde_json::from_str(&response_body)?;
 753
 754                        Ok(response_body.tokens)
 755                    } else {
 756                        Err(anyhow!(ApiError {
 757                            status,
 758                            body: response_body
 759                        }))
 760                    }
 761                }
 762                .boxed()
 763            }
 764        }
 765    }
 766
 767    fn stream_completion(
 768        &self,
 769        request: LanguageModelRequest,
 770        cx: &AsyncApp,
 771    ) -> BoxFuture<
 772        'static,
 773        Result<
 774            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 775        >,
 776    > {
 777        self.stream_completion_with_usage(request, cx)
 778            .map(|result| result.map(|(stream, _)| stream))
 779            .boxed()
 780    }
 781
 782    fn stream_completion_with_usage(
 783        &self,
 784        request: LanguageModelRequest,
 785        _cx: &AsyncApp,
 786    ) -> BoxFuture<
 787        'static,
 788        Result<(
 789            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 790            Option<RequestUsage>,
 791        )>,
 792    > {
 793        let thread_id = request.thread_id.clone();
 794        let prompt_id = request.prompt_id.clone();
 795        let mode = request.mode;
 796        match &self.model {
 797            CloudModel::Anthropic(model) => {
 798                let request = into_anthropic(
 799                    request,
 800                    model.request_id().into(),
 801                    model.default_temperature(),
 802                    model.max_output_tokens(),
 803                    model.mode(),
 804                );
 805                let client = self.client.clone();
 806                let llm_api_token = self.llm_api_token.clone();
 807                let future = self.request_limiter.stream_with_usage(async move {
 808                    let PerformLlmCompletionResponse {
 809                        response,
 810                        usage,
 811                        includes_queue_events,
 812                        tool_use_limit_reached,
 813                    } = Self::perform_llm_completion(
 814                        client.clone(),
 815                        llm_api_token,
 816                        CompletionBody {
 817                            thread_id,
 818                            prompt_id,
 819                            mode,
 820                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 821                            model: request.model.clone(),
 822                            provider_request: serde_json::to_value(&request)?,
 823                        },
 824                    )
 825                    .await
 826                    .map_err(|err| match err.downcast::<ApiError>() {
 827                        Ok(api_err) => {
 828                            if api_err.status == StatusCode::BAD_REQUEST {
 829                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
 830                                    return anyhow!(
 831                                        LanguageModelKnownError::ContextWindowLimitExceeded {
 832                                            tokens
 833                                        }
 834                                    );
 835                                }
 836                            }
 837                            anyhow!(api_err)
 838                        }
 839                        Err(err) => anyhow!(err),
 840                    })?;
 841
 842                    let mut mapper = AnthropicEventMapper::new();
 843                    Ok((
 844                        map_cloud_completion_events(
 845                            Box::pin(
 846                                response_lines(response, includes_queue_events)
 847                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 848                            ),
 849                            move |event| mapper.map_event(event),
 850                        ),
 851                        usage,
 852                    ))
 853                });
 854                async move {
 855                    let (stream, usage) = future.await?;
 856                    Ok((stream.boxed(), usage))
 857                }
 858                .boxed()
 859            }
 860            CloudModel::OpenAi(model) => {
 861                let client = self.client.clone();
 862                let request = into_open_ai(request, model, model.max_output_tokens());
 863                let llm_api_token = self.llm_api_token.clone();
 864                let future = self.request_limiter.stream_with_usage(async move {
 865                    let PerformLlmCompletionResponse {
 866                        response,
 867                        usage,
 868                        includes_queue_events,
 869                        tool_use_limit_reached,
 870                    } = Self::perform_llm_completion(
 871                        client.clone(),
 872                        llm_api_token,
 873                        CompletionBody {
 874                            thread_id,
 875                            prompt_id,
 876                            mode,
 877                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 878                            model: request.model.clone(),
 879                            provider_request: serde_json::to_value(&request)?,
 880                        },
 881                    )
 882                    .await?;
 883
 884                    let mut mapper = OpenAiEventMapper::new();
 885                    Ok((
 886                        map_cloud_completion_events(
 887                            Box::pin(
 888                                response_lines(response, includes_queue_events)
 889                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 890                            ),
 891                            move |event| mapper.map_event(event),
 892                        ),
 893                        usage,
 894                    ))
 895                });
 896                async move {
 897                    let (stream, usage) = future.await?;
 898                    Ok((stream.boxed(), usage))
 899                }
 900                .boxed()
 901            }
 902            CloudModel::Google(model) => {
 903                let client = self.client.clone();
 904                let request = into_google(request, model.id().into());
 905                let llm_api_token = self.llm_api_token.clone();
 906                let future = self.request_limiter.stream_with_usage(async move {
 907                    let PerformLlmCompletionResponse {
 908                        response,
 909                        usage,
 910                        includes_queue_events,
 911                        tool_use_limit_reached,
 912                    } = Self::perform_llm_completion(
 913                        client.clone(),
 914                        llm_api_token,
 915                        CompletionBody {
 916                            thread_id,
 917                            prompt_id,
 918                            mode,
 919                            provider: zed_llm_client::LanguageModelProvider::Google,
 920                            model: request.model.clone(),
 921                            provider_request: serde_json::to_value(&request)?,
 922                        },
 923                    )
 924                    .await?;
 925
 926                    let mut mapper = GoogleEventMapper::new();
 927                    Ok((
 928                        map_cloud_completion_events(
 929                            Box::pin(
 930                                response_lines(response, includes_queue_events)
 931                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 932                            ),
 933                            move |event| mapper.map_event(event),
 934                        ),
 935                        usage,
 936                    ))
 937                });
 938                async move {
 939                    let (stream, usage) = future.await?;
 940                    Ok((stream.boxed(), usage))
 941                }
 942                .boxed()
 943            }
 944        }
 945    }
 946}
 947
 948#[derive(Serialize, Deserialize)]
 949#[serde(rename_all = "snake_case")]
 950pub enum CloudCompletionEvent<T> {
 951    System(CompletionRequestStatus),
 952    Event(T),
 953}
 954
 955fn map_cloud_completion_events<T, F>(
 956    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 957    mut map_callback: F,
 958) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 959where
 960    T: DeserializeOwned + 'static,
 961    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 962        + Send
 963        + 'static,
 964{
 965    stream
 966        .flat_map(move |event| {
 967            futures::stream::iter(match event {
 968                Err(error) => {
 969                    vec![Err(LanguageModelCompletionError::Other(error))]
 970                }
 971                Ok(CloudCompletionEvent::System(event)) => {
 972                    vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
 973                }
 974                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
 975            })
 976        })
 977        .boxed()
 978}
 979
 980fn tool_use_limit_reached_event<T>(
 981    tool_use_limit_reached: bool,
 982) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 983    futures::stream::iter(tool_use_limit_reached.then(|| {
 984        Ok(CloudCompletionEvent::System(
 985            CompletionRequestStatus::ToolUseLimitReached,
 986        ))
 987    }))
 988}
 989
 990fn response_lines<T: DeserializeOwned>(
 991    response: Response<AsyncBody>,
 992    includes_queue_events: bool,
 993) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
 994    futures::stream::try_unfold(
 995        (String::new(), BufReader::new(response.into_body())),
 996        move |(mut line, mut body)| async move {
 997            match body.read_line(&mut line).await {
 998                Ok(0) => Ok(None),
 999                Ok(_) => {
1000                    let event = if includes_queue_events {
1001                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1002                    } else {
1003                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1004                    };
1005
1006                    line.clear();
1007                    Ok(Some((event, (line, body))))
1008                }
1009                Err(e) => Err(e.into()),
1010            }
1011        },
1012    )
1013}
1014
1015struct ConfigurationView {
1016    state: gpui::Entity<State>,
1017}
1018
1019impl ConfigurationView {
1020    fn authenticate(&mut self, cx: &mut Context<Self>) {
1021        self.state.update(cx, |state, cx| {
1022            state.authenticate(cx).detach_and_log_err(cx);
1023        });
1024        cx.notify();
1025    }
1026}
1027
1028impl Render for ConfigurationView {
1029    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1030        const ZED_AI_URL: &str = "https://zed.dev/ai";
1031
1032        let is_connected = !self.state.read(cx).is_signed_out();
1033        let plan = self.state.read(cx).user_store.read(cx).current_plan();
1034        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1035
1036        let is_pro = plan == Some(proto::Plan::ZedPro);
1037        let subscription_text = Label::new(if is_pro {
1038            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
1039        } else {
1040            "You have basic access to models from Anthropic through the Zed AI Free plan."
1041        });
1042        let manage_subscription_button = if is_pro {
1043            Some(
1044                h_flex().child(
1045                    Button::new("manage_settings", "Manage Subscription")
1046                        .style(ButtonStyle::Tinted(TintColor::Accent))
1047                        .on_click(
1048                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1049                        ),
1050                ),
1051            )
1052        } else if cx.has_flag::<ZedProFeatureFlag>() {
1053            Some(
1054                h_flex()
1055                    .gap_2()
1056                    .child(
1057                        Button::new("learn_more", "Learn more")
1058                            .style(ButtonStyle::Subtle)
1059                            .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
1060                    )
1061                    .child(
1062                        Button::new("upgrade", "Upgrade")
1063                            .style(ButtonStyle::Subtle)
1064                            .color(Color::Accent)
1065                            .on_click(
1066                                cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1067                            ),
1068                    ),
1069            )
1070        } else {
1071            None
1072        };
1073
1074        if is_connected {
1075            v_flex()
1076                .gap_3()
1077                .w_full()
1078                .children(render_accept_terms(
1079                    self.state.clone(),
1080                    LanguageModelProviderTosView::Configuration,
1081                    cx,
1082                ))
1083                .when(has_accepted_terms, |this| {
1084                    this.child(subscription_text)
1085                        .children(manage_subscription_button)
1086                })
1087        } else {
1088            v_flex()
1089                .gap_2()
1090                .child(Label::new("Use Zed AI to access hosted language models."))
1091                .child(
1092                    Button::new("sign_in", "Sign In")
1093                        .icon_color(Color::Muted)
1094                        .icon(IconName::Github)
1095                        .icon_position(IconPosition::Start)
1096                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1097                )
1098        }
1099    }
1100}