cloud.rs

   1use anthropic::AnthropicModelMode;
   2use anyhow::{Context as _, Result, anyhow};
   3use chrono::{DateTime, Utc};
   4use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   5use futures::{
   6    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   7};
   8use google_ai::GoogleModelMode;
   9use gpui::{
  10    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  11};
  12use http_client::http::{HeaderMap, HeaderValue};
  13use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  14use language_model::{
  15    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  16    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  17    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  18    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
  19    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
  20    ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  21};
  22use proto::Plan;
  23use release_channel::AppVersion;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize, de::DeserializeOwned};
  26use settings::SettingsStore;
  27use smol::io::{AsyncReadExt, BufReader};
  28use std::pin::Pin;
  29use std::str::FromStr as _;
  30use std::sync::Arc;
  31use std::time::Duration;
  32use thiserror::Error;
  33use ui::{TintColor, prelude::*};
  34use util::{ResultExt as _, maybe};
  35use zed_llm_client::{
  36    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
  37    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  38    ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
  39    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  40    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  41};
  42
  43use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  44use crate::provider::google::{GoogleEventMapper, into_google};
  45use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  46
  47const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  48const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  49
  50#[derive(Default, Clone, Debug, PartialEq)]
  51pub struct ZedDotDevSettings {
  52    pub available_models: Vec<AvailableModel>,
  53}
  54
  55#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  56#[serde(rename_all = "lowercase")]
  57pub enum AvailableProvider {
  58    Anthropic,
  59    OpenAi,
  60    Google,
  61}
  62
  63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  64pub struct AvailableModel {
  65    /// The provider of the language model.
  66    pub provider: AvailableProvider,
  67    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  68    pub name: String,
  69    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  70    pub display_name: Option<String>,
  71    /// The size of the context window, indicating the maximum number of tokens the model can process.
  72    pub max_tokens: usize,
  73    /// The maximum number of output tokens allowed by the model.
  74    pub max_output_tokens: Option<u64>,
  75    /// The maximum number of completion tokens allowed by the model (o1-* only)
  76    pub max_completion_tokens: Option<u64>,
  77    /// Override this model with a different Anthropic model for tool calls.
  78    pub tool_override: Option<String>,
  79    /// Indicates whether this custom model supports caching.
  80    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  81    /// The default temperature to use for this model.
  82    pub default_temperature: Option<f32>,
  83    /// Any extra beta headers to provide when using the model.
  84    #[serde(default)]
  85    pub extra_beta_headers: Vec<String>,
  86    /// The model's mode (e.g. thinking)
  87    pub mode: Option<ModelMode>,
  88}
  89
  90#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  91#[serde(tag = "type", rename_all = "lowercase")]
  92pub enum ModelMode {
  93    #[default]
  94    Default,
  95    Thinking {
  96        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  97        budget_tokens: Option<u32>,
  98    },
  99}
 100
 101impl From<ModelMode> for AnthropicModelMode {
 102    fn from(value: ModelMode) -> Self {
 103        match value {
 104            ModelMode::Default => AnthropicModelMode::Default,
 105            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 106        }
 107    }
 108}
 109
 110pub struct CloudLanguageModelProvider {
 111    client: Arc<Client>,
 112    state: gpui::Entity<State>,
 113    _maintain_client_status: Task<()>,
 114}
 115
 116pub struct State {
 117    client: Arc<Client>,
 118    llm_api_token: LlmApiToken,
 119    user_store: Entity<UserStore>,
 120    status: client::Status,
 121    accept_terms_of_service_task: Option<Task<Result<()>>>,
 122    models: Vec<Arc<zed_llm_client::LanguageModel>>,
 123    default_model: Option<Arc<zed_llm_client::LanguageModel>>,
 124    default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
 125    recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
 126    _fetch_models_task: Task<()>,
 127    _settings_subscription: Subscription,
 128    _llm_token_subscription: Subscription,
 129}
 130
 131impl State {
 132    fn new(
 133        client: Arc<Client>,
 134        user_store: Entity<UserStore>,
 135        status: client::Status,
 136        cx: &mut Context<Self>,
 137    ) -> Self {
 138        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 139
 140        Self {
 141            client: client.clone(),
 142            llm_api_token: LlmApiToken::default(),
 143            user_store,
 144            status,
 145            accept_terms_of_service_task: None,
 146            models: Vec::new(),
 147            default_model: None,
 148            default_fast_model: None,
 149            recommended_models: Vec::new(),
 150            _fetch_models_task: cx.spawn(async move |this, cx| {
 151                maybe!(async move {
 152                    let (client, llm_api_token) = this
 153                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 154
 155                    loop {
 156                        let status = this.read_with(cx, |this, _cx| this.status)?;
 157                        if matches!(status, client::Status::Connected { .. }) {
 158                            break;
 159                        }
 160
 161                        cx.background_executor()
 162                            .timer(Duration::from_millis(100))
 163                            .await;
 164                    }
 165
 166                    let response = Self::fetch_models(client, llm_api_token).await?;
 167                    cx.update(|cx| {
 168                        this.update(cx, |this, cx| {
 169                            let mut models = Vec::new();
 170
 171                            for model in response.models {
 172                                models.push(Arc::new(model.clone()));
 173
 174                                // Right now we represent thinking variants of models as separate models on the client,
 175                                // so we need to insert variants for any model that supports thinking.
 176                                if model.supports_thinking {
 177                                    models.push(Arc::new(zed_llm_client::LanguageModel {
 178                                        id: zed_llm_client::LanguageModelId(
 179                                            format!("{}-thinking", model.id).into(),
 180                                        ),
 181                                        display_name: format!("{} Thinking", model.display_name),
 182                                        ..model
 183                                    }));
 184                                }
 185                            }
 186
 187                            this.default_model = models
 188                                .iter()
 189                                .find(|model| model.id == response.default_model)
 190                                .cloned();
 191                            this.default_fast_model = models
 192                                .iter()
 193                                .find(|model| model.id == response.default_fast_model)
 194                                .cloned();
 195                            this.recommended_models = response
 196                                .recommended_models
 197                                .iter()
 198                                .filter_map(|id| models.iter().find(|model| &model.id == id))
 199                                .cloned()
 200                                .collect();
 201                            this.models = models;
 202                            cx.notify();
 203                        })
 204                    })??;
 205
 206                    anyhow::Ok(())
 207                })
 208                .await
 209                .context("failed to fetch Zed models")
 210                .log_err();
 211            }),
 212            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 213                cx.notify();
 214            }),
 215            _llm_token_subscription: cx.subscribe(
 216                &refresh_llm_token_listener,
 217                |this, _listener, _event, cx| {
 218                    let client = this.client.clone();
 219                    let llm_api_token = this.llm_api_token.clone();
 220                    cx.spawn(async move |_this, _cx| {
 221                        llm_api_token.refresh(&client).await?;
 222                        anyhow::Ok(())
 223                    })
 224                    .detach_and_log_err(cx);
 225                },
 226            ),
 227        }
 228    }
 229
 230    fn is_signed_out(&self) -> bool {
 231        self.status.is_signed_out()
 232    }
 233
 234    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 235        let client = self.client.clone();
 236        cx.spawn(async move |state, cx| {
 237            client
 238                .authenticate_and_connect(true, &cx)
 239                .await
 240                .into_response()?;
 241            state.update(cx, |_, cx| cx.notify())
 242        })
 243    }
 244
 245    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 246        self.user_store
 247            .read(cx)
 248            .current_user_has_accepted_terms()
 249            .unwrap_or(false)
 250    }
 251
 252    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 253        let user_store = self.user_store.clone();
 254        self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
 255            let _ = user_store
 256                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 257                .await;
 258            this.update(cx, |this, cx| {
 259                this.accept_terms_of_service_task = None;
 260                cx.notify()
 261            })
 262        }));
 263    }
 264
 265    async fn fetch_models(
 266        client: Arc<Client>,
 267        llm_api_token: LlmApiToken,
 268    ) -> Result<ListModelsResponse> {
 269        let http_client = &client.http_client();
 270        let token = llm_api_token.acquire(&client).await?;
 271
 272        let request = http_client::Request::builder()
 273            .method(Method::GET)
 274            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 275            .header("Authorization", format!("Bearer {token}"))
 276            .body(AsyncBody::empty())?;
 277        let mut response = http_client
 278            .send(request)
 279            .await
 280            .context("failed to send list models request")?;
 281
 282        if response.status().is_success() {
 283            let mut body = String::new();
 284            response.body_mut().read_to_string(&mut body).await?;
 285            return Ok(serde_json::from_str(&body)?);
 286        } else {
 287            let mut body = String::new();
 288            response.body_mut().read_to_string(&mut body).await?;
 289            anyhow::bail!(
 290                "error listing models.\nStatus: {:?}\nBody: {body}",
 291                response.status(),
 292            );
 293        }
 294    }
 295}
 296
 297impl CloudLanguageModelProvider {
 298    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 299        let mut status_rx = client.status();
 300        let status = *status_rx.borrow();
 301
 302        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 303
 304        let state_ref = state.downgrade();
 305        let maintain_client_status = cx.spawn(async move |cx| {
 306            while let Some(status) = status_rx.next().await {
 307                if let Some(this) = state_ref.upgrade() {
 308                    _ = this.update(cx, |this, cx| {
 309                        if this.status != status {
 310                            this.status = status;
 311                            cx.notify();
 312                        }
 313                    });
 314                } else {
 315                    break;
 316                }
 317            }
 318        });
 319
 320        Self {
 321            client,
 322            state: state.clone(),
 323            _maintain_client_status: maintain_client_status,
 324        }
 325    }
 326
 327    fn create_language_model(
 328        &self,
 329        model: Arc<zed_llm_client::LanguageModel>,
 330        llm_api_token: LlmApiToken,
 331    ) -> Arc<dyn LanguageModel> {
 332        Arc::new(CloudLanguageModel {
 333            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 334            model,
 335            llm_api_token: llm_api_token.clone(),
 336            client: self.client.clone(),
 337            request_limiter: RateLimiter::new(4),
 338        })
 339    }
 340}
 341
 342impl LanguageModelProviderState for CloudLanguageModelProvider {
 343    type ObservableEntity = State;
 344
 345    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 346        Some(self.state.clone())
 347    }
 348}
 349
 350impl LanguageModelProvider for CloudLanguageModelProvider {
 351    fn id(&self) -> LanguageModelProviderId {
 352        PROVIDER_ID
 353    }
 354
 355    fn name(&self) -> LanguageModelProviderName {
 356        PROVIDER_NAME
 357    }
 358
 359    fn icon(&self) -> IconName {
 360        IconName::AiZed
 361    }
 362
 363    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 364        let default_model = self.state.read(cx).default_model.clone()?;
 365        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 366        Some(self.create_language_model(default_model, llm_api_token))
 367    }
 368
 369    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 370        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 371        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 372        Some(self.create_language_model(default_fast_model, llm_api_token))
 373    }
 374
 375    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 376        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 377        self.state
 378            .read(cx)
 379            .recommended_models
 380            .iter()
 381            .cloned()
 382            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 383            .collect()
 384    }
 385
 386    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 387        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 388        self.state
 389            .read(cx)
 390            .models
 391            .iter()
 392            .cloned()
 393            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 394            .collect()
 395    }
 396
 397    fn is_authenticated(&self, cx: &App) -> bool {
 398        let state = self.state.read(cx);
 399        !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
 400    }
 401
 402    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 403        Task::ready(Ok(()))
 404    }
 405
 406    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 407        cx.new(|_| ConfigurationView::new(self.state.clone()))
 408            .into()
 409    }
 410
 411    fn must_accept_terms(&self, cx: &App) -> bool {
 412        !self.state.read(cx).has_accepted_terms_of_service(cx)
 413    }
 414
 415    fn render_accept_terms(
 416        &self,
 417        view: LanguageModelProviderTosView,
 418        cx: &mut App,
 419    ) -> Option<AnyElement> {
 420        let state = self.state.read(cx);
 421        if state.has_accepted_terms_of_service(cx) {
 422            return None;
 423        }
 424        Some(
 425            render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
 426                let state = self.state.clone();
 427                move |_window, cx| {
 428                    state.update(cx, |state, cx| state.accept_terms_of_service(cx));
 429                }
 430            })
 431            .into_any_element(),
 432        )
 433    }
 434
 435    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 436        Task::ready(Ok(()))
 437    }
 438}
 439
 440fn render_accept_terms(
 441    view_kind: LanguageModelProviderTosView,
 442    accept_terms_of_service_in_progress: bool,
 443    accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
 444) -> impl IntoElement {
 445    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 446    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
 447
 448    let terms_button = Button::new("terms_of_service", "Terms of Service")
 449        .style(ButtonStyle::Subtle)
 450        .icon(IconName::ArrowUpRight)
 451        .icon_color(Color::Muted)
 452        .icon_size(IconSize::XSmall)
 453        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 454        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 455
 456    let button_container = h_flex().child(
 457        Button::new("accept_terms", "I accept the Terms of Service")
 458            .when(!thread_empty_state, |this| {
 459                this.full_width()
 460                    .style(ButtonStyle::Tinted(TintColor::Accent))
 461                    .icon(IconName::Check)
 462                    .icon_position(IconPosition::Start)
 463                    .icon_size(IconSize::Small)
 464            })
 465            .when(thread_empty_state, |this| {
 466                this.style(ButtonStyle::Tinted(TintColor::Warning))
 467                    .label_size(LabelSize::Small)
 468            })
 469            .disabled(accept_terms_of_service_in_progress)
 470            .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
 471    );
 472
 473    if thread_empty_state {
 474        h_flex()
 475            .w_full()
 476            .flex_wrap()
 477            .justify_between()
 478            .child(
 479                h_flex()
 480                    .child(
 481                        Label::new("To start using Zed AI, please read and accept the")
 482                            .size(LabelSize::Small),
 483                    )
 484                    .child(terms_button),
 485            )
 486            .child(button_container)
 487    } else {
 488        v_flex()
 489            .w_full()
 490            .gap_2()
 491            .child(
 492                h_flex()
 493                    .flex_wrap()
 494                    .when(thread_fresh_start, |this| this.justify_center())
 495                    .child(Label::new(
 496                        "To start using Zed AI, please read and accept the",
 497                    ))
 498                    .child(terms_button),
 499            )
 500            .child({
 501                match view_kind {
 502                    LanguageModelProviderTosView::PromptEditorPopup => {
 503                        button_container.w_full().justify_end()
 504                    }
 505                    LanguageModelProviderTosView::Configuration => {
 506                        button_container.w_full().justify_start()
 507                    }
 508                    LanguageModelProviderTosView::ThreadFreshStart => {
 509                        button_container.w_full().justify_center()
 510                    }
 511                    LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
 512                }
 513            })
 514    }
 515}
 516
 517pub struct CloudLanguageModel {
 518    id: LanguageModelId,
 519    model: Arc<zed_llm_client::LanguageModel>,
 520    llm_api_token: LlmApiToken,
 521    client: Arc<Client>,
 522    request_limiter: RateLimiter,
 523}
 524
 525struct PerformLlmCompletionResponse {
 526    response: Response<AsyncBody>,
 527    usage: Option<ModelRequestUsage>,
 528    tool_use_limit_reached: bool,
 529    includes_status_messages: bool,
 530}
 531
 532impl CloudLanguageModel {
 533    async fn perform_llm_completion(
 534        client: Arc<Client>,
 535        llm_api_token: LlmApiToken,
 536        app_version: Option<SemanticVersion>,
 537        body: CompletionBody,
 538    ) -> Result<PerformLlmCompletionResponse> {
 539        let http_client = &client.http_client();
 540
 541        let mut token = llm_api_token.acquire(&client).await?;
 542        let mut refreshed_token = false;
 543
 544        loop {
 545            let request_builder = http_client::Request::builder()
 546                .method(Method::POST)
 547                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 548            let request_builder = if let Some(app_version) = app_version {
 549                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 550            } else {
 551                request_builder
 552            };
 553
 554            let request = request_builder
 555                .header("Content-Type", "application/json")
 556                .header("Authorization", format!("Bearer {token}"))
 557                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 558                .body(serde_json::to_string(&body)?.into())?;
 559            let mut response = http_client.send(request).await?;
 560            let status = response.status();
 561            if status.is_success() {
 562                let includes_status_messages = response
 563                    .headers()
 564                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 565                    .is_some();
 566
 567                let tool_use_limit_reached = response
 568                    .headers()
 569                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 570                    .is_some();
 571
 572                let usage = if includes_status_messages {
 573                    None
 574                } else {
 575                    ModelRequestUsage::from_headers(response.headers()).ok()
 576                };
 577
 578                return Ok(PerformLlmCompletionResponse {
 579                    response,
 580                    usage,
 581                    includes_status_messages,
 582                    tool_use_limit_reached,
 583                });
 584            }
 585
 586            if !refreshed_token
 587                && response
 588                    .headers()
 589                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 590                    .is_some()
 591            {
 592                token = llm_api_token.refresh(&client).await?;
 593                refreshed_token = true;
 594                continue;
 595            }
 596
 597            if status == StatusCode::FORBIDDEN
 598                && response
 599                    .headers()
 600                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 601                    .is_some()
 602            {
 603                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 604                    .headers()
 605                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 606                    .and_then(|resource| resource.to_str().ok())
 607                {
 608                    if let Some(plan) = response
 609                        .headers()
 610                        .get(CURRENT_PLAN_HEADER_NAME)
 611                        .and_then(|plan| plan.to_str().ok())
 612                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 613                    {
 614                        let plan = match plan {
 615                            zed_llm_client::Plan::ZedFree => Plan::Free,
 616                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 617                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 618                        };
 619                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 620                    }
 621                }
 622            } else if status == StatusCode::PAYMENT_REQUIRED {
 623                return Err(anyhow!(PaymentRequiredError));
 624            }
 625
 626            let mut body = String::new();
 627            let headers = response.headers().clone();
 628            response.body_mut().read_to_string(&mut body).await?;
 629            return Err(anyhow!(ApiError {
 630                status,
 631                body,
 632                headers
 633            }));
 634        }
 635    }
 636}
 637
 638#[derive(Debug, Error)]
 639#[error("cloud language model request failed with status {status}: {body}")]
 640struct ApiError {
 641    status: StatusCode,
 642    body: String,
 643    headers: HeaderMap<HeaderValue>,
 644}
 645
 646impl From<ApiError> for LanguageModelCompletionError {
 647    fn from(error: ApiError) -> Self {
 648        let retry_after = None;
 649        LanguageModelCompletionError::from_http_status(
 650            PROVIDER_NAME,
 651            error.status,
 652            error.body,
 653            retry_after,
 654        )
 655    }
 656}
 657
 658impl LanguageModel for CloudLanguageModel {
 659    fn id(&self) -> LanguageModelId {
 660        self.id.clone()
 661    }
 662
 663    fn name(&self) -> LanguageModelName {
 664        LanguageModelName::from(self.model.display_name.clone())
 665    }
 666
 667    fn provider_id(&self) -> LanguageModelProviderId {
 668        PROVIDER_ID
 669    }
 670
 671    fn provider_name(&self) -> LanguageModelProviderName {
 672        PROVIDER_NAME
 673    }
 674
 675    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 676        use zed_llm_client::LanguageModelProvider::*;
 677        match self.model.provider {
 678            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 679            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 680            Google => language_model::GOOGLE_PROVIDER_ID,
 681        }
 682    }
 683
 684    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 685        use zed_llm_client::LanguageModelProvider::*;
 686        match self.model.provider {
 687            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 688            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 689            Google => language_model::GOOGLE_PROVIDER_NAME,
 690        }
 691    }
 692
 693    fn supports_tools(&self) -> bool {
 694        self.model.supports_tools
 695    }
 696
 697    fn supports_images(&self) -> bool {
 698        self.model.supports_images
 699    }
 700
 701    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 702        match choice {
 703            LanguageModelToolChoice::Auto
 704            | LanguageModelToolChoice::Any
 705            | LanguageModelToolChoice::None => true,
 706        }
 707    }
 708
 709    fn supports_burn_mode(&self) -> bool {
 710        self.model.supports_max_mode
 711    }
 712
 713    fn telemetry_id(&self) -> String {
 714        format!("zed.dev/{}", self.model.id)
 715    }
 716
 717    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 718        match self.model.provider {
 719            zed_llm_client::LanguageModelProvider::Anthropic
 720            | zed_llm_client::LanguageModelProvider::OpenAi => {
 721                LanguageModelToolSchemaFormat::JsonSchema
 722            }
 723            zed_llm_client::LanguageModelProvider::Google => {
 724                LanguageModelToolSchemaFormat::JsonSchemaSubset
 725            }
 726        }
 727    }
 728
 729    fn max_token_count(&self) -> u64 {
 730        self.model.max_token_count as u64
 731    }
 732
 733    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 734        match &self.model.provider {
 735            zed_llm_client::LanguageModelProvider::Anthropic => {
 736                Some(LanguageModelCacheConfiguration {
 737                    min_total_token: 2_048,
 738                    should_speculate: true,
 739                    max_cache_anchors: 4,
 740                })
 741            }
 742            zed_llm_client::LanguageModelProvider::OpenAi
 743            | zed_llm_client::LanguageModelProvider::Google => None,
 744        }
 745    }
 746
 747    fn count_tokens(
 748        &self,
 749        request: LanguageModelRequest,
 750        cx: &App,
 751    ) -> BoxFuture<'static, Result<u64>> {
 752        match self.model.provider {
 753            zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
 754            zed_llm_client::LanguageModelProvider::OpenAi => {
 755                let model = match open_ai::Model::from_id(&self.model.id.0) {
 756                    Ok(model) => model,
 757                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 758                };
 759                count_open_ai_tokens(request, model, cx)
 760            }
 761            zed_llm_client::LanguageModelProvider::Google => {
 762                let client = self.client.clone();
 763                let llm_api_token = self.llm_api_token.clone();
 764                let model_id = self.model.id.to_string();
 765                let generate_content_request =
 766                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 767                async move {
 768                    let http_client = &client.http_client();
 769                    let token = llm_api_token.acquire(&client).await?;
 770
 771                    let request_body = CountTokensBody {
 772                        provider: zed_llm_client::LanguageModelProvider::Google,
 773                        model: model_id,
 774                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 775                            generate_content_request,
 776                        })?,
 777                    };
 778                    let request = http_client::Request::builder()
 779                        .method(Method::POST)
 780                        .uri(
 781                            http_client
 782                                .build_zed_llm_url("/count_tokens", &[])?
 783                                .as_ref(),
 784                        )
 785                        .header("Content-Type", "application/json")
 786                        .header("Authorization", format!("Bearer {token}"))
 787                        .body(serde_json::to_string(&request_body)?.into())?;
 788                    let mut response = http_client.send(request).await?;
 789                    let status = response.status();
 790                    let headers = response.headers().clone();
 791                    let mut response_body = String::new();
 792                    response
 793                        .body_mut()
 794                        .read_to_string(&mut response_body)
 795                        .await?;
 796
 797                    if status.is_success() {
 798                        let response_body: CountTokensResponse =
 799                            serde_json::from_str(&response_body)?;
 800
 801                        Ok(response_body.tokens as u64)
 802                    } else {
 803                        Err(anyhow!(ApiError {
 804                            status,
 805                            body: response_body,
 806                            headers
 807                        }))
 808                    }
 809                }
 810                .boxed()
 811            }
 812        }
 813    }
 814
 815    fn stream_completion(
 816        &self,
 817        request: LanguageModelRequest,
 818        cx: &AsyncApp,
 819    ) -> BoxFuture<
 820        'static,
 821        Result<
 822            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 823            LanguageModelCompletionError,
 824        >,
 825    > {
 826        let thread_id = request.thread_id.clone();
 827        let prompt_id = request.prompt_id.clone();
 828        let intent = request.intent;
 829        let mode = request.mode;
 830        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 831        match self.model.provider {
 832            zed_llm_client::LanguageModelProvider::Anthropic => {
 833                let request = into_anthropic(
 834                    request,
 835                    self.model.id.to_string(),
 836                    1.0,
 837                    self.model.max_output_tokens as u64,
 838                    if self.model.id.0.ends_with("-thinking") {
 839                        AnthropicModelMode::Thinking {
 840                            budget_tokens: Some(4_096),
 841                        }
 842                    } else {
 843                        AnthropicModelMode::Default
 844                    },
 845                );
 846                let client = self.client.clone();
 847                let llm_api_token = self.llm_api_token.clone();
 848                let future = self.request_limiter.stream(async move {
 849                    let PerformLlmCompletionResponse {
 850                        response,
 851                        usage,
 852                        includes_status_messages,
 853                        tool_use_limit_reached,
 854                    } = Self::perform_llm_completion(
 855                        client.clone(),
 856                        llm_api_token,
 857                        app_version,
 858                        CompletionBody {
 859                            thread_id,
 860                            prompt_id,
 861                            intent,
 862                            mode,
 863                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 864                            model: request.model.clone(),
 865                            provider_request: serde_json::to_value(&request)
 866                                .map_err(|e| anyhow!(e))?,
 867                        },
 868                    )
 869                    .await
 870                    .map_err(|err| match err.downcast::<ApiError>() {
 871                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 872                        Err(err) => anyhow!(err),
 873                    })?;
 874
 875                    let mut mapper = AnthropicEventMapper::new();
 876                    Ok(map_cloud_completion_events(
 877                        Box::pin(
 878                            response_lines(response, includes_status_messages)
 879                                .chain(usage_updated_event(usage))
 880                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 881                        ),
 882                        move |event| mapper.map_event(event),
 883                    ))
 884                });
 885                async move { Ok(future.await?.boxed()) }.boxed()
 886            }
 887            zed_llm_client::LanguageModelProvider::OpenAi => {
 888                let client = self.client.clone();
 889                let model = match open_ai::Model::from_id(&self.model.id.0) {
 890                    Ok(model) => model,
 891                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 892                };
 893                let request = into_open_ai(
 894                    request,
 895                    model.id(),
 896                    model.supports_parallel_tool_calls(),
 897                    None,
 898                );
 899                let llm_api_token = self.llm_api_token.clone();
 900                let future = self.request_limiter.stream(async move {
 901                    let PerformLlmCompletionResponse {
 902                        response,
 903                        usage,
 904                        includes_status_messages,
 905                        tool_use_limit_reached,
 906                    } = Self::perform_llm_completion(
 907                        client.clone(),
 908                        llm_api_token,
 909                        app_version,
 910                        CompletionBody {
 911                            thread_id,
 912                            prompt_id,
 913                            intent,
 914                            mode,
 915                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 916                            model: request.model.clone(),
 917                            provider_request: serde_json::to_value(&request)
 918                                .map_err(|e| anyhow!(e))?,
 919                        },
 920                    )
 921                    .await?;
 922
 923                    let mut mapper = OpenAiEventMapper::new();
 924                    Ok(map_cloud_completion_events(
 925                        Box::pin(
 926                            response_lines(response, includes_status_messages)
 927                                .chain(usage_updated_event(usage))
 928                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 929                        ),
 930                        move |event| mapper.map_event(event),
 931                    ))
 932                });
 933                async move { Ok(future.await?.boxed()) }.boxed()
 934            }
 935            zed_llm_client::LanguageModelProvider::Google => {
 936                let client = self.client.clone();
 937                let request =
 938                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 939                let llm_api_token = self.llm_api_token.clone();
 940                let future = self.request_limiter.stream(async move {
 941                    let PerformLlmCompletionResponse {
 942                        response,
 943                        usage,
 944                        includes_status_messages,
 945                        tool_use_limit_reached,
 946                    } = Self::perform_llm_completion(
 947                        client.clone(),
 948                        llm_api_token,
 949                        app_version,
 950                        CompletionBody {
 951                            thread_id,
 952                            prompt_id,
 953                            intent,
 954                            mode,
 955                            provider: zed_llm_client::LanguageModelProvider::Google,
 956                            model: request.model.model_id.clone(),
 957                            provider_request: serde_json::to_value(&request)
 958                                .map_err(|e| anyhow!(e))?,
 959                        },
 960                    )
 961                    .await?;
 962
 963                    let mut mapper = GoogleEventMapper::new();
 964                    Ok(map_cloud_completion_events(
 965                        Box::pin(
 966                            response_lines(response, includes_status_messages)
 967                                .chain(usage_updated_event(usage))
 968                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 969                        ),
 970                        move |event| mapper.map_event(event),
 971                    ))
 972                });
 973                async move { Ok(future.await?.boxed()) }.boxed()
 974            }
 975        }
 976    }
 977}
 978
 979#[derive(Serialize, Deserialize)]
 980#[serde(rename_all = "snake_case")]
 981pub enum CloudCompletionEvent<T> {
 982    Status(CompletionRequestStatus),
 983    Event(T),
 984}
 985
 986fn map_cloud_completion_events<T, F>(
 987    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 988    mut map_callback: F,
 989) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 990where
 991    T: DeserializeOwned + 'static,
 992    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 993        + Send
 994        + 'static,
 995{
 996    stream
 997        .flat_map(move |event| {
 998            futures::stream::iter(match event {
 999                Err(error) => {
1000                    vec![Err(LanguageModelCompletionError::from(error))]
1001                }
1002                Ok(CloudCompletionEvent::Status(event)) => {
1003                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1004                }
1005                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1006            })
1007        })
1008        .boxed()
1009}
1010
1011fn usage_updated_event<T>(
1012    usage: Option<ModelRequestUsage>,
1013) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1014    futures::stream::iter(usage.map(|usage| {
1015        Ok(CloudCompletionEvent::Status(
1016            CompletionRequestStatus::UsageUpdated {
1017                amount: usage.amount as usize,
1018                limit: usage.limit,
1019            },
1020        ))
1021    }))
1022}
1023
1024fn tool_use_limit_reached_event<T>(
1025    tool_use_limit_reached: bool,
1026) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1027    futures::stream::iter(tool_use_limit_reached.then(|| {
1028        Ok(CloudCompletionEvent::Status(
1029            CompletionRequestStatus::ToolUseLimitReached,
1030        ))
1031    }))
1032}
1033
1034fn response_lines<T: DeserializeOwned>(
1035    response: Response<AsyncBody>,
1036    includes_status_messages: bool,
1037) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1038    futures::stream::try_unfold(
1039        (String::new(), BufReader::new(response.into_body())),
1040        move |(mut line, mut body)| async move {
1041            match body.read_line(&mut line).await {
1042                Ok(0) => Ok(None),
1043                Ok(_) => {
1044                    let event = if includes_status_messages {
1045                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1046                    } else {
1047                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1048                    };
1049
1050                    line.clear();
1051                    Ok(Some((event, (line, body))))
1052                }
1053                Err(e) => Err(e.into()),
1054            }
1055        },
1056    )
1057}
1058
1059#[derive(IntoElement, RegisterComponent)]
1060struct ZedAIConfiguration {
1061    is_connected: bool,
1062    plan: Option<proto::Plan>,
1063    subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1064    eligible_for_trial: bool,
1065    has_accepted_terms_of_service: bool,
1066    accept_terms_of_service_in_progress: bool,
1067    accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1068    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1069}
1070
1071impl RenderOnce for ZedAIConfiguration {
1072    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1073        const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1074
1075        let is_pro = self.plan == Some(proto::Plan::ZedPro);
1076        let subscription_text = match (self.plan, self.subscription_period) {
1077            (Some(proto::Plan::ZedPro), Some(_)) => {
1078                "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1079            }
1080            (Some(proto::Plan::ZedProTrial), Some(_)) => {
1081                "You have access to Zed's hosted LLMs through your Zed Pro trial."
1082            }
1083            (Some(proto::Plan::Free), Some(_)) => {
1084                "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1085            }
1086            _ => {
1087                if self.eligible_for_trial {
1088                    "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1089                } else {
1090                    "Subscribe for access to Zed's hosted LLMs."
1091                }
1092            }
1093        };
1094        let manage_subscription_buttons = if is_pro {
1095            h_flex().child(
1096                Button::new("manage_settings", "Manage Subscription")
1097                    .style(ButtonStyle::Tinted(TintColor::Accent))
1098                    .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1099            )
1100        } else {
1101            h_flex()
1102                .gap_2()
1103                .child(
1104                    Button::new("learn_more", "Learn more")
1105                        .style(ButtonStyle::Subtle)
1106                        .on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
1107                )
1108                .child(
1109                    Button::new(
1110                        "upgrade",
1111                        if self.plan.is_none() && self.eligible_for_trial {
1112                            "Start Trial"
1113                        } else {
1114                            "Upgrade"
1115                        },
1116                    )
1117                    .style(ButtonStyle::Subtle)
1118                    .color(Color::Accent)
1119                    .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1120                )
1121        };
1122
1123        if self.is_connected {
1124            v_flex()
1125                .gap_3()
1126                .w_full()
1127                .when(!self.has_accepted_terms_of_service, |this| {
1128                    this.child(render_accept_terms(
1129                        LanguageModelProviderTosView::Configuration,
1130                        self.accept_terms_of_service_in_progress,
1131                        {
1132                            let callback = self.accept_terms_of_service_callback.clone();
1133                            move |window, cx| (callback)(window, cx)
1134                        },
1135                    ))
1136                })
1137                .when(self.has_accepted_terms_of_service, |this| {
1138                    this.child(subscription_text)
1139                        .child(manage_subscription_buttons)
1140                })
1141        } else {
1142            v_flex()
1143                .gap_2()
1144                .child(Label::new("Use Zed AI to access hosted language models."))
1145                .child(
1146                    Button::new("sign_in", "Sign In")
1147                        .icon_color(Color::Muted)
1148                        .icon(IconName::Github)
1149                        .icon_position(IconPosition::Start)
1150                        .on_click({
1151                            let callback = self.sign_in_callback.clone();
1152                            move |_, window, cx| (callback)(window, cx)
1153                        }),
1154                )
1155        }
1156    }
1157}
1158
1159struct ConfigurationView {
1160    state: Entity<State>,
1161    accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1162    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1163}
1164
1165impl ConfigurationView {
1166    fn new(state: Entity<State>) -> Self {
1167        let accept_terms_of_service_callback = Arc::new({
1168            let state = state.clone();
1169            move |_window: &mut Window, cx: &mut App| {
1170                state.update(cx, |state, cx| {
1171                    state.accept_terms_of_service(cx);
1172                });
1173            }
1174        });
1175
1176        let sign_in_callback = Arc::new({
1177            let state = state.clone();
1178            move |_window: &mut Window, cx: &mut App| {
1179                state.update(cx, |state, cx| {
1180                    state.authenticate(cx).detach_and_log_err(cx);
1181                });
1182            }
1183        });
1184
1185        Self {
1186            state,
1187            accept_terms_of_service_callback,
1188            sign_in_callback,
1189        }
1190    }
1191}
1192
1193impl Render for ConfigurationView {
1194    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1195        let state = self.state.read(cx);
1196        let user_store = state.user_store.read(cx);
1197
1198        ZedAIConfiguration {
1199            is_connected: !state.is_signed_out(),
1200            plan: user_store.current_plan(),
1201            subscription_period: user_store.subscription_period(),
1202            eligible_for_trial: user_store.trial_started_at().is_none(),
1203            has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1204            accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1205            accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1206            sign_in_callback: self.sign_in_callback.clone(),
1207        }
1208    }
1209}
1210
1211impl Component for ZedAIConfiguration {
1212    fn scope() -> ComponentScope {
1213        ComponentScope::Agent
1214    }
1215
1216    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1217        fn configuration(
1218            is_connected: bool,
1219            plan: Option<proto::Plan>,
1220            eligible_for_trial: bool,
1221            has_accepted_terms_of_service: bool,
1222        ) -> AnyElement {
1223            ZedAIConfiguration {
1224                is_connected,
1225                plan,
1226                subscription_period: plan
1227                    .is_some()
1228                    .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1229                eligible_for_trial,
1230                has_accepted_terms_of_service,
1231                accept_terms_of_service_in_progress: false,
1232                accept_terms_of_service_callback: Arc::new(|_, _| {}),
1233                sign_in_callback: Arc::new(|_, _| {}),
1234            }
1235            .into_any_element()
1236        }
1237
1238        Some(
1239            v_flex()
1240                .p_4()
1241                .gap_4()
1242                .children(vec![
1243                    single_example("Not connected", configuration(false, None, false, true)),
1244                    single_example(
1245                        "Accept Terms of Service",
1246                        configuration(true, None, true, false),
1247                    ),
1248                    single_example(
1249                        "No Plan - Not eligible for trial",
1250                        configuration(true, None, false, true),
1251                    ),
1252                    single_example(
1253                        "No Plan - Eligible for trial",
1254                        configuration(true, None, true, true),
1255                    ),
1256                    single_example(
1257                        "Free Plan",
1258                        configuration(true, Some(proto::Plan::Free), true, true),
1259                    ),
1260                    single_example(
1261                        "Zed Pro Trial Plan",
1262                        configuration(true, Some(proto::Plan::ZedProTrial), true, true),
1263                    ),
1264                    single_example(
1265                        "Zed Pro Plan",
1266                        configuration(true, Some(proto::Plan::ZedPro), true, true),
1267                    ),
1268                ])
1269                .into_any_element(),
1270        )
1271    }
1272}