cloud.rs

   1use anthropic::AnthropicModelMode;
   2use anyhow::{Context as _, Result, anyhow};
   3use chrono::{DateTime, Utc};
   4use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   5use futures::{
   6    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   7};
   8use google_ai::GoogleModelMode;
   9use gpui::{
  10    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  11};
  12use http_client::http::{HeaderMap, HeaderValue};
  13use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  14use language_model::{
  15    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  16    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  17    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  18    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
  19    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
  20    ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  21};
  22use proto::Plan;
  23use release_channel::AppVersion;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize, de::DeserializeOwned};
  26use settings::SettingsStore;
  27use smol::io::{AsyncReadExt, BufReader};
  28use std::pin::Pin;
  29use std::str::FromStr as _;
  30use std::sync::Arc;
  31use std::time::Duration;
  32use thiserror::Error;
  33use ui::{TintColor, prelude::*};
  34use util::{ResultExt as _, maybe};
  35use zed_llm_client::{
  36    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
  37    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  38    ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
  39    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  40    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  41};
  42
  43use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  44use crate::provider::google::{GoogleEventMapper, into_google};
  45use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  46
  47const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  48const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  49
  50#[derive(Default, Clone, Debug, PartialEq)]
  51pub struct ZedDotDevSettings {
  52    pub available_models: Vec<AvailableModel>,
  53}
  54
  55#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  56#[serde(rename_all = "lowercase")]
  57pub enum AvailableProvider {
  58    Anthropic,
  59    OpenAi,
  60    Google,
  61}
  62
  63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  64pub struct AvailableModel {
  65    /// The provider of the language model.
  66    pub provider: AvailableProvider,
  67    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  68    pub name: String,
  69    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  70    pub display_name: Option<String>,
  71    /// The size of the context window, indicating the maximum number of tokens the model can process.
  72    pub max_tokens: usize,
  73    /// The maximum number of output tokens allowed by the model.
  74    pub max_output_tokens: Option<u64>,
  75    /// The maximum number of completion tokens allowed by the model (o1-* only)
  76    pub max_completion_tokens: Option<u64>,
  77    /// Override this model with a different Anthropic model for tool calls.
  78    pub tool_override: Option<String>,
  79    /// Indicates whether this custom model supports caching.
  80    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  81    /// The default temperature to use for this model.
  82    pub default_temperature: Option<f32>,
  83    /// Any extra beta headers to provide when using the model.
  84    #[serde(default)]
  85    pub extra_beta_headers: Vec<String>,
  86    /// The model's mode (e.g. thinking)
  87    pub mode: Option<ModelMode>,
  88}
  89
  90#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  91#[serde(tag = "type", rename_all = "lowercase")]
  92pub enum ModelMode {
  93    #[default]
  94    Default,
  95    Thinking {
  96        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  97        budget_tokens: Option<u32>,
  98    },
  99}
 100
 101impl From<ModelMode> for AnthropicModelMode {
 102    fn from(value: ModelMode) -> Self {
 103        match value {
 104            ModelMode::Default => AnthropicModelMode::Default,
 105            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 106        }
 107    }
 108}
 109
 110pub struct CloudLanguageModelProvider {
 111    client: Arc<Client>,
 112    state: gpui::Entity<State>,
 113    _maintain_client_status: Task<()>,
 114}
 115
 116pub struct State {
 117    client: Arc<Client>,
 118    llm_api_token: LlmApiToken,
 119    user_store: Entity<UserStore>,
 120    status: client::Status,
 121    accept_terms_of_service_task: Option<Task<Result<()>>>,
 122    models: Vec<Arc<zed_llm_client::LanguageModel>>,
 123    default_model: Option<Arc<zed_llm_client::LanguageModel>>,
 124    default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
 125    recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
 126    _fetch_models_task: Task<()>,
 127    _settings_subscription: Subscription,
 128    _llm_token_subscription: Subscription,
 129}
 130
 131impl State {
 132    fn new(
 133        client: Arc<Client>,
 134        user_store: Entity<UserStore>,
 135        status: client::Status,
 136        cx: &mut Context<Self>,
 137    ) -> Self {
 138        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 139
 140        Self {
 141            client: client.clone(),
 142            llm_api_token: LlmApiToken::default(),
 143            user_store,
 144            status,
 145            accept_terms_of_service_task: None,
 146            models: Vec::new(),
 147            default_model: None,
 148            default_fast_model: None,
 149            recommended_models: Vec::new(),
 150            _fetch_models_task: cx.spawn(async move |this, cx| {
 151                maybe!(async move {
 152                    let (client, llm_api_token) = this
 153                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 154
 155                    loop {
 156                        let status = this.read_with(cx, |this, _cx| this.status)?;
 157                        if matches!(status, client::Status::Connected { .. }) {
 158                            break;
 159                        }
 160
 161                        cx.background_executor()
 162                            .timer(Duration::from_millis(100))
 163                            .await;
 164                    }
 165
 166                    let response = Self::fetch_models(client, llm_api_token).await?;
 167                    this.update(cx, |this, cx| {
 168                        this.update_models(response, cx);
 169                    })
 170                })
 171                .await
 172                .context("failed to fetch Zed models")
 173                .log_err();
 174            }),
 175            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 176                cx.notify();
 177            }),
 178            _llm_token_subscription: cx.subscribe(
 179                &refresh_llm_token_listener,
 180                move |this, _listener, _event, cx| {
 181                    let client = this.client.clone();
 182                    let llm_api_token = this.llm_api_token.clone();
 183                    cx.spawn(async move |this, cx| {
 184                        llm_api_token.refresh(&client).await?;
 185                        let response = Self::fetch_models(client, llm_api_token).await?;
 186                        this.update(cx, |this, cx| {
 187                            this.update_models(response, cx);
 188                        })
 189                    })
 190                    .detach_and_log_err(cx);
 191                },
 192            ),
 193        }
 194    }
 195
 196    fn is_signed_out(&self) -> bool {
 197        self.status.is_signed_out()
 198    }
 199
 200    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 201        let client = self.client.clone();
 202        cx.spawn(async move |state, cx| {
 203            client
 204                .authenticate_and_connect(true, &cx)
 205                .await
 206                .into_response()?;
 207            state.update(cx, |_, cx| cx.notify())
 208        })
 209    }
 210
 211    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 212        self.user_store
 213            .read(cx)
 214            .current_user_has_accepted_terms()
 215            .unwrap_or(false)
 216    }
 217
 218    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 219        let user_store = self.user_store.clone();
 220        self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
 221            let _ = user_store
 222                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 223                .await;
 224            this.update(cx, |this, cx| {
 225                this.accept_terms_of_service_task = None;
 226                cx.notify()
 227            })
 228        }));
 229    }
 230
 231    fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
 232        let mut models = Vec::new();
 233
 234        for model in response.models {
 235            models.push(Arc::new(model.clone()));
 236
 237            // Right now we represent thinking variants of models as separate models on the client,
 238            // so we need to insert variants for any model that supports thinking.
 239            if model.supports_thinking {
 240                models.push(Arc::new(zed_llm_client::LanguageModel {
 241                    id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
 242                    display_name: format!("{} Thinking", model.display_name),
 243                    ..model
 244                }));
 245            }
 246        }
 247
 248        self.default_model = models
 249            .iter()
 250            .find(|model| model.id == response.default_model)
 251            .cloned();
 252        self.default_fast_model = models
 253            .iter()
 254            .find(|model| model.id == response.default_fast_model)
 255            .cloned();
 256        self.recommended_models = response
 257            .recommended_models
 258            .iter()
 259            .filter_map(|id| models.iter().find(|model| &model.id == id))
 260            .cloned()
 261            .collect();
 262        self.models = models;
 263        cx.notify();
 264    }
 265
 266    async fn fetch_models(
 267        client: Arc<Client>,
 268        llm_api_token: LlmApiToken,
 269    ) -> Result<ListModelsResponse> {
 270        let http_client = &client.http_client();
 271        let token = llm_api_token.acquire(&client).await?;
 272
 273        let request = http_client::Request::builder()
 274            .method(Method::GET)
 275            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 276            .header("Authorization", format!("Bearer {token}"))
 277            .body(AsyncBody::empty())?;
 278        let mut response = http_client
 279            .send(request)
 280            .await
 281            .context("failed to send list models request")?;
 282
 283        if response.status().is_success() {
 284            let mut body = String::new();
 285            response.body_mut().read_to_string(&mut body).await?;
 286            return Ok(serde_json::from_str(&body)?);
 287        } else {
 288            let mut body = String::new();
 289            response.body_mut().read_to_string(&mut body).await?;
 290            anyhow::bail!(
 291                "error listing models.\nStatus: {:?}\nBody: {body}",
 292                response.status(),
 293            );
 294        }
 295    }
 296}
 297
 298impl CloudLanguageModelProvider {
 299    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 300        let mut status_rx = client.status();
 301        let status = *status_rx.borrow();
 302
 303        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 304
 305        let state_ref = state.downgrade();
 306        let maintain_client_status = cx.spawn(async move |cx| {
 307            while let Some(status) = status_rx.next().await {
 308                if let Some(this) = state_ref.upgrade() {
 309                    _ = this.update(cx, |this, cx| {
 310                        if this.status != status {
 311                            this.status = status;
 312                            cx.notify();
 313                        }
 314                    });
 315                } else {
 316                    break;
 317                }
 318            }
 319        });
 320
 321        Self {
 322            client,
 323            state: state.clone(),
 324            _maintain_client_status: maintain_client_status,
 325        }
 326    }
 327
 328    fn create_language_model(
 329        &self,
 330        model: Arc<zed_llm_client::LanguageModel>,
 331        llm_api_token: LlmApiToken,
 332    ) -> Arc<dyn LanguageModel> {
 333        Arc::new(CloudLanguageModel {
 334            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 335            model,
 336            llm_api_token: llm_api_token.clone(),
 337            client: self.client.clone(),
 338            request_limiter: RateLimiter::new(4),
 339        })
 340    }
 341}
 342
 343impl LanguageModelProviderState for CloudLanguageModelProvider {
 344    type ObservableEntity = State;
 345
 346    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 347        Some(self.state.clone())
 348    }
 349}
 350
 351impl LanguageModelProvider for CloudLanguageModelProvider {
 352    fn id(&self) -> LanguageModelProviderId {
 353        PROVIDER_ID
 354    }
 355
 356    fn name(&self) -> LanguageModelProviderName {
 357        PROVIDER_NAME
 358    }
 359
 360    fn icon(&self) -> IconName {
 361        IconName::AiZed
 362    }
 363
 364    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 365        let default_model = self.state.read(cx).default_model.clone()?;
 366        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 367        Some(self.create_language_model(default_model, llm_api_token))
 368    }
 369
 370    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 371        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 372        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 373        Some(self.create_language_model(default_fast_model, llm_api_token))
 374    }
 375
 376    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 377        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 378        self.state
 379            .read(cx)
 380            .recommended_models
 381            .iter()
 382            .cloned()
 383            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 384            .collect()
 385    }
 386
 387    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 388        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 389        self.state
 390            .read(cx)
 391            .models
 392            .iter()
 393            .cloned()
 394            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 395            .collect()
 396    }
 397
 398    fn is_authenticated(&self, cx: &App) -> bool {
 399        let state = self.state.read(cx);
 400        !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
 401    }
 402
 403    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 404        Task::ready(Ok(()))
 405    }
 406
 407    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 408        cx.new(|_| ConfigurationView::new(self.state.clone()))
 409            .into()
 410    }
 411
 412    fn must_accept_terms(&self, cx: &App) -> bool {
 413        !self.state.read(cx).has_accepted_terms_of_service(cx)
 414    }
 415
 416    fn render_accept_terms(
 417        &self,
 418        view: LanguageModelProviderTosView,
 419        cx: &mut App,
 420    ) -> Option<AnyElement> {
 421        let state = self.state.read(cx);
 422        if state.has_accepted_terms_of_service(cx) {
 423            return None;
 424        }
 425        Some(
 426            render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
 427                let state = self.state.clone();
 428                move |_window, cx| {
 429                    state.update(cx, |state, cx| state.accept_terms_of_service(cx));
 430                }
 431            })
 432            .into_any_element(),
 433        )
 434    }
 435
 436    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 437        Task::ready(Ok(()))
 438    }
 439}
 440
 441fn render_accept_terms(
 442    view_kind: LanguageModelProviderTosView,
 443    accept_terms_of_service_in_progress: bool,
 444    accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
 445) -> impl IntoElement {
 446    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 447    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
 448
 449    let terms_button = Button::new("terms_of_service", "Terms of Service")
 450        .style(ButtonStyle::Subtle)
 451        .icon(IconName::ArrowUpRight)
 452        .icon_color(Color::Muted)
 453        .icon_size(IconSize::XSmall)
 454        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 455        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 456
 457    let button_container = h_flex().child(
 458        Button::new("accept_terms", "I accept the Terms of Service")
 459            .when(!thread_empty_state, |this| {
 460                this.full_width()
 461                    .style(ButtonStyle::Tinted(TintColor::Accent))
 462                    .icon(IconName::Check)
 463                    .icon_position(IconPosition::Start)
 464                    .icon_size(IconSize::Small)
 465            })
 466            .when(thread_empty_state, |this| {
 467                this.style(ButtonStyle::Tinted(TintColor::Warning))
 468                    .label_size(LabelSize::Small)
 469            })
 470            .disabled(accept_terms_of_service_in_progress)
 471            .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
 472    );
 473
 474    if thread_empty_state {
 475        h_flex()
 476            .w_full()
 477            .flex_wrap()
 478            .justify_between()
 479            .child(
 480                h_flex()
 481                    .child(
 482                        Label::new("To start using Zed AI, please read and accept the")
 483                            .size(LabelSize::Small),
 484                    )
 485                    .child(terms_button),
 486            )
 487            .child(button_container)
 488    } else {
 489        v_flex()
 490            .w_full()
 491            .gap_2()
 492            .child(
 493                h_flex()
 494                    .flex_wrap()
 495                    .when(thread_fresh_start, |this| this.justify_center())
 496                    .child(Label::new(
 497                        "To start using Zed AI, please read and accept the",
 498                    ))
 499                    .child(terms_button),
 500            )
 501            .child({
 502                match view_kind {
 503                    LanguageModelProviderTosView::PromptEditorPopup => {
 504                        button_container.w_full().justify_end()
 505                    }
 506                    LanguageModelProviderTosView::Configuration => {
 507                        button_container.w_full().justify_start()
 508                    }
 509                    LanguageModelProviderTosView::ThreadFreshStart => {
 510                        button_container.w_full().justify_center()
 511                    }
 512                    LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
 513                }
 514            })
 515    }
 516}
 517
 518pub struct CloudLanguageModel {
 519    id: LanguageModelId,
 520    model: Arc<zed_llm_client::LanguageModel>,
 521    llm_api_token: LlmApiToken,
 522    client: Arc<Client>,
 523    request_limiter: RateLimiter,
 524}
 525
 526struct PerformLlmCompletionResponse {
 527    response: Response<AsyncBody>,
 528    usage: Option<ModelRequestUsage>,
 529    tool_use_limit_reached: bool,
 530    includes_status_messages: bool,
 531}
 532
 533impl CloudLanguageModel {
 534    async fn perform_llm_completion(
 535        client: Arc<Client>,
 536        llm_api_token: LlmApiToken,
 537        app_version: Option<SemanticVersion>,
 538        body: CompletionBody,
 539    ) -> Result<PerformLlmCompletionResponse> {
 540        let http_client = &client.http_client();
 541
 542        let mut token = llm_api_token.acquire(&client).await?;
 543        let mut refreshed_token = false;
 544
 545        loop {
 546            let request_builder = http_client::Request::builder()
 547                .method(Method::POST)
 548                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 549            let request_builder = if let Some(app_version) = app_version {
 550                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 551            } else {
 552                request_builder
 553            };
 554
 555            let request = request_builder
 556                .header("Content-Type", "application/json")
 557                .header("Authorization", format!("Bearer {token}"))
 558                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 559                .body(serde_json::to_string(&body)?.into())?;
 560            let mut response = http_client.send(request).await?;
 561            let status = response.status();
 562            if status.is_success() {
 563                let includes_status_messages = response
 564                    .headers()
 565                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 566                    .is_some();
 567
 568                let tool_use_limit_reached = response
 569                    .headers()
 570                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 571                    .is_some();
 572
 573                let usage = if includes_status_messages {
 574                    None
 575                } else {
 576                    ModelRequestUsage::from_headers(response.headers()).ok()
 577                };
 578
 579                return Ok(PerformLlmCompletionResponse {
 580                    response,
 581                    usage,
 582                    includes_status_messages,
 583                    tool_use_limit_reached,
 584                });
 585            }
 586
 587            if !refreshed_token
 588                && response
 589                    .headers()
 590                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 591                    .is_some()
 592            {
 593                token = llm_api_token.refresh(&client).await?;
 594                refreshed_token = true;
 595                continue;
 596            }
 597
 598            if status == StatusCode::FORBIDDEN
 599                && response
 600                    .headers()
 601                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 602                    .is_some()
 603            {
 604                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 605                    .headers()
 606                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 607                    .and_then(|resource| resource.to_str().ok())
 608                {
 609                    if let Some(plan) = response
 610                        .headers()
 611                        .get(CURRENT_PLAN_HEADER_NAME)
 612                        .and_then(|plan| plan.to_str().ok())
 613                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 614                    {
 615                        let plan = match plan {
 616                            zed_llm_client::Plan::ZedFree => Plan::Free,
 617                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 618                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 619                        };
 620                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 621                    }
 622                }
 623            } else if status == StatusCode::PAYMENT_REQUIRED {
 624                return Err(anyhow!(PaymentRequiredError));
 625            }
 626
 627            let mut body = String::new();
 628            let headers = response.headers().clone();
 629            response.body_mut().read_to_string(&mut body).await?;
 630            return Err(anyhow!(ApiError {
 631                status,
 632                body,
 633                headers
 634            }));
 635        }
 636    }
 637}
 638
 639#[derive(Debug, Error)]
 640#[error("cloud language model request failed with status {status}: {body}")]
 641struct ApiError {
 642    status: StatusCode,
 643    body: String,
 644    headers: HeaderMap<HeaderValue>,
 645}
 646
 647impl From<ApiError> for LanguageModelCompletionError {
 648    fn from(error: ApiError) -> Self {
 649        let retry_after = None;
 650        LanguageModelCompletionError::from_http_status(
 651            PROVIDER_NAME,
 652            error.status,
 653            error.body,
 654            retry_after,
 655        )
 656    }
 657}
 658
 659impl LanguageModel for CloudLanguageModel {
 660    fn id(&self) -> LanguageModelId {
 661        self.id.clone()
 662    }
 663
 664    fn name(&self) -> LanguageModelName {
 665        LanguageModelName::from(self.model.display_name.clone())
 666    }
 667
 668    fn provider_id(&self) -> LanguageModelProviderId {
 669        PROVIDER_ID
 670    }
 671
 672    fn provider_name(&self) -> LanguageModelProviderName {
 673        PROVIDER_NAME
 674    }
 675
 676    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 677        use zed_llm_client::LanguageModelProvider::*;
 678        match self.model.provider {
 679            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 680            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 681            Google => language_model::GOOGLE_PROVIDER_ID,
 682        }
 683    }
 684
 685    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 686        use zed_llm_client::LanguageModelProvider::*;
 687        match self.model.provider {
 688            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 689            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 690            Google => language_model::GOOGLE_PROVIDER_NAME,
 691        }
 692    }
 693
 694    fn supports_tools(&self) -> bool {
 695        self.model.supports_tools
 696    }
 697
 698    fn supports_images(&self) -> bool {
 699        self.model.supports_images
 700    }
 701
 702    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 703        match choice {
 704            LanguageModelToolChoice::Auto
 705            | LanguageModelToolChoice::Any
 706            | LanguageModelToolChoice::None => true,
 707        }
 708    }
 709
 710    fn supports_burn_mode(&self) -> bool {
 711        self.model.supports_max_mode
 712    }
 713
 714    fn telemetry_id(&self) -> String {
 715        format!("zed.dev/{}", self.model.id)
 716    }
 717
 718    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 719        match self.model.provider {
 720            zed_llm_client::LanguageModelProvider::Anthropic
 721            | zed_llm_client::LanguageModelProvider::OpenAi => {
 722                LanguageModelToolSchemaFormat::JsonSchema
 723            }
 724            zed_llm_client::LanguageModelProvider::Google => {
 725                LanguageModelToolSchemaFormat::JsonSchemaSubset
 726            }
 727        }
 728    }
 729
 730    fn max_token_count(&self) -> u64 {
 731        self.model.max_token_count as u64
 732    }
 733
 734    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 735        self.model
 736            .max_token_count_in_max_mode
 737            .filter(|_| self.model.supports_max_mode)
 738            .map(|max_token_count| max_token_count as u64)
 739    }
 740
 741    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 742        match &self.model.provider {
 743            zed_llm_client::LanguageModelProvider::Anthropic => {
 744                Some(LanguageModelCacheConfiguration {
 745                    min_total_token: 2_048,
 746                    should_speculate: true,
 747                    max_cache_anchors: 4,
 748                })
 749            }
 750            zed_llm_client::LanguageModelProvider::OpenAi
 751            | zed_llm_client::LanguageModelProvider::Google => None,
 752        }
 753    }
 754
 755    fn count_tokens(
 756        &self,
 757        request: LanguageModelRequest,
 758        cx: &App,
 759    ) -> BoxFuture<'static, Result<u64>> {
 760        match self.model.provider {
 761            zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
 762            zed_llm_client::LanguageModelProvider::OpenAi => {
 763                let model = match open_ai::Model::from_id(&self.model.id.0) {
 764                    Ok(model) => model,
 765                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 766                };
 767                count_open_ai_tokens(request, model, cx)
 768            }
 769            zed_llm_client::LanguageModelProvider::Google => {
 770                let client = self.client.clone();
 771                let llm_api_token = self.llm_api_token.clone();
 772                let model_id = self.model.id.to_string();
 773                let generate_content_request =
 774                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 775                async move {
 776                    let http_client = &client.http_client();
 777                    let token = llm_api_token.acquire(&client).await?;
 778
 779                    let request_body = CountTokensBody {
 780                        provider: zed_llm_client::LanguageModelProvider::Google,
 781                        model: model_id,
 782                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 783                            generate_content_request,
 784                        })?,
 785                    };
 786                    let request = http_client::Request::builder()
 787                        .method(Method::POST)
 788                        .uri(
 789                            http_client
 790                                .build_zed_llm_url("/count_tokens", &[])?
 791                                .as_ref(),
 792                        )
 793                        .header("Content-Type", "application/json")
 794                        .header("Authorization", format!("Bearer {token}"))
 795                        .body(serde_json::to_string(&request_body)?.into())?;
 796                    let mut response = http_client.send(request).await?;
 797                    let status = response.status();
 798                    let headers = response.headers().clone();
 799                    let mut response_body = String::new();
 800                    response
 801                        .body_mut()
 802                        .read_to_string(&mut response_body)
 803                        .await?;
 804
 805                    if status.is_success() {
 806                        let response_body: CountTokensResponse =
 807                            serde_json::from_str(&response_body)?;
 808
 809                        Ok(response_body.tokens as u64)
 810                    } else {
 811                        Err(anyhow!(ApiError {
 812                            status,
 813                            body: response_body,
 814                            headers
 815                        }))
 816                    }
 817                }
 818                .boxed()
 819            }
 820        }
 821    }
 822
 823    fn stream_completion(
 824        &self,
 825        request: LanguageModelRequest,
 826        cx: &AsyncApp,
 827    ) -> BoxFuture<
 828        'static,
 829        Result<
 830            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 831            LanguageModelCompletionError,
 832        >,
 833    > {
 834        let thread_id = request.thread_id.clone();
 835        let prompt_id = request.prompt_id.clone();
 836        let intent = request.intent;
 837        let mode = request.mode;
 838        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 839        let thinking_allowed = request.thinking_allowed;
 840        match self.model.provider {
 841            zed_llm_client::LanguageModelProvider::Anthropic => {
 842                let request = into_anthropic(
 843                    request,
 844                    self.model.id.to_string(),
 845                    1.0,
 846                    self.model.max_output_tokens as u64,
 847                    if thinking_allowed && self.model.id.0.ends_with("-thinking") {
 848                        AnthropicModelMode::Thinking {
 849                            budget_tokens: Some(4_096),
 850                        }
 851                    } else {
 852                        AnthropicModelMode::Default
 853                    },
 854                );
 855                let client = self.client.clone();
 856                let llm_api_token = self.llm_api_token.clone();
 857                let future = self.request_limiter.stream(async move {
 858                    let PerformLlmCompletionResponse {
 859                        response,
 860                        usage,
 861                        includes_status_messages,
 862                        tool_use_limit_reached,
 863                    } = Self::perform_llm_completion(
 864                        client.clone(),
 865                        llm_api_token,
 866                        app_version,
 867                        CompletionBody {
 868                            thread_id,
 869                            prompt_id,
 870                            intent,
 871                            mode,
 872                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 873                            model: request.model.clone(),
 874                            provider_request: serde_json::to_value(&request)
 875                                .map_err(|e| anyhow!(e))?,
 876                        },
 877                    )
 878                    .await
 879                    .map_err(|err| match err.downcast::<ApiError>() {
 880                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 881                        Err(err) => anyhow!(err),
 882                    })?;
 883
 884                    let mut mapper = AnthropicEventMapper::new();
 885                    Ok(map_cloud_completion_events(
 886                        Box::pin(
 887                            response_lines(response, includes_status_messages)
 888                                .chain(usage_updated_event(usage))
 889                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 890                        ),
 891                        move |event| mapper.map_event(event),
 892                    ))
 893                });
 894                async move { Ok(future.await?.boxed()) }.boxed()
 895            }
 896            zed_llm_client::LanguageModelProvider::OpenAi => {
 897                let client = self.client.clone();
 898                let model = match open_ai::Model::from_id(&self.model.id.0) {
 899                    Ok(model) => model,
 900                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 901                };
 902                let request = into_open_ai(
 903                    request,
 904                    model.id(),
 905                    model.supports_parallel_tool_calls(),
 906                    None,
 907                );
 908                let llm_api_token = self.llm_api_token.clone();
 909                let future = self.request_limiter.stream(async move {
 910                    let PerformLlmCompletionResponse {
 911                        response,
 912                        usage,
 913                        includes_status_messages,
 914                        tool_use_limit_reached,
 915                    } = Self::perform_llm_completion(
 916                        client.clone(),
 917                        llm_api_token,
 918                        app_version,
 919                        CompletionBody {
 920                            thread_id,
 921                            prompt_id,
 922                            intent,
 923                            mode,
 924                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 925                            model: request.model.clone(),
 926                            provider_request: serde_json::to_value(&request)
 927                                .map_err(|e| anyhow!(e))?,
 928                        },
 929                    )
 930                    .await?;
 931
 932                    let mut mapper = OpenAiEventMapper::new();
 933                    Ok(map_cloud_completion_events(
 934                        Box::pin(
 935                            response_lines(response, includes_status_messages)
 936                                .chain(usage_updated_event(usage))
 937                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 938                        ),
 939                        move |event| mapper.map_event(event),
 940                    ))
 941                });
 942                async move { Ok(future.await?.boxed()) }.boxed()
 943            }
 944            zed_llm_client::LanguageModelProvider::Google => {
 945                let client = self.client.clone();
 946                let request =
 947                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 948                let llm_api_token = self.llm_api_token.clone();
 949                let future = self.request_limiter.stream(async move {
 950                    let PerformLlmCompletionResponse {
 951                        response,
 952                        usage,
 953                        includes_status_messages,
 954                        tool_use_limit_reached,
 955                    } = Self::perform_llm_completion(
 956                        client.clone(),
 957                        llm_api_token,
 958                        app_version,
 959                        CompletionBody {
 960                            thread_id,
 961                            prompt_id,
 962                            intent,
 963                            mode,
 964                            provider: zed_llm_client::LanguageModelProvider::Google,
 965                            model: request.model.model_id.clone(),
 966                            provider_request: serde_json::to_value(&request)
 967                                .map_err(|e| anyhow!(e))?,
 968                        },
 969                    )
 970                    .await?;
 971
 972                    let mut mapper = GoogleEventMapper::new();
 973                    Ok(map_cloud_completion_events(
 974                        Box::pin(
 975                            response_lines(response, includes_status_messages)
 976                                .chain(usage_updated_event(usage))
 977                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 978                        ),
 979                        move |event| mapper.map_event(event),
 980                    ))
 981                });
 982                async move { Ok(future.await?.boxed()) }.boxed()
 983            }
 984        }
 985    }
 986}
 987
 988#[derive(Serialize, Deserialize)]
 989#[serde(rename_all = "snake_case")]
 990pub enum CloudCompletionEvent<T> {
 991    Status(CompletionRequestStatus),
 992    Event(T),
 993}
 994
 995fn map_cloud_completion_events<T, F>(
 996    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 997    mut map_callback: F,
 998) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 999where
1000    T: DeserializeOwned + 'static,
1001    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1002        + Send
1003        + 'static,
1004{
1005    stream
1006        .flat_map(move |event| {
1007            futures::stream::iter(match event {
1008                Err(error) => {
1009                    vec![Err(LanguageModelCompletionError::from(error))]
1010                }
1011                Ok(CloudCompletionEvent::Status(event)) => {
1012                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1013                }
1014                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1015            })
1016        })
1017        .boxed()
1018}
1019
1020fn usage_updated_event<T>(
1021    usage: Option<ModelRequestUsage>,
1022) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1023    futures::stream::iter(usage.map(|usage| {
1024        Ok(CloudCompletionEvent::Status(
1025            CompletionRequestStatus::UsageUpdated {
1026                amount: usage.amount as usize,
1027                limit: usage.limit,
1028            },
1029        ))
1030    }))
1031}
1032
1033fn tool_use_limit_reached_event<T>(
1034    tool_use_limit_reached: bool,
1035) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1036    futures::stream::iter(tool_use_limit_reached.then(|| {
1037        Ok(CloudCompletionEvent::Status(
1038            CompletionRequestStatus::ToolUseLimitReached,
1039        ))
1040    }))
1041}
1042
1043fn response_lines<T: DeserializeOwned>(
1044    response: Response<AsyncBody>,
1045    includes_status_messages: bool,
1046) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1047    futures::stream::try_unfold(
1048        (String::new(), BufReader::new(response.into_body())),
1049        move |(mut line, mut body)| async move {
1050            match body.read_line(&mut line).await {
1051                Ok(0) => Ok(None),
1052                Ok(_) => {
1053                    let event = if includes_status_messages {
1054                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1055                    } else {
1056                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1057                    };
1058
1059                    line.clear();
1060                    Ok(Some((event, (line, body))))
1061                }
1062                Err(e) => Err(e.into()),
1063            }
1064        },
1065    )
1066}
1067
1068#[derive(IntoElement, RegisterComponent)]
1069struct ZedAiConfiguration {
1070    is_connected: bool,
1071    plan: Option<proto::Plan>,
1072    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1073    eligible_for_trial: bool,
1074    has_accepted_terms_of_service: bool,
1075    accept_terms_of_service_in_progress: bool,
1076    accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1077    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1078}
1079
1080impl RenderOnce for ZedAiConfiguration {
1081    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1082        const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1083
1084        let is_pro = self.plan == Some(proto::Plan::ZedPro);
1085        let subscription_text = match (self.plan, self.subscription_period) {
1086            (Some(proto::Plan::ZedPro), Some(_)) => {
1087                "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1088            }
1089            (Some(proto::Plan::ZedProTrial), Some(_)) => {
1090                "You have access to Zed's hosted LLMs through your Zed Pro trial."
1091            }
1092            (Some(proto::Plan::Free), Some(_)) => {
1093                "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1094            }
1095            _ => {
1096                if self.eligible_for_trial {
1097                    "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1098                } else {
1099                    "Subscribe for access to Zed's hosted LLMs."
1100                }
1101            }
1102        };
1103        let manage_subscription_buttons = if is_pro {
1104            h_flex().child(
1105                Button::new("manage_settings", "Manage Subscription")
1106                    .style(ButtonStyle::Tinted(TintColor::Accent))
1107                    .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1108            )
1109        } else {
1110            h_flex()
1111                .gap_2()
1112                .child(
1113                    Button::new("learn_more", "Learn more")
1114                        .style(ButtonStyle::Subtle)
1115                        .on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
1116                )
1117                .child(
1118                    Button::new(
1119                        "upgrade",
1120                        if self.plan.is_none() && self.eligible_for_trial {
1121                            "Start Trial"
1122                        } else {
1123                            "Upgrade"
1124                        },
1125                    )
1126                    .style(ButtonStyle::Subtle)
1127                    .color(Color::Accent)
1128                    .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1129                )
1130        };
1131
1132        if self.is_connected {
1133            v_flex()
1134                .gap_3()
1135                .w_full()
1136                .when(!self.has_accepted_terms_of_service, |this| {
1137                    this.child(render_accept_terms(
1138                        LanguageModelProviderTosView::Configuration,
1139                        self.accept_terms_of_service_in_progress,
1140                        {
1141                            let callback = self.accept_terms_of_service_callback.clone();
1142                            move |window, cx| (callback)(window, cx)
1143                        },
1144                    ))
1145                })
1146                .when(self.has_accepted_terms_of_service, |this| {
1147                    this.child(subscription_text)
1148                        .child(manage_subscription_buttons)
1149                })
1150        } else {
1151            v_flex()
1152                .gap_2()
1153                .child(Label::new("Use Zed AI to access hosted language models."))
1154                .child(
1155                    Button::new("sign_in", "Sign In")
1156                        .icon_color(Color::Muted)
1157                        .icon(IconName::Github)
1158                        .icon_position(IconPosition::Start)
1159                        .on_click({
1160                            let callback = self.sign_in_callback.clone();
1161                            move |_, window, cx| (callback)(window, cx)
1162                        }),
1163                )
1164        }
1165    }
1166}
1167
1168struct ConfigurationView {
1169    state: Entity<State>,
1170    accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1171    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1172}
1173
1174impl ConfigurationView {
1175    fn new(state: Entity<State>) -> Self {
1176        let accept_terms_of_service_callback = Arc::new({
1177            let state = state.clone();
1178            move |_window: &mut Window, cx: &mut App| {
1179                state.update(cx, |state, cx| {
1180                    state.accept_terms_of_service(cx);
1181                });
1182            }
1183        });
1184
1185        let sign_in_callback = Arc::new({
1186            let state = state.clone();
1187            move |_window: &mut Window, cx: &mut App| {
1188                state.update(cx, |state, cx| {
1189                    state.authenticate(cx).detach_and_log_err(cx);
1190                });
1191            }
1192        });
1193
1194        Self {
1195            state,
1196            accept_terms_of_service_callback,
1197            sign_in_callback,
1198        }
1199    }
1200}
1201
1202impl Render for ConfigurationView {
1203    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1204        let state = self.state.read(cx);
1205        let user_store = state.user_store.read(cx);
1206
1207        ZedAiConfiguration {
1208            is_connected: !state.is_signed_out(),
1209            plan: user_store.current_plan(),
1210            subscription_period: user_store.subscription_period(),
1211            eligible_for_trial: user_store.trial_started_at().is_none(),
1212            has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1213            accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1214            accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1215            sign_in_callback: self.sign_in_callback.clone(),
1216        }
1217    }
1218}
1219
1220impl Component for ZedAiConfiguration {
1221    fn scope() -> ComponentScope {
1222        ComponentScope::Agent
1223    }
1224
1225    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1226        fn configuration(
1227            is_connected: bool,
1228            plan: Option<proto::Plan>,
1229            eligible_for_trial: bool,
1230            has_accepted_terms_of_service: bool,
1231        ) -> AnyElement {
1232            ZedAiConfiguration {
1233                is_connected,
1234                plan,
1235                subscription_period: plan
1236                    .is_some()
1237                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1238                eligible_for_trial,
1239                has_accepted_terms_of_service,
1240                accept_terms_of_service_in_progress: false,
1241                accept_terms_of_service_callback: Arc::new(|_, _| {}),
1242                sign_in_callback: Arc::new(|_, _| {}),
1243            }
1244            .into_any_element()
1245        }
1246
1247        Some(
1248            v_flex()
1249                .p_4()
1250                .gap_4()
1251                .children(vec![
1252                    single_example("Not connected", configuration(false, None, false, true)),
1253                    single_example(
1254                        "Accept Terms of Service",
1255                        configuration(true, None, true, false),
1256                    ),
1257                    single_example(
1258                        "No Plan - Not eligible for trial",
1259                        configuration(true, None, false, true),
1260                    ),
1261                    single_example(
1262                        "No Plan - Eligible for trial",
1263                        configuration(true, None, true, true),
1264                    ),
1265                    single_example(
1266                        "Free Plan",
1267                        configuration(true, Some(proto::Plan::Free), true, true),
1268                    ),
1269                    single_example(
1270                        "Zed Pro Trial Plan",
1271                        configuration(true, Some(proto::Plan::ZedProTrial), true, true),
1272                    ),
1273                    single_example(
1274                        "Zed Pro Plan",
1275                        configuration(true, Some(proto::Plan::ZedPro), true, true),
1276                    ),
1277                ])
1278                .into_any_element(),
1279        )
1280    }
1281}