cloud.rs

   1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
   2use anyhow::{Context as _, Result, anyhow};
   3use client::{Client, UserStore, zed_urls};
   4use futures::{
   5    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   6};
   7use gpui::{
   8    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
   9};
  10use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  11use language_model::{
  12    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  13    LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
  14    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  15    LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
  16    LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
  17    ZED_CLOUD_PROVIDER_ID,
  18};
  19use language_model::{
  20    LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
  21    RefreshLlmTokenListener,
  22};
  23use proto::Plan;
  24use release_channel::AppVersion;
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize, de::DeserializeOwned};
  27use settings::SettingsStore;
  28use smol::Timer;
  29use smol::io::{AsyncReadExt, BufReader};
  30use std::pin::Pin;
  31use std::str::FromStr as _;
  32use std::sync::Arc;
  33use std::time::Duration;
  34use thiserror::Error;
  35use ui::{TintColor, prelude::*};
  36use util::{ResultExt as _, maybe};
  37use zed_llm_client::{
  38    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
  39    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  40    ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
  41    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  42    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  43};
  44
  45use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  46use crate::provider::google::{GoogleEventMapper, into_google};
  47use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  48
  49pub const PROVIDER_NAME: &str = "Zed";
  50
  51#[derive(Default, Clone, Debug, PartialEq)]
  52pub struct ZedDotDevSettings {
  53    pub available_models: Vec<AvailableModel>,
  54}
  55
  56#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  57#[serde(rename_all = "lowercase")]
  58pub enum AvailableProvider {
  59    Anthropic,
  60    OpenAi,
  61    Google,
  62}
  63
  64#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  65pub struct AvailableModel {
  66    /// The provider of the language model.
  67    pub provider: AvailableProvider,
  68    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  69    pub name: String,
  70    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  71    pub display_name: Option<String>,
  72    /// The size of the context window, indicating the maximum number of tokens the model can process.
  73    pub max_tokens: usize,
  74    /// The maximum number of output tokens allowed by the model.
  75    pub max_output_tokens: Option<u32>,
  76    /// The maximum number of completion tokens allowed by the model (o1-* only)
  77    pub max_completion_tokens: Option<u32>,
  78    /// Override this model with a different Anthropic model for tool calls.
  79    pub tool_override: Option<String>,
  80    /// Indicates whether this custom model supports caching.
  81    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  82    /// The default temperature to use for this model.
  83    pub default_temperature: Option<f32>,
  84    /// Any extra beta headers to provide when using the model.
  85    #[serde(default)]
  86    pub extra_beta_headers: Vec<String>,
  87    /// The model's mode (e.g. thinking)
  88    pub mode: Option<ModelMode>,
  89}
  90
  91#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  92#[serde(tag = "type", rename_all = "lowercase")]
  93pub enum ModelMode {
  94    #[default]
  95    Default,
  96    Thinking {
  97        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  98        budget_tokens: Option<u32>,
  99    },
 100}
 101
 102impl From<ModelMode> for AnthropicModelMode {
 103    fn from(value: ModelMode) -> Self {
 104        match value {
 105            ModelMode::Default => AnthropicModelMode::Default,
 106            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 107        }
 108    }
 109}
 110
 111pub struct CloudLanguageModelProvider {
 112    client: Arc<Client>,
 113    state: gpui::Entity<State>,
 114    _maintain_client_status: Task<()>,
 115}
 116
 117pub struct State {
 118    client: Arc<Client>,
 119    llm_api_token: LlmApiToken,
 120    user_store: Entity<UserStore>,
 121    status: client::Status,
 122    accept_terms: Option<Task<Result<()>>>,
 123    models: Vec<Arc<zed_llm_client::LanguageModel>>,
 124    default_model: Option<Arc<zed_llm_client::LanguageModel>>,
 125    default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
 126    recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
 127    _fetch_models_task: Task<()>,
 128    _settings_subscription: Subscription,
 129    _llm_token_subscription: Subscription,
 130}
 131
 132impl State {
 133    fn new(
 134        client: Arc<Client>,
 135        user_store: Entity<UserStore>,
 136        status: client::Status,
 137        cx: &mut Context<Self>,
 138    ) -> Self {
 139        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 140
 141        Self {
 142            client: client.clone(),
 143            llm_api_token: LlmApiToken::default(),
 144            user_store,
 145            status,
 146            accept_terms: None,
 147            models: Vec::new(),
 148            default_model: None,
 149            default_fast_model: None,
 150            recommended_models: Vec::new(),
 151            _fetch_models_task: cx.spawn(async move |this, cx| {
 152                maybe!(async move {
 153                    let (client, llm_api_token) = this
 154                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 155
 156                    loop {
 157                        let status = this.read_with(cx, |this, _cx| this.status)?;
 158                        if matches!(status, client::Status::Connected { .. }) {
 159                            break;
 160                        }
 161
 162                        cx.background_executor()
 163                            .timer(Duration::from_millis(100))
 164                            .await;
 165                    }
 166
 167                    let response = Self::fetch_models(client, llm_api_token).await?;
 168                    cx.update(|cx| {
 169                        this.update(cx, |this, cx| {
 170                            let mut models = Vec::new();
 171
 172                            for model in response.models {
 173                                models.push(Arc::new(model.clone()));
 174
 175                                // Right now we represent thinking variants of models as separate models on the client,
 176                                // so we need to insert variants for any model that supports thinking.
 177                                if model.supports_thinking {
 178                                    models.push(Arc::new(zed_llm_client::LanguageModel {
 179                                        id: zed_llm_client::LanguageModelId(
 180                                            format!("{}-thinking", model.id).into(),
 181                                        ),
 182                                        display_name: format!("{} Thinking", model.display_name),
 183                                        ..model
 184                                    }));
 185                                }
 186                            }
 187
 188                            this.default_model = models
 189                                .iter()
 190                                .find(|model| model.id == response.default_model)
 191                                .cloned();
 192                            this.default_fast_model = models
 193                                .iter()
 194                                .find(|model| model.id == response.default_fast_model)
 195                                .cloned();
 196                            this.recommended_models = response
 197                                .recommended_models
 198                                .iter()
 199                                .filter_map(|id| models.iter().find(|model| &model.id == id))
 200                                .cloned()
 201                                .collect();
 202                            this.models = models;
 203                            cx.notify();
 204                        })
 205                    })??;
 206
 207                    anyhow::Ok(())
 208                })
 209                .await
 210                .context("failed to fetch Zed models")
 211                .log_err();
 212            }),
 213            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 214                cx.notify();
 215            }),
 216            _llm_token_subscription: cx.subscribe(
 217                &refresh_llm_token_listener,
 218                |this, _listener, _event, cx| {
 219                    let client = this.client.clone();
 220                    let llm_api_token = this.llm_api_token.clone();
 221                    cx.spawn(async move |_this, _cx| {
 222                        llm_api_token.refresh(&client).await?;
 223                        anyhow::Ok(())
 224                    })
 225                    .detach_and_log_err(cx);
 226                },
 227            ),
 228        }
 229    }
 230
 231    fn is_signed_out(&self) -> bool {
 232        self.status.is_signed_out()
 233    }
 234
 235    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 236        let client = self.client.clone();
 237        cx.spawn(async move |state, cx| {
 238            client
 239                .authenticate_and_connect(true, &cx)
 240                .await
 241                .into_response()?;
 242            state.update(cx, |_, cx| cx.notify())
 243        })
 244    }
 245
 246    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 247        self.user_store
 248            .read(cx)
 249            .current_user_has_accepted_terms()
 250            .unwrap_or(false)
 251    }
 252
 253    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 254        let user_store = self.user_store.clone();
 255        self.accept_terms = Some(cx.spawn(async move |this, cx| {
 256            let _ = user_store
 257                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 258                .await;
 259            this.update(cx, |this, cx| {
 260                this.accept_terms = None;
 261                cx.notify()
 262            })
 263        }));
 264    }
 265
 266    async fn fetch_models(
 267        client: Arc<Client>,
 268        llm_api_token: LlmApiToken,
 269    ) -> Result<ListModelsResponse> {
 270        let http_client = &client.http_client();
 271        let token = llm_api_token.acquire(&client).await?;
 272
 273        let request = http_client::Request::builder()
 274            .method(Method::GET)
 275            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 276            .header("Authorization", format!("Bearer {token}"))
 277            .body(AsyncBody::empty())?;
 278        let mut response = http_client
 279            .send(request)
 280            .await
 281            .context("failed to send list models request")?;
 282
 283        if response.status().is_success() {
 284            let mut body = String::new();
 285            response.body_mut().read_to_string(&mut body).await?;
 286            return Ok(serde_json::from_str(&body)?);
 287        } else {
 288            let mut body = String::new();
 289            response.body_mut().read_to_string(&mut body).await?;
 290            anyhow::bail!(
 291                "error listing models.\nStatus: {:?}\nBody: {body}",
 292                response.status(),
 293            );
 294        }
 295    }
 296}
 297
 298impl CloudLanguageModelProvider {
 299    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 300        let mut status_rx = client.status();
 301        let status = *status_rx.borrow();
 302
 303        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 304
 305        let state_ref = state.downgrade();
 306        let maintain_client_status = cx.spawn(async move |cx| {
 307            while let Some(status) = status_rx.next().await {
 308                if let Some(this) = state_ref.upgrade() {
 309                    _ = this.update(cx, |this, cx| {
 310                        if this.status != status {
 311                            this.status = status;
 312                            cx.notify();
 313                        }
 314                    });
 315                } else {
 316                    break;
 317                }
 318            }
 319        });
 320
 321        Self {
 322            client,
 323            state: state.clone(),
 324            _maintain_client_status: maintain_client_status,
 325        }
 326    }
 327
 328    fn create_language_model(
 329        &self,
 330        model: Arc<zed_llm_client::LanguageModel>,
 331        llm_api_token: LlmApiToken,
 332    ) -> Arc<dyn LanguageModel> {
 333        Arc::new(CloudLanguageModel {
 334            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 335            model,
 336            llm_api_token: llm_api_token.clone(),
 337            client: self.client.clone(),
 338            request_limiter: RateLimiter::new(4),
 339        })
 340    }
 341}
 342
 343impl LanguageModelProviderState for CloudLanguageModelProvider {
 344    type ObservableEntity = State;
 345
 346    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 347        Some(self.state.clone())
 348    }
 349}
 350
 351impl LanguageModelProvider for CloudLanguageModelProvider {
 352    fn id(&self) -> LanguageModelProviderId {
 353        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 354    }
 355
 356    fn name(&self) -> LanguageModelProviderName {
 357        LanguageModelProviderName(PROVIDER_NAME.into())
 358    }
 359
 360    fn icon(&self) -> IconName {
 361        IconName::AiZed
 362    }
 363
 364    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 365        let default_model = self.state.read(cx).default_model.clone()?;
 366        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 367        Some(self.create_language_model(default_model, llm_api_token))
 368    }
 369
 370    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 371        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 372        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 373        Some(self.create_language_model(default_fast_model, llm_api_token))
 374    }
 375
 376    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 377        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 378        self.state
 379            .read(cx)
 380            .recommended_models
 381            .iter()
 382            .cloned()
 383            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 384            .collect()
 385    }
 386
 387    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 388        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 389        self.state
 390            .read(cx)
 391            .models
 392            .iter()
 393            .cloned()
 394            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 395            .collect()
 396    }
 397
 398    fn is_authenticated(&self, cx: &App) -> bool {
 399        !self.state.read(cx).is_signed_out()
 400    }
 401
 402    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 403        Task::ready(Ok(()))
 404    }
 405
 406    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 407        cx.new(|_| ConfigurationView {
 408            state: self.state.clone(),
 409        })
 410        .into()
 411    }
 412
 413    fn must_accept_terms(&self, cx: &App) -> bool {
 414        !self.state.read(cx).has_accepted_terms_of_service(cx)
 415    }
 416
 417    fn render_accept_terms(
 418        &self,
 419        view: LanguageModelProviderTosView,
 420        cx: &mut App,
 421    ) -> Option<AnyElement> {
 422        render_accept_terms(self.state.clone(), view, cx)
 423    }
 424
 425    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 426        Task::ready(Ok(()))
 427    }
 428}
 429
 430fn render_accept_terms(
 431    state: Entity<State>,
 432    view_kind: LanguageModelProviderTosView,
 433    cx: &mut App,
 434) -> Option<AnyElement> {
 435    if state.read(cx).has_accepted_terms_of_service(cx) {
 436        return None;
 437    }
 438
 439    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
 440
 441    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 442    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
 443
 444    let terms_button = Button::new("terms_of_service", "Terms of Service")
 445        .style(ButtonStyle::Subtle)
 446        .icon(IconName::ArrowUpRight)
 447        .icon_color(Color::Muted)
 448        .icon_size(IconSize::XSmall)
 449        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 450        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 451
 452    let button_container = h_flex().child(
 453        Button::new("accept_terms", "I accept the Terms of Service")
 454            .when(!thread_empty_state, |this| {
 455                this.full_width()
 456                    .style(ButtonStyle::Tinted(TintColor::Accent))
 457                    .icon(IconName::Check)
 458                    .icon_position(IconPosition::Start)
 459                    .icon_size(IconSize::Small)
 460            })
 461            .when(thread_empty_state, |this| {
 462                this.style(ButtonStyle::Tinted(TintColor::Warning))
 463                    .label_size(LabelSize::Small)
 464            })
 465            .disabled(accept_terms_disabled)
 466            .on_click({
 467                let state = state.downgrade();
 468                move |_, _window, cx| {
 469                    state
 470                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
 471                        .ok();
 472                }
 473            }),
 474    );
 475
 476    let form = if thread_empty_state {
 477        h_flex()
 478            .w_full()
 479            .flex_wrap()
 480            .justify_between()
 481            .child(
 482                h_flex()
 483                    .child(
 484                        Label::new("To start using Zed AI, please read and accept the")
 485                            .size(LabelSize::Small),
 486                    )
 487                    .child(terms_button),
 488            )
 489            .child(button_container)
 490    } else {
 491        v_flex()
 492            .w_full()
 493            .gap_2()
 494            .child(
 495                h_flex()
 496                    .flex_wrap()
 497                    .when(thread_fresh_start, |this| this.justify_center())
 498                    .child(Label::new(
 499                        "To start using Zed AI, please read and accept the",
 500                    ))
 501                    .child(terms_button),
 502            )
 503            .child({
 504                match view_kind {
 505                    LanguageModelProviderTosView::PromptEditorPopup => {
 506                        button_container.w_full().justify_end()
 507                    }
 508                    LanguageModelProviderTosView::Configuration => {
 509                        button_container.w_full().justify_start()
 510                    }
 511                    LanguageModelProviderTosView::ThreadFreshStart => {
 512                        button_container.w_full().justify_center()
 513                    }
 514                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
 515                }
 516            })
 517    };
 518
 519    Some(form.into_any())
 520}
 521
 522pub struct CloudLanguageModel {
 523    id: LanguageModelId,
 524    model: Arc<zed_llm_client::LanguageModel>,
 525    llm_api_token: LlmApiToken,
 526    client: Arc<Client>,
 527    request_limiter: RateLimiter,
 528}
 529
 530struct PerformLlmCompletionResponse {
 531    response: Response<AsyncBody>,
 532    usage: Option<RequestUsage>,
 533    tool_use_limit_reached: bool,
 534    includes_status_messages: bool,
 535}
 536
 537impl CloudLanguageModel {
 538    const MAX_RETRIES: usize = 3;
 539
 540    async fn perform_llm_completion(
 541        client: Arc<Client>,
 542        llm_api_token: LlmApiToken,
 543        app_version: Option<SemanticVersion>,
 544        body: CompletionBody,
 545    ) -> Result<PerformLlmCompletionResponse> {
 546        let http_client = &client.http_client();
 547
 548        let mut token = llm_api_token.acquire(&client).await?;
 549        let mut retries_remaining = Self::MAX_RETRIES;
 550        let mut retry_delay = Duration::from_secs(1);
 551
 552        loop {
 553            let request_builder = http_client::Request::builder()
 554                .method(Method::POST)
 555                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 556            let request_builder = if let Some(app_version) = app_version {
 557                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 558            } else {
 559                request_builder
 560            };
 561
 562            let request = request_builder
 563                .header("Content-Type", "application/json")
 564                .header("Authorization", format!("Bearer {token}"))
 565                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 566                .body(serde_json::to_string(&body)?.into())?;
 567            let mut response = http_client.send(request).await?;
 568            let status = response.status();
 569            if status.is_success() {
 570                let includes_status_messages = response
 571                    .headers()
 572                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 573                    .is_some();
 574
 575                let tool_use_limit_reached = response
 576                    .headers()
 577                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 578                    .is_some();
 579
 580                let usage = if includes_status_messages {
 581                    None
 582                } else {
 583                    RequestUsage::from_headers(response.headers()).ok()
 584                };
 585
 586                return Ok(PerformLlmCompletionResponse {
 587                    response,
 588                    usage,
 589                    includes_status_messages,
 590                    tool_use_limit_reached,
 591                });
 592            } else if response
 593                .headers()
 594                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 595                .is_some()
 596            {
 597                retries_remaining -= 1;
 598                token = llm_api_token.refresh(&client).await?;
 599            } else if status == StatusCode::FORBIDDEN
 600                && response
 601                    .headers()
 602                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 603                    .is_some()
 604            {
 605                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 606                    .headers()
 607                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 608                    .and_then(|resource| resource.to_str().ok())
 609                {
 610                    if let Some(plan) = response
 611                        .headers()
 612                        .get(CURRENT_PLAN_HEADER_NAME)
 613                        .and_then(|plan| plan.to_str().ok())
 614                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 615                    {
 616                        let plan = match plan {
 617                            zed_llm_client::Plan::ZedFree => Plan::Free,
 618                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 619                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 620                        };
 621                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 622                    }
 623                }
 624
 625                anyhow::bail!("Forbidden");
 626            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
 627                // If we encounter an error in the 500 range, retry after a delay.
 628                // We've seen at least these in the wild from API providers:
 629                // * 500 Internal Server Error
 630                // * 502 Bad Gateway
 631                // * 529 Service Overloaded
 632
 633                if retries_remaining == 0 {
 634                    let mut body = String::new();
 635                    response.body_mut().read_to_string(&mut body).await?;
 636                    anyhow::bail!(
 637                        "cloud language model completion failed after {} retries with status {status}: {body}",
 638                        Self::MAX_RETRIES
 639                    );
 640                }
 641
 642                Timer::after(retry_delay).await;
 643
 644                retries_remaining -= 1;
 645                retry_delay *= 2; // If it fails again, wait longer.
 646            } else if status == StatusCode::PAYMENT_REQUIRED {
 647                return Err(anyhow!(PaymentRequiredError));
 648            } else {
 649                let mut body = String::new();
 650                response.body_mut().read_to_string(&mut body).await?;
 651                return Err(anyhow!(ApiError { status, body }));
 652            }
 653        }
 654    }
 655}
 656
 657#[derive(Debug, Error)]
 658#[error("cloud language model request failed with status {status}: {body}")]
 659struct ApiError {
 660    status: StatusCode,
 661    body: String,
 662}
 663
 664impl LanguageModel for CloudLanguageModel {
 665    fn id(&self) -> LanguageModelId {
 666        self.id.clone()
 667    }
 668
 669    fn name(&self) -> LanguageModelName {
 670        LanguageModelName::from(self.model.display_name.clone())
 671    }
 672
 673    fn provider_id(&self) -> LanguageModelProviderId {
 674        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
 675    }
 676
 677    fn provider_name(&self) -> LanguageModelProviderName {
 678        LanguageModelProviderName(PROVIDER_NAME.into())
 679    }
 680
 681    fn supports_tools(&self) -> bool {
 682        self.model.supports_tools
 683    }
 684
 685    fn supports_images(&self) -> bool {
 686        self.model.supports_images
 687    }
 688
 689    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 690        match choice {
 691            LanguageModelToolChoice::Auto
 692            | LanguageModelToolChoice::Any
 693            | LanguageModelToolChoice::None => true,
 694        }
 695    }
 696
 697    fn supports_max_mode(&self) -> bool {
 698        self.model.supports_max_mode
 699    }
 700
 701    fn telemetry_id(&self) -> String {
 702        format!("zed.dev/{}", self.model.id)
 703    }
 704
 705    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 706        match self.model.provider {
 707            zed_llm_client::LanguageModelProvider::Anthropic
 708            | zed_llm_client::LanguageModelProvider::OpenAi => {
 709                LanguageModelToolSchemaFormat::JsonSchema
 710            }
 711            zed_llm_client::LanguageModelProvider::Google => {
 712                LanguageModelToolSchemaFormat::JsonSchemaSubset
 713            }
 714        }
 715    }
 716
 717    fn max_token_count(&self) -> usize {
 718        self.model.max_token_count
 719    }
 720
 721    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 722        match &self.model.provider {
 723            zed_llm_client::LanguageModelProvider::Anthropic => {
 724                Some(LanguageModelCacheConfiguration {
 725                    min_total_token: 2_048,
 726                    should_speculate: true,
 727                    max_cache_anchors: 4,
 728                })
 729            }
 730            zed_llm_client::LanguageModelProvider::OpenAi
 731            | zed_llm_client::LanguageModelProvider::Google => None,
 732        }
 733    }
 734
 735    fn count_tokens(
 736        &self,
 737        request: LanguageModelRequest,
 738        cx: &App,
 739    ) -> BoxFuture<'static, Result<usize>> {
 740        match self.model.provider {
 741            zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
 742            zed_llm_client::LanguageModelProvider::OpenAi => {
 743                let model = match open_ai::Model::from_id(&self.model.id.0) {
 744                    Ok(model) => model,
 745                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 746                };
 747                count_open_ai_tokens(request, model, cx)
 748            }
 749            zed_llm_client::LanguageModelProvider::Google => {
 750                let client = self.client.clone();
 751                let llm_api_token = self.llm_api_token.clone();
 752                let model_id = self.model.id.to_string();
 753                let generate_content_request = into_google(request, model_id.clone());
 754                async move {
 755                    let http_client = &client.http_client();
 756                    let token = llm_api_token.acquire(&client).await?;
 757
 758                    let request_body = CountTokensBody {
 759                        provider: zed_llm_client::LanguageModelProvider::Google,
 760                        model: model_id,
 761                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 762                            generate_content_request,
 763                        })?,
 764                    };
 765                    let request = http_client::Request::builder()
 766                        .method(Method::POST)
 767                        .uri(
 768                            http_client
 769                                .build_zed_llm_url("/count_tokens", &[])?
 770                                .as_ref(),
 771                        )
 772                        .header("Content-Type", "application/json")
 773                        .header("Authorization", format!("Bearer {token}"))
 774                        .body(serde_json::to_string(&request_body)?.into())?;
 775                    let mut response = http_client.send(request).await?;
 776                    let status = response.status();
 777                    let mut response_body = String::new();
 778                    response
 779                        .body_mut()
 780                        .read_to_string(&mut response_body)
 781                        .await?;
 782
 783                    if status.is_success() {
 784                        let response_body: CountTokensResponse =
 785                            serde_json::from_str(&response_body)?;
 786
 787                        Ok(response_body.tokens)
 788                    } else {
 789                        Err(anyhow!(ApiError {
 790                            status,
 791                            body: response_body
 792                        }))
 793                    }
 794                }
 795                .boxed()
 796            }
 797        }
 798    }
 799
 800    fn stream_completion(
 801        &self,
 802        request: LanguageModelRequest,
 803        cx: &AsyncApp,
 804    ) -> BoxFuture<
 805        'static,
 806        Result<
 807            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 808        >,
 809    > {
 810        let thread_id = request.thread_id.clone();
 811        let prompt_id = request.prompt_id.clone();
 812        let intent = request.intent;
 813        let mode = request.mode;
 814        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 815        match self.model.provider {
 816            zed_llm_client::LanguageModelProvider::Anthropic => {
 817                let request = into_anthropic(
 818                    request,
 819                    self.model.id.to_string(),
 820                    1.0,
 821                    self.model.max_output_tokens as u32,
 822                    if self.model.id.0.ends_with("-thinking") {
 823                        AnthropicModelMode::Thinking {
 824                            budget_tokens: Some(4_096),
 825                        }
 826                    } else {
 827                        AnthropicModelMode::Default
 828                    },
 829                );
 830                let client = self.client.clone();
 831                let llm_api_token = self.llm_api_token.clone();
 832                let future = self.request_limiter.stream(async move {
 833                    let PerformLlmCompletionResponse {
 834                        response,
 835                        usage,
 836                        includes_status_messages,
 837                        tool_use_limit_reached,
 838                    } = Self::perform_llm_completion(
 839                        client.clone(),
 840                        llm_api_token,
 841                        app_version,
 842                        CompletionBody {
 843                            thread_id,
 844                            prompt_id,
 845                            intent,
 846                            mode,
 847                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 848                            model: request.model.clone(),
 849                            provider_request: serde_json::to_value(&request)?,
 850                        },
 851                    )
 852                    .await
 853                    .map_err(|err| match err.downcast::<ApiError>() {
 854                        Ok(api_err) => {
 855                            if api_err.status == StatusCode::BAD_REQUEST {
 856                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
 857                                    return anyhow!(
 858                                        LanguageModelKnownError::ContextWindowLimitExceeded {
 859                                            tokens
 860                                        }
 861                                    );
 862                                }
 863                            }
 864                            anyhow!(api_err)
 865                        }
 866                        Err(err) => anyhow!(err),
 867                    })?;
 868
 869                    let mut mapper = AnthropicEventMapper::new();
 870                    Ok(map_cloud_completion_events(
 871                        Box::pin(
 872                            response_lines(response, includes_status_messages)
 873                                .chain(usage_updated_event(usage))
 874                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 875                        ),
 876                        move |event| mapper.map_event(event),
 877                    ))
 878                });
 879                async move { Ok(future.await?.boxed()) }.boxed()
 880            }
 881            zed_llm_client::LanguageModelProvider::OpenAi => {
 882                let client = self.client.clone();
 883                let model = match open_ai::Model::from_id(&self.model.id.0) {
 884                    Ok(model) => model,
 885                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 886                };
 887                let request = into_open_ai(request, &model, None);
 888                let llm_api_token = self.llm_api_token.clone();
 889                let future = self.request_limiter.stream(async move {
 890                    let PerformLlmCompletionResponse {
 891                        response,
 892                        usage,
 893                        includes_status_messages,
 894                        tool_use_limit_reached,
 895                    } = Self::perform_llm_completion(
 896                        client.clone(),
 897                        llm_api_token,
 898                        app_version,
 899                        CompletionBody {
 900                            thread_id,
 901                            prompt_id,
 902                            intent,
 903                            mode,
 904                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 905                            model: request.model.clone(),
 906                            provider_request: serde_json::to_value(&request)?,
 907                        },
 908                    )
 909                    .await?;
 910
 911                    let mut mapper = OpenAiEventMapper::new();
 912                    Ok(map_cloud_completion_events(
 913                        Box::pin(
 914                            response_lines(response, includes_status_messages)
 915                                .chain(usage_updated_event(usage))
 916                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 917                        ),
 918                        move |event| mapper.map_event(event),
 919                    ))
 920                });
 921                async move { Ok(future.await?.boxed()) }.boxed()
 922            }
 923            zed_llm_client::LanguageModelProvider::Google => {
 924                let client = self.client.clone();
 925                let request = into_google(request, self.model.id.to_string());
 926                let llm_api_token = self.llm_api_token.clone();
 927                let future = self.request_limiter.stream(async move {
 928                    let PerformLlmCompletionResponse {
 929                        response,
 930                        usage,
 931                        includes_status_messages,
 932                        tool_use_limit_reached,
 933                    } = Self::perform_llm_completion(
 934                        client.clone(),
 935                        llm_api_token,
 936                        app_version,
 937                        CompletionBody {
 938                            thread_id,
 939                            prompt_id,
 940                            intent,
 941                            mode,
 942                            provider: zed_llm_client::LanguageModelProvider::Google,
 943                            model: request.model.model_id.clone(),
 944                            provider_request: serde_json::to_value(&request)?,
 945                        },
 946                    )
 947                    .await?;
 948
 949                    let mut mapper = GoogleEventMapper::new();
 950                    Ok(map_cloud_completion_events(
 951                        Box::pin(
 952                            response_lines(response, includes_status_messages)
 953                                .chain(usage_updated_event(usage))
 954                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 955                        ),
 956                        move |event| mapper.map_event(event),
 957                    ))
 958                });
 959                async move { Ok(future.await?.boxed()) }.boxed()
 960            }
 961        }
 962    }
 963}
 964
 965#[derive(Serialize, Deserialize)]
 966#[serde(rename_all = "snake_case")]
 967pub enum CloudCompletionEvent<T> {
 968    Status(CompletionRequestStatus),
 969    Event(T),
 970}
 971
 972fn map_cloud_completion_events<T, F>(
 973    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 974    mut map_callback: F,
 975) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 976where
 977    T: DeserializeOwned + 'static,
 978    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 979        + Send
 980        + 'static,
 981{
 982    stream
 983        .flat_map(move |event| {
 984            futures::stream::iter(match event {
 985                Err(error) => {
 986                    vec![Err(LanguageModelCompletionError::Other(error))]
 987                }
 988                Ok(CloudCompletionEvent::Status(event)) => {
 989                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
 990                }
 991                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
 992            })
 993        })
 994        .boxed()
 995}
 996
 997fn usage_updated_event<T>(
 998    usage: Option<RequestUsage>,
 999) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1000    futures::stream::iter(usage.map(|usage| {
1001        Ok(CloudCompletionEvent::Status(
1002            CompletionRequestStatus::UsageUpdated {
1003                amount: usage.amount as usize,
1004                limit: usage.limit,
1005            },
1006        ))
1007    }))
1008}
1009
1010fn tool_use_limit_reached_event<T>(
1011    tool_use_limit_reached: bool,
1012) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1013    futures::stream::iter(tool_use_limit_reached.then(|| {
1014        Ok(CloudCompletionEvent::Status(
1015            CompletionRequestStatus::ToolUseLimitReached,
1016        ))
1017    }))
1018}
1019
1020fn response_lines<T: DeserializeOwned>(
1021    response: Response<AsyncBody>,
1022    includes_status_messages: bool,
1023) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1024    futures::stream::try_unfold(
1025        (String::new(), BufReader::new(response.into_body())),
1026        move |(mut line, mut body)| async move {
1027            match body.read_line(&mut line).await {
1028                Ok(0) => Ok(None),
1029                Ok(_) => {
1030                    let event = if includes_status_messages {
1031                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1032                    } else {
1033                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1034                    };
1035
1036                    line.clear();
1037                    Ok(Some((event, (line, body))))
1038                }
1039                Err(e) => Err(e.into()),
1040            }
1041        },
1042    )
1043}
1044
1045struct ConfigurationView {
1046    state: gpui::Entity<State>,
1047}
1048
1049impl ConfigurationView {
1050    fn authenticate(&mut self, cx: &mut Context<Self>) {
1051        self.state.update(cx, |state, cx| {
1052            state.authenticate(cx).detach_and_log_err(cx);
1053        });
1054        cx.notify();
1055    }
1056}
1057
1058impl Render for ConfigurationView {
1059    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1060        const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1061
1062        let is_connected = !self.state.read(cx).is_signed_out();
1063        let user_store = self.state.read(cx).user_store.read(cx);
1064        let plan = user_store.current_plan();
1065        let subscription_period = user_store.subscription_period();
1066        let eligible_for_trial = user_store.trial_started_at().is_none();
1067        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1068
1069        let is_pro = plan == Some(proto::Plan::ZedPro);
1070        let subscription_text = match (plan, subscription_period) {
1071            (Some(proto::Plan::ZedPro), Some(_)) => {
1072                "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1073            }
1074            (Some(proto::Plan::ZedProTrial), Some(_)) => {
1075                "You have access to Zed's hosted LLMs through your Zed Pro trial."
1076            }
1077            (Some(proto::Plan::Free), Some(_)) => {
1078                "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1079            }
1080            _ => {
1081                if eligible_for_trial {
1082                    "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1083                } else {
1084                    "Subscribe for access to Zed's hosted LLMs."
1085                }
1086            }
1087        };
1088        let manage_subscription_buttons = if is_pro {
1089            h_flex().child(
1090                Button::new("manage_settings", "Manage Subscription")
1091                    .style(ButtonStyle::Tinted(TintColor::Accent))
1092                    .on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
1093            )
1094        } else {
1095            h_flex()
1096                .gap_2()
1097                .child(
1098                    Button::new("learn_more", "Learn more")
1099                        .style(ButtonStyle::Subtle)
1100                        .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
1101                )
1102                .child(
1103                    Button::new("upgrade", "Upgrade")
1104                        .style(ButtonStyle::Subtle)
1105                        .color(Color::Accent)
1106                        .on_click(
1107                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1108                        ),
1109                )
1110        };
1111
1112        if is_connected {
1113            v_flex()
1114                .gap_3()
1115                .w_full()
1116                .children(render_accept_terms(
1117                    self.state.clone(),
1118                    LanguageModelProviderTosView::Configuration,
1119                    cx,
1120                ))
1121                .when(has_accepted_terms, |this| {
1122                    this.child(subscription_text)
1123                        .child(manage_subscription_buttons)
1124                })
1125        } else {
1126            v_flex()
1127                .gap_2()
1128                .child(Label::new("Use Zed AI to access hosted language models."))
1129                .child(
1130                    Button::new("sign_in", "Sign In")
1131                        .icon_color(Color::Muted)
1132                        .icon(IconName::Github)
1133                        .icon_position(IconPosition::Start)
1134                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1135                )
1136        }
1137    }
1138}