cloud.rs

   1use anthropic::AnthropicModelMode;
   2use anyhow::{Context as _, Result, anyhow};
   3use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   4use futures::{
   5    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   6};
   7use google_ai::GoogleModelMode;
   8use gpui::{
   9    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  10};
  11use http_client::http::{HeaderMap, HeaderValue};
  12use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  13use language_model::{
  14    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  15    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  16    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  17    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
  18    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
  19    ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  20};
  21use proto::Plan;
  22use release_channel::AppVersion;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize, de::DeserializeOwned};
  25use settings::SettingsStore;
  26use smol::io::{AsyncReadExt, BufReader};
  27use std::pin::Pin;
  28use std::str::FromStr as _;
  29use std::sync::Arc;
  30use std::time::Duration;
  31use thiserror::Error;
  32use ui::{TintColor, prelude::*};
  33use util::{ResultExt as _, maybe};
  34use zed_llm_client::{
  35    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
  36    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  37    ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
  38    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  39    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  40};
  41
  42use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  43use crate::provider::google::{GoogleEventMapper, into_google};
  44use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  45
  46const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  47const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  48
  49#[derive(Default, Clone, Debug, PartialEq)]
  50pub struct ZedDotDevSettings {
  51    pub available_models: Vec<AvailableModel>,
  52}
  53
  54#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  55#[serde(rename_all = "lowercase")]
  56pub enum AvailableProvider {
  57    Anthropic,
  58    OpenAi,
  59    Google,
  60}
  61
  62#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  63pub struct AvailableModel {
  64    /// The provider of the language model.
  65    pub provider: AvailableProvider,
  66    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  67    pub name: String,
  68    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  69    pub display_name: Option<String>,
  70    /// The size of the context window, indicating the maximum number of tokens the model can process.
  71    pub max_tokens: usize,
  72    /// The maximum number of output tokens allowed by the model.
  73    pub max_output_tokens: Option<u64>,
  74    /// The maximum number of completion tokens allowed by the model (o1-* only)
  75    pub max_completion_tokens: Option<u64>,
  76    /// Override this model with a different Anthropic model for tool calls.
  77    pub tool_override: Option<String>,
  78    /// Indicates whether this custom model supports caching.
  79    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  80    /// The default temperature to use for this model.
  81    pub default_temperature: Option<f32>,
  82    /// Any extra beta headers to provide when using the model.
  83    #[serde(default)]
  84    pub extra_beta_headers: Vec<String>,
  85    /// The model's mode (e.g. thinking)
  86    pub mode: Option<ModelMode>,
  87}
  88
  89#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  90#[serde(tag = "type", rename_all = "lowercase")]
  91pub enum ModelMode {
  92    #[default]
  93    Default,
  94    Thinking {
  95        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  96        budget_tokens: Option<u32>,
  97    },
  98}
  99
 100impl From<ModelMode> for AnthropicModelMode {
 101    fn from(value: ModelMode) -> Self {
 102        match value {
 103            ModelMode::Default => AnthropicModelMode::Default,
 104            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 105        }
 106    }
 107}
 108
 109pub struct CloudLanguageModelProvider {
 110    client: Arc<Client>,
 111    state: gpui::Entity<State>,
 112    _maintain_client_status: Task<()>,
 113}
 114
 115pub struct State {
 116    client: Arc<Client>,
 117    llm_api_token: LlmApiToken,
 118    user_store: Entity<UserStore>,
 119    status: client::Status,
 120    accept_terms: Option<Task<Result<()>>>,
 121    models: Vec<Arc<zed_llm_client::LanguageModel>>,
 122    default_model: Option<Arc<zed_llm_client::LanguageModel>>,
 123    default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
 124    recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
 125    _fetch_models_task: Task<()>,
 126    _settings_subscription: Subscription,
 127    _llm_token_subscription: Subscription,
 128}
 129
 130impl State {
 131    fn new(
 132        client: Arc<Client>,
 133        user_store: Entity<UserStore>,
 134        status: client::Status,
 135        cx: &mut Context<Self>,
 136    ) -> Self {
 137        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 138
 139        Self {
 140            client: client.clone(),
 141            llm_api_token: LlmApiToken::default(),
 142            user_store,
 143            status,
 144            accept_terms: None,
 145            models: Vec::new(),
 146            default_model: None,
 147            default_fast_model: None,
 148            recommended_models: Vec::new(),
 149            _fetch_models_task: cx.spawn(async move |this, cx| {
 150                maybe!(async move {
 151                    let (client, llm_api_token) = this
 152                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 153
 154                    loop {
 155                        let status = this.read_with(cx, |this, _cx| this.status)?;
 156                        if matches!(status, client::Status::Connected { .. }) {
 157                            break;
 158                        }
 159
 160                        cx.background_executor()
 161                            .timer(Duration::from_millis(100))
 162                            .await;
 163                    }
 164
 165                    let response = Self::fetch_models(client, llm_api_token).await?;
 166                    cx.update(|cx| {
 167                        this.update(cx, |this, cx| {
 168                            let mut models = Vec::new();
 169
 170                            for model in response.models {
 171                                models.push(Arc::new(model.clone()));
 172
 173                                // Right now we represent thinking variants of models as separate models on the client,
 174                                // so we need to insert variants for any model that supports thinking.
 175                                if model.supports_thinking {
 176                                    models.push(Arc::new(zed_llm_client::LanguageModel {
 177                                        id: zed_llm_client::LanguageModelId(
 178                                            format!("{}-thinking", model.id).into(),
 179                                        ),
 180                                        display_name: format!("{} Thinking", model.display_name),
 181                                        ..model
 182                                    }));
 183                                }
 184                            }
 185
 186                            this.default_model = models
 187                                .iter()
 188                                .find(|model| model.id == response.default_model)
 189                                .cloned();
 190                            this.default_fast_model = models
 191                                .iter()
 192                                .find(|model| model.id == response.default_fast_model)
 193                                .cloned();
 194                            this.recommended_models = response
 195                                .recommended_models
 196                                .iter()
 197                                .filter_map(|id| models.iter().find(|model| &model.id == id))
 198                                .cloned()
 199                                .collect();
 200                            this.models = models;
 201                            cx.notify();
 202                        })
 203                    })??;
 204
 205                    anyhow::Ok(())
 206                })
 207                .await
 208                .context("failed to fetch Zed models")
 209                .log_err();
 210            }),
 211            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 212                cx.notify();
 213            }),
 214            _llm_token_subscription: cx.subscribe(
 215                &refresh_llm_token_listener,
 216                |this, _listener, _event, cx| {
 217                    let client = this.client.clone();
 218                    let llm_api_token = this.llm_api_token.clone();
 219                    cx.spawn(async move |_this, _cx| {
 220                        llm_api_token.refresh(&client).await?;
 221                        anyhow::Ok(())
 222                    })
 223                    .detach_and_log_err(cx);
 224                },
 225            ),
 226        }
 227    }
 228
 229    fn is_signed_out(&self) -> bool {
 230        self.status.is_signed_out()
 231    }
 232
 233    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 234        let client = self.client.clone();
 235        cx.spawn(async move |state, cx| {
 236            client
 237                .authenticate_and_connect(true, &cx)
 238                .await
 239                .into_response()?;
 240            state.update(cx, |_, cx| cx.notify())
 241        })
 242    }
 243
 244    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 245        self.user_store
 246            .read(cx)
 247            .current_user_has_accepted_terms()
 248            .unwrap_or(false)
 249    }
 250
 251    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 252        let user_store = self.user_store.clone();
 253        self.accept_terms = Some(cx.spawn(async move |this, cx| {
 254            let _ = user_store
 255                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 256                .await;
 257            this.update(cx, |this, cx| {
 258                this.accept_terms = None;
 259                cx.notify()
 260            })
 261        }));
 262    }
 263
 264    async fn fetch_models(
 265        client: Arc<Client>,
 266        llm_api_token: LlmApiToken,
 267    ) -> Result<ListModelsResponse> {
 268        let http_client = &client.http_client();
 269        let token = llm_api_token.acquire(&client).await?;
 270
 271        let request = http_client::Request::builder()
 272            .method(Method::GET)
 273            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 274            .header("Authorization", format!("Bearer {token}"))
 275            .body(AsyncBody::empty())?;
 276        let mut response = http_client
 277            .send(request)
 278            .await
 279            .context("failed to send list models request")?;
 280
 281        if response.status().is_success() {
 282            let mut body = String::new();
 283            response.body_mut().read_to_string(&mut body).await?;
 284            return Ok(serde_json::from_str(&body)?);
 285        } else {
 286            let mut body = String::new();
 287            response.body_mut().read_to_string(&mut body).await?;
 288            anyhow::bail!(
 289                "error listing models.\nStatus: {:?}\nBody: {body}",
 290                response.status(),
 291            );
 292        }
 293    }
 294}
 295
 296impl CloudLanguageModelProvider {
 297    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 298        let mut status_rx = client.status();
 299        let status = *status_rx.borrow();
 300
 301        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 302
 303        let state_ref = state.downgrade();
 304        let maintain_client_status = cx.spawn(async move |cx| {
 305            while let Some(status) = status_rx.next().await {
 306                if let Some(this) = state_ref.upgrade() {
 307                    _ = this.update(cx, |this, cx| {
 308                        if this.status != status {
 309                            this.status = status;
 310                            cx.notify();
 311                        }
 312                    });
 313                } else {
 314                    break;
 315                }
 316            }
 317        });
 318
 319        Self {
 320            client,
 321            state: state.clone(),
 322            _maintain_client_status: maintain_client_status,
 323        }
 324    }
 325
 326    fn create_language_model(
 327        &self,
 328        model: Arc<zed_llm_client::LanguageModel>,
 329        llm_api_token: LlmApiToken,
 330    ) -> Arc<dyn LanguageModel> {
 331        Arc::new(CloudLanguageModel {
 332            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 333            model,
 334            llm_api_token: llm_api_token.clone(),
 335            client: self.client.clone(),
 336            request_limiter: RateLimiter::new(4),
 337        })
 338    }
 339}
 340
 341impl LanguageModelProviderState for CloudLanguageModelProvider {
 342    type ObservableEntity = State;
 343
 344    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 345        Some(self.state.clone())
 346    }
 347}
 348
 349impl LanguageModelProvider for CloudLanguageModelProvider {
 350    fn id(&self) -> LanguageModelProviderId {
 351        PROVIDER_ID
 352    }
 353
 354    fn name(&self) -> LanguageModelProviderName {
 355        PROVIDER_NAME
 356    }
 357
 358    fn icon(&self) -> IconName {
 359        IconName::AiZed
 360    }
 361
 362    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 363        let default_model = self.state.read(cx).default_model.clone()?;
 364        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 365        Some(self.create_language_model(default_model, llm_api_token))
 366    }
 367
 368    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 369        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 370        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 371        Some(self.create_language_model(default_fast_model, llm_api_token))
 372    }
 373
 374    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 375        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 376        self.state
 377            .read(cx)
 378            .recommended_models
 379            .iter()
 380            .cloned()
 381            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 382            .collect()
 383    }
 384
 385    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 386        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 387        self.state
 388            .read(cx)
 389            .models
 390            .iter()
 391            .cloned()
 392            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 393            .collect()
 394    }
 395
 396    fn is_authenticated(&self, cx: &App) -> bool {
 397        let state = self.state.read(cx);
 398        !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
 399    }
 400
 401    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 402        Task::ready(Ok(()))
 403    }
 404
 405    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 406        cx.new(|_| ConfigurationView {
 407            state: self.state.clone(),
 408        })
 409        .into()
 410    }
 411
 412    fn must_accept_terms(&self, cx: &App) -> bool {
 413        !self.state.read(cx).has_accepted_terms_of_service(cx)
 414    }
 415
 416    fn render_accept_terms(
 417        &self,
 418        view: LanguageModelProviderTosView,
 419        cx: &mut App,
 420    ) -> Option<AnyElement> {
 421        render_accept_terms(self.state.clone(), view, cx)
 422    }
 423
 424    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 425        Task::ready(Ok(()))
 426    }
 427}
 428
 429fn render_accept_terms(
 430    state: Entity<State>,
 431    view_kind: LanguageModelProviderTosView,
 432    cx: &mut App,
 433) -> Option<AnyElement> {
 434    if state.read(cx).has_accepted_terms_of_service(cx) {
 435        return None;
 436    }
 437
 438    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
 439
 440    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 441    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
 442
 443    let terms_button = Button::new("terms_of_service", "Terms of Service")
 444        .style(ButtonStyle::Subtle)
 445        .icon(IconName::ArrowUpRight)
 446        .icon_color(Color::Muted)
 447        .icon_size(IconSize::XSmall)
 448        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 449        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 450
 451    let button_container = h_flex().child(
 452        Button::new("accept_terms", "I accept the Terms of Service")
 453            .when(!thread_empty_state, |this| {
 454                this.full_width()
 455                    .style(ButtonStyle::Tinted(TintColor::Accent))
 456                    .icon(IconName::Check)
 457                    .icon_position(IconPosition::Start)
 458                    .icon_size(IconSize::Small)
 459            })
 460            .when(thread_empty_state, |this| {
 461                this.style(ButtonStyle::Tinted(TintColor::Warning))
 462                    .label_size(LabelSize::Small)
 463            })
 464            .disabled(accept_terms_disabled)
 465            .on_click({
 466                let state = state.downgrade();
 467                move |_, _window, cx| {
 468                    state
 469                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
 470                        .ok();
 471                }
 472            }),
 473    );
 474
 475    let form = if thread_empty_state {
 476        h_flex()
 477            .w_full()
 478            .flex_wrap()
 479            .justify_between()
 480            .child(
 481                h_flex()
 482                    .child(
 483                        Label::new("To start using Zed AI, please read and accept the")
 484                            .size(LabelSize::Small),
 485                    )
 486                    .child(terms_button),
 487            )
 488            .child(button_container)
 489    } else {
 490        v_flex()
 491            .w_full()
 492            .gap_2()
 493            .child(
 494                h_flex()
 495                    .flex_wrap()
 496                    .when(thread_fresh_start, |this| this.justify_center())
 497                    .child(Label::new(
 498                        "To start using Zed AI, please read and accept the",
 499                    ))
 500                    .child(terms_button),
 501            )
 502            .child({
 503                match view_kind {
 504                    LanguageModelProviderTosView::PromptEditorPopup => {
 505                        button_container.w_full().justify_end()
 506                    }
 507                    LanguageModelProviderTosView::Configuration => {
 508                        button_container.w_full().justify_start()
 509                    }
 510                    LanguageModelProviderTosView::ThreadFreshStart => {
 511                        button_container.w_full().justify_center()
 512                    }
 513                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
 514                }
 515            })
 516    };
 517
 518    Some(form.into_any())
 519}
 520
 521pub struct CloudLanguageModel {
 522    id: LanguageModelId,
 523    model: Arc<zed_llm_client::LanguageModel>,
 524    llm_api_token: LlmApiToken,
 525    client: Arc<Client>,
 526    request_limiter: RateLimiter,
 527}
 528
 529struct PerformLlmCompletionResponse {
 530    response: Response<AsyncBody>,
 531    usage: Option<ModelRequestUsage>,
 532    tool_use_limit_reached: bool,
 533    includes_status_messages: bool,
 534}
 535
 536impl CloudLanguageModel {
 537    async fn perform_llm_completion(
 538        client: Arc<Client>,
 539        llm_api_token: LlmApiToken,
 540        app_version: Option<SemanticVersion>,
 541        body: CompletionBody,
 542    ) -> Result<PerformLlmCompletionResponse> {
 543        let http_client = &client.http_client();
 544
 545        let mut token = llm_api_token.acquire(&client).await?;
 546        let mut refreshed_token = false;
 547
 548        loop {
 549            let request_builder = http_client::Request::builder()
 550                .method(Method::POST)
 551                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 552            let request_builder = if let Some(app_version) = app_version {
 553                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 554            } else {
 555                request_builder
 556            };
 557
 558            let request = request_builder
 559                .header("Content-Type", "application/json")
 560                .header("Authorization", format!("Bearer {token}"))
 561                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 562                .body(serde_json::to_string(&body)?.into())?;
 563            let mut response = http_client.send(request).await?;
 564            let status = response.status();
 565            if status.is_success() {
 566                let includes_status_messages = response
 567                    .headers()
 568                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 569                    .is_some();
 570
 571                let tool_use_limit_reached = response
 572                    .headers()
 573                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 574                    .is_some();
 575
 576                let usage = if includes_status_messages {
 577                    None
 578                } else {
 579                    ModelRequestUsage::from_headers(response.headers()).ok()
 580                };
 581
 582                return Ok(PerformLlmCompletionResponse {
 583                    response,
 584                    usage,
 585                    includes_status_messages,
 586                    tool_use_limit_reached,
 587                });
 588            }
 589
 590            if !refreshed_token
 591                && response
 592                    .headers()
 593                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 594                    .is_some()
 595            {
 596                token = llm_api_token.refresh(&client).await?;
 597                refreshed_token = true;
 598                continue;
 599            }
 600
 601            if status == StatusCode::FORBIDDEN
 602                && response
 603                    .headers()
 604                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 605                    .is_some()
 606            {
 607                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 608                    .headers()
 609                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 610                    .and_then(|resource| resource.to_str().ok())
 611                {
 612                    if let Some(plan) = response
 613                        .headers()
 614                        .get(CURRENT_PLAN_HEADER_NAME)
 615                        .and_then(|plan| plan.to_str().ok())
 616                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 617                    {
 618                        let plan = match plan {
 619                            zed_llm_client::Plan::ZedFree => Plan::Free,
 620                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 621                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 622                        };
 623                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 624                    }
 625                }
 626            } else if status == StatusCode::PAYMENT_REQUIRED {
 627                return Err(anyhow!(PaymentRequiredError));
 628            }
 629
 630            let mut body = String::new();
 631            let headers = response.headers().clone();
 632            response.body_mut().read_to_string(&mut body).await?;
 633            return Err(anyhow!(ApiError {
 634                status,
 635                body,
 636                headers
 637            }));
 638        }
 639    }
 640}
 641
 642#[derive(Debug, Error)]
 643#[error("cloud language model request failed with status {status}: {body}")]
 644struct ApiError {
 645    status: StatusCode,
 646    body: String,
 647    headers: HeaderMap<HeaderValue>,
 648}
 649
 650impl From<ApiError> for LanguageModelCompletionError {
 651    fn from(error: ApiError) -> Self {
 652        let retry_after = None;
 653        LanguageModelCompletionError::from_http_status(
 654            PROVIDER_NAME,
 655            error.status,
 656            error.body,
 657            retry_after,
 658        )
 659    }
 660}
 661
 662impl LanguageModel for CloudLanguageModel {
 663    fn id(&self) -> LanguageModelId {
 664        self.id.clone()
 665    }
 666
 667    fn name(&self) -> LanguageModelName {
 668        LanguageModelName::from(self.model.display_name.clone())
 669    }
 670
 671    fn provider_id(&self) -> LanguageModelProviderId {
 672        PROVIDER_ID
 673    }
 674
 675    fn provider_name(&self) -> LanguageModelProviderName {
 676        PROVIDER_NAME
 677    }
 678
 679    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 680        use zed_llm_client::LanguageModelProvider::*;
 681        match self.model.provider {
 682            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 683            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 684            Google => language_model::GOOGLE_PROVIDER_ID,
 685        }
 686    }
 687
 688    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 689        use zed_llm_client::LanguageModelProvider::*;
 690        match self.model.provider {
 691            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 692            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 693            Google => language_model::GOOGLE_PROVIDER_NAME,
 694        }
 695    }
 696
 697    fn supports_tools(&self) -> bool {
 698        self.model.supports_tools
 699    }
 700
 701    fn supports_images(&self) -> bool {
 702        self.model.supports_images
 703    }
 704
 705    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 706        match choice {
 707            LanguageModelToolChoice::Auto
 708            | LanguageModelToolChoice::Any
 709            | LanguageModelToolChoice::None => true,
 710        }
 711    }
 712
 713    fn supports_burn_mode(&self) -> bool {
 714        self.model.supports_max_mode
 715    }
 716
 717    fn telemetry_id(&self) -> String {
 718        format!("zed.dev/{}", self.model.id)
 719    }
 720
 721    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 722        match self.model.provider {
 723            zed_llm_client::LanguageModelProvider::Anthropic
 724            | zed_llm_client::LanguageModelProvider::OpenAi => {
 725                LanguageModelToolSchemaFormat::JsonSchema
 726            }
 727            zed_llm_client::LanguageModelProvider::Google => {
 728                LanguageModelToolSchemaFormat::JsonSchemaSubset
 729            }
 730        }
 731    }
 732
 733    fn max_token_count(&self) -> u64 {
 734        self.model.max_token_count as u64
 735    }
 736
 737    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 738        match &self.model.provider {
 739            zed_llm_client::LanguageModelProvider::Anthropic => {
 740                Some(LanguageModelCacheConfiguration {
 741                    min_total_token: 2_048,
 742                    should_speculate: true,
 743                    max_cache_anchors: 4,
 744                })
 745            }
 746            zed_llm_client::LanguageModelProvider::OpenAi
 747            | zed_llm_client::LanguageModelProvider::Google => None,
 748        }
 749    }
 750
 751    fn count_tokens(
 752        &self,
 753        request: LanguageModelRequest,
 754        cx: &App,
 755    ) -> BoxFuture<'static, Result<u64>> {
 756        match self.model.provider {
 757            zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
 758            zed_llm_client::LanguageModelProvider::OpenAi => {
 759                let model = match open_ai::Model::from_id(&self.model.id.0) {
 760                    Ok(model) => model,
 761                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 762                };
 763                count_open_ai_tokens(request, model, cx)
 764            }
 765            zed_llm_client::LanguageModelProvider::Google => {
 766                let client = self.client.clone();
 767                let llm_api_token = self.llm_api_token.clone();
 768                let model_id = self.model.id.to_string();
 769                let generate_content_request =
 770                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 771                async move {
 772                    let http_client = &client.http_client();
 773                    let token = llm_api_token.acquire(&client).await?;
 774
 775                    let request_body = CountTokensBody {
 776                        provider: zed_llm_client::LanguageModelProvider::Google,
 777                        model: model_id,
 778                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 779                            generate_content_request,
 780                        })?,
 781                    };
 782                    let request = http_client::Request::builder()
 783                        .method(Method::POST)
 784                        .uri(
 785                            http_client
 786                                .build_zed_llm_url("/count_tokens", &[])?
 787                                .as_ref(),
 788                        )
 789                        .header("Content-Type", "application/json")
 790                        .header("Authorization", format!("Bearer {token}"))
 791                        .body(serde_json::to_string(&request_body)?.into())?;
 792                    let mut response = http_client.send(request).await?;
 793                    let status = response.status();
 794                    let headers = response.headers().clone();
 795                    let mut response_body = String::new();
 796                    response
 797                        .body_mut()
 798                        .read_to_string(&mut response_body)
 799                        .await?;
 800
 801                    if status.is_success() {
 802                        let response_body: CountTokensResponse =
 803                            serde_json::from_str(&response_body)?;
 804
 805                        Ok(response_body.tokens as u64)
 806                    } else {
 807                        Err(anyhow!(ApiError {
 808                            status,
 809                            body: response_body,
 810                            headers
 811                        }))
 812                    }
 813                }
 814                .boxed()
 815            }
 816        }
 817    }
 818
 819    fn stream_completion(
 820        &self,
 821        request: LanguageModelRequest,
 822        cx: &AsyncApp,
 823    ) -> BoxFuture<
 824        'static,
 825        Result<
 826            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 827            LanguageModelCompletionError,
 828        >,
 829    > {
 830        let thread_id = request.thread_id.clone();
 831        let prompt_id = request.prompt_id.clone();
 832        let intent = request.intent;
 833        let mode = request.mode;
 834        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 835        match self.model.provider {
 836            zed_llm_client::LanguageModelProvider::Anthropic => {
 837                let request = into_anthropic(
 838                    request,
 839                    self.model.id.to_string(),
 840                    1.0,
 841                    self.model.max_output_tokens as u64,
 842                    if self.model.id.0.ends_with("-thinking") {
 843                        AnthropicModelMode::Thinking {
 844                            budget_tokens: Some(4_096),
 845                        }
 846                    } else {
 847                        AnthropicModelMode::Default
 848                    },
 849                );
 850                let client = self.client.clone();
 851                let llm_api_token = self.llm_api_token.clone();
 852                let future = self.request_limiter.stream(async move {
 853                    let PerformLlmCompletionResponse {
 854                        response,
 855                        usage,
 856                        includes_status_messages,
 857                        tool_use_limit_reached,
 858                    } = Self::perform_llm_completion(
 859                        client.clone(),
 860                        llm_api_token,
 861                        app_version,
 862                        CompletionBody {
 863                            thread_id,
 864                            prompt_id,
 865                            intent,
 866                            mode,
 867                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 868                            model: request.model.clone(),
 869                            provider_request: serde_json::to_value(&request)
 870                                .map_err(|e| anyhow!(e))?,
 871                        },
 872                    )
 873                    .await
 874                    .map_err(|err| match err.downcast::<ApiError>() {
 875                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 876                        Err(err) => anyhow!(err),
 877                    })?;
 878
 879                    let mut mapper = AnthropicEventMapper::new();
 880                    Ok(map_cloud_completion_events(
 881                        Box::pin(
 882                            response_lines(response, includes_status_messages)
 883                                .chain(usage_updated_event(usage))
 884                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 885                        ),
 886                        move |event| mapper.map_event(event),
 887                    ))
 888                });
 889                async move { Ok(future.await?.boxed()) }.boxed()
 890            }
 891            zed_llm_client::LanguageModelProvider::OpenAi => {
 892                let client = self.client.clone();
 893                let model = match open_ai::Model::from_id(&self.model.id.0) {
 894                    Ok(model) => model,
 895                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 896                };
 897                let request = into_open_ai(
 898                    request,
 899                    model.id(),
 900                    model.supports_parallel_tool_calls(),
 901                    None,
 902                );
 903                let llm_api_token = self.llm_api_token.clone();
 904                let future = self.request_limiter.stream(async move {
 905                    let PerformLlmCompletionResponse {
 906                        response,
 907                        usage,
 908                        includes_status_messages,
 909                        tool_use_limit_reached,
 910                    } = Self::perform_llm_completion(
 911                        client.clone(),
 912                        llm_api_token,
 913                        app_version,
 914                        CompletionBody {
 915                            thread_id,
 916                            prompt_id,
 917                            intent,
 918                            mode,
 919                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 920                            model: request.model.clone(),
 921                            provider_request: serde_json::to_value(&request)
 922                                .map_err(|e| anyhow!(e))?,
 923                        },
 924                    )
 925                    .await?;
 926
 927                    let mut mapper = OpenAiEventMapper::new();
 928                    Ok(map_cloud_completion_events(
 929                        Box::pin(
 930                            response_lines(response, includes_status_messages)
 931                                .chain(usage_updated_event(usage))
 932                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 933                        ),
 934                        move |event| mapper.map_event(event),
 935                    ))
 936                });
 937                async move { Ok(future.await?.boxed()) }.boxed()
 938            }
 939            zed_llm_client::LanguageModelProvider::Google => {
 940                let client = self.client.clone();
 941                let request =
 942                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 943                let llm_api_token = self.llm_api_token.clone();
 944                let future = self.request_limiter.stream(async move {
 945                    let PerformLlmCompletionResponse {
 946                        response,
 947                        usage,
 948                        includes_status_messages,
 949                        tool_use_limit_reached,
 950                    } = Self::perform_llm_completion(
 951                        client.clone(),
 952                        llm_api_token,
 953                        app_version,
 954                        CompletionBody {
 955                            thread_id,
 956                            prompt_id,
 957                            intent,
 958                            mode,
 959                            provider: zed_llm_client::LanguageModelProvider::Google,
 960                            model: request.model.model_id.clone(),
 961                            provider_request: serde_json::to_value(&request)
 962                                .map_err(|e| anyhow!(e))?,
 963                        },
 964                    )
 965                    .await?;
 966
 967                    let mut mapper = GoogleEventMapper::new();
 968                    Ok(map_cloud_completion_events(
 969                        Box::pin(
 970                            response_lines(response, includes_status_messages)
 971                                .chain(usage_updated_event(usage))
 972                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 973                        ),
 974                        move |event| mapper.map_event(event),
 975                    ))
 976                });
 977                async move { Ok(future.await?.boxed()) }.boxed()
 978            }
 979        }
 980    }
 981}
 982
 983#[derive(Serialize, Deserialize)]
 984#[serde(rename_all = "snake_case")]
 985pub enum CloudCompletionEvent<T> {
 986    Status(CompletionRequestStatus),
 987    Event(T),
 988}
 989
 990fn map_cloud_completion_events<T, F>(
 991    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 992    mut map_callback: F,
 993) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 994where
 995    T: DeserializeOwned + 'static,
 996    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 997        + Send
 998        + 'static,
 999{
1000    stream
1001        .flat_map(move |event| {
1002            futures::stream::iter(match event {
1003                Err(error) => {
1004                    vec![Err(LanguageModelCompletionError::from(error))]
1005                }
1006                Ok(CloudCompletionEvent::Status(event)) => {
1007                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1008                }
1009                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1010            })
1011        })
1012        .boxed()
1013}
1014
1015fn usage_updated_event<T>(
1016    usage: Option<ModelRequestUsage>,
1017) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1018    futures::stream::iter(usage.map(|usage| {
1019        Ok(CloudCompletionEvent::Status(
1020            CompletionRequestStatus::UsageUpdated {
1021                amount: usage.amount as usize,
1022                limit: usage.limit,
1023            },
1024        ))
1025    }))
1026}
1027
1028fn tool_use_limit_reached_event<T>(
1029    tool_use_limit_reached: bool,
1030) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1031    futures::stream::iter(tool_use_limit_reached.then(|| {
1032        Ok(CloudCompletionEvent::Status(
1033            CompletionRequestStatus::ToolUseLimitReached,
1034        ))
1035    }))
1036}
1037
1038fn response_lines<T: DeserializeOwned>(
1039    response: Response<AsyncBody>,
1040    includes_status_messages: bool,
1041) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1042    futures::stream::try_unfold(
1043        (String::new(), BufReader::new(response.into_body())),
1044        move |(mut line, mut body)| async move {
1045            match body.read_line(&mut line).await {
1046                Ok(0) => Ok(None),
1047                Ok(_) => {
1048                    let event = if includes_status_messages {
1049                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1050                    } else {
1051                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1052                    };
1053
1054                    line.clear();
1055                    Ok(Some((event, (line, body))))
1056                }
1057                Err(e) => Err(e.into()),
1058            }
1059        },
1060    )
1061}
1062
1063struct ConfigurationView {
1064    state: gpui::Entity<State>,
1065}
1066
1067impl ConfigurationView {
1068    fn authenticate(&mut self, cx: &mut Context<Self>) {
1069        self.state.update(cx, |state, cx| {
1070            state.authenticate(cx).detach_and_log_err(cx);
1071        });
1072        cx.notify();
1073    }
1074}
1075
1076impl Render for ConfigurationView {
1077    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1078        const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1079
1080        let is_connected = !self.state.read(cx).is_signed_out();
1081        let user_store = self.state.read(cx).user_store.read(cx);
1082        let plan = user_store.current_plan();
1083        let subscription_period = user_store.subscription_period();
1084        let eligible_for_trial = user_store.trial_started_at().is_none();
1085        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1086
1087        let is_pro = plan == Some(proto::Plan::ZedPro);
1088        let subscription_text = match (plan, subscription_period) {
1089            (Some(proto::Plan::ZedPro), Some(_)) => {
1090                "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1091            }
1092            (Some(proto::Plan::ZedProTrial), Some(_)) => {
1093                "You have access to Zed's hosted LLMs through your Zed Pro trial."
1094            }
1095            (Some(proto::Plan::Free), Some(_)) => {
1096                "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1097            }
1098            _ => {
1099                if eligible_for_trial {
1100                    "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1101                } else {
1102                    "Subscribe for access to Zed's hosted LLMs."
1103                }
1104            }
1105        };
1106        let manage_subscription_buttons = if is_pro {
1107            h_flex().child(
1108                Button::new("manage_settings", "Manage Subscription")
1109                    .style(ButtonStyle::Tinted(TintColor::Accent))
1110                    .on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
1111            )
1112        } else {
1113            h_flex()
1114                .gap_2()
1115                .child(
1116                    Button::new("learn_more", "Learn more")
1117                        .style(ButtonStyle::Subtle)
1118                        .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
1119                )
1120                .child(
1121                    Button::new("upgrade", "Upgrade")
1122                        .style(ButtonStyle::Subtle)
1123                        .color(Color::Accent)
1124                        .on_click(
1125                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1126                        ),
1127                )
1128        };
1129
1130        if is_connected {
1131            v_flex()
1132                .gap_3()
1133                .w_full()
1134                .children(render_accept_terms(
1135                    self.state.clone(),
1136                    LanguageModelProviderTosView::Configuration,
1137                    cx,
1138                ))
1139                .when(has_accepted_terms, |this| {
1140                    this.child(subscription_text)
1141                        .child(manage_subscription_buttons)
1142                })
1143        } else {
1144            v_flex()
1145                .gap_2()
1146                .child(Label::new("Use Zed AI to access hosted language models."))
1147                .child(
1148                    Button::new("sign_in", "Sign In")
1149                        .icon_color(Color::Muted)
1150                        .icon(IconName::Github)
1151                        .icon_position(IconPosition::Start)
1152                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1153                )
1154        }
1155    }
1156}